tensorflow摘要必须输入占位符张量的值

时间:2019-07-19 18:44:18

标签: tensorflow

我正在尝试将tf.summary用于我的代码。为了进行训练,我从占位符获取数据,并在占位符上执行了一些计算后,在损失输出上创建了标量汇总,但是却收到错误消息“您必须为占位符张量输入一个值”。

这有点可笑,因为我什至没有尝试从会话中运行摘要节点,只是在图中创建了一个摘要节点。

当我没有汇总节点时,代码运行良好,但是即使我只包含一个汇总节点,也会因上述错误而中断。

我尝试了Internet上所有可用的解决方案,但现在我真的很沮丧,因为与其他人一起使用的所有解决方案都不适用于我。请帮助:)

在其他答案中,在创建自己的图形之前,一开始就包含了“ tf.reset_default_graph()”。 我还明确地在tf.summary.merge中提到了摘要节点,而不是使用tf.summary.merge_all()。 我正在图g下进行所有操作,其中g = tf.Graph()

这是我在损失输出上创建汇总节点的功能。

<span></span>

这是计算我的损失的函数。

      class Model(object):
      def __init__(self, is_training, config, name):
          self.name = name
          self.config = config
          self.output_list = []
          self.index_dict = {}
          self.loss = 0
          self.lr_list = lr_list
          self.loss_list = loss_list

          H = W = self.config.allowed_lf_dim*self.config.patch_size

          self.copy_lf_batch_holder = tf.placeholder(tf.float32, shape = (None,H,W,config.channels), name = name + 'data_placeholder')

          self.struct(self.copy_lf_batch_holder, is_training)

          self.loss_summary = tf.summary.scalar(name+"_loss", self.loss)


          if(not is_training):
              return

          self.tvars = tf.trainable_variables()
          self.grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, self.tvars, name = 'train_gradients'),
                                    self.config.gradient_clip, name = "clip_gradients_train")



          self.global_step=tf.train.get_or_create_global_step()      
          self.learning_rate = self.config.learning_rate
          self.optimizer = tf.train.GradientDescentOptimizer(self.learning_rate, name = "gradient_descent_train")

          update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

          with tf.control_dependencies(update_ops):
          self.train_op = self.optimizer.apply_gradients(zip(self.grads, self.tvars),
              global_step = self.global_step, name = "apply_gradients_train")

          print ('exporting_graph...')    

          print ('graph_exported')

这是合并摘要操作:

  def struct(self,copy_lf_batch_holder, is_training):

      cell = tf.nn.rnn_cell.LSTMCell(self.config.hidden_size, 
      forget_bias=0.0, state_is_tuple=True,reuse = not is_training)

      self.initial_state = cell.zero_state(self.config.batch_size,  
      tf.float32)    

      tf_resize_tensor = tf.constant([self.config.patch_size, 
       self.config.patch_size], name= self.name + 
       'interpolate_to_size')

      for col in range(1,self.config.lf_dim):

          state = self.initial_state

          for row in range(1,self.config.lf_dim):

              seq_input, key_target = self.prepare_input(row, col, 
              copy_lf_batch_holder, is_training)

              output, state_ = self.layers(row, col, seq_input, 
              cell, state, tf_resize_tensor, is_training)
              state = state_

              output = tf.identity(output, 
               name="output_from_layers")

              self.output_list.append(output)
              index = len(self.output_list) - 1
              self.index_dict[key_target] = index

              target = 
              copy_lf_batch_holder[:,row*self.config.patch_size: 
              (row*self.config.patch_size)+self.config.patch_size,
                               col*self.config.patch_size:              

             (col*self.config.patch_size)+self.config.patch_size,:]

              target = tf.identity(target, name="target_original")
              self.loss = self.loss + tf.reduce_sum(tf.pow((target- 
              output),2))

              if(self.config.train_single_image):
                  break

          if(self.config.train_single_image):
              break

      self.output_stack = tf.stack(self.output_list)

在这里,我正在会话中运行摘要节点,但是即使我从此操作中注释了摘要节点,该错误仍然存​​在。

    with g.as_default():

        merged = tf.summary.merge([m_train.loss_summary, 
         m_test.loss_summary])                   
        train_writer = tf.summary.FileWriter(train_config.save_path 
         + "tensorboard_train/")
        test_writer = tf.summary.FileWriter(train_config.save_path 
        + "tensorboard_test/")

我收到的错误消息:

            fetch_train = { "loss_training" : m_train.loss,
                     "train_op" : m_train.train_op,
                     "train_merged_summary": merged
                    }

            train_results = sess.run(fetch_train,feed_dict = {m_train.copy_lf_batch_holder : train_batch})

我的占位符的批次大小为:[无,64,64,3]。

0 个答案:

没有答案