如何保存Keras模型?

时间:2018-12-07 09:51:11

标签: tensorflow keras finetunning

我正在微调数据集上的Universal-Sentence-encoder-large。但是在训练模型后,我无法将其保存到硬盘中。这是我的代码:

from tensorflow.python.keras import layers
from tensorflow.python.keras.models import Model
import tensorflow_hub as hub

def create_model():
    module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/3"
    embed = hub.Module(module_url)
    embed_size = embed.get_output_info_dict()['default'].get_shape()[1].value
    def UniversalEmbedding(x):
        return embed(tf.squeeze(tf.cast(x, tf.string)), signature="default", as_dict=True)["default"]

    input_text = layers.Input(shape=(1,), dtype=tf.string)
    embedding = layers.Lambda(UniversalEmbedding, output_shape=(embed_size,))(input_text)
    dense = layers.Dense(256, activation='relu')(embedding)
    pred = layers.Dense(2, activation='softmax')(dense)
    model = Model(inputs=[input_text], outputs=pred)
    return model

model = create_model()
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

from tensorflow.python.keras import backend as K
import pickle

with tf.Session() as session:
    K.set_session(session)
    session.run(tf.global_variables_initializer())
    session.run(tf.tables_initializer())
    history = model.fit(train_text, 
            train_label,
            validation_data=(test_text, test_label),
            epochs=2,
            batch_size=16)
    model.save('./model.h5')

这是我遇到的错误:

TypeError: can't pickle _thread.RLock objects

请帮助我保存模型。

0 个答案:

没有答案
相关问题