如何从这段Pytorch代码中找到Keras中使用的等效’batch_size’?

我正在使用来自github的Pytorch代码

我正在尝试将其移植到Keras。在Keras中,训练神经网络使用model.fit,它有一个batch_size参数。我试图设置这个参数,但无法从上面链接的Pytorch脚本中确定它。

在脚本中,第4个框中有一个名为sliding_window的函数,其中有一个名为step的参数。我不确定这是否等同于batch_size的设置。

另外,我正在研究如何从Pytorch代码的第11个框中设置学习率调度器:

base_lr = 0.01params_dict = dict(net.named_parameters())params = []for key, value in params_dict.items():    if '_D' in key:        # 解码器权重以标称学习率进行训练        params += [{'params':[value],'lr': base_lr}]    else:        # 编码器权重以 lr / 2 的学习率进行训练(我们使用VGG-16权重作为初始化)        params += [{'params':[value],'lr': base_lr / 2}]optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0005)# 我们定义调度器scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [25, 35, 45], gamma=0.1)

我一直使用Keras的默认学习率。关于如何将这个调度器转换成Keras代码的任何见解也将不胜感激。


回答:

下面是一个在Keras中编写学习率调度器的示例:

from keras.callbacks import Callbackfrom keras import backend as Kclass LRSchedule(Callback):    def __init__(self, schedule):        super(LRSchedule, self).__init__()        self.schedule = schedule    def on_train_begin(self, logs = {}):        self.epoch_counter = 0        self.schedule_index = 0    def on_epoch_end(self, epoch, logs = {}):        self.epoch_counter += 1        if len(self.schedule)  > self.schedule_index + 1:            next_epoch = self.schedule[self.schedule_index + 1]            if self.epoch_counter == next_epoch:                K.set_value(self.model.optimizer.lr, self.model.optimizer.lr / 2.0)                self.schedule_index += 1

在您的代码中,您将以这种方式调用回调:

lr_scheduler = LRSchedule([25, 35, 45])model.fit_generator(..., callbacks = [lr_scheduler])

请注意,这个调度器在达到某个纪元时,会将学习率设置为较低的值(通过除以2)。将其修改为更复杂的调度策略是非常简单的。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注