Mxnet :调用基础网络(vgg,resnet等)的某一层输出作为后续网络的输入
1.ipython 示例 读取vgg网络的relu7的输出:
主要函数有
get_symbol(num_layers,num_classes):
得到网络所有的结构。
list_outputs():(给出符号的输出变量)的说明
list_arguments (给出当前符号的输入变量)的说明
get_internals():
获取中间层结果
完整代码:
import mxnet as mx from importlib import import_module def get_bonenet(num_classes, bonenet): """调用基础网络作为输入""" if bonenet.startswith('vgg19'): net = import_module('network.symbols.vgg') sym = net.get_symbol(num_classes, num_layers=19) internals = sym.get_internals() # print(internals.list_outputs()) bonenet_layer = internals['drop7_output'] return bonenet_layer
def nbc_network(num_classes, bonenet): """n binary classifiers network""" bonenet_layer = get_bonenet(num_classes,bonenet) fc = mx.sym.FullyConnected(data=bonenet_layer, num_hidden=num_classes, name="fc") label = mx.sym.Variable(name='label') symbol = mx.symbol.LogisticRegressionOutput(data=fc, label=label, name='LRO_1') return symbol