我是否正确使用了装饰器?

我不确定如何正确使用装饰器;我参考了Real Python针对多个方法的Try-Except。我正在编写一个线性回归类,我意识到在调用predict或类中的其他方法之前,需要先调用fit,但是每次定义方法时都检查self._fitted标志是否为False并抛出错误,实在是太麻烦了。所以我转而使用装饰器,我不确定我是否使用正确,虽然它确实按照我的期望运行,但是它忽略了其他类型的错误,如ValueError等。在这里寻求建议。

import functoolsfrom sklearn.exceptions import NotFittedErrordef NotFitted(func):    @functools.wraps(func)    def wrapper(*args, **kwargs):        try:            return func(*args, **kwargs)        except:            raise NotFittedError    return wrapperclass LinearRegression:    def __init__(self, fit_intercept: bool = True):        self.coef_ = None        self.intercept_ = None        self.fit_intercept = fit_intercept        # a flag to turn to true once we called fit on the data        self._fitted = Falsedef check_shape(self, X: np.array, y: np.array):    # if X is 1D array, then it is simple linear regression, reshape to 2D    # [1,2,3] -> [[1],[2],[3]] to fit the data    if X is not None and len(X.shape) == 1:        X = X.reshape(-1, 1)    # self._features = X    # self.intercept_ = y    return X, ydef fit(self, X: np.array = None, y: np.array = None):    X, y = self.check_shape(X, y)    n_samples, n_features = X.shape[0], X.shape[1]    if self.fit_intercept:           X = np.c_[np.ones(n_samples), X]    XtX = np.dot(X.T, X)    XtX_inv = np.linalg.inv(XtX)    XtX_inv_Xt = np.dot(XtX_inv, X.T)    _optimal_betas = np.dot(XtX_inv_Xt, y)    # set attributes from None to the optimal ones    self.coef_ = _optimal_betas[1:]    self.intercept_ = _optimal_betas[0]    self._fitted = True    return self@NotFitteddef predict(self, X: np.array):    """    after calling .fit, you can continue to .predict to get model prediction    """    # if self._fitted is False:    #     raise NotFittedError    if self.fit_intercept:        y_hat = self.intercept_ + np.dot(X, self.coef_)    else:        y_hat = self.intercept_    return y_hat

回答:

让我快速重复一下你想做的事情,以确保我没有误解。你希望有一个装饰器@NotFitted,这样你标注它的每个函数都会首先检查self._fitted是否为True,如果为False,则会抛出NotFittedError而不执行函数。

通过查看这个问题,你可以了解如何向装饰器传递额外的参数。
我不习惯使用装饰器,所以我必须快速测试一下,看看你的代码中发生了什么——为什么def wrapper不需要self参数:

>>> def deco1(func):...   def wrapper(*args, **kwargs):...     print("Args are {}".format(args))...   return wrapper>>> class Foo(object):...   @deco1...   def meth(self, a):...     print("a: "+a)>>> f = Foo()>>> f.meth("hello")Args are (<__main__.Foo object at 0x7f37676a4128>, 'hello')

如你所见,这里wrapper打印的第一个参数实际上是self*args只是将所有非关键字参数收集到一个元组中,包括self,这是这里的第一个参数。如果我们想的话,我们可以通过def wrapper(self, *args, **kwargs)来更明确地表示(参见链接的问题)。

我需要在装饰器中调用_fitted吗?

是的,因为self._fitted是你用来跟踪是否已经拟合的方法。你可以通过*args的第一个元素访问它,通过args[0]._fitted来访问。但我更喜欢明确地传递self。无论哪种方式,你可以在wrapper中检查self._fitted是否为True,如果不是,则失败。所以我定义了这个例子:

#!/bin/env/python3# Declaring my own NotFittedError, because I don't want to# from sklearn.exceptions import NotFittedError# just for this small example.class NotFittedError (Exception):    passdef NotFitted ( foo ):    def wrapper ( self, *args, **kwargs ):        if not self._fitted:            raise NotFittedError()        else:            foo ( self, *args, **kwargs )    return wrapperclass Foo() :    # Set self._fitted to false just to be explicit.    # The initial value should be False anyway.    def __init__(self):        self._fitted = False    def fit(self):        self._fitted = True    @NotFitted    def predict(self, X):        # code here that assumes fit was already called        print ( "Successfully Predicted!" )

现在我们可以使用它了。在下面的代码片段中,我将其作为tmp导入,因为我在一个名为tmp.py的文件中。你不需要这样做,因为你所有的代码都在同一个文件中。

>>> import tmp>>> f = tmp.Foo()>>> f.predict("a")Traceback (most recent call last):  File "<stdin>", line 1, in <module>  File "/home/generic/Downloads/tmp.py", line 12, in wrapper    raise NotFittedError()tmp.NotFittedError>>> f.fit()>>> f.predict("a")Successfully Predicted!

一些进一步的评论:

  • 如果你的唯一目标是抛出一个NotFittedError,也许你不需要做这些。我认为sklearn.NotFittedError无论如何都会被抛出。
  • 如果你想区分不同类型的错误,在这种情况下,了解你可以有多个except子句也可能对你有用。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注