我已经为我的自定义Op实现了一个内核,并将其放入/tensorflow/core/user_ops
目录下,文件名为custom_op.cc
。在Op中,我完成了所有的注册工作,比如REGISTER_OP
和REGISTER_KERNEL_BUILDER
。
然后,我在Python中为这个Op实现了梯度,并将其放置在同一文件夹下,文件名为custom_op_grad.py
。我在这里也完成了所有的注册工作(@ops.RegisterGradient
)。
我创建了BUILD
文件,内容如下:
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")tf_custom_op_library( name = "custom_op.so", srcs = ["custom_op.cc"],)py_library( name = "custom_op_grad", srcs = ["custom_op_grad.py"], srcs_version = "PY2", deps = [ ":custom_op_grad", "//tensorflow:tensorflow_py", ],)
之后,我重新构建了TensorFlow:
pip uninstall tensorflowbazel cleanbazel build -c opt //tensorflow/tools/pip_package:build_pip_packagecp -r bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/__main__/* bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkgpip install /tmp/tensorflow_pkg/tensorflow-0.8.0-py2-none-any.whl
当我尝试通过调用tf.user_ops.custom_op
使用我的Op时,它告诉我该模块中没有这个Op。
可能还有一些额外的步骤我需要做?还是我在BUILD
文件中做错了什么?
回答:
好的,我找到了解决方案。我只是删除了BUILD
文件,我的自定义Op成功构建并可以使用tensorflow.user_ops.custom_op()
在Python中导入。
为了使用梯度,我不得不将其代码直接放入tensorflow/python/user_ops/user_ops.py
中。这不是最优雅的解决方案,但目前可以使用。