我有一个已经训练好用于识别手写数字的模型。现在我有一个新的数字样本需要进一步训练到这个模型中。有没有办法做到这一点?
import os import cv2 as cv import numpy as np import matplotlib.pyplot as plt import tensorflow as tf ### train model ### mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = x_train/255 x_test = x_test/255 model = tf.keras.models.Sequential() model.add(tf.keras.layers.Flatten(input_shape=(28,28))) model.add(tf.keras.layers.Dense(units=128,activation=tf.nn.relu)) model.add(tf.keras.layers.Dense(units=10,activation=tf.nn.softmax)) model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) model.fit(x_train, y_train, epochs=5) model.save('traineddata.model')
回答:
你可以加载图像,将其转换为灰度图,调整大小为(28,28)
,然后转换为包含一个样本的训练数组,并使用fit()
进行训练
x_example = cv2.imread('image.png')x_example = cv2.cvtColor(x_example, cv2.COLOR_BGR2GRAY)x_example = cv2.resize(x_example, (28, 28))y_example = 1x_data = np.array( [ x_example ] ) # 它必须是形状为(1, 28, 28)的数组,而不是(28, 28)y_data = np.array( [ y_example ] ) # 它必须是形状为(1, 1)的数组,而不是(1,)model.fit(x_data, y_data, epochs=5)
但是在epochs=5
的情况下,预测效果并不好。对于epochs=10
,它能正确预测这张图片,但我没有检查它是否仍然能正确预测其他图片。
也许将图像添加到x_train
、y_train
中并重新训练整个数据集会更好。
x_data = np.append(x_train, [x_example], axis=0)y_data = np.append(y_train, y_example)model.fit(x_data, y_data, epochs=5)
这就像现实生活中的情况——当你学习一个新元素时,你会比记住旧元素更容易记住它。当你重新学习所有元素并包含新元素时,你会刷新所有信息,从而记住所有元素。
我用于测试的最小工作代码。
import warningswarnings.filterwarnings('ignore') # 隐藏/抑制警告import osimport cv2import numpy as npimport matplotlib.pyplot as pltimport tensorflow as tf### train model ###def build(): print('-'*50) print('# Building model ') print('-'*50) model = tf.keras.models.Sequential() model.add(tf.keras.layers.Flatten(input_shape=(28, 28))) model.add(tf.keras.layers.Dense(units=128, activation=tf.nn.relu)) model.add(tf.keras.layers.Dense(units=10, activation=tf.nn.softmax)) model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) return model def train(model, x_train, y_train, epochs=5): print('-'*50) print('# Training model') print('-'*50) model.fit(x_train, y_train, epochs=epochs)def save(model): print('-'*50) print('# Saving model') print('-'*50) model.save('traineddata.model')def load(): print('-'*50) print('# Loading model') print('-'*50) return tf.keras.models.load_model('traineddata.model')def test_one(model, x_example, y_example): print('-'*50) print('# Testing one element') print('-'*50) # 创建包含一个或多个图像的数组 x_data = np.array( [ x_example ] ) y_data = np.array( [ y_example ] ) print('x_data shape:', x_data.shape) print('y_data shape:', y_data.shape) print(' expected:', y_data) # 获取一个或多个预测的列表 y_results = model.predict(x_data) print('predicted:', y_results.argmax(axis=1))def retrain_one(model, x_example, y_example, epochs=5): print('-'*50) print('# Retraining one element') print('-'*50) # 创建包含一个或多个图像的数组 x_data = np.array( [ x_example ] ) y_data = np.array( [ y_example ] ) print('x_data shape:', x_data.shape) print('y_data shape:', y_data.shape) print('y_data:', y_data) model.fit(x_data, y_data, epochs=epochs)def retrain_all(model, x_train, y_train, x_example, y_example, epochs=5): print('-'*50) print('# Retraining all elements') print('-'*50) # 创建包含所有图像的数组 x_data = np.append(x_train, [x_example], axis=0) y_data = np.append(y_train, y_example) print('x_data shape:', x_data.shape) print('y_data shape:', y_data.shape) print('y_data:', y_data) model.fit(x_data, y_data, epochs=epochs)# --- main ---# - load train/test images -print('>>> Loading train/test data ...')(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()x_train = x_train/255x_test = x_test/255# - train + save or load -if not os.path.exists('traineddata.model'): print('>>> Building model ...') model = build() print('>>> Training model ...') train(model, x_train, y_train) print('>>> Saving model ...') save(model)else: print('>>> Loading model ...') model = load()#print(' - test on single example - ')#index = 0#test_one(model, x_train[index], y_train[index])print(' - image - ')x_example = cv2.imread('image.png')x_example = cv2.cvtColor(x_example, cv2.COLOR_BGR2GRAY)x_example = cv2.resize(x_example, (28, 28))y_example = 1print('>>> Predicting without training')test_one(model, x_example, y_example)print('>>> Predicting with training one element (epochs=10)')retrain_one(model, x_example, y_example, epochs=10) # epochs=5 epochs=7test_one(model, x_example, y_example)print('>>> Predicting with retraining all elements')retrain_all(model, x_train, y_train, x_example, y_example)test_one(model, x_example, y_example)#print('>>> Saving new model')#model.save('traineddata.model')