TensorFlow(神经网络)FC输出大小
问题描述:
不确定我的问题是否是特定于TF或仅仅是NNs,但是我使用tensorflow创建了CNN。而且我很难理解为什么我的完全连接层上的输出的大小就是这样。TensorFlow(神经网络)FC输出大小
X = tf.placeholder(tf.float32, [None, 32, 32, 3])
y = tf.placeholder(tf.int64, [None])
is_training = tf.placeholder(tf.bool)
# define model
def complex_model(X,y,is_training):
# conv layer
wconv_1 = tf.get_variable('wconv_1', [7 ,7 ,3, 32])
bconv_1 = tf.get_variable('bconv_1', [32])
# affine layer 1
w1 = tf.get_variable('w1', [26*26*32//4, 1024]) #LINE 13
b1 = tf.get_variable('b1', [1024])
# batchnorm params
bn_gamma = tf.get_variable('bn_gamma', shape=[32]) #scale
bn_beta = tf.get_variable('bn_beta', shape=[32]) #shift
# affine layer 2
w2 = tf.get_variable('w2', [1024, 10])
b2 = tf.get_variable('b2', [10])
c1_out = tf.nn.conv2d(X, wconv_1, strides=[1, 1, 1, 1], padding="VALID") + bconv_1
activ_1 = tf.nn.relu(c1_out)
mean, var = tf.nn.moments(activ_1, axes=[0,1,2], keep_dims=False)
bn = tf.nn.batch_normalization(act_1, mean, var, bn_gamma, bn_beta, 1e-6)
mp = tf.nn.max_pool(bn, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
affine_in_flat = tf.reshape(mp, [-1, 26*26*32//4])
affine_1 = tf.matmul(affine_in_flat, w1) + b1
activ_2 = tf.nn.relu(affine_1)
affine_2 = tf.matmul(activ_2, w2) + b2
return affine_2
#print(affine_2.shape)
在线路13,其中i设定w1的值i本来期望只是把:
w1 = tf.get_variable('w1', [26*26*32, 1024])
但是如果我运行与线上面的代码示出和
affine_in_flat = tf.reshape(mp, [-1, 26*26*32])
我的输出大小是16,10而不是64,10这是我期望给予以下初始化:
x = np.random.randn(64, 32, 32,3)
with tf.Session() as sess:
with tf.device("/cpu:0"): #"/cpu:0" or "/gpu:0"
tf.global_variables_initializer().run()
#print("train", x.size, is_training, y_out)
ans = sess.run(y_out,feed_dict={X:x,is_training:True})
%timeit sess.run(y_out,feed_dict={X:x,is_training:True})
print(ans.shape)
print(np.array_equal(ans.shape, np.array([64, 10])))
有人可以告诉我为什么我需要将w1 [0]的大小除以4吗?
答
添加为bn
和mp
print
语句,我得到:
bn
:<tf.Tensor 'batchnorm/add_1:0' shape=(?, 26, 26, 32) dtype=float32>
mp
:<tf.Tensor 'MaxPool:0' shape=(?, 13, 13, 32) dtype=float32>
这似乎是由于上最大的strides=[1, 2, 2, 1]
池(但要保持26, 26
你还需要padding='SAME'
)。