配件定制层Keras模式的失败

问题描述:

我创建了一个keras我的自定义层(1.1):配件定制层Keras模式的失败

from keras import backend as K 
from keras.engine.topology import Layer 
import numpy as np 

class MyLayer(Layer): 

def __init__(self,input_shape,**kwargs): 
    self.W_init = np.random.rand(input_shape[0], input_shape[1], input_shape[2]) 
    self.input_len = input_shape[0] 
    self.output_dim = 1 
    super(MyLayer, self).__init__(**kwargs) 

def build(self, input_shape): 
    # Create a trainable weight variable for this layer. 
    self.W = K.variable(self.W_init, name="W") 
    self.trainable_weights = [ self.W ] 
    super(MyLayer, self).build(input_shape) # Be sure to call this somewhere! 

def call(self, x, mask=None): 
    res= K.sum(x*self.W,axis=(1,2)) 
    res= K.expand_dims(res, -1) 
    res = K.expand_dims(res, -1) 
    return res 

def get_output_shape_for(self, input_shape): 
    return (input_shape[0], self.input_len, self.output_dim, self.output_dim) 

该模型的成功编译: enter image description here

但是当我尝试以适应它,我得到错误:

ValueError: cannot reshape array of size 64 into shape (1,4) 
Apply node that caused the error: Reshape{2}(HostFromGpu.0, MakeVector{dtype='int64'}.0) 
Toposort index: 895 
Inputs types: [TensorType(float32, vector), TensorType(int64, vector)] 
Inputs shapes: [(64,), (2,)] 
Inputs strides: [(4,), (8,)] 
Inputs values: ['not shown', array([1, 4])] 
Inputs type_num: [11, 7] 
Outputs clients: [[InplaceDimShuffle{0,1,x,x}(Reshape{2}.0)]] 

Backtrace when the node is created(use Theano flag traceback.limit=N to make it longer): 
    File "<ipython-input-155-09ee1207017c>", line 22, in get_my_model_2 
    Dense(10, activation='softmax') 
    File "/home/universal/anaconda3/envs/practicecourse2/lib/python2.7/site-packages/keras/models.py", line 255, in __init__ 
    self.add(layer) 

在我的自定义图层中,可训练的权重有问题吗?

由于您使用的是Theano,因此宽度和高度轴是(2,3),而不是(1,2)

也就是说,你应该改变行:

res= K.sum(x*self.W,axis=(1,2)) 

res= K.sum(x*self.W,axis=(2,3)) 

将引发错误,因为你call()函数的输出的形状(None, 4, 1, 1)代替(None, 64, 1, 1),如规定get_output_shape_for()

+0

非常感谢!是因为轴号从1开始,而不是0? – Kseniya

+0

是的,在'call()'里面,数组'x'的轴0是批量大小。 –

+0

我明白了,谢谢! – Kseniya