keras 高级技巧--------重写Layer
keras 高级技巧
概述
在使用keras快速建模时,会遇到现有的库无法实现,需要自己做一些定制,有两种方式:
- 不涉及参数weights的优化训练,可以使用keras.layers.Lambda()
- 使用类支出keras.layers.Layer
- build(input_shape): this is where you will define your weights. This method must set self.built = True at the end, which can be done by calling super([Layer], self).build().
- call(x): this is where the layer’s logic lives. Unless you want your layer to support masking, you only have to care about the first argument passed to call: the input tensor.
- compute_output_shape(input_shape): in case your layer modifies the shape of its input, you should specify here the shape transformation logic. This allows Keras to do automatic shape inference.
环境依赖
- Python 3.6
- TensorFlow 1.12
- keras 2.2.4
例子
- 使用Lambda
from keras.models import Sequential
from keras.backend import concatenate
from keras.layers import Input, Lambda
from keras.layers import Dense
import tensorflow as tf
import keras.backend as K
model = Sequential()
model.add(Dense(32, input_dim=32))
# now: model.output_shape == (None, 32)
# note: `None` is the batch dimension
model.add(RepeatVector(3))
# now: model.output_shape == (None, 3, 32)
model.add(Lambda(lambda x: x ** 2))
def antirectifier(x):
gamma=tf.Variable(10.0, trainable=True)
x -= K.mean(x, axis=1, keepdims=True)
x = K.l2_normalize(x, axis=1)
pos = K.relu(x)
neg = K.relu(-x)
return tf.multiply(gamma,K.concatenate([pos, neg], axis=1))
def antirectifier_output_shape(input_shape):
shape = list(input_shape)
assert len(shape) == 2 # only valid for 2D tensors
shape[-1] *= 2
return tuple(shape)
model.add(Lambda(antirectifier,
output_shape=antirectifier_output_shape))
model.summary()
输出:
- 使用自定义layer
from keras import backend as K
from keras.layers import Layer
from keras.models import Sequential
from keras.layers import Input
from keras.layers import Dense
import tensorflow as tf
import keras.backend as K
class MyLayer(Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(MyLayer, self).__init__(**kwargs)
def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.kernel = self.add_weight(name='kernel',
shape=(input_shape[1], self.output_dim),
initializer='uniform',
trainable=True)
super(MyLayer, self).build(input_shape) # Be sure to call this at the end
def call(self, x):
return K.dot(x, self.kernel)
def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)
model = Sequential()
model.add(Dense(32, input_dim=32))
model.add(MyLayer(100))
model.summary()
输出:
参考
[1]: Keras Lambda
[2]: Writing your own Keras layers
[3]: 使用Keras编写自定义网络层(layer)