我是否正确使用了装饰器?

我不确定如何正确使用装饰器;我参考了Real Python针对多个方法的Try-Except。我正在编写一个线性回归类,我意识到在调用predict或类中的其他方法之前,需要先调用fit,但是每次定义方法时都检查self._fitted标志是否为False并抛出错误,实在是太麻烦了。所以我转而使用装饰器,我不确定我是否使用正确,虽然它确实按照我的期望运行,但是它忽略了其他类型的错误,如ValueError等。在这里寻求建议。

import functoolsfrom sklearn.exceptions import NotFittedErrordef NotFitted(func):    @functools.wraps(func)    def wrapper(*args, **kwargs):        try:            return func(*args, **kwargs)        except:            raise NotFittedError    return wrapperclass LinearRegression:    def __init__(self, fit_intercept: bool = True):        self.coef_ = None        self.intercept_ = None        self.fit_intercept = fit_intercept        # a flag to turn to true once we called fit on the data        self._fitted = Falsedef check_shape(self, X: np.array, y: np.array):    # if X is 1D array, then it is simple linear regression, reshape to 2D    # [1,2,3] -> [[1],[2],[3]] to fit the data    if X is not None and len(X.shape) == 1:        X = X.reshape(-1, 1)    # self._features = X    # self.intercept_ = y    return X, ydef fit(self, X: np.array = None, y: np.array = None):    X, y = self.check_shape(X, y)    n_samples, n_features = X.shape[0], X.shape[1]    if self.fit_intercept:           X = np.c_[np.ones(n_samples), X]    XtX = np.dot(X.T, X)    XtX_inv = np.linalg.inv(XtX)    XtX_inv_Xt = np.dot(XtX_inv, X.T)    _optimal_betas = np.dot(XtX_inv_Xt, y)    # set attributes from None to the optimal ones    self.coef_ = _optimal_betas[1:]    self.intercept_ = _optimal_betas[0]    self._fitted = True    return self@NotFitteddef predict(self, X: np.array):    """    after calling .fit, you can continue to .predict to get model prediction    """    # if self._fitted is False:    #     raise NotFittedError    if self.fit_intercept:        y_hat = self.intercept_ + np.dot(X, self.coef_)    else:        y_hat = self.intercept_    return y_hat

回答:

让我快速重复一下你想做的事情,以确保我没有误解。你希望有一个装饰器@NotFitted,这样你标注它的每个函数都会首先检查self._fitted是否为True,如果为False,则会抛出NotFittedError而不执行函数。

通过查看这个问题,你可以了解如何向装饰器传递额外的参数。
我不习惯使用装饰器,所以我必须快速测试一下,看看你的代码中发生了什么——为什么def wrapper不需要self参数:

>>> def deco1(func):...   def wrapper(*args, **kwargs):...     print("Args are {}".format(args))...   return wrapper>>> class Foo(object):...   @deco1...   def meth(self, a):...     print("a: "+a)>>> f = Foo()>>> f.meth("hello")Args are (<__main__.Foo object at 0x7f37676a4128>, 'hello')

如你所见,这里wrapper打印的第一个参数实际上是self*args只是将所有非关键字参数收集到一个元组中,包括self,这是这里的第一个参数。如果我们想的话,我们可以通过def wrapper(self, *args, **kwargs)来更明确地表示(参见链接的问题)。

我需要在装饰器中调用_fitted吗?

是的,因为self._fitted是你用来跟踪是否已经拟合的方法。你可以通过*args的第一个元素访问它,通过args[0]._fitted来访问。但我更喜欢明确地传递self。无论哪种方式,你可以在wrapper中检查self._fitted是否为True,如果不是,则失败。所以我定义了这个例子:

#!/bin/env/python3# Declaring my own NotFittedError, because I don't want to# from sklearn.exceptions import NotFittedError# just for this small example.class NotFittedError (Exception):    passdef NotFitted ( foo ):    def wrapper ( self, *args, **kwargs ):        if not self._fitted:            raise NotFittedError()        else:            foo ( self, *args, **kwargs )    return wrapperclass Foo() :    # Set self._fitted to false just to be explicit.    # The initial value should be False anyway.    def __init__(self):        self._fitted = False    def fit(self):        self._fitted = True    @NotFitted    def predict(self, X):        # code here that assumes fit was already called        print ( "Successfully Predicted!" )

现在我们可以使用它了。在下面的代码片段中,我将其作为tmp导入,因为我在一个名为tmp.py的文件中。你不需要这样做,因为你所有的代码都在同一个文件中。

>>> import tmp>>> f = tmp.Foo()>>> f.predict("a")Traceback (most recent call last):  File "<stdin>", line 1, in <module>  File "/home/generic/Downloads/tmp.py", line 12, in wrapper    raise NotFittedError()tmp.NotFittedError>>> f.fit()>>> f.predict("a")Successfully Predicted!

一些进一步的评论:

  • 如果你的唯一目标是抛出一个NotFittedError,也许你不需要做这些。我认为sklearn.NotFittedError无论如何都会被抛出。
  • 如果你想区分不同类型的错误,在这种情况下,了解你可以有多个except子句也可能对你有用。

Related Posts

如何对SVC进行超参数调优?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

如何在初始训练后向模型添加训练数据?

我想在我的scikit-learn模型已经训练完成后再…

使用Google Cloud Function并行运行带有不同用户参数的相同训练作业

我正在寻找一种方法来并行运行带有不同用户参数的相同训练…

加载Keras模型,TypeError: ‘module’ object is not callable

我已经在StackOverflow上搜索并阅读了文档,…

在计算KNN填补方法中特定列中NaN值的”距离平均值”时

当我从头开始实现KNN填补方法来处理缺失数据时,我遇到…

使用巨大的S3 CSV文件或直接从预处理的关系型或NoSQL数据库获取数据的机器学习训练/测试工作

已关闭。此问题需要更多细节或更清晰的说明。目前不接受回…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注