Mxnet :调用基础网络(vgg,resnet等)的某一层输出作为后续网络的输入

1.ipython 示例 读取vgg网络的relu7的输出:

主要函数有

get_symbol(num_layers,num_classes):

得到网络所有的结构。

list_outputs():(给出符号的输出变量)的说明

list_arguments (给出当前符号的输入变量)的说明

get_internals():

获取中间层结果

Mxnet :调用基础网络(vgg,resnet等)的某一层输出作为后续网络的输入

 

 

Mxnet :调用基础网络(vgg,resnet等)的某一层输出作为后续网络的输入

完整代码:

import mxnet as mx
from importlib import import_module

def get_bonenet(num_classes, bonenet):
    """调用基础网络作为输入"""
    if bonenet.startswith('vgg19'):
        net = import_module('network.symbols.vgg')
        sym = net.get_symbol(num_classes, num_layers=19)
        internals = sym.get_internals()
        # print(internals.list_outputs())
        bonenet_layer = internals['drop7_output']
    return bonenet_layer
def nbc_network(num_classes, bonenet):
    """n binary classifiers network"""
    bonenet_layer = get_bonenet(num_classes,bonenet)
    fc = mx.sym.FullyConnected(data=bonenet_layer, num_hidden=num_classes, name="fc")
    label = mx.sym.Variable(name='label')
    symbol = mx.symbol.LogisticRegressionOutput(data=fc, label=label, name='LRO_1')
    return symbol