使用Keras Tuner调整预-existing模型

我在研究Keras Tuner作为超参数优化的一种方法,但我看到的所有例子都是定义一个全新的模型。例如,在Keras Tuner的Hello World中:

def build_model(hp):    model = keras.Sequential()    model.add(layers.Flatten(input_shape=(28, 28)))    for i in range(hp.Int('num_layers', 2, 20)):        model.add(layers.Dense(units=hp.Int('units_' + str(i), 32, 512, 32),                               activation='relu'))    model.add(layers.Dense(10, activation='softmax'))    model.compile(        optimizer=keras.optimizers.Adam(            hp.Choice('learning_rate', [1e-2, 1e-3, 1e-4])),        loss='sparse_categorical_crossentropy',        metrics=['accuracy'])    return model

我已经有一个希望调整的模型,但这是否意味着我必须像上面那样重写模型,并将超参数嵌入到模型主体中,还是可以简单地在模型顶部传递超参数?例如,像这样:

def build_model(hp):    model = MyExistingModel(        batch_size=hp['batch_size'],        seq_len=hp['seq_len'],        rnn_hidden_units=hp['hidden_units'],        rnn_type='gru',        num_rnn_layers=hp['num_rnn_layers']    )    optimizer = optimizer_factory['adam'](        learning_rate=hp['learning_rate'],        momentum=0.9,    )    model.compile(        optimizer=optimizer,        loss='sparse_categorical_crossentropy',        metrics=['sparse_categorical_accuracy'],    )    return model

据我所见,上面的代码似乎是可行的。模型初始化参数通过HyperParameters实例传递给内部的TF层,并从那里访问…虽然我不太确定如何传递…我想可以通过预定义一个HyperParameters对象并将其传递给调谐器,这样它就会传递给build_model

hp = HyperParameters()hp.Choice('learning_rate', [1e-1, 1e-3])tuner = RandomSearch(    build_model,    max_trials=5,    hyperparameters=hp,    tune_new_entries=False,    objective='val_accuracy')

我的模型内部有两个RNN(LSTM或GRU)和一个MLP。但我还没有找到一个Keras Tuner的build_model例子,它可以像这样接收一个现有模型并简单地传递超参数。模型相当复杂,我希望避免重新定义它(同时避免代码重复)。


回答:

事实上,这是可能的,正如这个GitHub问题所明确指出的…

然而,与其通过hyperparameters参数将hp对象传递给调谐器,我更建议按照这里提出的方法覆盖调谐器的run_trial方法。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注