如何构建基于StandardScaler的自定义缩放器?

我正在尝试构建一个自定义缩放器,用于缩放数据集中的连续变量(美国成人收入数据集:https://www.kaggle.com/uciml/adult-census-income),以StandardScaler为基础。以下是我使用的Python代码:

from sklearn.base import BaseEstimator, TransformerMixinfrom sklearn.preprocessing import StandardScalerclass CustomScaler(BaseEstimator,TransformerMixin):             def __init__(self,columns,copy=True,with_mean=True,with_std=True):                self.scaler = StandardScaler(copy,with_mean,with_std)        self.columns = columns        self.mean_ = None        self.var_ = None                    def fit(self, X, y=None):        self.scaler.fit(X[self.columns], y)        self.mean_ = np.mean(X[self.columns])        self.var_ = np.var(X[self.columns])        return self        def transform(self, X, y=None, copy=None):                init_col_order = X.columns                X_scaled = pd.DataFrame(self.scaler.transform(X[self.columns]), columns=self.columns)                X_not_scaled = X.loc[:,~X.columns.isin(self.columns)]                return pd.concat([X_not_scaled, X_scaled], axis=1)[init_col_order]X=new_df_upsampled.copy()X.drop('income',axis=1,inplace=True)continuous = df.iloc[:, np.r_[0,2,10:13]] #basically independent variables that I consider continuouscolumns_to_scale = continuousscaler = CustomScaler(columns_to_scale)scaler.fit(X)

然而,当我尝试运行这个缩放器时,遇到了这个问题:enter image description here

那么,我在构建缩放器时遇到了什么错误?此外,您如何为这个数据集构建一个自定义缩放器?

谢谢!


回答:

我同意@AntoineDubuis的观点,ColumnTransformer是一种更好的(内置的)方法来实现这个功能。尽管如此,我想指出你的代码中出现的问题。

fit中,你有self.scaler.fit(X[self.columns], y);这表明self.columns应该是一个列名列表(或其他几种选择)。但你定义的参数是continuous = df.iloc[:, np.r_[0,2,10:13]],这是一个数据框

还有几个其他问题:

  1. 你应该只在__init__中设置来自其签名的属性,否则克隆将失败。将self.scaler移到fit中,并在__init__中直接保存其参数copy等。不要初始化mean_var_
  2. 你从未实际使用mean_var_。如果你想保留它们可以,但相关统计数据已经存储在缩放器对象中。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注