对此处使用验证集感到困惑

时间:2018-08-15 21:07:30

标签: validation tensorflow train-test-split

对于main.py项目的px2graph,培训和验证部分如下所示:

splits = [s for s in ['train', 'valid'] if opt.iters[s] > 0]
start_round = opt.last_round - opt.num_rounds

# Main training loop
for round_idx in range(start_round, opt.last_round):
    for split in splits:

        print("Round %d: %s" % (round_idx, split))
        loader.start_epoch(sess, split, train_flag, opt.iters[split] * opt.batchsize)

        flag_val = split == 'train'

        for step in tqdm(range(opt.iters[split]), ascii=True):
            global_step = step + round_idx * opt.iters[split]
            to_run = [sample_idx, summaries[split], loss, accuracy]
            if split == 'train': to_run += [optim]

            # Do image summaries at the end of each round
            do_image_summary = step == opt.iters[split] - 1
            if do_image_summary: to_run[1] = image_summaries[split]

            # Start with lower learning rate to prevent early divergence
            t = 1/(1+np.exp(-(global_step-5000)/1000))
            lr_start = opt.learning_rate / 15
            lr_end = opt.learning_rate
            tmp_lr = (1-t) * lr_start + t * lr_end

            # Run computation graph
            result = sess.run(to_run, feed_dict={train_flag:flag_val, lr:tmp_lr})

            out_loss = result[2]
            out_accuracy = result[3]
            if sum(out_loss) > 1e5:
                print("Loss diverging...exiting before code freezes due to NaN values.")
                print("If this continues you may need to try a lower learning rate, a")
                print("different optimizer, or a larger batch size.")
                return

            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            print("{}: step {}, loss {:g}, acc {:g}".format(time_str, global_step, out_loss, out_accuracy))

            # Log data
            if split == 'valid' or (split == 'train' and step % 20 == 0) or do_image_summary:
                writer.add_summary(result[1], global_step)
                writer.flush()

    # Save training snapshot
    saver.save(sess, 'exp/' + opt.exp_id + '/snapshot')
    with open('exp/' + opt.exp_id + '/last_round', 'w') as f:
        f.write('%d\n' % round_idx)

似乎作者仅获得验证集每批次的结果。我想知道,如果我想观察模型是在改进还是达到最佳性能,是否应该在整个验证集中使用结果?

1 个答案:

答案 0 :(得分:0)

如果验证集足够小,我们可以在训练期间观察整个验证集的损失,准确性。但是,如果验证集太大,则最好分批计算多个步骤的验证结果。