当我在新的自定义c ++ op中调用标准操作(例如MatMul)时,Bazel会返回错误

时间:2018-05-25 17:09:21

标签: tensorflow

我在tensorflow中实现了一个新的自定义c ++ op。在相应的操作内核的Compute函数中,调用了一些标准的ops(例如MatMul)。 主要的源代码是:

REGISTER_OP("NewOp")
.Input("input: int32")
.Output("output: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
  c->set_output(0, c->input(0));
  return Status::OK();
});

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor.h"

using namespace tensorflow;
using namespace tensorflow::ops;

class MyNewOp : public OpKernel {
public:
    explicit MyNewOp(OpKernelConstruction* context) : OpKernel(context) {}
    void Compute(OpKernelContext* context) override {
        // Grab the input tensor
        ……
        // Create an output tensor
        ……
        Scope root = Scope::NewRootScope();
        auto A = Const(root, { {35.f, 22.f}, {-10.f, 0.f} });
        auto b = Const(root, { {30.f, 55.f} });
        auto v = MatMul(root.WithOpName("v"), A, b, MatMul::TransposeB(true));
        std::vector<Tensor> results;
        ClientSession session(root);
        TF_CHECK_OK(session.Run({v}, &results));
        // Set the output tensor according to the results of MatMul
        ……
    }
};
REGISTER_KERNEL_BUILDER(Name("NewOp").Device(DEVICE_CPU), MyNewOp);

相应的Bazel BUILD文件是:

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
tf_custom_op_library(
    name = "MyNewOp.so",
    srcs = ["mynewop.cc"],
    deps = [
    "//tensorflow/cc:cc_ops",
    "//tensorflow/cc:client_session",
    "//tensorflow/core:tensorflow",
    ],
)

当我构建上述目标时,Bazel会返回错误:

tensorflow/cc:cc_ops cannot depend on tensorflow/core:framework

我该如何解决这个问题?我想知道我是否可以在新的自定义c ++ op中调用ternsorflow预定义操作?非常感谢你!

1 个答案:

答案 0 :(得分:0)

您遇到的问题是因为您的自定义操作取决于此rule明确禁止的tensorflow/core:framework

disallowed_deps=[
      clean_dep("//tensorflow/core:framework"),
      clean_dep("//tensorflow/core:lib")
]

最好的方法是找到另一种解决方案。

如果您确实希望拥有此依赖关系,那么 hacky方式会重新实现tf_custom_op_library规则而不会出现禁用依赖关系。

这可以通过以下方式完成:

load("//tensorflow:tensorflow.bzl", "tf_copts")
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
load("//tensorflow:tensorflow.bzl", "clean_dep")

tf_cc_shared_object(
    name = "MyNewOp.so",
    srcs = ["mynewop.cc"],
    copts = tf_copts(is_external=True),
    linkstatic = 1,
    visibility = ["//visibility:public"],
    deps = [
    "//tensorflow/core:framework",
    "//tensorflow/cc:cc_ops",
    "//tensorflow/cc:client_session",
    "//tensorflow/core:tensorflow",],
    linkopts= select({
              "//conditions:default": [
                  "-lm",
              ],
              clean_dep("//tensorflow:windows"): [],
              clean_dep("//tensorflow:windows_msvc"): [],
              clean_dep("//tensorflow:darwin"): [],
          }),
)

工作正常:

Target //tensorflow/user_ops:MyNewOp.so up-to-date:
  bazel-bin/tensorflow/user_ops/MyNewOp.so
INFO: Elapsed time: 46.399s, Critical Path: 19.71s
INFO: 397 processes, local.
INFO: Build completed successfully, 400 total actions
相关问题