如何从这段Pytorch代码中找到Keras中使用的等效’batch_size’?

我正在使用来自github的Pytorch代码

我正在尝试将其移植到Keras。在Keras中,训练神经网络使用model.fit,它有一个batch_size参数。我试图设置这个参数,但无法从上面链接的Pytorch脚本中确定它。

在脚本中,第4个框中有一个名为sliding_window的函数,其中有一个名为step的参数。我不确定这是否等同于batch_size的设置。

另外,我正在研究如何从Pytorch代码的第11个框中设置学习率调度器:

base_lr = 0.01params_dict = dict(net.named_parameters())params = []for key, value in params_dict.items():    if '_D' in key:        # 解码器权重以标称学习率进行训练        params += [{'params':[value],'lr': base_lr}]    else:        # 编码器权重以 lr / 2 的学习率进行训练(我们使用VGG-16权重作为初始化)        params += [{'params':[value],'lr': base_lr / 2}]optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0005)# 我们定义调度器scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [25, 35, 45], gamma=0.1)

我一直使用Keras的默认学习率。关于如何将这个调度器转换成Keras代码的任何见解也将不胜感激。


回答:

下面是一个在Keras中编写学习率调度器的示例:

from keras.callbacks import Callbackfrom keras import backend as Kclass LRSchedule(Callback):    def __init__(self, schedule):        super(LRSchedule, self).__init__()        self.schedule = schedule    def on_train_begin(self, logs = {}):        self.epoch_counter = 0        self.schedule_index = 0    def on_epoch_end(self, epoch, logs = {}):        self.epoch_counter += 1        if len(self.schedule)  > self.schedule_index + 1:            next_epoch = self.schedule[self.schedule_index + 1]            if self.epoch_counter == next_epoch:                K.set_value(self.model.optimizer.lr, self.model.optimizer.lr / 2.0)                self.schedule_index += 1

在您的代码中,您将以这种方式调用回调:

lr_scheduler = LRSchedule([25, 35, 45])model.fit_generator(..., callbacks = [lr_scheduler])

请注意,这个调度器在达到某个纪元时,会将学习率设置为较低的值(通过除以2)。将其修改为更复杂的调度策略是非常简单的。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注