当我尝试按一个分类列进行分层拆分时,返回给我一个错误。
Country ColumnA ColumnB ColumnC Label
AB 0.2 0.5 0.1 14
CD 0.9 0.2 0.6 60
EF 0.4 0.3 0.8 5
FG 0.6 0.9 0.2 15
这是我的代码:
X = df.loc[:, df.columns != 'Label']
y = df['Label']
# 训练/测试拆分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0, stratify=df.Country)
from sklearn.linear_model import LinearRegression
lm = LinearRegression()
lm.fit(X_train,y_train)
lm_predictions = lm.predict(X_test)
所以我得到了如下错误:
ValueError: could not convert string to float: 'AB'
回答:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
df = pd.DataFrame({
'Country': ['AB', 'CD', 'EF', 'FG']*20,
'ColumnA' : [1]*20*4,
'ColumnB' : [10]*20*4,
'Label': [1,0,1,0]*20
})
df['Country_Code'] = df['Country'].astype('category').cat.codes
X = df.loc[:, df.columns.drop(['Label','Country'])]
y = df['Label']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0, stratify=df.Country_Code)
lm = LinearRegression()
lm.fit(X_train,y_train)
lm_predictions = lm.predict(X_test)
- 将
country
列中的字符串值转换为数字,并保存为一个新列 - 在创建
x
训练数据时,删除label
(y
)以及字符串country
列
方法2
如果您稍后将获得用于预测的测试数据,您将需要一种机制,在进行预测之前将它们的country
转换为code
。在这种情况下,推荐的方法是使用LabelEncoder
,您可以使用fit
方法将字符串编码为标签,然后使用transform
方法对测试数据的国家进行编码。
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn import preprocessing
df = pd.DataFrame({
'Country': ['AB', 'CD', 'EF', 'FG']*20,
'ColumnA' : [1]*20*4,
'ColumnB' : [10]*20*4,
'Label': [1,0,1,0]*20
})
# 训练-验证
le = preprocessing.LabelEncoder()
df['Country_Code'] = le.fit_transform(df['Country'])
X = df.loc[:, df.columns.drop(['Label','Country'])]
y = df['Label']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0, stratify=df.Country_Code)
lm = LinearRegression()
lm.fit(X_train,y_train)
# 测试
test_df = pd.DataFrame({'Country': ['AB'], 'ColumnA' : [1],'ColumnB' : [10] })
test_df['Country_Code'] = le.transform(test_df['Country'])
print (lm.predict(test_df.loc[:, test_df.columns.drop(['Country'])]))