其中()需要1到2个位置参数,但有3个被给出

问题描述:

我有三个数组,X,YZ。如果Z的对应元素为真,我想要放入resX的元素;否则,我会放入一个来自Y的元素。其中()需要1到2个位置参数,但有3个被给出

我实现这样的:

X = tf.constant([[1, 2], [3, 4]]) 
Y = tf.constant([[5, 6], [7, 8]]) 
Z = tf.constant([[True, False], [False, True]], tf.bool) 
res = tf.where(Z, X, Y) 
print(res.eval()) 

不过,我得到这个错误:

TypeError: where() takes from 1 to 2 positional arguments but 3 were given 

我看着tf.where的definiton从here和我的使用似乎罚款。

任何想法可能是什么问题?

+0

你可以试试'tf.where(Z,X = X,Y = Y)' – pramod

+0

您的代码工作正常TensorFlow 1.0.1,所以我很好奇:这你使用TF版本? – npf

我怀疑你使用的是旧版本的TensorFlow:

在r0.10 tf.where过去只有2个参数。

tf.where(input, name=None)

https://www.tensorflow.org/versions/r0.10/api_docs/python/math_ops/sequence_comparison_and_indexing#where

+0

我使用'0.8.0',可能是因为我用'pip'安装了它。 – octavian

+0

那有道理。您应该安装最新版本:https://www.tensorflow.org/install/ – npf