神经网络中,训练集、验证集和测试集有什么区别?

我正在使用这个库来实现一个学习代理。

我已经生成了训练数据,但我不太确定验证集和测试集应该是什么。
老师说:

70%应该是训练数据,10%是测试数据,剩下的20%是验证数据。

编辑

我有以下训练代码,但我不知道什么时候该停止训练。

  def train(self, train, validation, N=0.3, M=0.1):    # N: learning rate    # M: momentum factor    accuracy = list()    while(True):        error = 0.0        for p in train:            input, target = p            self.update(input)            error = error + self.backPropagate(target, N, M)        print "validation"        total = 0        for p in validation:            input, target = p            output = self.update(input)            total += sum([abs(target - output) for target, output in zip(target, output)]) #calculates sum of absolute diference between target and output        accuracy.append(total)        print min(accuracy)        print sum(accuracy[-5:])/5        #if i % 100 == 0:        print 'error %-14f' % error        if ? < ?:            break

编辑

我可以使用验证数据获得平均误差0.2,可能经过20次训练迭代后,这应该是80%?

平均误差 = 给定验证数据输入/验证数据大小,验证目标和输出之间的绝对差之和。

1        avg error 0.520395         validation        0.2469378826842        avg error 0.272367           validation        0.2288324208793        avg error 0.249578            validation        0.216253590304        ...22        avg error 0.227753        validation        0.20023924471423        avg error 0.227905            validation        0.199875013416

回答:

训练集和验证集在训练期间使用。

for each epoch    for each training data instance        propagate error through the network        adjust the weights        calculate the accuracy over training data    for each validation data instance        calculate the accuracy over the validation data    if the threshold validation accuracy is met        exit training    else        continue training

完成训练后,针对测试集运行,并验证准确性是否足够。

训练集:此数据集用于调整神经网络上的权重。

验证集:此数据集用于最大限度地减少过拟合。 你没有使用此数据集来调整网络的权重,而只是验证训练数据集上的任何准确性提高是否确实提高了以前未向网络显示的数据集(或者至少网络未在上面训练过)(即验证数据集)的准确性。 如果训练数据集上的准确性提高,但验证数据集上的准确性保持不变或降低,则说明你的神经网络过拟合,你应该停止训练。

测试集:此数据集仅用于测试最终解决方案,以确认网络的实际预测能力。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注