其中()需要1到2个位置参数,但有3个被给出
问题描述:
我有三个数组,X
,Y
和Z
。如果Z
的对应元素为真,我想要放入res
和X
的元素;否则,我会放入一个来自Y
的元素。其中()需要1到2个位置参数,但有3个被给出
我实现这样的:
X = tf.constant([[1, 2], [3, 4]])
Y = tf.constant([[5, 6], [7, 8]])
Z = tf.constant([[True, False], [False, True]], tf.bool)
res = tf.where(Z, X, Y)
print(res.eval())
不过,我得到这个错误:
TypeError: where() takes from 1 to 2 positional arguments but 3 were given
我看着tf.where
的definiton从here和我的使用似乎罚款。
任何想法可能是什么问题?
答
我怀疑你使用的是旧版本的TensorFlow:
在r0.10 tf.where
过去只有2个参数。
tf.where(input, name=None)
你可以试试'tf.where(Z,X = X,Y = Y)' – pramod
您的代码工作正常TensorFlow 1.0.1,所以我很好奇:这你使用TF版本? – npf