深入浅出理解tf.transpose
一、tf.transpose()实例
本博客主要讲解tensorflow中的transpose()函数,初次接触此函数时必然对维度变换不太理解,看了很多博客发现只有代码示例却没有解析操作的细节,所以记录下来帮助自己理解。
使用transpose()函数时,需要给出matrix和perm两个参数,matrix是待变换的张量,perm是张量变换方式的具体参数。
tf.transpose(
matrix,
perm=None
)
实现的操作是将matrix进行转置,并且根据perm参数重新排列输出维度。这是对数据的维度的进行操作的形式。
1.1 说明
- 输出数据tensor的第i维将根据perm[i]指定。比如,如果perm没有给定,那么默认是perm = [n-1, n-2, …, 0],其中rank(a) = n。
- 默认情况下,对于二维输入数据,其实就是常规的矩阵转置操作。
1.2 代码演示
import tensorflow as tf
sess = tf.Session()
input_data = tf.constant([[1, 2, 3], [4, 5, 6]])
print(sess.run(tf.transpose(input_data)))
# [[1 4]
# [2 5]
# [3 6]]
print(sess.run(input_data))
# [[1 2 3]
# [4 5 6]]
print(sess.run(tf.transpose(input_data, perm=[1, 0])))
# [[1 4]
# [2 5]
# [3 6]]
input_data = tf.constant([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]])
print('input_data shape: ', sess.run(tf.shape(input_data)))
# [1, 4, 3]
output_data = tf.transpose(input_data, perm=[2, 1, 0])
print('output_data shape: ', sess.run(tf.shape(output_data)))
# [3, 4, 1]
print(sess.run(output_data))
# [[[ 1]
# [ 4]
# [ 7]
# [10]]
#
# [[ 2]
# [ 5]
# [ 8]
# [11]]
#
# [[ 3]
# [ 6]
# [ 9]
# [12]]]
sess.close()
"""输入参数:
● a: 一个Tensor。
● perm: 一个对于a的维度的重排列组合。
● name:(可选)为这个操作取一个名字。
输出参数:
● 一个经过翻转的Tensor。"""
二、理解转置的意义
上例中将一个shape为 [1, 4, 3]的tensor转置成了[3, 4, 1]的shape,那么矩阵中的元素是如何变换的呢?
2.1 通用推导
上例中的操作其实是讲第一维和第三维互换,即将每个元素的下标都按照下式进行互换。
2.2 感性理解
将第一维作X轴,第二维作Y轴,第三维作Z轴。
2.2.1 沿坐标轴翻转
第一种理解方法是,由于只有第一维和第三维互换,因此第二维不变,只需要沿着第二维进行旋转即可。
2.2.2 坐标系各轴互换
第一种理解方法是,由于只有第一维和第三维互换,只需要将X轴和Z轴在坐标系中互换位置,矩阵按照新坐标系读取即可。