如何在Python中导入TensorFlow的自定义Op?

我已经为我的自定义Op实现了一个内核,并将其放入/tensorflow/core/user_ops目录下,文件名为custom_op.cc。在Op中,我完成了所有的注册工作,比如REGISTER_OPREGISTER_KERNEL_BUILDER

然后,我在Python中为这个Op实现了梯度,并将其放置在同一文件夹下,文件名为custom_op_grad.py。我在这里也完成了所有的注册工作(@ops.RegisterGradient)。

我创建了BUILD文件,内容如下:

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")tf_custom_op_library(        name = "custom_op.so",        srcs = ["custom_op.cc"],)py_library(        name = "custom_op_grad",        srcs = ["custom_op_grad.py"],        srcs_version = "PY2",        deps = [        ":custom_op_grad",        "//tensorflow:tensorflow_py",        ],)

之后,我重新构建了TensorFlow:

pip uninstall tensorflowbazel cleanbazel build -c opt //tensorflow/tools/pip_package:build_pip_packagecp -r bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/__main__/* bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkgpip install /tmp/tensorflow_pkg/tensorflow-0.8.0-py2-none-any.whl

当我尝试通过调用tf.user_ops.custom_op使用我的Op时,它告诉我该模块中没有这个Op。

可能还有一些额外的步骤我需要做?还是我在BUILD文件中做错了什么?


回答:

好的,我找到了解决方案。我只是删除了BUILD文件,我的自定义Op成功构建并可以使用tensorflow.user_ops.custom_op()在Python中导入。

为了使用梯度,我不得不将其代码直接放入tensorflow/python/user_ops/user_ops.py中。这不是最优雅的解决方案,但目前可以使用。

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注