即使在数据集过采样后性能仍然非常低

我在使用MLPClassifier对心脏病进行分类。我使用了imblearn.SMOTE来平衡各个类别的对象。之前我得到了非常好的结果(85%的平衡准确率),但有人建议我在测试数据上不应该使用SMOTE,只在训练数据上使用。在我做出这些更改后,我的分类器性能下降得非常厉害(大约35%的平衡准确率),我不知道哪里出了问题。

以下是训练数据平衡但测试数据不平衡的简单基准测试:

预测结果

这是代码:

    def makeOverSamplesSMOTE(X,y):         from imblearn.over_sampling import SMOTE         sm = SMOTE(sampling_strategy='all')         X, y = sm.fit_sample(X, y)         return X,y    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=20)    ## Normalize data    from sklearn.preprocessing import StandardScaler    sc_X = StandardScaler()    X_train = sc_X.fit_transform(X_train)    X_test = sc_X.fit_transform(X_test)    ## SMOTE only on training data    X_train, y_train = makeOverSamplesSMOTE(X_train, y_train)    clf = MLPClassifier(hidden_layer_sizes=(20),verbose=10,                        learning_rate_init=0.5, max_iter=2000,                         activation='logistic', solver='sgd', shuffle=True, random_state=30)    clf.fit(X_train, y_train)    y_pred = clf.predict(X_test)

我想知道我哪里做错了,因为这似乎是准备数据的正确方法。


回答:

你的代码中的第一个错误是在将数据转换为标准格式时。你只需要在X_train上拟合StandardScaler一次。你不应该在X_test上重新拟合它。所以正确的代码应该是:

def makeOverSamplesSMOTE(X,y):     from imblearn.over_sampling import SMOTE     sm = SMOTE(sampling_strategy='all')     X, y = sm.fit_sample(X, y)     return X,yX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=20)## Normalize datafrom sklearn.preprocessing import StandardScalersc_X = StandardScaler()X_train = sc_X.fit_transform(X_train)X_test = sc_X.transform(X_test)## SMOTE only on training dataX_train, y_train = makeOverSamplesSMOTE(X_train, y_train)clf = MLPClassifier(hidden_layer_sizes=(20),verbose=10,                    learning_rate_init=0.5, max_iter=2000,                     activation='logistic', solver='sgd', shuffle=True, random_state=30)clf.fit(X_train, y_train)y_pred = clf.predict(X_test)

对于机器学习模型,尝试降低学习率。当前的学习率太高了。scikit-learn的默认学习率是0.001。尝试更改激活函数和层的数量。另外,并不是每个机器学习模型都适用于每个数据集,因此你可能需要查看你的数据并相应地选择机器学习模型。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注