使用Pybrain构建的简单前馈网络未输出预期值

我尝试使用pybrain来输出RGB值。输入层接受一个RGB值数组,所有隐藏层都是线性模型。我原本期望网络能够输出RGB值。然而,网络的输出结果却是一个值数组,这些值远不在0到255的范围内。这些图像大约是25张公牛的.jpg图片。每张图片是一个长度为575280的扁平化数组。我希望网络能够收敛到一张最终看起来像公牛的图像。

import numpy as npfrom pybrain.structure import FeedForwardNetwork, LinearLayer, SigmoidLayer, GaussianLayer, TanhLayerfrom pybrain.structure import FullConnection, BiasUnitimport testabullbull_x = 510bull_y = 398bull_flat = 575280n = FeedForwardNetwork()bias_unit = BiasUnit()in_layer = LinearLayer(bull_flat)hidden_A = LinearLayer(5)hidden_B = LinearLayer(10)out_layer = LinearLayer(bull_flat)n.addInputModule(in_layer)n.addModule(hidden_A)n.addModule(hidden_B)n.addOutputModule(out_layer)n.addModule(bias_unit)in_to_hidden = FullConnection(in_layer, hidden_A)hidden_to_hidden = FullConnection(hidden_A, hidden_B)hidden_to_out = FullConnection(hidden_B, out_layer)bias_to_hidden = FullConnection(hidden_B, out_layer)n.addConnection(in_to_hidden)n.addConnection(hidden_to_hidden)n.addConnection(bias_to_hidden)n.addConnection(hidden_to_out)n.sortModules()bull_img_array = testabull.crop_the_bull_images('../../imgs/thebull/')trainable_array = [] ## an array of flattened imagesfor im in bull_img_array:    flat_im = np.array(im).flatten()    trainable_array.append(flat_im)print nprint n.activate(trainable_array[0])output = Nonefor a in trainable_array:    output = n.activate(a)print output, len(output) 

如果有人有任何建议,我将非常感激。


回答:

首先,这里有两个问题。首先,你需要将输出缩放到0到255之间。你可以在输出后进行一些转换来实现这一点。通过获取最大值和最小值,然后将它们转换到0到255之间。

另一方面,这个网络可能不会学习你希望它学习的内容,你的隐藏层使用的是线性层。这不太有用,因为权重本身就构成了线性变换。你最终得到的将是一个线性函数。 ftp://ftp.sas.com/pub/neural/FAQ2.html#A_act

我建议在隐藏层使用SigmoidLayer,这当然会将值压缩到0到1之间。你可以通过在输出层乘以255来纠正这一点。可以通过一个固定的层或者只是在输出后转换值来实现。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注