如何保存神经网络的权重

我在将训练好的神经网络的权重保存到文本文件时遇到了问题。以下是我的代码

def nNetwork(trainingData,filename):    lamda = 1    input_layer = 1200    output_layer = 10    hidden_layer = 25    X=trainingData[0]    y=trainingData[1]    theta1 = randInitializeWeights(1200,25)    theta2 = randInitializeWeights(25,10)    m,n = np.shape(X)    yk = recodeLabel(y,output_layer)    theta = np.r_[theta1.T.flatten(), theta2.T.flatten()]    X_bias = np.r_[np.ones((1,X.shape[0])), X.T]    #共轭梯度算法    result = scipy.optimize.fmin_cg(computeCost,fprime=computeGradient,x0=theta,args=(input_layer,hidden_layer,output_layer,X,y,lamda,yk,X_bias),maxiter=100,disp=True,full_output=True )    print result[1]  #最小值    theta1,theta2 = paramUnroll(result[0],input_layer,hidden_layer,output_layer)    counter = 0    for i in range(m):        prediction = predict(X[i],theta1,theta2)        actual = y[i]        if(prediction == actual):            counter+=1    print  str(counter *100/m) + '% 准确率'    data = {"Theta1":[theta1],            "Theta2":[theta2]}    op=open(filename,'w')    json.dump(data,op)    op.close()

def paramUnroll(params,input_layer,hidden_layer,labels):    theta1_elems = (input_layer+1)*hidden_layer    theta1_size = (input_layer+1,hidden_layer)    theta2_size = (hidden_layer+1,labels)    theta1 = params[:theta1_elems].T.reshape(theta1_size).T    theta2 = params[theta1_elems:].T.reshape(theta2_size).T    return theta1, theta2

我遇到了以下错误 raise TypeError(repr(o) + ” is not JSON serializable”)

请提供解决方案或其他保存权重的方法,以便我可以在其他代码中轻松加载它们。


回答:

保存numpy数组到纯文本文件的最简单方法是执行 numpy.savetxt(并使用 numpy.loadtxt 加载)。然而,如果你想使用JSON格式保存,可以使用 StringIO 实例来写入文件:

with StringIO as theta1IO:    numpy.savetxt(theta1IO, theta1)    data = {"theta1": theta1IO.getvalue() }    # 像往常一样以JSON格式写入

你也可以对其他参数使用相同的方法。

要检索数据,你可以这样做:

# 从JSON读取数据with StringIO as theta1IO:    theta1IO.write(data['theta1'])    theta1 = numpy.loadtxt(theta1IO)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注