随机森林在训练和测试中获得98%的准确率,但在其他情况下总是预测相同的类别

我已经花了30个小时在调试这个问题上,完全没有头绪,希望你们中的某个人能给我一个不同的视角。

问题是我在随机森林中使用训练数据框架时获得了98%-99%的良好准确率,但当我尝试加载新的样本进行预测时,模型总是猜测相同的类别。

#  打乱数据框架的记录。标签仍然附着df = df.sample(frac=1).reset_index(drop=True)#  提取标签并从数据中删除它们y = list(df['label'])X = df.drop(['label'], axis='columns')#  将数据拆分为训练和测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE)#  构建模型model = RandomForestClassifier(n_estimators=N_ESTIMATORS, max_depth=MAX_DEPTH, random_state=RANDOM_STATE,oob_score=True)#  计算训练准确率in_sample_accuracy = model.fit(X_train, y_train).score(X_train, y_train)#  计算测试准确率test_accuracy = model.score(X_test, y_test)print()print('In Sample Accuracy: {:.2f}%'.format(model.oob_score_ * 100))print('Test Accuracy: {:.2f}%'.format(test_accuracy * 100))

我处理数据的方式是相同的,但当我对X_test或X_train进行预测时,我得到正常的98%准确率,而当我对新数据进行预测时,它总是猜测相同的类别。

    #  json文件格式不正确,此函数对其进行标准化    normalized_json = json_normalizer(json_file, "", training=False)    #  将json转换为包含特征的字典列表    features_dict = create_dict(normalized_json, label=None)    #  将字典转换为pandas数据框架    df = pd.DataFrame.from_records(features_dict)    print('Total amount of email samples: ', len(df))    print()    df = df.fillna(-1)    #  对字符串值进行独热编码    df = one_hot_encode(df, noOverride=True)    if 'label' in df.columns:        df = df.drop(['label'], axis='columns')    print(list(model.predict(df))[:100])    print(list(model.predict(X_train))[:100])

上面的测试场景中,你可以看到在最后两行中,我在X_train(用于训练模型的数据)和df(模型总是猜测类别0的样本外数据)上进行预测。

一些有用信息:

  • 数据集不平衡;类别0约有150,000个样本,而类别1约有600,000个样本
  • 有141个特征
  • 更改n_estimators和max_depth并不能解决问题

任何想法都会有帮助,如果你需要更多信息请告诉我,我的脑子现在已经累坏了,这就是我能想到的全部了。


回答:

已解决,问题出在数据集的不平衡上,我也意识到改变深度会得到不同的结果。

例如,10棵树深度为3 -> 似乎工作正常10棵树深度为6 -> 又回到只猜测相同的类别

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注