神经网络在第一次迭代后表现异常

我刚开始接触神经网络领域,正在尝试我的第一个实际工作样本,使用的是手写数字MNIST数据集。我编写了一个代码,按我的理解应该能正常工作(至少在一定程度上),但我无法弄清楚为什么它在读取第一个训练样本后就卡住了。我的代码如下:

from keras.datasets import mnistimport numpy as npdef relu(x):    return (x > 0) * xdef relu_deriv(x):    return x > 0(x_train, y_train), (x_test, y_test) = mnist.load_data();images = x_train[0:1000].reshape(1000, 28*28)labels = y_train[0:1000]test_images = x_test[0:1000].reshape(1000, 28*28)test_labels = y_test[0:1000]# converting the labels to a matrixone_hot_labels = np.zeros((len(labels),10))for i,j in enumerate(labels):    one_hot_labels[i][j] = 1labels = one_hot_labelsalpha = 0.005hidden_size = 5 # size of the hidden layer# initial weight matrixesw1 = .2 * np.random.random(size=[784, hidden_size]) - .1w2 = .2 * np.random.random(size=[hidden_size, 10]) - .1for iteration in range(1000):    error = 0    for i in range(len(images)):        layer_0 = images[i:i+1]        layer_1 = relu(np.dot(layer_0, w1))        layer_2 = np.dot(layer_1, w2)        delta_2 = (labels[i:i+1] - layer_2)        error += np.sum((delta_2) ** 2)        delta_1 = delta_2.dot(w2.T) * relu_deriv(layer_1)        w2 += alpha * np.dot(layer_1.T, delta_2)        w1 += alpha * np.dot(layer_0.T, delta_1)    print("error: {0}".format(error))

发生的情况是,在第一次迭代中,显然存在一个很大的误差,并且在第一次迭代后误差被校正为1000,但之后无论进行多少次迭代,误差都一直停留在那个值上,无法继续改变。


回答:

您没有对图像数据进行归一化处理。图像数据的值范围从0到255。由于这些较大的值,权重的更新变得很大,导致第一次迭代后权重变得非常大。您可以按照以下方式对图像数据进行归一化处理。

images = x_train[0:1000].reshape(1000, 28*28)images = images / 255labels = y_train[0:1000]test_images = x_test[0:1000].reshape(1000, 28*28)test_images = test_images / 255test_labels = y_test[0:1000]

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注