Tensorflow定制操作 - 如何从Tensors读取和写入?

时间:2017-03-23 19:10:21

标签: tensorflow

我正在使用本教程编写自定义Tensorflow操作系统,而且我无法理解如何读取和写入Tensors。

让我说我的OpKernel中有一个Tensor const Tensor& values_tensor = context->input(0);(其中context = OpKernelConstruction*

如果Tensor有形状,比如说[2,10,20],我怎样才能将其编入索引(例如auto x = values_tensor[1, 4, 12]等)?

等价,如果我有

Tensor *output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(
  0,
  {batch_size, value_len - window_size, window_size},
  &output_tensor
));

如何分配到output_tensor,如output_tensor[1, 2, 3] = 11等?

对不起这个愚蠢的问题,但是文档真的让我吵架了,内置操作的Tensorflow内核代码中的示例以某种方式混淆了这一点,我感到非常困惑:)

谢谢你!

1 个答案:

答案 0 :(得分:1)

读取和写入tensorflow::Tensor个对象的最简单方法是使用Eigen tensor方法将它们转换为tensorflow::Tensor::tensor<T, NDIMS>()。请注意,您必须将张量中的(C ++)元素类型指定为模板参数T

例如,要从DT_FLOAT32张量中读取特定值:

const Tensor& values_tensor = context->input(0);
auto x = value_tensor.tensor<float, 3>()(1, 4, 12);

要将特定值写入DT_FLOAT32张量:

Tensor* output_tensor = ...;
output_tensor->tensor<float, 3>()(1, 2, 3) = 11.0;

还有便捷方法可以访问scalarvectormatrix

相关问题