Tensorflow RNN源码理解

一、阅读源码

这个是TensorflowRNN源码,官方注释解释的比较清楚:

 Tensorflow RNN源码理解

RNNCell是一个抽象类,我们看下下它的属性:

 Tensorflow RNN源码理解

我们可以发现这里用到的是Python内置的@property装饰器就是负责把一个方法变成属性调用的,很像C#中的属性、字段的那种概念。State_sizeOutput_size规定了隐层的大小和输出张量的大小。

 Tensorflow RNN源码理解

下面是重要的__call__方法,有点像USRP中的work()或者general_work()的功能。这里我们注意到输入的参数有Inputs,State,这里其实就是指输入层和隐层了。但是这里有规定Inputs的格式为(batch_size,input_size,State的格式为(batch_zie,state_size,这很容易理解,因为我们进行训练数据会分成很多batch。与普通的神经网络结构一样,输入层、隐层、输出层的size并没有关系,视应用场景而定。

还有一个总要的方法是初始化方法:

Tensorflow RNN源码理解

BasicRNNCellGRUCellBasicLSTMCellLSTMCell都是继承于LayerRNNCell,而LayerRNNCell继承于上面讲的抽象类RNNCell,这就是TensorflowRNN的继承关系。

这里不做过多介绍,但是有意思的一点是这里:

 Tensorflow RNN源码理解

其实BasicRNNCell的输出、隐层状态是一样的。而BasicLSTMCell的隐层状态和输出是不一样的。New_stae = LSTMStateTuple(new_c,new_h)

 Tensorflow RNN源码理解

同样RNN的隐层也可以构建多层MultiRNNCell

根据源码可知:

 Tensorflow RNN源码理解

隐层状态的返回值是元组(tuple)类型

最重要的一个类Dynamic_rnn(batch_size,time_steps,input_size),参数很好理解,但是需要强调的是State是最后一步的隐藏状态,形状是(batch_size,cell.State_size),time_steps是调用RNNCell抽象类中__call__()函数的次数,Output是所有steps的输出。Time_major=True的情况下将Output的格式中batch_szietime_steps位置交换。