我正在尝试训练一个用于回归问题的 neural Network,并实现了Keras的早期停止功能以避免过拟合。
现在,当我监控“val_loss”时,早期停止功能几乎立即停止程序,结果是一个无用的神经网络,但当我监控“val_mse”时,训练持续进行且没有停止,尽管我可以看到“val_mse”在训练过程中增加,并且我设置了patience = 0。
我似乎误解了早期停止回调,因为我以为它会监控数值并在数值开始再次增加时立即停止训练。
np.random.seed(7)#Define Inputtf_features_64 = np.load("IN_2.npy")tf_labels_64 = np.load("OUT_2.npy")tf_features_32 = tf_features_64.astype(np.float32)tf_labels_32 = tf_labels_64.astype(np.float32)X = tf_features_32Y = tf_labels_32[0:10680, 4:8]#Define CallbacktbCallBack = TensorBoard(log_dir='./Graph{}', histogram_freq=0, write_graph=True, write_images=True) #TensorBoard MonitoringesCallback = EarlyStopping(monitor='val_mse', min_delta=0, patience=0, verbose=1, mode='min')#create Layersvisible = Input(shape=(33,))x = Dropout(.1)(visible)#x = Dense(63)(x)#x = Dropout(.4)(x)output = Dense(4)(x) Optimizer = optimizers.Adam(lr=0.001 #amsgrad = True)model = Model(inputs=visible, outputs = output)model.compile(optimizer=Optimizer, loss=['mse'], metrics=['mae', 'mse'] )model.fit(X, Y, epochs=8000, batch_size=20, shuffle=True, validation_split=0.35, callbacks=[tbCallBack, esCallback])
作为一个例子,我得到以下输出,其中我可以清楚地看到val_mse在各轮次中增加。
20/6942 [..............................] - ETA: 0s - loss: 0.0022 - mean_absolute_error: 0.0373 - mean_squared_error: 0.00221620/6942 [======>.......................] - ETA: 0s - loss: 0.0011 - mean_absolute_error: 0.0251 - mean_squared_error: 0.00113260/6942 [=============>................] - ETA: 0s - loss: 0.0015 - mean_absolute_error: 0.0290 - mean_squared_error: 0.00154900/6942 [====================>.........] - ETA: 0s - loss: 0.0017 - mean_absolute_error: 0.0301 - mean_squared_error: 0.00176500/6942 [===========================>..] - ETA: 0s - loss: 0.0016 - mean_absolute_error: 0.0301 - mean_squared_error: 0.00166942/6942 [==============================] - 0s 37us/step - loss: 0.0016 - mean_absolute_error: 0.0294 - mean_squared_error: 0.0016 - val_loss: 0.0011 - val_mean_absolute_error: 0.0240 - **val_mean_squared_error: 0.0011****Epoch 334/8000** 20/6942 [..............................] - ETA: 0s - loss: 0.0025 - mean_absolute_error: 0.0367 - mean_squared_error: 0.00251620/6942 [======>.......................] - ETA: 0s - loss: 0.0012 - mean_absolute_error: 0.0257 - mean_squared_error: 0.00123260/6942 [=============>................] - ETA: 0s - loss: 0.0014 - mean_absolute_error: 0.0274 - mean_squared_error: 0.00144860/6942 [====================>.........] - ETA: 0s - loss: 0.0014 - mean_absolute_error: 0.0268 - mean_squared_error: 0.00146400/6942 [==========================>...] - ETA: 0s - loss: 0.0012 - mean_absolute_error: 0.0254 - mean_squared_error: 0.00126942/6942 [==============================] - 0s 39us/step - loss: 0.0012 - mean_absolute_error: 0.0249 - mean_squared_error: 0.0012 - val_loss: 0.0032 - val_mean_absolute_error: 0.0393 - **val_mean_squared_error: 0.0032****Epoch 335/8000** 20/6942 [..............................] - ETA: 0s - loss: 9.5175e-04 - mean_absolute_error: 0.0243 - mean_squared_error: 9.5175e-041620/6942 [======>.......................] - ETA: 0s - loss: 0.0017 - mean_absolute_error: 0.0312 - mean_squared_error: 0.0017 3260/6942 [=============>................] - ETA: 0s - loss: 0.0013 - mean_absolute_error: 0.0271 - mean_squared_error: 0.00134860/6942 [====================>.........] - ETA: 0s - loss: 0.0014 - mean_absolute_error: 0.0277 - mean_squared_error: 0.00146460/6942 [==========================>...] - ETA: 0s - loss: 0.0013 - mean_absolute_error: 0.0266 - mean_squared_error: 0.00136942/6942 [==============================] - 0s 38us/step - loss: 0.0013 - mean_absolute_error: 0.0268 - mean_squared_error: 0.0013 - val_loss: 0.0046 - val_mean_absolute_error: 0.0491 - **val_mean_squared_error: 0.0046****Epoch 336/8000**
回答:
你的代码中没有名为val_mse
的指标,你的回调函数监控了错误的指标。存在val_mean_squared_error
,但那与val_mse
不同。
你应该将要监控的指标从val_mse
更改为val_mean_squared_error
,这样应该可以正常工作。