Tensoeflow 中使用tf.variable_scope()和tf.get_variable()实现变量共享【精】

一 为什么要用变量共享

当我们有一个非常庞大的模型的时候免不了需要进行大量的变量共享,而且有时候还希望能够在一个地方初始化所有的变量;


假设定义一个图片滤波器,my_image_filter(input_image)

Tensoeflow 中使用tf.variable_scope()和tf.get_variable()实现变量共享【精】

有两个image想要使用同一组参数的同一个滤波器,可以调用该滤波器函数两次,但这样做会产生两组变量,造成资源浪费

Tensoeflow 中使用tf.variable_scope()和tf.get_variable()实现变量共享【精】

一种变量共享的方式是:通过独立的代码来定义并使用,如可以通过外部定义字典,实现变量的共享;在字典中定义并初始化变量,多次调用共用字典中定义的变量;但这样做缺点是破坏了模块的封装性

Tensoeflow 中使用tf.variable_scope()和tf.get_variable()实现变量共享【精】

tensorflow中提供了tf.variable_scope()和tf.get_variable()方法来实现变量共享。

Tensoeflow 中使用tf.variable_scope()和tf.get_variable()实现变量共享【精】

Tensoeflow 中使用tf.variable_scope()和tf.get_variable()实现变量共享【精】

下面将分别介绍tf.variable_scope()和tf.get_variable()

二  tf.get_variable()

tf.get_variable()的作用:①定义一个新变量

                                          ②得到一个之前定义过的变量

case 1 :定义一个新变量

        -      当tf.get_variable_scope().reuse == False时,变量作用域就是为创建新变量所设置的。

      w = tf.get_variable("weights", kernel_shape,initializer=tf.random_normal_initializer())

      当创建重名的新变量时,会引发ValueError

Tips:Tensorflow中定义变量的方法只有两个tf.get_variable()和 tf.Variable(),这两个方法的区别会写在文章最后。

case 2:得到一个之前定义过的变量

       -       当tf.get_variable_scope().reuse == True时,变量作用域就是为重用变量所设置的。

      w = tf.get_variable("weight")

      如果想要重用的变量不存在时,会抛出ValueError,如果找到就返回这个变量。

Tensoeflow 中使用tf.variable_scope()和tf.get_variable()实现变量共享【精】

三 tf.variable_scope()

tf.name_scope()及tf.variable_scope()的作用都是为了不传引用而访问跨代码区域变量(不同变量作用域)的一种方式,其内部功能是在其代码块内显式创建的变量,计算节点(如Add)都会带上scope前缀;

1  tf.name_scope()和tf.variable_scope()区别

tf.name_scope()不会为tf.get_variable()方法创建或复用的变量加scope前缀,会给tf.Variable()方法创建的变量加scope前缀;

tf.variable_scope()都会给tf.get_variable()和tf.Variable()方法加scope前缀;

Tensoeflow 中使用tf.variable_scope()和tf.get_variable()实现变量共享【精】