我在Google Colab上编写了一个Resnet50模型。在训练并保存模型后,加载模型而不重启运行时可以得到相同的结果。但是,当我重启Google Colab的运行时并运行xtrain, ytest, x_val, y_val
,然后再次加载模型时,我得到了不同的结果。
这是我的代码:
#超参数和回调
batch_size = 128
num_epochs = 120
input_shape = (48, 48, 1)
num_classes = 7
#编译模型
from keras.optimizers import Adam, SGD
model = ResNet50(input_shape = (48, 48, 1), classes = 7)
optimizer = SGD(learning_rate=0.0005)
model.compile(optimizer= optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()
history = model.fit(
data_generator.flow(xtrain, ytrain,),
steps_per_epoch=len(xtrain) / batch_size,
epochs=num_epochs,
verbose=1,
validation_data= (x_val,y_val))
import matplotlib.pyplot as plt
model.save('Fix_Model_resnet50editSGD5st.h5')
#绘制图表
accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
num_epochs = range(len(accuracy))
plt.plot(num_epochs, accuracy, 'r', label='训练准确率')
plt.plot(num_epochs, val_accuracy, 'b', label='验证准确率')
plt.title('训练和验证准确率')
plt.ylabel('准确率')
plt.xlabel('轮次')
plt.legend()
plt.figure()
plt.plot(num_epochs, loss, 'r', label='训练损失')
plt.plot(num_epochs, val_loss, 'b', label='验证损失')
plt.title('训练和验证损失')
plt.ylabel('损失')
plt.xlabel('轮次')
plt.legend()
plt.show()
#加载模型
from keras.models import load_model
model_load = load_model('Fix_Model_resnet50editSGD5st.h5')
model_load.summary()
testdatamodel = model_load.evaluate(xtest, ytest)
print("测试损失 " + str(testdatamodel[0]))
print("测试准确率: " + str(testdatamodel[1]))
traindata = model_load.evaluate(xtrain, ytrain)
print("训练损失 " + str(traindata[0]))
print("训练准确率: " + str(traindata[1]))
valdata = model_load.evaluate(x_val, y_val)
print("验证损失 " + str(valdata[0]))
print("验证准确率: " + str(valdata[1]))
在训练并保存模型后,运行加载模型而不重启Google Colab的运行时:如你所见,
测试损失: 0.9411 – 准确率: 0.6514
训练损失: 0.7796 – 准确率: 0.7091
在重启运行时后再次运行加载模型:
测试损失: 0.7928 – 准确率: 0.6999
训练损失: 0.8189 – 准确率: 0.6965
回答:
你需要设置随机种子,以便在每次迭代中,无论是在同一个会话中还是在重启后,都能得到相同的结果。
tf.random.set_seed(seed)
查看 https://www.tensorflow.org/api_docs/python/tf/random/set_seed