权重和偏置值发生极大变化,最终达到无穷大和非数值-Tensorflow

我是一个Tensorflow的新手,今天在使用它时,我的代码在输入大数字时会产生错误。当我输入较小的数字时,这个问题不会发生。以下是它打印的内容… 权重应该为300,偏置值为13000。我输入这些数字只是为了确保错误不是由我的文件引起的,因为我从csv文件中读取输入。两种情况下都会产生这个错误!谢谢,任何帮助都将非常棒!

代码

import tensorflow as tfimport osimport numpy as npdatapoint_size = 20steps = 10000# = 300#actual_b = 13000learn_rate = 0.0001w1=tf.Variable(([1.0]),tf.float32)b=tf.Variable(([1.0]),tf.float32)x1= tf.placeholder(tf.float32)y_=tf.placeholder(tf.float32)init= tf.global_variables_initializer()sess=tf.Session()sess.run(init)i=0with tf.Session() as sess:    sess.run( tf.global_variables_initializer())x=1sess = tf.Session()sess.run(init)y_pred = x1 * w1 + bsquared_deltas = tf.square(y_ - y_pred)cost = tf.reduce_sum(squared_deltas)train_step = tf.train.GradientDescentOptimizer(learn_rate).minimize(cost)for i in range(100000):    if x>=20:        x=0    sess.run(train_step, {x1:[1000,2000,3000,4000],y_:[313000,613000,913000,1213000]})    print("After %d iteration:" % i)    print("W: %f" % sess.run(w1))    print("b: %f" % sess.run(b))    x=x+1

错误

After 0 iteration:W: 1793999.000000b: 598.999207After 1 iteration:W: -10760400896.000000b: -3586799.250000After 2 iteration:W: 64551647182848.000000b: 21517219840.000000After 3 iteration:W: -387245349602852864.000000b: -129081785974784.000000After 4 iteration:W: 2323085181124570251264.000000b: 774361712747872256.000000After 5 iteration:W: -13936188870901762508193792.000000b: -4645396869013139619840.000000After 6 iteration:W: 83603198942318920401772609536.000000b: 27867733773982968636768256.000000After 7 iteration:W: -501535661866597157445008806117376.000000b: -167178554811854841375365267456.000000After 8 iteration:W: infb: 1002904201827890788552906230464512.000000After 9 iteration:W: nanb: -infAfter 10 iteration:W: nanb: nanAfter 11 iteration:W: nanb: nan

回答:

问题仅仅是你的学习率太大了。如果你看到你的变量在发散,那么尝试降低学习率是一个好方法。

实际上,在你的情况下,最好的做法是归一化输入,使它们具有较小的范围。然后你可以使用更高的学习率,它会更快地收敛。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注