我正在尝试使用MNIST数据集训练一个深度神经网络。
BATCH_SIZE = 100train_data = train_data.batch(BATCH_SIZE)validation_data = validation_data.batch(num_validation_samples)test_data = scaled_test_data.batch(num_test_samples)validation_inputs, validation_targets = next(iter(validation_data))input_size = 784output_size = 10hidden_layer_size = 50model = tf.keras.Sequential([ tf.keras.layers.Flatten(input_shape=(28,28,1)), tf.keras.layers.Dense(hidden_layer_size, activation='relu'), tf.keras.layers.Dense(hidden_layer_size, activation='relu'), tf.keras.layers.Dense(output_size, activation='softmax') ])model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])NUM_EPOCHS = 5model.fit(train_data, epochs=NUM_EPOCHS, validation_data=(validation_inputs,validation_targets))
model.fit抛出了以下错误
---------------------------------------------------------------------------ValueError Traceback (most recent call last)<ipython-input-58-c083185dafc6> in <module> 1 NUM_EPOCHS = 5----> 2 model.fit(train_data, epochs=NUM_EPOCHS, validation_data=(validation_inputs,validation_targets))~/anaconda3/envs/py3-TF2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs) 726 max_queue_size=max_queue_size, 727 workers=workers,--> 728 use_multiprocessing=use_multiprocessing) 729 730 def evaluate(self,~/anaconda3/envs/py3-TF2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, **kwargs) 222 validation_data=validation_data, 223 validation_steps=validation_steps,--> 224 distribution_strategy=strategy) 225 226 total_samples = _get_total_number_of_samples(training_data_adapter)~/anaconda3/envs/py3-TF2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_training_inputs(model, x, y, batch_size, epochs, sample_weights, class_weights, steps_per_epoch, validation_split, validation_data, validation_steps, shuffle, distribution_strategy, max_queue_size, workers, use_multiprocessing) 562 class_weights=class_weights, 563 steps=validation_steps,--> 564 distribution_strategy=distribution_strategy) 565 elif validation_steps: 566 raise ValueError('`validation_steps` should not be specified if '~/anaconda3/envs/py3-TF2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_inputs(model, x, y, batch_size, epochs, sample_weights, class_weights, shuffle, steps, distribution_strategy, max_queue_size, workers, use_multiprocessing) 604 max_queue_size=max_queue_size, 605 workers=workers,--> 606 use_multiprocessing=use_multiprocessing) 607 # As a fallback for the data type that does not work with 608 # _standardize_user_data, use the _prepare_model_with_inputs.~/anaconda3/envs/py3-TF2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/data_adapter.py in __init__(self, x, y, sample_weights, batch_size, epochs, steps, shuffle, **kwargs) 252 if not batch_size: 253 raise ValueError(--> 254 "`batch_size` or `steps` is required for `Tensor` or `NumPy`" 255 " input data.") 256 ValueError: `batch_size` or `steps` is required for `Tensor` or `NumPy` input data.
训练和验证数据来自MNIST数据集。数据的一部分被用作训练数据,另一部分被用作测试数据。
我在这里做错了什么?
更新根据Dominques的建议,我已经将model.fit改为
model.fit(train_data, batch_size=128, epochs=NUM_EPOCHS, validation_data=(validation_inputs,validation_targets))
但现在,我得到了以下错误
ValueError: The `batch_size` argument must not be specified for the given input type. Received input: <BatchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int64)>, batch_size: 128
回答:
TensorFlow文档将为您提供更多线索,解释您为何会遇到这个错误。
https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
validation_data: 在每个epoch结束时评估损失和任何模型指标的数据。模型不会在这个数据上进行训练。validation_data将覆盖validation_split。validation_data可以是: • 包含Numpy数组或张量的元组(x_val, y_val) • 包含Numpy数组的元组(x_val, y_val, val_sample_weights) • 数据集
对于前两种情况,必须提供batch_size。对于最后一种情况,必须提供validation_steps。
由于您已经对验证数据集进行了批处理,考虑直接使用它并指定验证步骤如下所示。
BATCH_SIZE = 100train_data = train_data.batch(BATCH_SIZE)validation_data = validation_data.batch(BATCH_SIZE)...model.fit(train_data, epochs=NUM_EPOCHS, validation_data=validation_data,validation_steps=1)