鸢尾花数据集的多类分类

我知道我一次问了很多问题,但这些是我在使用逻辑回归进行鸢尾花数据集分析时产生的疑问。

这是我在鸢尾花数据集上使用LogisticRegression的代码。

iris = datasets.load_iris()X, y = iris.data, iris.targetx_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state= 81,                                                           test_size=0.3)logreg = LogisticRegression()logreg.fit(x_train, y_train)pred = logreg.predict(x_test)accuracy_score(y_test, pred) # 这给出准确率0.95555

我知道逻辑回归通过预测结果为10来工作,但对于这个鸢尾花数据集,需要根据指定分类为012

问)我需要使用StandardScalar标准化数据吗?

问)这是如何工作的?我知道LR通过预测来工作,但在鸢尾花这里,我们需要预测012

问)如果LogisticRegression也适用于多类分类,那么我如何优化上述代码,以在其他我想尝试的多类数据集上获得更好的预测?

问)我需要转换我的y_train吗,或者我需要对其进行任何类型的编码等才能使其工作?

如果有人能帮助我解决这些问题,我将不胜感激。任何好的参考资料也将受到欢迎。


回答:

我需要使用StandardScalar标准化数据吗?

一般来说,这被称为特征缩放,为此目的有不止一种缩放器,简而言之:

  1. StandardScaler:通常是你的首选选项,它非常常用。它通过标准化数据来工作(即居中),以使它们达到STD=1Mean=0。它会受到异常值的影响,并且只有在你的数据具有高斯分布时才应使用。
  2. MinMaxScaler:通常在你想将所有数据点带入特定范围时使用(例如[0-1])。它容易受到异常值的影响,因为它使用了范围
  3. RobustScaler:它对异常值“robust”,因为它根据四分位范围来缩放数据。然而,你应该知道,缩放后的数据中仍然会存在异常值。
  4. MaxAbsScaler:主要用于稀疏数据
  5. 单位归一化:基本上,它将每个样本的向量缩放为单位范数,与样本的分布无关。

现在,作为经验法则,我们通常因为以下一个(或多个)原因来缩放特征:

  1. 某些算法要求特征被缩放,例如神经网络。(例如为了避免梯度消失),另一个例子是当我们在SVM中使用RBF核时…等。
  2. 特征缩放提高/加速收敛
  3. 当特征在大小、单位和范围上差异很大时(例如5公斤和5000克),因为我们不希望算法错误地认为一个特征比另一个更重要(即对模型有更大的影响)。

如你所见,特征缩放与你Y中类别的数量无关。


…但对于这个鸢尾花数据集,需要根据指定分类为0或1或2…这是如何工作的?我知道LR通过预测是或否来工作,但在鸢尾花这里,我们需要预测0或1或2

嗯,与二元分类相反,这被称为多类分类

这里的基本思想是Scikit LogisticRegresser使用一对多(OvR)方案-默认情况下-来解决它(也称为一对所有),它工作(用我能想到的最简单的话来说)如下:

为每个类i训练一个逻辑回归分类器,以预测y = i的概率。在新的输入x上进行预测时,选择具有最大似然性(即最高假设结果)的类i,换句话说,它将多类分类问题简化为多个二元分类问题,欲了解更多详情,请查看这里


如果LogisticRegression也适用于多类分类,那么我如何优化上述代码,以在其他我想尝试的多类数据集上获得更好的预测?

嗯,你不需要进行任何优化,你抽象地使用Scikit库,所以它会处理优化问题,实际上它是通过使用一个求解器来实现的,关于求解器的比较,请查看这里(我曾经在Stackoverflow上写过)。


我需要转换我的y_train吗,或者我需要对其进行任何类型的编码等才能使其工作?

对于你的特定情况(即对于鸢尾花数据集),答案是不需要,因为它已经为你准备好了,但如果依赖变量中的值(即Y)不是数字,那么你应该将它们转换为数字,例如,如果你有4个类,你可以用一个数字来表示每个类(例如0, 1, 2, 3)。(替换0和1为男和女的示例)(你应该做相反的事情,但你从中明白了:D)。


我强烈推荐你开始学习的一个非常好的参考资料是这门由Andrew NG教授开设的课程,它将清除你所有的疑问。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注