什么是CNTK中以下tensorflow片段的等价物

时间:2017-04-12 22:45:14

标签: python cntk

我正在尝试在CNTK中实现DDPG并遇到以下代码(使用Tensorflow)来创建评论网络:

state_input = tf.placeholder("float",[None,state_dim])
action_input = tf.placeholder("float",[None,action_dim])

W1 = self.variable([state_dim,layer1_size],state_dim)
b1 = self.variable([layer1_size],state_dim)
W2 = self.variable([layer1_size,layer2_size],layer1_size+action_dim)
W2_action = self.variable([action_dim,layer2_size],layer1_size+action_dim)
b2 = self.variable([layer2_size],layer1_size+action_dim)
W3 = tf.Variable(tf.random_uniform([layer2_size,1],-3e-3,3e-3))
b3 = tf.Variable(tf.random_uniform([1],-3e-3,3e-3))

layer1 = tf.nn.relu(tf.matmul(state_input,W1) + b1)
layer2 = tf.nn.relu(tf.matmul(layer1,W2) + tf.matmul(action_input,W2_action) + b2)
q_value_output = tf.identity(tf.matmul(layer2,W3) + b3)

其中self.variable定义为:

def variable(self,shape,f):
    return tf.Variable(tf.random_uniform(shape,-1/math.sqrt(f),1/math.sqrt(f)))

忽略随机初始化(我只想要结构),我尝试了以下内容:

state_in = cntk.input(state_dim, dtype=np.float32)
action_in = cntk.input_variable(action_dim, dtype=np.float32)

W1 = cntk.parameter(shape=(state_dim, layer1_size))
b1 = cntk.parameter(shape=(layer1_size))
W2 = cntk.parameter(shape=(layer1_size, layer2_size))
W2a = cntk.parameter(shape=(action_dim, layer2_size))
b2 = cntk.parameter(shape=(layer2_size))
W3 = cntk.parameter(shape=(layer2_size, 1))
b3 = cntk.parameter(shape=(1))

l1 = cntk.relu(cntk.times(state_in, W1) + b1)
l2 = cntk.relu(cntk.times(l1, W2) + cntk.times(action_in, W2a) + b2)
Q = cntk.times(l2, W3) + b3

但是,layer2的初始化失败,出现以下错误(片段):

  

RuntimeError:'Plus'操作:操作数'输出('Times24_Output_0',   [#,*],[300])'具有动态轴,与动态轴不匹配   其他操作数的'[#]'。

我想知道我做错了什么以及如何准确地重新创建相同的模型。

1 个答案:

答案 0 :(得分:2)

原因是您已将state_in定义为cntk.input,将action_in定义为cntk.input_variable,默认情况下,类型略有不同:默认情况下,cntk.input会创建一个无法绑定到序列数据的变量,而cntk.input_variable默认情况下会创建一个必须绑定到序列数据的变量(NB input_variable已弃用,某些IDE(如PyCharm)将使用删除线显示此项,请使用cntk.input()或cntk.sequence.input())。

错误说加号操作无法添加cntk.times(l1, W2)具有动态轴[#](意味着小批量维度)的cntk.times(action_in, W2a),其中包含动态轴[#,*](表示小批量和序列维度)。

最简单的解决方法是声明 action_in = cntk.input(action_dim, dtype=np.float32) 这使得其余的操作变得严格。