如何在Java中
问题描述:
使用TensorFlow LinearClassifier在Python我已经训练了TensorFlow LinearClassifier并保存它喜欢:如何在Java中
model = tf.contrib.learn.LinearClassifier(feature_columns=columns)
model.fit(input_fn=train_input_fn, steps=100)
model.export_savedmodel(export_dir, parsing_serving_input_fn)
通过使用TensorFlow的Java API我可以用在Java中加载这个模型:
model = SavedModelBundle.load(export_dir, "serve");
看来我应该能够运行使用的东西的图形像
model.session().runner().feed(???, ???).fetch(???, ???).run()
但是我应该从图表中提取/提取哪些变量名称/数据以提供其功能并获取类别的概率?据我所知,Java文档缺少这些信息。
答
要馈入的节点的名称取决于parsing_serving_input_fn
的作用,特别是它们应该是由parsing_serving_input_fn
返回的Tensor
对象的名称。要获取的节点名称取决于您预测的内容(如果使用来自Python的模型,则参数为model.predict()
)。
也就是说,TensorFlow保存的模型格式确实包含模型的“签名”(即可以提供或提取的所有Tensors的名称)作为可以提供提示的元数据。
从Python中可以加载保存的模型,并使用类似列出其签名:
with tf.Session() as sess:
md = tf.saved_model.loader.load(sess, ['serve'], export_dir)
sig = md.signature_def[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
print(sig)
这将打印出类似这样:
inputs {
key: "inputs"
value {
name: "input_example_tensor:0"
dtype: DT_STRING
tensor_shape {
dim {
size: -1
}
}
}
}
outputs {
key: "scores"
value {
name: "linear/binary_logistic_head/predictions/probabilities:0"
dtype: DT_FLOAT
tensor_shape {
dim {
size: -1
}
dim {
size: 2
}
}
}
}
method_name: "tensorflow/serving/classify"
暗示你想用Java做什么是:
Tensor t = /* Tensor object to be fed */
model.session().runner().feed("input_example_tensor", t).fetch("linear/binary_logistic_head/predictions/probabilities").run()
您还可以纯粹在Java中提取此信息如果y我们的计划包括TensorFlow协议缓冲区生成的Java代码(封装在org.tensorflow:proto
artifact)使用这样的事情:
// Same as tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
// in Python. Perhaps this should be an exported constant in TensorFlow's Java API.
final String DEFAULT_SERVING_SIGNATURE_DEF_KEY = "serving_default";
final SignatureDef sig =
MetaGraphDef.parseFrom(model.metaGraphDef())
.getSignatureDefOrThrow(DEFAULT_SERVING_SIGNATURE_DEF_KEY);
你将不得不补充:
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
因为Java API和saved-模型格式有点新,在文档中有很大的改进空间。
希望有所帮助。
感谢您的回答!这看起来有希望。但是,我必须为input_example_tensor提供什么?例如,考虑[TensorFlow Iris分类教程](https://www.tensorflow.org/get_started/tflearn):导出该模型会得到与您提供的相同的签名(输入,dtype:DT_STRING),但我需要以某种方式喂这个模型4个数字。 –
现在我明白了,模型需要一个序列化的示例协议缓冲区,但是此时(1)协议缓冲区在Java中不可用,(2)使用DataType字符串创建Tensors(这是序列化示例所需的)是尚未支持。 :( –
仅供参考:在[org.tensorflow:proto](https://maven-badges.herokuapp.com/maven-central/org.tensorflow/proto)中提供了Java中的协议缓冲区maven artifact([javadoc]( http://javadoc.io/doc/org.tensorflow/proto/)) DataType.STRING张量支持标量(即,一个字符串),但不是多维数组呢(https://github.com/tensorflow/tensorflow/issues/8531) 希望有帮助。 – ash