Keras模型预测在使用张量输入时发生变化

Keras模型预测在使用张量输入时发生变化

问题描述:

我想使用来自Keras的预训练Inception-V3模型,与来自Tensorflow的输入管道(即通过张量输入网络输入)配对。 这是我的代码:Keras模型预测在使用张量输入时发生变化

import tensorflow as tf 
from keras.preprocessing.image import load_img, img_to_array 
from keras.applications.inception_v3 import InceptionV3, decode_predictions, preprocess_input 
import numpy as np 

img_sample_filename = 'my_image.jpg' 
img = img_to_array(load_img(img_sample_filename, target_size=(299,299))) 
img = preprocess_input(img) 
img_tensor = tf.constant(img[None,:]) 

# WITH KERAS: 
model = InceptionV3() 
pred = model.predict(img[None,:]) 
pred = decode_predictions(np.asarray(pred)) #<------ correct prediction! 
print(pred) 

# WITH TF: 
model = InceptionV3(input_tensor=img_tensor) 
init = tf.global_variables_initializer() 

with tf.Session() as sess: 
    from keras import backend as K 
    K.set_session(sess) 

    sess.run(init) 
    pred = sess.run([model.output], feed_dict={K.learning_phase(): 0}) 

pred = decode_predictions(np.asarray(pred)[0]) 
print(pred)        #<------ wrong prediction! 

其中my_image.jpg是我要分类的任何图像。

如果我用keras'predict函数来计算预测,结果是正确的。但是,如果我将张量从图像阵列中取出并通过input_tensor=...将张量输入到模型,然后通过sess.run([model.output], ...)计算预测结果是非常错误的。

不同行为的原因是什么?我不能以这种方式使用Keras网络吗?

最后,通过InceptionV3代码挖,我发现这个问题:sess.run(init)覆盖在InceptionV3的构造函数加载weigts。 我发现这个问题的-dirty-修复是在sess.run(init)之后重新加载权重。

from keras.applications.inception_v3 import get_file, WEIGHTS_PATH 

with tf.Session() as sess: 
    from keras import backend as K 
    K.set_session(sess) 

    sess.run(init) 
    weights_path = get_file(
       'inception_v3_weights_tf_dim_ordering_tf_kernels.h5', 
       WEIGHTS_PATH, 
       cache_subdir='models', 
       md5_hash='9a0d58056eeedaa3f26cb7ebd46da564') 
    model.load_weights(weights_path) 
    pred = sess.run([model.output], feed_dict={K.learning_phase(): 0}) 

注意:为get_file()的参数直接从InceptionV3的构造和拍摄,在我的例子,都仅限于与image_data_format='channels_last'还原整个网络的权重。 我在this Github issue询问是否有更好的解决方法。我会更新这个答案,如果我应该得到更多的信息。

+0

您可以始终初始化变量子集,而不是初始化每个变量(包括模型预先训练的权重)。 – abhinavkulkarni