仅有一个数值特征的逻辑回归

当您只有一个数值特征时,如何正确使用 scikit-learnLogisticRegression 求解器?

我运行了一个简单的例子,但我发现很难解释。请问有人能解释一下我在这里做错了什么吗?

import pandasimport numpy as npfrom sklearn.linear_model import LogisticRegressionX = [1, 2, 3, 10, 11, 12]X = np.reshape(X, (6, 1))Y = [0, 0, 0, 1, 1, 1]Y = np.reshape(Y, (6, 1))lr = LogisticRegression()lr.fit(X, Y)print ("2 --> {0}".format(lr.predict(2)))print ("4 --> {0}".format(lr.predict(4)))

这是脚本运行结束后我得到的输出。既然根据高斯分布,4更接近于测试集中被分类为0的分布,4的预测结果不应该是0吗

2 --> [0]4 --> [1]

当您只有一个包含数值数据的列时,逻辑回归采取什么方法?


回答:

您处理单个特征的方法是正确的,但您错误地假设了仅仅因为4接近于0类特征,它也会被预测为0类

您可以绘制训练数据以及S形函数,假设分类阈值为 y=0.5,并使用回归模型中学习到的系数和截距:

import numpy as npimport matplotlib.pyplot as pltfrom sklearn.linear_model import LogisticRegressionX = [1, 2, 3, 10, 11, 12]X = np.reshape(X, (6, 1))Y = [0, 0, 0, 1, 1, 1]Y = np.reshape(Y, (6, 1))lr = LogisticRegression()lr.fit(X, Y)plt.figure(1, figsize=(4, 3))plt.scatter(X.ravel(), Y, color='black', zorder=20)def model(x):    return 1 / (1 + np.exp(-x))X_test = np.linspace(-5, 15, 300)loss = model(X_test * lr.coef_ + lr.intercept_).ravel()plt.plot(X_test, loss, color='red', linewidth=3)plt.axhline(y=0, color='k', linestyle='-')plt.axhline(y=1, color='k', linestyle='-')plt.axhline(y=0.5, color='b', linestyle='--')plt.axvline(x=X_test[123], color='b', linestyle='--')plt.ylabel('y')plt.xlabel('X')plt.xlim(0, 13)plt.show()

这是您情况下的S形函数的外观:

enter image description here

稍微放大一点:

enter image description here

对于您的特定模型,当 Y 处于0.5分类阈值时,X 的值在 3.1613.227 之间。您可以通过比较 lossX_test 数组来检查这一点(X_test[123] 是与上限相关的X值 – 如果您想得到确切的值,可以使用一些函数优化方法)

因此,4被预测为类 1 的原因是4高于 Y == 0.5 的界限

您可以进一步通过以下方式展示这一点:

print ("2 --> {0}".format(lr.predict(2)))print ("3 --> {0}".format(lr.predict(3)))print ("3.1 --> {0}".format(lr.predict(3.1)))print ("3.3 --> {0}".format(lr.predict(3.3)))print ("4 --> {0}".format(lr.predict(4)))

这将打印出以下内容:

2 --> [0]3 --> [0]3.1 --> [0]  # 低于阈值3.3 --> [1]  # 高于阈值4 --> [1]

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注