如何使用hdfs中的tensorflow读/写文件?

时间:2017-05-10 11:15:30

标签: tensorflow deep-learning

我想使用tensorflow来编写和读取hdfs中的文件。我使用了' pip install ten ......'安装tensorflow.And当我从hdfs读取文件时,它确实有效,只需停在那里并且没有错误回复。 我是否需要通过./configure和bazel build安装tensorflow?它是否必须像这样安装才能支持hdfs?

这是我的代码写入本地文件系统文件:

with tf.Session(graph=graph,config=config) as sess:
    sess.run(init)
    summary_writer = tf.summary.FileWriter('./mnist_logs2/', graph_def=sess.graph_def)
    for i in range(2000000):
        batch=mnist.train.next_batch(10000)
        train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.8})

        if i%100==0:
            acc_test=sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0})
            print("step %d,test accuracy %g"%(i,acc_test))
            if acc_test>0.993:
                break

    saver_path=saver.save(sess,'/home/test/mnist/model.ckpt')

    print("test accuracy %g"%accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))

这是我的代码我写文件到hdfs,只改变路径:

with tf.Session(graph=graph,config=config) as sess:
    sess.run(init)
    summary_writer = tf.summary.FileWriter('hdfs://user/mlp/zpc/mnist_logs2/', graph_def=sess.graph_def)
    for i in range(2000000):
        batch=mnist.train.next_batch(10000)
        train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.8})

        if i%100==0:
            acc_test=sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0})
            print("step %d,test accuracy %g"%(i,acc_test))
            if acc_test>0.993:
                break

    saver_path=saver.save(sess,'hdfs://user/mlp/zpc/mnist_logs2/model.ckpt')

    print("test accuracy %g"%accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))

当我运行写入hdfs的代码时,我运行如下:

CLASSPATH=$($HADOOP_HDFS_HOME/bin/hadoop classpath --glob) python mnist_linux.py

1 个答案:

答案 0 :(得分:0)

在 Tensorflow 2.x 中,您可以使用 model.save() 库来保存 hdf5 文件。

#Create a model
model = create_model()
#Train the model
model.fit(train_images, train_labels, epochs=10)

# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')
#Restore entire model
new_model = tf.keras.models.load_model('my_model.h5')
相关问题