如何在Tensorflow中从张量中获取特定行?
问题描述:
我已经张量的定义如下:如何在Tensorflow中从张量中获取特定行?
idx = tf.constant([0, 2])
现在我想利用temp_var
一个子集在那些:
temp_var = tf.Variable(initial_value=np.asarray([[1, 2, 3],[4, 5, 6],[7, 8, 9],[10, 11, 12]]))
我也有行索引的阵列,以从张量中获取指标即idx
我知道,要采取单一索引或切片,我们可以做这样的事情
temp_var[single_row_index, :]
或
temp_var[start:end, :]
但如何读取行由idx
阵列表示? 类似于temp_var[idx, :]
?
答
tf.gather()
op正好满足您的需求:它从矩阵(或从N维张量中选择一般(N-1)维片)中选择行。以下是它如何在你的情况下工作:
temp_var = tf.Variable([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]))
idx = tf.constant([0, 2])
rows = tf.gather(temp_var, idx)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
print(sess.run(rows)) # ==> [[1, 2, 3], [7, 8, 9]]