tf.Module无法收集tf.keras.Model中的变量

时间:2019-05-19 14:23:31

标签: tensorflow

如果我有一个继承自tf.Module的类(例如my_module),并且该类内部有tf.keras.Models。我应该使用my_module.variables获取模型中的所有变量吗?

尝试了tf2.0中的一些简单示例。似乎tf.module无法收集tf.keras.models中的变量。

def nn(layers_sizes):
    model = tf.keras.Sequential()
    for i, size in enumerate(layers_sizes):
        model.add(tf.keras.layers.Dense(
            units=size,
            activation=tf.keras.layers.ReLU() if i < len(layers_sizes) - 1 else None,
        ))
    return model

class Actor(tf.Module):

  def __init__(self, name=None):
    super(Actor, self).__init__(name=name)
    self.check = tf.Variable(initial_value=np.array((1,2,3)))
    self.nn = nn([3,16])

  def call(self, inputs):
    self.check.assign_add(np.array((1,2,3)))
    self.x = self.nn(inputs)

if __name__ == "__main__":
    model3 = Actor(name="test")
    input = np.array((1.,2.,3.,4.)).reshape(-1,1)
    model3.call(input)
    print(model3.variables)
    print(model3.nn.variables)

我期望model3.variables包含model3.nn.variables。

1 个答案:

答案 0 :(得分:0)

好的,这是一个错误。已在tf2beta中修复。