torch.nn.Module模块简单介绍

torch.nn是专门为神经网络设计的模块化接口,nn.Module是nn中十分重要的类。在介绍该模块前,我们先看下pytorch官方对该模块的注释:

torch.nn.Module模块简单介绍

根据官方注释我们了解到Module类是所有神经网络模块的基类,Module可以以树形结构包含其他的Module。Module类中包含网络各层的定义及forward方法,下面介绍我们如何定义自已的网络:

  1. 需要继承nn.Module类,并实现forward方法;

  2. 一般把网络中具有可学习参数的层放在构造函数__init__()中;

  3. 不具有可学习参数的层(如ReLU)可在forward中使用nn.functional来代替;

  4. 只要在nn.Module的子类中定义了forward函数,利用Autograd自动实现反向求导。

那么这时候就有一些疑问:

  1. 为什么要继承nn.Module?
  2. forward函数什么时候会被调用?

answer:

1、关于第一个问题,我们需要看下Module类的源码,Module初始化后就相当于8个有序字典,因此,当实例化你定义的Net(nn.Module的子类)时,要确保父类的构造函数首先被调用,这样才能确保上述8个OrderedDict被create。

torch.nn.Module模块简单介绍

_modules:桥梁作用,在获取一个net的所有的parameters的时候,是通过递归遍历该net的所有_modules来实现的。

2、forward函数需要通过Net(input)(Net为自己定义的类)来调用,而非Net.forward(input),因为前者实现了额外的功能:

      a) 先执行_forward_pre_hooks里的所有hooks

      b) 再调用forward函数

      c) 执行_forward_hooks中所有hooks

      d) 执行_backward_hooks中所有hooks

_forward_pre_hooks通常只有一些Norm操作会定义_forward_pre_hooks,这种hook不能改变input的内容;_forward_hooks不改变input和output,目前就是方便自己测试的时候用;_backward_hooks和_forward_hooks类似。所以,网络中没有Norm操作,使用Net(input)和Net.forward(input)是等价的。

以上是针对torch.nn.Module模块的介绍,构建模型过程中的一些问题总结及理解后续更新。