我在Udemy上学习一个Python机器学习课程,使用以下数据集(仅展示前几行)
R&D Spend Administration Marketing Spend State Profit0 165349 136898 471784 New York 1922621 162598 151378 443899 California 1917922 153442 101146 407935 Florida 1910503 144372 118672 383200 New York 182902
该课程制作于2016年,因此一些模块已更新,我也在代码中进行了相应的更改(例如:使用ColumnTransformer的make_column_transformer)。代码的输出应该是一个浮点数组(在Udemy教程中确实如此),然而,由于某些原因,在代码更新后,我的变量x
在处理后被视为ndarray对象
。我不确定这是为什么,因为当我打印变量x
时,它输出的确实是一个浮点数数组。
原始数据文件可以在此链接(一个zip文件夹)中找到,文件名为50_startups.csv
。
我尝试添加.toarray()
,但这导致代码出错。
谢谢
import pandas as pd import matplotlib.pyplot as plt import numpy as np dataset = pd.read_csv("Startups (multiple linear regression).csv")x=dataset.iloc[:,:-1].valuesy=dataset.iloc[:,-1]#Encode categorical variables (New York, California, Florida)from sklearn.compose import ColumnTransformer, make_column_transformerfrom sklearn.preprocessing import OneHotEncoderpreprocess = make_column_transformer((OneHotEncoder(),[-1]),remainder="passthrough")x = preprocess.fit_transform(x)
回答:
在这种情况下,我认为这是由于您的输入和输出中混合了数据类型。例如,如果您检查x
:
xarray([[165349, 136898, 471784, 'New York'], [162598, 151378, 443899, 'California'], [153442, 101146, 407935, 'Florida'], [144372, 118672, 383200, 'New York']], dtype=object)
您会发现它的dtype=object
。这是因为数组中混合了整数和字符串。因此,直通数组(R&D Spend, Administration, 和Marketing Spend)保持相同的dtype
。在fit_transform
中,这个数组与您的OneHotEncoder
转换结果堆叠在一起,产生最终结果。因此,输出的dtype
与您提供的输入相同。
如果您想更改dtype
,您可以始终使用.astype(float)
。