TypeError: 参数’weight_decay’有多个值

我正在使用一个结合了余弦衰减和预热学习率调度器的AdamW优化器。我从头开始编写了自定义调度器,并使用了TensorFlow附加组件库提供的AdamW优化器。

class CosineScheduler(tf.keras.optimizers.schedules.LearningRateSchedule):    def __init__(self,                learning_rate_base,                total_steps,                warmup_learning_rate=0.0,                warmup_steps=0):        self.learning_rate_base = learning_rate_base        self.total_steps = total_steps        self.warmup_learning_rate =warmup_learning_rate        self.warmup_steps = warmup_steps        def __call__(self,step):        learning_rate = 0.5 * self.learning_rate_base * (1 + tf.cos(            np.pi *             (tf.cast(step, tf.float32) - self.warmup_steps)/ float(self.total_steps-self.warmup_steps)))        if self.warmup_steps > 0:            slope = (self.learning_rate_base - self.warmup_learning_rate) / self.warmup_steps            warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate            learning_rate = tf.where(step < self.warmup_steps, warmup_rate, learning_rate)        lr = tf.where(step > self.total_steps, 0.0, learning_rate, name='learning_rate')        wandb.log({"lr": lr})        return lrlearning_rate = CosineScheduler(learning_rate_base=0.001,                                 total_steps=23000,                                 warmup_learning_rate=0.0,                                 warmup_steps=1660)
loss_func = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1)optimizer = tfa.optimizers.AdamW(learning_rate,weight_decay=0.1)

我得到了以下错误提示,说weight_decay参数有多个值

---------------------------------------------------------------------------TypeError                                 Traceback (most recent call last)<ipython-input-12-6f9fd0a9c1cb> in <module>      1 loss_func = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1)----> 2 optimizer = tfa.optimizers.AdamW(learning_rate,weight_decay=0.1)/opt/conda/lib/python3.7/site-packages/typeguard/__init__.py in wrapper(*args, **kwargs)    923     924     def wrapper(*args, **kwargs):--> 925         memo = _CallMemo(python_func, _localns, args=args, kwargs=kwargs)    926         check_argument_types(memo)    927         retval = func(*args, **kwargs)/opt/conda/lib/python3.7/site-packages/typeguard/__init__.py in __init__(self, func, frame_locals, args, kwargs, forward_refs_policy)    126     127         if args is not None and kwargs is not None:--> 128             self.arguments = signature.bind(*args, **kwargs).arguments    129         else:    130             assert frame_locals is not None, 'frame must be specified if args or kwargs is None'/opt/conda/lib/python3.7/inspect.py in bind(*args, **kwargs)   3013         if the passed arguments can not be bound.   3014         """-> 3015         return args[0]._bind(args[1:], kwargs)   3016    3017     def bind_partial(*args, **kwargs):/opt/conda/lib/python3.7/inspect.py in _bind(self, args, kwargs, partial)   2954                         raise TypeError(   2955                             'multiple values for argument {arg!r}'.format(-> 2956                                 arg=param.name)) from None   2957    2958                     arguments[param.name] = arg_valTypeError: multiple values for argument 'weight_decay'

是什么导致了这个问题,我该如何解决?


回答:

问题在于weight_decaytfa.optimizers.AdamW的第一个位置参数。在

optimizer = tfa.optimizers.AdamW(learning_rate,weight_decay=0.1)

你同时传递了一个位置参数和一个关键字参数weight_decay。这导致了错误。根据文档learning rate是第二个位置参数(尽管是可选的),不是第一个。

只需写

optimizer = tfa.optimizers.AdamW(0.1, learning_rate)

或者

optimizer = tfa.optimizers.AdamW(weight_decay=0.1, learning_rate=learning_rate)

或者

optimizer = tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=0.1)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注