我想在同一张量流会话中加载两个模型,我想从第一个模型中检索变量,并且想使用第二个模型中的层。这是我的代码:
def extract_img_features_attention(img_paths, demo=False):
"""
- Runs every image in "img_paths" through the pretrained CNN and
saves their respective feature array (the third-to-last layer
of the CNN transformed to 64x300) to disk.
"""
tf.reset_default_graph()
W_img = tf.get_variable("img_transform/W_img", shape=[2048,300])
b_img = tf.get_variable("img_transform/b_img", shape=[1,300])
saver = tf.train.Saver()
tf.train.import_meta_graph('models/LSTMs/best_model/model.meta')
img_id_2_feature_vector = {}
with tf.Session() as sess:
saver.restore(sess, "models/LSTMs/best_model/model")
# load the Inception-V3 model:
load_pretrained_CNN()
# get the third-to-last layer in the Inception-V3 model (a tensor
# of shape (1, 8, 8, 2048)):
img_features_tensor = sess.graph.get_tensor_by_name("mixed_10/join:0")
# reshape the tensor to shape (64, 2048):
img_features_tensor = tf.reshape(img_features_tensor, (64, 2048))
# apply the img transorm (get a tensor of shape (64, 300)):
linear_transform = tf.matmul(img_features_tensor, W_img) + b_img
img_features_tensor = tf.nn.sigmoid(linear_transform)
print img_features_tensor
for step, img_path in enumerate(img_paths):
if step % 10 == 0:
print step
log(str(step))
# read the image:
img_data = gfile.FastGFile(img_path, "rb").read()
# get the img features (np array of shape (64, 300)):
img_features = sess.run(img_features_tensor,
feed_dict={"DecodeJpeg/contents:0": img_data})
#img_features = np.float16(img_features)
else:
if not demo:
# get the image id:
img_name = img_path.split("/")[3]
img_id = img_name.split("_")[2].split(".")[0].lstrip("0")
img_id = int(img_id)
else: # (if demo)
# we're only extracting features for one img, (arbitrarily)
# set the img id to -1:
img_id = -1
# save the img features to disk:
img_id_2_feature_vector[img_id] = img_features
return img_id_2_feature_vector
因此,我想从第一个模型中检索变量W_img
和b_img
,并从第二个模型中使用以下函数load_pretrained_CNN()
def load_pretrained_CNN():
"""
- Loads the pretrained Inception-V3 model.
"""
# define where the pretrained inception model is located:
model_dir = "inception"
path_to_saved_model = os.path.join(model_dir,
"classify_image_graph_def.pb")
with gfile.FastGFile(path_to_saved_model, "rb") as model_file:
# create an empty GraphDef object:
graph_def = tf.GraphDef()
# import the model definitions:
graph_def.ParseFromString(model_file.read())
_ = tf.import_graph_def(graph_def, name="")
我想获得模型的某些层次;不幸的是,运行这段代码我得到了全部零功能,所以我想我做错了什么。我希望有人可以帮助我了解如何解决它。谢谢。