创建我自己的资源类型(tf.resource)

时间:2017-06-09 10:41:48

标签: tensorflow

我目前的代码:

// For Eigen::ThreadPoolDevice.
#define EIGEN_USE_THREADS 1

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"

using namespace tensorflow;

REGISTER_OP("ArrayContainerCreate")
.Attr("T: type")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Output("resource: resource")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(Array container, random index access)doc");

REGISTER_OP("ArrayContainerGetSize")
.Input("handle: resource")
.Output("out: int32")
.SetShapeFn(shape_inference::ScalarShape)
;

// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_mgr.h
struct ArrayContainer : public ResourceBase {
  ArrayContainer(const DataType& dtype) : dtype_(dtype) {}

  string DebugString() override { return "ArrayContainer"; }
  int64 MemoryUsed() const override { return 0; };

  mutex mu_;
  const DataType dtype_;

  int32 get_size() {
    mutex_lock l(mu_);
    return (int32) 42;
  }

};

// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_op_kernel.h
class ArrayContainerCreateOp : public ResourceOpKernel<ArrayContainer> {
public:
  explicit ArrayContainerCreateOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
    OP_REQUIRES_OK(context, context->GetAttr("T", &dtype_));
  }

private:
  virtual bool IsCancellable() const { return false; }
  virtual void Cancel() {}

  Status CreateResource(ArrayContainer** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    *ret = new ArrayContainer(dtype_);
    if(*ret == nullptr)
      return errors::ResourceExhausted("Failed to allocate");
    return Status::OK();
  }

  Status VerifyResource(ArrayContainer* ar) override {
    if(ar->dtype_ != dtype_)
      return errors::InvalidArgument("Data type mismatch: expected ", DataTypeString(dtype_),
                                     " but got ", DataTypeString(ar->dtype_), ".");
    return Status::OK();
  }

  DataType dtype_;
};
REGISTER_KERNEL_BUILDER(Name("ArrayContainerCreate").Device(DEVICE_CPU), ArrayContainerCreateOp);

class ArrayContainerGetSizeOp : public OpKernel {
public:
  using OpKernel::OpKernel;

  void Compute(OpKernelContext* context) override {
    ArrayContainer* ar;
    OP_REQUIRES_OK(context, GetResourceFromContext(context, "handle", &ar));
    core::ScopedUnref unref(ar);

    int32 size = ar->get_size();
    Tensor* tensor_size = nullptr;
    OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &tensor_size));
    tensor_size->flat<int32>().setConstant(size);
  }
};
REGISTER_KERNEL_BUILDER(Name("ArrayContainerGetSize").Device(DEVICE_CPU), ArrayContainerGetSizeOp);

我编译了。请注意,我首先遇到了一些undefined symbol: _ZN6google8protobuf8internal26fixed_address_empty_stringE错误,但我通过添加这些额外的编译器标志解决了这个问题:

from google.protobuf.pyext import _message as msg
lib = msg.__file__

extra_compiler_flags = [
    "-Xlinker", "-rpath", "-Xlinker", os.path.dirname(lib),
    "-L", os.path.dirname(lib), "-l", ":" + os.path.basename(lib)]

我读到了here

然后我通过tf.load_op_library将其作为模块加载。

然后,我有这个Python代码:

handle = mod.array_container_create(T=tf.int32)
size = mod.array_container_get_size(handle=handle)

当我尝试评估size时,我收到错误:

InvalidArgumentError (see above for traceback): Trying to access resource located in device 14ArrayContainer from device /job:localhost/replica:0/task:0/cpu:0
         [[Node: ArrayContainerGetSize = ArrayContainerGetSize[_device="/job:localhost/replica:0/task:0/cpu:0"](array_container)]]

设备名称(14ArrayContainer)似乎搞砸了。这是为什么?代码有什么问题?

对于更多测试,我在ArrayContainerCreateOp

中添加了此附加代码
    ResourceHandle rhandle = MakeResourceHandle<ArrayContainer>(context, cinfo_.container(), cinfo_.name());
    printf("created. device: %s\n", rhandle.device().c_str());
    printf("container: %s\n", rhandle.container().c_str());
    printf("name: %s\n", rhandle.name().c_str());
    printf("actual device: %s\n", context->device()->attributes().name().c_str());
    printf("actual name: %s\n", cinfo_.name().c_str());

这给了我输出:

created. device: 14ArrayContainer
container: 14ArrayContainer
name: 14ArrayContainer
actual device: /job:localhost/replica:0/task:0/cpu:0
actual name: _2_array_container

很明显,存在一些问题。

这看起来像是与protobuf搞砸了?也许我正在连接错误的lib?但我还没有找到要链接的lib。

(我还发布了有关此here的问题。)

1 个答案:

答案 0 :(得分:0)

那是upstream bug 10950,应在TensorFlow 1.2.2中修复。