如何保存代表在Tensorflow中构建的神经网络的对象

时间:2019-06-06 02:22:27

标签: python oop tensorflow deep-learning save

我是Tensorflow的新手,正在github上玩一些代码。该代码为神经网络创建了一个类,其中包含构造网络,制定损失函数,训练网络,执行预测等方法。

骨架代码如下所示:

class NeuralNetwork:
    def __init__(...):

    def initializeNN():

    def trainNN():

    def predictNN():

等神经网络是用Tensorflow构造的,因此,类定义及其方法使用tensorflow语法。

现在,在脚本的主要部分中,我将通过

创建此类的实例
model = NeuralNetwork(...)

,并使用诸如model.predict之类的模型方法来制作图。

由于训练神经网络需要很长时间,因此我想保存对象“模型”以备将来使用,并可以调用其方法。我试过泡菜和莳萝,但都失败了。对于泡菜,我得到了错误:

TypeError:无法腌制_thread.RLock对象

在莳萝中,我得到了:

TypeError:无法腌制SwigPyObject对象

关于如何保存对象并仍然能够调用其方法的任何建议?这很重要,因为将来我可能希望对一组不同的点进行预测。

谢谢!

1 个答案:

答案 0 :(得分:0)

您应该执行以下操作:

# Build the graph
model = NeuralNetwork(...)
# Create a train saver/loader object
saver = tf.train.Saver()
# Create a session
with tf.Session() as sess:
    # Train the model in the same way you are doing it currently
    model.train_model()
    # Once you are done training, just save the model definition and it's learned weights
    saver.save(sess, save_path)

然后,您完成了。然后,当您想再次使用该模型时,您可以做的是:

# Build the graph
model = NeuralNetwork()
# Create a train saver/loader object
loader = tf.train.Saver()
# Create a session
with tf.Session() as sess:
    # Load the model variables
    loader.restore(sess, save_path)
    # Train the model again for example
    model.train_model()
相关问题