### TensorFlow中Softmax函数的反向传播

我在尝试了解TensorFlow中tf.nn.softmax()函数的反向传播机制,以便在我的项目中使用。为此,我实现了一个简单的网络来验证TensorFlow网络中Softmax层的导数,与数学推导的导数相类似。

x=tf.placeholder(tf.float32,[5])y_true = tf.placeholder(tf.float32,[5])w=tf.Variable(tf.zeros([5]))logits = tf.multiply(x,w)y = tf.nn.softmax(logits)loss = tf.pow(y - y_true,2)cost = tf.reduce_mean(loss)train_x = [1.0,2.0,3.0,4.0,5.0]train_y = [3.0,4.0,5.0,6.0,7.0]sess = tf.Session()sess.run(tf.initialize_all_variables())# 以下函数用于打印所需的关键层值def get_val():    print('LOSS  : ', sess.run(loss,feed_dict={x:train_x,y_true:train_y}))    print('COST  : ', sess.run(cost,feed_dict={x:train_x,y_true:train_y}))    print('Y     : ', sess.run(y,feed_dict={x:train_x,y_true:train_y}))    print('LOGITS: ', sess.run(logits,feed_dict={x:train_x,y_true:train_y}))    print('W     : ', sess.run(w,feed_dict={x:train_x,y_true:train_y}))# 训练前get_val()# 使用标准的梯度下降优化器来计算权重值optimizer=tf.train.GradientDescentOptimizer(learning_rate=1).minimize(cost)# 只训练一次sess.run(optimizer,feed_dict={x:train_x,y_true:train_y})# 训练后get_val()

在这里,您可以看到我使用get_val()函数获得的值。

**训练前**LOSS  :  [ 7.8399997, 14.44,      23.04,      33.640003,  46.24     ]COST  :  25.040003Y     :  [0.2, 0.2, 0.2, 0.2, 0.2]LOGITS:  [0., 0., 0., 0., 0.]W     :  [0., 0., 0., 0., 0.]**训练后**LOSS  :  [ 8.916067, 15.904554, 24.835724, 35.293324, 37.2296  ]COST  :  24.435854Y     :  [0.01402173, 0.01194853, 0.01645466, 0.0591815,  0.8983936 ]LOGITS:  [-0.16000001, -0.32000008  0.,          1.2800003,   3.9999998 ]W     :  [-0.16000001, -0.16000004,  0.,         0.32000008,  0.79999995]

在这里你可以看到需要验证的权重导数函数...!!!

y_true = train_ym = 5alpha = 1 # 学习率x = train_x

使用这个函数,我将计算第一次训练后的权重值。

这些是我使用此函数得到的权重值。[-0.1792, -0.4864, -0.9216, -1.4848, -2.176]

但这些值与我训练TensorFlow网络后得到的权重值不一致。这些是训练后的权重值。[-0.16000001, -0.16000004, 0., 0.32000008, 0.79999995]

能否有人解释一下为什么我的函数没有按预期给出权重值?


回答:

enter image description here

上面的方程是权重导数的推导方程。可以利用梯度下降优化器相应地进行权重更新。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注