如何在自定义数据集上从头开始训练初始模型后对测试图像进​​行预测?

时间:2016-12-19 04:37:00

标签: machine-learning dataset computer-vision tensorflow

我刚从头开始在我的自定义数据集上训练了onceptionv3(1675个火车图像,400个验证图像,2个类):

  1. 我不知道如何使用我新训练的模型对我的测试图像进​​行预测。(在哪里指向模型的label_image.py)

  2. 我新培训的模型在哪里得到了保存? 关于我的设置/运行的一些元数据:---

  3. 我在train_dir中生成了以下文件:

    • events.out.tfevents.1481980070.airig-灵-7559(4.9GB)
    • graph.pbtxt(18.5MB)
    • 和一堆 model.ckpt- .meta和model.ckpt - .index 文件
  4. 在运行火车脚本后,我得到了: -

    ....
    INFO:tensorflow:Stopping Training.
    INFO:tensorflow:Finished training! Saving model to disk.
    

    运行eval脚本后,我得到了: -

    .....
    INFO:tensorflow:Evaluation [0/25]
    INFO:tensorflow:Evaluation [1/25]
    INFO:tensorflow:Evaluation [2/25]
    INFO:tensorflow:Evaluation [3/25]
    INFO:tensorflow:Evaluation [5/25]
    INFO:tensorflow:Evaluation [5/25]
    INFO:tensorflow:Evaluation [6/25]
    INFO:tensorflow:Evaluation [7/25]
    INFO:tensorflow:Evaluation [8/25]
    INFO:tensorflow:Evaluation [9/25]
    INFO:tensorflow:Evaluation [10/25]
    INFO:tensorflow:Evaluation [11/25]
    INFO:tensorflow:Evaluation [13/25]
    INFO:tensorflow:Evaluation [13/25]
    INFO:tensorflow:Evaluation [14/25]
    INFO:tensorflow:Evaluation [15/25]
    INFO:tensorflow:Evaluation [16/25]
    INFO:tensorflow:Evaluation [17/25]
    INFO:tensorflow:Evaluation [18/25]
    INFO:tensorflow:Evaluation [19/25]
    INFO:tensorflow:Evaluation [20/25]
    INFO:tensorflow:Evaluation [21/25]
    INFO:tensorflow:Evaluation [22/25]
    INFO:tensorflow:Evaluation [23/25]
    INFO:tensorflow:Evaluation [25/25]
    I tensorflow/core/kernels/logging_ops.cc:79] eval/Recall@5[1]
    I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[1]
    INFO:tensorflow:Finished evaluation at 2016-12-19-03:59:04
    

2 个答案:

答案 0 :(得分:0)

  

我新训练的模型在哪里保存?

您完整的TensorFlow图(即所有变量,操作,集合等)都保存在.meta文件中。 .cpkt文件是一个检查点文件。该文件包含权重,偏差,渐变和所有其他变量的所有值。

  

我不知道如何使用新近训练的模型对测试图像进​​行预测。(在哪里将label_image.py指向模型)

要恢复训练有素的模型,请使用:

withth tf.Session() as sess:    
   saver = tf.train.import_meta_graph('my-model-1000.meta')
   saver.restore(sess,tf.train.latest_checkpoint('./'))

请注意,已设置图层名称。例如:

out_layer = tf.layers.dense(inputs=layer_1, units=6, name='prediction')

现在您可以将其用于预测:

sess.run(prediction, feed_dict)

其中预测是网络最后一层的输出变量的名称。

答案 1 :(得分:0)

一般情况下:

  1. 模型另存为检查点文件(model.ckpt)

  2. 您可以通过传递模型的路径作为保护对象的参数以及会话来加载模型:Saver.restore(sess, "path to you model.ckpt files")

  3. 因为要还原模型,请勿初始化全局变量,请勿执行:sess.run(tf.globale_variables_intializer().init())

  4. 还原模型后,只需执行以下操作:Sess.run(prediction, feed-dict{your input image here})

  5. 因为如上所述,它是2类分类,所以5的输出应该是2维向量,例如[0.22331,-23.21],它表示图像是哪个类的概率,因此您只需需要使用numpy来运行,例如:np.argmax([0.22331,-23.21]),这将为您提供第一个元素的索引,因为第一个元素大于第二个元素,这意味着您的图像被显示的可能性更高来自头等舱。