如何在Java中

问题描述:

使用TensorFlow LinearClassifier在Python我已经训练了TensorFlow LinearClassifier并保存它喜欢:如何在Java中

model = tf.contrib.learn.LinearClassifier(feature_columns=columns) 
model.fit(input_fn=train_input_fn, steps=100) 
model.export_savedmodel(export_dir, parsing_serving_input_fn) 

通过使用TensorFlow的Java API我可以用在Java中加载这个模型:

model = SavedModelBundle.load(export_dir, "serve"); 

看来我应该能够运行使用的东西的图形像

model.session().runner().feed(???, ???).fetch(???, ???).run() 

但是我应该从图表中提取/提取哪些变量名称/数据以提供其功能并获取类别的概率?据我所知,Java文档缺少这些信息。

要馈入的节点的名称取决于parsing_serving_input_fn的作用,特别是它们应该是由parsing_serving_input_fn返回的Tensor对象的名称。要获取的节点名称取决于您预测的内容(如果使用来自Python的模型,则参数为model.predict())。

也就是说,TensorFlow保存的模型格式确实包含模型的“签名”(即可以提供或提取的所有Tensors的名称)作为可以提供提示的元数据。

从Python中可以加载保存的模型,并使用类似列出其签名:

with tf.Session() as sess: 
    md = tf.saved_model.loader.load(sess, ['serve'], export_dir) 
    sig = md.signature_def[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 
    print(sig) 

这将打印出类似这样:

inputs { 
    key: "inputs" 
    value { 
    name: "input_example_tensor:0" 
    dtype: DT_STRING 
    tensor_shape { 
     dim { 
     size: -1 
     } 
    } 
    } 
} 
outputs { 
    key: "scores" 
    value { 
    name: "linear/binary_logistic_head/predictions/probabilities:0" 
    dtype: DT_FLOAT 
    tensor_shape { 
     dim { 
     size: -1 
     } 
     dim { 
     size: 2 
     } 
    } 
    } 
} 
method_name: "tensorflow/serving/classify" 

暗示你想用Java做什么是:

Tensor t = /* Tensor object to be fed */ 
model.session().runner().feed("input_example_tensor", t).fetch("linear/binary_logistic_head/predictions/probabilities").run() 

您还可以纯粹在Java中提取此信息如果y我们的计划包括TensorFlow协议缓冲区生成的Java代码(封装在org.tensorflow:proto artifact)使用这样的事情:

// Same as tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 
// in Python. Perhaps this should be an exported constant in TensorFlow's Java API. 
final String DEFAULT_SERVING_SIGNATURE_DEF_KEY = "serving_default"; 

final SignatureDef sig = 
     MetaGraphDef.parseFrom(model.metaGraphDef()) 
      .getSignatureDefOrThrow(DEFAULT_SERVING_SIGNATURE_DEF_KEY); 

你将不得不补充:

import org.tensorflow.framework.MetaGraphDef; 
import org.tensorflow.framework.SignatureDef; 

因为Java API和saved-模型格式有点新,在文档中有很大的改进空间。

希望有所帮助。

+0

感谢您的回答!这看起来有希望。但是,我必须为input_example_tensor提供什么?例如,考虑[TensorFlow Iris分类教程](https://www.tensorflow.org/get_started/tflearn):导出该模型会得到与您提供的相同的签名(输入,dtype:DT_STRING),但我需要以某种方式喂这个模型4个数字。 –

+0

现在我明白了,模型需要一个序列化的示例协议缓冲区,但是此时(1)协议缓冲区在Java中不可用,(2)使用DataType字符串创建Tensors(这是序列化示例所需的)是尚未支持。 :( –

+0

仅供参考:在[org.tensorflow:proto](https://maven-badges.herokuapp.com/maven-central/org.tensorflow/proto)中提供了Java中的协议缓冲区maven artifact([javadoc]( http://javadoc.io/doc/org.tensorflow/proto/)) DataType.STRING张量支持标量(即,一个字符串),但不是多维数组呢(https://github.com/tensorflow/tensorflow/issues/8531) 希望有帮助。 – ash