Keras模型输出常数值作为预测

时间:2018-01-09 00:12:14

标签: tensorflow keras reinforcement-learning

我目前正在实施强化学习算法,但是keras似乎并不想合作。在训练中,我使用批量数据来提供模型,并且一切都像它应该的那样工作,但是当我想在单个数据点上使用模型进行预测时,它只是反复输出相同的值(使用不同的输入)。这是我的代码:

//初始化网络:

    def initialize_actor(self):

    state_variable_input = Input(shape=(3, ))

    dense = Dense(16, activation="relu", kernel_initializer="lecun_uniform")(state_variable_input)
    batch_norm1 = BatchNormalization()(dense)
    dense2 = Dense(16, activation="relu", kernel_initializer="lecun_uniform")(batch_norm1)
    batch_norm2 = BatchNormalization()(dense2)
    output = Dense(1, activation=lecun_tanh, kernel_initializer=RandomUniform(-3e-3, 3e-3),
                   bias_initializer=RandomUniform(-3e-3, 3e-3))(batch_norm2)

    model = Model(inputs=state_variable_input,
                  outputs=output)

    model.compile(optimizer="adam", loss="mse")

    return model, state_variable_input

//训练:

    def train_actor(self, samples):

    cur_state_var, action, reward, new_state_var = samples
    predicted_actions = self.actor_model.predict([cur_state_var]) # shape of cur_state_var (batch_size, 3), predicts n = batch_size actions as intended
    grads = self.sess.run(self.critic_grads, feed_dict={
        self.critic_var_in: cur_state_var,
        self.critic_action_in: predicted_action})[0]

    summ, _  = self.sess.run([self.merged, self.optimize], feed_dict={
        self.actor_var_in: cur_state_var,
        self.actor_critic_grad: grads
        })
    self.writer.add_summary(summ)

//预测:

  def act(self, cur_state_var):

      if np.random.random() < self.epsilon:
          return env.action_space.sample()

      else:
          action = self.actor_model.predict([cur_state_var], batch_size=1)[0] # shape of cur_state_var = (1,3), the network returns the same value for action at every call
          return action

我希望我的错误可以在我认为与此相关的代码中找到。我怀疑它与训练和预测时间的不同批量大小有关。非常感谢你的时间!

1 个答案:

答案 0 :(得分:0)

我经常遇到这个问题,其中一个原因是样本的功能不均衡。例如,该功能如下:

[1234, 0.001, 0.002, 0.01, 0.01]

正如您所看到的,第一列功能很大,其他功能太小,所以您应该执行标准化功能。

另一个原因可能是激活错误,你应该根据你的具体工作检查激活,特别是最后一层的激活。

我希望这可以帮到你。