keras 高级技巧--------重写Layer

keras 高级技巧

概述

在使用keras快速建模时,会遇到现有的库无法实现,需要自己做一些定制,有两种方式:

  1. 不涉及参数weights的优化训练,可以使用keras.layers.Lambda()
  2. 使用类支出keras.layers.Layer
  • build(input_shape): this is where you will define your weights. This method must set self.built = True at the end, which can be done by calling super([Layer], self).build().
  • call(x): this is where the layer’s logic lives. Unless you want your layer to support masking, you only have to care about the first argument passed to call: the input tensor.
  • compute_output_shape(input_shape): in case your layer modifies the shape of its input, you should specify here the shape transformation logic. This allows Keras to do automatic shape inference.

环境依赖

  • Python 3.6
  • TensorFlow 1.12
  • keras 2.2.4

例子

  1. 使用Lambda
from keras.models import Sequential
from keras.backend import concatenate
from keras.layers import Input, Lambda
from keras.layers import Dense
import tensorflow as tf
import keras.backend as K



model = Sequential()
model.add(Dense(32, input_dim=32))
# now: model.output_shape == (None, 32)
# note: `None` is the batch dimension

model.add(RepeatVector(3))
# now: model.output_shape == (None, 3, 32)
model.add(Lambda(lambda x: x ** 2))
def antirectifier(x):
	gamma=tf.Variable(10.0, trainable=True)
    x -= K.mean(x, axis=1, keepdims=True)
    x = K.l2_normalize(x, axis=1)
    pos = K.relu(x)
    neg = K.relu(-x)
    return tf.multiply(gamma,K.concatenate([pos, neg], axis=1))

def antirectifier_output_shape(input_shape):
    shape = list(input_shape)
    assert len(shape) == 2  # only valid for 2D tensors
    shape[-1] *= 2
    return tuple(shape)

model.add(Lambda(antirectifier,
                 output_shape=antirectifier_output_shape))
model.summary()

输出:
keras 高级技巧--------重写Layer

  1. 使用自定义layer
from keras import backend as K
from keras.layers import Layer
from keras.models import Sequential
from keras.layers import Input
from keras.layers import Dense
import tensorflow as tf
import keras.backend as K


class MyLayer(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.kernel = self.add_weight(name='kernel', 
                                      shape=(input_shape[1], self.output_dim),
                                      initializer='uniform',
                                      trainable=True)
        super(MyLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        return K.dot(x, self.kernel)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)
    
model = Sequential()
model.add(Dense(32, input_dim=32))
model.add(MyLayer(100))
model.summary()

输出:
keras 高级技巧--------重写Layer

参考

[1]: Keras Lambda
[2]: Writing your own Keras layers
[3]: 使用Keras编写自定义网络层(layer)