我的Python代码在处理MNIST数据集时显示出巨大的错误值,我遗漏了什么?

下面是我的代码,这是我自学机器学习和Python的第一个认真尝试。我尝试从头开始编写代码,没有使用像NumPy这样的库。对于单一输入和输出,代码能够正常工作,但当处理实际数据集(在这种情况下是784个输入到10个输出)时,错误值会返回无穷大。我检查了所有我认为可能的问题,但没有成功。

这个代码可能是一个临时解决方案。我开始研究Trask的GitHub代码,他的多输入/输出代码是可行的,但当我修改它以使用MNIST数据集时,一切都变得混乱。能否有人帮我看一下,告诉我我遗漏了什么,以及问题出在哪里?我非常感激。

for i in range (x_train.shape[0]):    x_labels[i,x_label[i]]=1def w_sum(a,b):    assert(len(a) == len(b))    output = 0    for i in range(len(a)):        output += (a[i] * b[i])    return outputdef neural_network(input1, weights):    pred = vect_mat_mul(input1,weights)    return preddef vect_mat_mul(vect,matrix):    output = np.zeros(10)    for i in range(10):        output[i] = w_sum(vect[0],matrix[:,i])    return outputdef outer_prod(a, b):    out = np.zeros((len(a), len(b)))    for i in range(len(a)):        for j in range(len(b)):            out[i][j] = a[i] * b[j]    return out(x_train,x_label),(t_test,t_label)=ks.datasets.mnist.load_data()x_train=x_train.reshape(x_train.shape[0],28*28)weights=np.random.rand(784,10)x_labels=np.zeros((x_train.shape[0],10))alpha = 0.00001for i in range(x_train.shape[0]):    error = np.zeros(10)    delta = np.zeros(10)    for iter in range(50):        x_train_to_NN = np.array([x_train[i]])        pred = neural_network(x_train_to_NN, weights)        for j in range(10):            error[j] = (pred[j] - x_labels[i, j]) ** 2            delta[j] = pred[j] - x_labels[i, j]        weight_deltas = outer_prod(x_train[i], delta) #calculate the gradient        for idx in range(784):            for jdx in range(10):                weights[idx][jdx] -= alpha * weight_deltas[idx][jdx] #update weight matrixprint('key=', i, '\n Error=', error, '\n Delta=', delta, '\n Prediction=', pred)

回答:

我最终找到了答案,那就是”梯度裁剪“。问题在于计算梯度时,需要对其进行限制(归一化)以避免梯度爆炸。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注