我在编写损失函数时,假设损失是针对双精度浮点数计算的,而不是针对张量。这里是我的函数:
def prediction_loss(a,b): IGNORE=.0025 EPSILON=.00001 if IGNORE > abs(a) and IGNORE > abs(b) and np.sign(a)==np.sign(b): return 0 scale=min(abs(a),abs(b)) distance=abs(a-b) if abs(scale)<EPSILON: scale=max(abs(a),abs(b)) if abs(scale)<EPSILON: scale=1 distance**=2 return min(distance,distance/scale)
当我在model.compile中使用它时,我得到了以下错误:
OperatorNotAllowedInGraphError Traceback (most recent call last)<ipython-input-44-92af3f50a682> in <module>() 9 keras.layers.Dense(1) 10 ])---> 11 model.compile(loss=prediction_loss, optimizer=keras.optimizers.SGD(lr=0.001, momentum=0.9, nesterov=True))~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs) 455 self._self_setattr_tracking = False # pylint: disable=protected-access 456 try:--> 457 result = method(self, *args, **kwargs) 458 finally: 459 self._self_setattr_tracking = previous_value # pylint: disable=protected-access~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in compile(self, optimizer, loss, metrics, loss_weights, sample_weight_mode, weighted_metrics, target_tensors, distribute, **kwargs) 371 372 # Creates the model loss and weighted metrics sub-graphs.--> 373 self._compile_weights_loss_and_weighted_metrics() 374 375 # Functions for train, test and predict will~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs) 455 self._self_setattr_tracking = False # pylint: disable=protected-access 456 try:--> 457 result = method(self, *args, **kwargs) 458 finally: 459 self._self_setattr_tracking = previous_value # pylint: disable=protected-access~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in _compile_weights_loss_and_weighted_metrics(self, sample_weights) 1651 # loss_weight_2 * output_2_loss_fn(...) + 1652 # layer losses.-> 1653 self.total_loss = self._prepare_total_loss(masks) 1654 1655 def _prepare_skip_target_masks(self):~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in _prepare_total_loss(self, masks) 1711 1712 if hasattr(loss_fn, 'reduction'):-> 1713 per_sample_losses = loss_fn.call(y_true, y_pred) 1714 weighted_losses = losses_utils.compute_weighted_loss( 1715 per_sample_losses,~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/losses.py in call(self, y_true, y_pred) 219 y_pred, y_true = tf_losses_util.squeeze_or_expand_dimensions( 220 y_pred, y_true)--> 221 return self.fn(y_true, y_pred, **self._fn_kwargs) 222 223 def get_config(self):<ipython-input-43-4630edd6290a> in prediction_loss(a, b) 14 EPSILON=.00001 15 ---> 16 if IGNORE > abs(a) and IGNORE > abs(b) and np.sign(a)==np.sign(b): 17 return 0 18 ~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py in __bool__(self) 763 `TypeError`. 764 """--> 765 self._disallow_bool_casting() 766 767 def __nonzero__(self):~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py in _disallow_bool_casting(self) 532 else: 533 # Default: V1-style Graph execution.--> 534 self._disallow_in_graph_mode("using a `tf.Tensor` as a Python `bool`") 535 536 def _disallow_iteration(self):~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py in _disallow_in_graph_mode(self, task) 521 raise errors.OperatorNotAllowedInGraphError( 522 "{} is not allowed in Graph execution. Use Eager execution or decorate"--> 523 " this function with @tf.function.".format(task)) 524 525 def _disallow_bool_casting(self):OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
显然,TensorFlow将tf.Tensor作为参数a
和b
传递,而这些参数不能用于逻辑操作。我应该如何修改函数以使其正常工作?我希望忽略大小较小且符号相同的a
和b
。
回答:
是的,tf.tensor不能使用Python的bool类型。请使用keras.backend.switch()来处理条件语句。
请参考以下链接了解使用方法:
该链接列出了所有可以用于适应您方程的函数,如greater, greater_than, equal等。
使用Keras后端功能修改您的语句“if IGNORE > abs(a) and IGNORE > abs(b) and np.sign(a)==np.sign(b):”,这应该可以解决您的问题。