AssertionError Tensorflow

时间:2017-09-04 17:03:58

标签: python tensorflow

我尝试使用此代码 - https://github.com/KGPML/Hyperspectral

def run_training():
"""Train MNIST for a number of steps."""
# Get the sets of images and labels for training, validation, and
# test on IndianPines.

"""Concatenating all the training and test mat files"""
for i in range(TRAIN_FILES):
    Training_data = input_data.read_data_sets(os.path.join(DATA_PATH, 'Train_'+str(IMAGE_SIZE)+'_'+str(1+1)+'.mat'), 'train')

for i in range(TEST_FILES):
    Test_data = input_data.read_data_sets(os.path.join(DATA_PATH, 'Test_'+str(IMAGE_SIZE)+'_'+str(0+1)+'.mat'),'test')


# Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
# Generate placeholders for the images and labels.
    images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)

    # Build a Graph that computes predictions from the inference model.
    logits = IndianPinesMLP.inference(images_placeholder,
                             FLAGS.hidden1,
                             FLAGS.hidden2,
                             FLAGS.hidden3)

    # Add to the Graph the Ops for loss calculation.
    loss = IndianPinesMLP.loss(labels=labels_placeholder, logits=logits)

    # Add to the Graph the Ops that calculate and apply gradients.
    train_op = IndianPinesMLP.training(loss, FLAGS.learning_rate)

    # Add the Op to compare the logits to the labels during evaluation.
    eval_correct = IndianPinesMLP.evaluation(labels=labels_placeholder, logits=logits)

    # Build the summary operation based on the TF collection of Summaries.
#    summary_op = tf.merge_all_summaries()

    # Add the variable initializer Op.
    init = tf.initialize_all_variables()

    # Create a saver for writing training checkpoints.
    saver = tf.train.Saver()

    # Create a session for running Ops on the Graph.
    sess = tf.Session()

    # Instantiate a SummaryWriter to output summaries and the Graph.
#    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

    # And then after everything is built:

    # Run the Op to initialize the variables.
    sess.run(init)

    # Start the training loop.
    for step in xrange(FLAGS.max_steps):
        start_time = time.time()

        # Fill a feed dictionary with the actual set of images and labels
        # for this particular training step.
        feed_dict = fill_feed_dict(Training_data,
                                 images_placeholder,
                                 labels_placeholder)

        # Run one step of the model.  The return values are the activations
        # from the `train_op` (which is discarded) and the `loss` Op.  To
        # inspect the values of your Ops or variables, you may include them
        # in the list passed to sess.run() and the value tensors will be
        # returned in the tuple from the call.
        _, loss_value = sess.run([train_op, loss],
                               feed_dict=feed_dict)

        duration = time.time() - start_time

        # Write the summaries and print an overview fairly often.
        if step % 50 == 0:
        # Print status to stdout.
            print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
        # Update the events file.
#             summary_str = sess.run(summary_op, feed_dict=feed_dict)
#             summary_writer.add_summary(summary_str, step)
#             summary_writer.flush()

        # Save a checkpoint and evaluate the model periodically.
        if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
            saver.save(sess, '.\model-MLP-'+str(IMAGE_SIZE)+'X'+str(IMAGE_SIZE)+'.ckpt', global_step=step)

        # Evaluate against the training set.
            print('Training Data Eval:')
            do_eval(sess,
                    eval_correct,
                    images_placeholder,
                    labels_placeholder,
                    Training_data)
            print('Test Data Eval:')
            do_eval(sess,
                    eval_correct,
                    images_placeholder,
                    labels_placeholder,
                    Test_data)

并收到错误:

 ---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-23-0683f80cdbe4> in <module>()
----> 1 run_training()

    <ipython-input-22-b34daa52b702> in run_training()
         60             feed_dict = fill_feed_dict(Training_data,
         61                                      images_placeholder,
    ---> 62                                      labels_placeholder)
         63 
         64             # Run one step of the model.  The return values are the activations

如果手动运行这些部件,我没有错误:

<ipython-input-5-f04ef9a1e6b2> in fill_feed_dict(data_set, images_pl, labels_pl)
     15     # Create the feed_dict for the placeholders filled with the next
     16     # `batch size ` examples.
---> 17     images_feed, labels_feed = data_set.next_batch(batch_size)
     18     feed_dict = {
     19       images_pl: images_feed,

同样的问题在这里:

~\Path to: \Spatial_dataset.py in next_batch(self, batch_size)
     87             start = 0
     88             self._index_in_epoch = batch_size
---> 89             assert batch_size <= self._num_examples
     90         end = self._index_in_epoch
     91         return self._images[start:end], np.reshape(self._labels[start:end],len(self._labels[start:end]))

AssertionError:

当我现在运行run_training()时,会出现上述错误。

这是什么意思,我该如何解决,谷歌在这种情况下不是一个帮助。 感谢您的帮助。

1 个答案:

答案 0 :(得分:0)

主要错误是由于:

---> 89             assert batch_size <= self._num_examples

更改batch_size,使其成为训练集文件数量的一个因素(未经验证),以及训练集图像总数的一个因素(具有验证)。

例如,如果您的训练集中有100个文件,而validation_size是0.2,那么将训练80张图像,并使用20张图像进行验证。因此,选择batch_size使其为80的因数,例如20。20是80的因数以及100的因数。

相关问题