为什么我重新训练的模型精度很差?

我尝试使用相同的数据集(MNIST手写数字数据集)重新训练预训练模型的最后一层,但重新训练的模型精度远低于初始模型。我的初始模型精度约为98%,而重新训练的模型精度在40-80%之间波动,具体取决于运行情况。当我不训练前两层时,得到的结果也相似。

以下是我尝试做的可视化。enter image description here

以及代码:

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataepochs1 = 150epochs2 = 300batch_size = 11000learning_rate1 = 1e-3learning_rate2 = 1e-4# Base modeldef base_model(input, reuse=False):    with tf.variable_scope('base_model', reuse=reuse):        layer1 = tf.contrib.layers.fully_connected(input, 300)        features = tf.contrib.layers.fully_connected(layer1, 300)        return featuresmnist = input_data.read_data_sets('./mnist/', one_hot=True)image = tf.placeholder(tf.float32, [None, 784])label = tf.placeholder(tf.float32, [None, 10])features1 = base_model(image, reuse=False)features2 = base_model(image, reuse=True)# Logits1 trained with the base modelwith tf.variable_scope('logits1', reuse=False):    logits1 = tf.contrib.layers.fully_connected(features1, 10, tf.nn.relu)# Logits2 trained while the base model is frozenwith tf.variable_scope('logits2', reuse=False):    logits2 = tf.contrib.layers.fully_connected(features2, 10, tf.nn.relu)# Var Listsvar_list_partial1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='logits1')var_list_partial2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='base_model')var_list1 = var_list_partial1 + var_list_partial2var_list2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='logits2')# Sanity checkprint("var_list1:", var_list1)print("var_list2:", var_list2)# Cross Entropy Lossesloss1 = tf.nn.softmax_cross_entropy_with_logits(logits=logits1, labels=label)loss2 = tf.nn.softmax_cross_entropy_with_logits(logits=logits2, labels=label)# Train the final logits layertrain1 = tf.train.AdamOptimizer(learning_rate1).minimize(loss1, var_list=var_list1)train2 = tf.train.AdamOptimizer(learning_rate2).minimize(loss2, var_list=var_list2)# Accuracy operationscorrect_prediction1 = tf.equal(tf.argmax(logits1, 1), tf.argmax(label, 1))correct_prediction2 = tf.equal(tf.argmax(logits2, 1), tf.argmax(label, 1))accuracy1 = tf.reduce_mean(tf.cast(correct_prediction1, "float"))accuracy2 = tf.reduce_mean(tf.cast(correct_prediction2, "float"))with tf.Session() as sess:    sess.run(tf.initialize_all_variables())    batches = int(len(mnist.train.images) / batch_size)    # Train base model and logits1    for epoch in range(epochs1):        for batch in range(batches):            batch_xs, batch_ys = mnist.train.next_batch(batch_size)            sess.run(train1, feed_dict={image: batch_xs, label: batch_ys})    # Train logits2 keeping the base model frozen    for epoch in range(epochs2):        for batch in range(batches):            batch_xs, batch_ys = mnist.train.next_batch(batch_size)            sess.run(train2, feed_dict={image: batch_xs, label: batch_ys})    # Print the both models after training    accuracy = sess.run(accuracy1, feed_dict={image: mnist.test.images, label: mnist.test.labels})    print("Initial Model Accuracy After training final model:", accuracy)    accuracy = sess.run(accuracy2, feed_dict={image: mnist.test.images, label: mnist.test.labels})    print("Final Model Accuracy After Training:", accuracy)

提前感谢!


回答:

尝试从”logits1″和”logits2″中移除非线性操作。

我修改了您的代码如下:

# Logits1 trained with the base modelwith tf.variable_scope('logits1', reuse=False):    #logits1 = tf.contrib.layers.fully_connected(features1, 10, tf.nn.relu)    logits1 = tf.contrib.layers.fully_connected(features1, 10, None)# Logits2 trained while the base model is frozenwith tf.variable_scope('logits2', reuse=False):     #logits2 = tf.contrib.layers.fully_connected(features2, 10, tf.nn.relu)     logits2 = tf.contrib.layers.fully_connected(features2, 10, None)

结果变为:

Initial Model Accuracy After training final model: 0.9805Final Model Accuracy After Training: 0.9658

P.S. 对于MNIST分类器来说,300 + 300个神经元太多了,但我想您的重点不是分类MNIST 🙂

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注