DynamicPartition返回单个输出而不是多个输出

时间:2017-07-15 07:01:55

标签: java tensorflow

这是我的代码,使用DynamicPartition操作构建图形,使用掩码将矢量[1,2,3,4,5,6]分割为两个矢量[1,2,3]和[4,5,6] [1,1,1,0,0,0]:

@Test
public void dynamicPartition2() {
    Graph graph = new Graph();

    Output a = graph.opBuilder("Const", "a")
            .setAttr("dtype", DataType.INT64)
            .setAttr("value", Tensor.create(new long[]{6}, LongBuffer.wrap(new long[] {1, 2, 3, 4, 5, 6})))
            .build().output(0);

    Output partitions = graph.opBuilder("Const", "partitions")
            .setAttr("dtype", DataType.INT32)
            .setAttr("value", Tensor.create(new long[]{6}, IntBuffer.wrap(new int[] {1, 1, 1, 0, 0, 0})))
            .build().output(0);

    graph.opBuilder("DynamicPartition", "result")
            .addInput(a)
            .addInput(partitions)
            .setAttr("num_partitions", 2)
            .build().output(0);

    try (Session s = new Session(graph)) {
        List<Tensor> outputs = s.runner().fetch("result").run();

        try (Tensor output = outputs.get(0)) {
            LongBuffer result = LongBuffer.allocate(3);
            output.writeTo(result);

            assertArrayEquals("Shape", new long[]{3}, output.shape());
            assertArrayEquals("Values", new long[]{4, 5, 6}, result.array());
        }

        //Test will fail here
        try (Tensor output = outputs.get(1)) {
            LongBuffer result = LongBuffer.allocate(3);
            output.writeTo(result);

            assertArrayEquals("Shape", new long[]{3}, output.shape());
            assertArrayEquals("Values", new long[]{1, 2, 3}, result.array());
        }
    }
}

调用s.runner().fetch("result").run()后,返回长度为1的列表,其值为[4,5,6]。我的图表似乎只产生一个输出。

如何获得分裂向量的其余部分?

2 个答案:

答案 0 :(得分:1)

DynamicPartition操作返回多个输出(每个分区一个),但Session.Runner.fetch调用仅请求第0个输出。

Java API缺少Python API所具有的一堆便利糖,但您可以通过显式请求所有输出来执行您想要的操作。换句话说,改变自:

List<Tensor> outputs = s.runner().fetch("result").run();

List<Tensor> outputs = s.runner().fetch("result", 0).fetch("result", 1).run();

希望有所帮助。

答案 1 :(得分:0)

不确定java(我不知道它,也没有调查的环境),但在python中一切正常。例如这个

import tensorflow as tf
a = tf.constant([1, 2, 3, 4, 5, 6])
b = tf.constant([1, 1, 1, 0, 0, 0])
c = tf.dynamic_partition(a, b, 2)
with tf.Session() as sess:
    v1, v2 = sess.run(c)
    print v1
    print v2

返回正确的分区。