为什么Keras的Lambda层会导致Mask_RCNN问题?

我正在使用这个仓库中的Mask_RCNN包:https://github.com/matterport/Mask_RCNN

我尝试使用这个包来训练我自己的数据集,但在开始时就出现了错误。

2020-11-30 12:13:16.577252: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.12020-11-30 12:13:16.587017: E tensorflow/stream_executor/cuda/cuda_driver.cc:314] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected2020-11-30 12:13:16.587075: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (7612ade969e5): /proc/driver/nvidia/version does not exist2020-11-30 12:13:16.587479: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX2 FMATo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.2020-11-30 12:13:16.593569: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] CPU Frequency: 2300000000 Hz2020-11-30 12:13:16.593811: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x1b2aa00 initialized for platform Host (this does not guarantee that XLA will be used). Devices:2020-11-30 12:13:16.593846: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default VersionTraceback (most recent call last):  File "machines.py", line 345, in <module>    model_dir=args.logs)  File "/content/Mask_RCNN/mrcnn/model.py", line 1837, in __init__    self.keras_model = self.build(mode=mode, config=config)  File "/content/Mask_RCNN/mrcnn/model.py", line 1934, in build    anchors = KL.Lambda(lambda x: tf.Variable(anchors), name="anchors")(input_image)  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 926, in __call__    input_list)  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 1117, in _functional_construction_call    outputs = call_fn(cast_inputs, *args, **kwargs)  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/core.py", line 904, in call    self._check_variables(created_variables, tape.watched_variables())  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/core.py", line 931, in _check_variables    raise ValueError(error_str)ValueError: The following Variables were created within a Lambda layer (anchors)but are not tracked by said layer:  <tf.Variable 'anchors/Variable:0' shape=(1, 261888, 4) dtype=float32>The layer cannot safely ensure proper Variable reuse across multiplecalls, and consquently this behavior is disallowed for safety. Lambdalayers are not well suited to stateful computation; instead, writing asubclassed Layer is the recommend way to define layers withVariables.

我查看了导致问题的代码部分(位于仓库中的文件: /mrcnn/model.py 行: 1935):IN[0]: anchors = KL.Lambda(lambda x: tf.Variable(anchors), name="anchors")(input_image)

如果有人知道如何解决这个问题,或者已经解决了这个问题,请提供解决方案。


回答:

根本原因:Tensorflow 2.X中Keras的Lambda层的行为与Tensorflow 1.X有所不同。在Tensorflow 1.X的Keras中,所有的tf.Variable和tf.get_variable都会自动通过变量创建器上下文追踪到layer.weights中,因此它们会自动接收梯度并可训练。这种方法在Tensorflow 2.X中用于自动图编译时会出现问题,自动图编译会将Python代码转换为执行图,因此这种方法被移除,现在Lambda层会检查变量的创建并引发你所看到的错误。简而言之,Tensorflow 2.X中的Lambda层必须是无状态的。如果你想创建变量,在Tensorflow 2.X中正确的方法是子类化层类,并将可训练的权重添加为类成员。

解决方案:有两种选择 –

  1. 改用Tensorflow 1.X。这种错误不会被触发。

  2. 用Keras层的子类替换Lambda层:

class AnchorsLayer(tensorflow.keras.layers.Layer):   def __init__(self, anchors):     super(AnchorLayer, self).__init__()     self.anchors_v = tf.Variable(anchors)      def call(self):     return self.anchors_v# 然后用这个替换Lambda调用:      anchors_layer = AnchorLayers(anchors)   anchors = anchors_layer()

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注