tensorflow-在同一会话中加载两个模型

时间:2018-07-07 15:27:34

标签: tensorflow

我想在同一张量流会话中加载两个模型,我想从第一个模型中检索变量,并且想使用第二个模型中的层。这是我的代码:

def extract_img_features_attention(img_paths, demo=False):
    """
    - Runs every image in "img_paths" through the pretrained CNN and
    saves their respective feature array (the third-to-last layer
    of the CNN transformed to 64x300) to disk.
    """ 

    tf.reset_default_graph()

    W_img = tf.get_variable("img_transform/W_img", shape=[2048,300])
    b_img = tf.get_variable("img_transform/b_img", shape=[1,300])

    saver = tf.train.Saver()
    tf.train.import_meta_graph('models/LSTMs/best_model/model.meta')
    img_id_2_feature_vector = {}

    with tf.Session() as sess:
        saver.restore(sess, "models/LSTMs/best_model/model")

        # load the Inception-V3 model:
        load_pretrained_CNN()
        # get the third-to-last layer in the Inception-V3 model (a tensor
        # of shape (1, 8, 8, 2048)):
        img_features_tensor = sess.graph.get_tensor_by_name("mixed_10/join:0")
        # reshape the tensor to shape (64, 2048):
        img_features_tensor = tf.reshape(img_features_tensor, (64, 2048))

        # apply the img transorm (get a tensor of shape (64, 300)):
        linear_transform = tf.matmul(img_features_tensor, W_img) + b_img
        img_features_tensor = tf.nn.sigmoid(linear_transform)
        print img_features_tensor
        for step, img_path in enumerate(img_paths):
            if step % 10 == 0:
                print step
                log(str(step))

                # read the image:
                img_data = gfile.FastGFile(img_path, "rb").read()
                # get the img features (np array of shape (64, 300)):
                img_features = sess.run(img_features_tensor,
                        feed_dict={"DecodeJpeg/contents:0": img_data})
            #img_features = np.float16(img_features)
            else:
                if not demo:
                    # get the image id:
                    img_name = img_path.split("/")[3]
                    img_id = img_name.split("_")[2].split(".")[0].lstrip("0")
                    img_id = int(img_id)
                else: # (if demo)
                    # we're only extracting features for one img, (arbitrarily)
                    # set the img id to -1:
                    img_id = -1

                # save the img features to disk:
                img_id_2_feature_vector[img_id] = img_features


        return img_id_2_feature_vector

因此,我想从第一个模型中检索变量W_imgb_img,并从第二个模型中使用以下函数load_pretrained_CNN()

def load_pretrained_CNN():
    """
    - Loads the pretrained Inception-V3 model.
    """

    # define where the pretrained inception model is located:
    model_dir = "inception"

    path_to_saved_model = os.path.join(model_dir,
            "classify_image_graph_def.pb")

    with gfile.FastGFile(path_to_saved_model, "rb") as model_file:
        # create an empty GraphDef object:
        graph_def = tf.GraphDef()

        # import the model definitions:
        graph_def.ParseFromString(model_file.read())
        _ = tf.import_graph_def(graph_def, name="")

我想获得模型的某些层次;不幸的是,运行这段代码我得到了全部零功能,所以我想我做错了什么。我希望有人可以帮助我了解如何解决它。谢谢。

0 个答案:

没有答案