Sklearn NN clf.predict() 总是返回全0或全1?

我正在使用sklearn的MLPClassifier()函数制作一个分类器,并有以下代码:

def train_NN(self, trained_data_list, test_size=0.25):    X = np.array([[7, 4], [3,7], [8,5], [2,3]])    y = np.array([0, 1, 1, 0])    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)    #奇怪的是,打印'X_train'会显示以下内容:[[2 3] [3 7] [7 4]],似乎丢失了我的点[8 5]    #打印'y_train'会显示以下内容:[0 1 0]    self.clf = MLPClassifier(solver='lbfgs', alpha=1e-10, hidden_layer_sizes=(2, 2), activation = 'tanh', random_state=5, max_iter=10000, learning_rate_init = 0.1)        self.clf.fit(X_train, y_train)    print(self.clf.predict(np.atleast_2d([7,4])))  #返回 0     print(self.clf.predict(np.atleast_2d([3,7])))  #返回 0     print(self.clf.predict(np.atleast_2d([8,5])))  #返回 0     print(self.clf.predict(np.atleast_2d([2,3])))  #返回 0 

我的理解是,’X’数组应该与’y’数组相关联。例如,将[3,7]值传递给clf.predict()应该返回’1’,因为这是我的NP数组定义的方式。[8,5]应该返回’1’,[2, 3]应该返回’0’,依此类推。

我计划在更大规模上进行这个操作,这里使用的数据只是为了演示我的问题。我刚接触NN领域,所以我相信我只是错过了一些小细节。


回答:

如果有人遇到这个问题,我找到了问题所在。我使用了错误的激活函数。结果发现,在这种类型的分类场景中,’logistic’比’tanh’效果好得多:

def train_NN(self, trained_data_list, test_size=0.25):    X = np.array([[7, 4], [3,7], [8,5], [2,3]])    y = np.array([0, 1, 1, 0])    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)    self.clf = MLPClassifier(solver='lbfgs', alpha=1e-10, hidden_layer_sizes=(2, 2), activation = 'logistic', random_state=5, max_iter=10000, learning_rate_init = 0.1)    self.clf.fit(X_train, y_train)    print(self.clf.predict(np.atleast_2d([7,4])))  #返回 0    print(self.clf.predict(np.atleast_2d([3,7])))  #返回 1     print(self.clf.predict(np.atleast_2d([8,5])))  #返回 1     print(self.clf.predict(np.atleast_2d([2,3])))  #返回 0 

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注