我在Linux操作系统上训练了我的模型,以便使用MirroredStrategy()
并在2个GPU上进行训练。训练在第610个epoch时停止了。我想恢复训练,但当我加载模型并评估时,内核就崩溃了。我使用的是Jupyter Notebook。如果我减少训练数据集,代码可以运行,但只能在1个GPU上运行。我加载的模型中是否保存了分布策略,还是我需要再次包含它?
更新
我尝试包含MirroredStrategy()
:
mirrored_strategy = tf.distribute.MirroredStrategy()with mirrored_strategy.scope(): new_model = load_model('\\models\\model_0610.h5', custom_objects = {'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef}, compile = True) new_model.evaluate(train_x, train_y, batch_size = 2,verbose=1)
新错误
包含MirroredStrategy()
时的错误:
ValueError: 'handle' is not available outside the replica context or a 'tf.distribute.Stragety.update()' call.
源代码:
smooth = 1def dice_coef(y_true, y_pred): y_true_f = K.flatten(y_true) y_pred_f = K.flatten(y_pred) intersection = K.sum(y_true_f * y_pred_f) return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)def dice_coef_loss(y_true, y_pred): return (1. - dice_coef(y_true, y_pred))new_model = load_model('\\models\\model_0610.h5', custom_objects = {'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef}, compile = True)new_model.evaluate(train_x, train_y, batch_size = 2,verbose=1)observe_var = 'dice_coef'strategy = 'max' # greater dice_coef is bettermodel_resume_dir = '//models_resume//'model_checkpoint = ModelCheckpoint(model_resume_dir + 'resume_{epoch:04}.h5', monitor=observe_var, mode='auto', save_weights_only=False, save_best_only=False, period = 2)new_model.fit(train_x, train_y, batch_size = 2, epochs = 5000, verbose=1, shuffle = True, validation_split = .15, callbacks = [model_checkpoint])new_model.save(model_resume_dir + 'final_resume.h5')
回答:
new_model.evaluate()
和加载模型时的compile = True
导致了问题。我设置了compile = False
,并从原始脚本中添加了一行编译代码。
mirrored_strategy = tf.distribute.MirroredStrategy()with mirrored_strategy.scope(): new_model = load_model('\\models\\model_0610.h5', custom_objects = {'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef}, compile = False) new_model.compile(optimizer = Adam(learning_rate = 1e-4, loss = dice_coef_loss, metrics = [dice_coef])