我不确定如何正确使用装饰器;我参考了Real Python和针对多个方法的Try-Except。我正在编写一个线性回归类,我意识到在调用predict
或类中的其他方法之前,需要先调用fit
,但是每次定义方法时都检查self._fitted
标志是否为False
并抛出错误,实在是太麻烦了。所以我转而使用装饰器,我不确定我是否使用正确,虽然它确实按照我的期望运行,但是它忽略了其他类型的错误,如ValueError等。在这里寻求建议。
import functoolsfrom sklearn.exceptions import NotFittedErrordef NotFitted(func): @functools.wraps(func) def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except: raise NotFittedError return wrapperclass LinearRegression: def __init__(self, fit_intercept: bool = True): self.coef_ = None self.intercept_ = None self.fit_intercept = fit_intercept # a flag to turn to true once we called fit on the data self._fitted = Falsedef check_shape(self, X: np.array, y: np.array): # if X is 1D array, then it is simple linear regression, reshape to 2D # [1,2,3] -> [[1],[2],[3]] to fit the data if X is not None and len(X.shape) == 1: X = X.reshape(-1, 1) # self._features = X # self.intercept_ = y return X, ydef fit(self, X: np.array = None, y: np.array = None): X, y = self.check_shape(X, y) n_samples, n_features = X.shape[0], X.shape[1] if self.fit_intercept: X = np.c_[np.ones(n_samples), X] XtX = np.dot(X.T, X) XtX_inv = np.linalg.inv(XtX) XtX_inv_Xt = np.dot(XtX_inv, X.T) _optimal_betas = np.dot(XtX_inv_Xt, y) # set attributes from None to the optimal ones self.coef_ = _optimal_betas[1:] self.intercept_ = _optimal_betas[0] self._fitted = True return self@NotFitteddef predict(self, X: np.array): """ after calling .fit, you can continue to .predict to get model prediction """ # if self._fitted is False: # raise NotFittedError if self.fit_intercept: y_hat = self.intercept_ + np.dot(X, self.coef_) else: y_hat = self.intercept_ return y_hat
回答:
让我快速重复一下你想做的事情,以确保我没有误解。你希望有一个装饰器@NotFitted
,这样你标注它的每个函数都会首先检查self._fitted
是否为True
,如果为False
,则会抛出NotFittedError
而不执行函数。
通过查看这个问题,你可以了解如何向装饰器传递额外的参数。
我不习惯使用装饰器,所以我必须快速测试一下,看看你的代码中发生了什么——为什么def wrapper
不需要self
参数:
>>> def deco1(func):... def wrapper(*args, **kwargs):... print("Args are {}".format(args))... return wrapper>>> class Foo(object):... @deco1... def meth(self, a):... print("a: "+a)>>> f = Foo()>>> f.meth("hello")Args are (<__main__.Foo object at 0x7f37676a4128>, 'hello')
如你所见,这里wrapper
打印的第一个参数实际上是self
。*args
只是将所有非关键字参数收集到一个元组中,包括self
,这是这里的第一个参数。如果我们想的话,我们可以通过def wrapper(self, *args, **kwargs)
来更明确地表示(参见链接的问题)。
我需要在装饰器中调用
_fitted
吗?
是的,因为self._fitted
是你用来跟踪是否已经拟合的方法。你可以通过*args
的第一个元素访问它,通过args[0]._fitted
来访问。但我更喜欢明确地传递self
。无论哪种方式,你可以在wrapper
中检查self._fitted
是否为True
,如果不是,则失败。所以我定义了这个例子:
#!/bin/env/python3# Declaring my own NotFittedError, because I don't want to# from sklearn.exceptions import NotFittedError# just for this small example.class NotFittedError (Exception): passdef NotFitted ( foo ): def wrapper ( self, *args, **kwargs ): if not self._fitted: raise NotFittedError() else: foo ( self, *args, **kwargs ) return wrapperclass Foo() : # Set self._fitted to false just to be explicit. # The initial value should be False anyway. def __init__(self): self._fitted = False def fit(self): self._fitted = True @NotFitted def predict(self, X): # code here that assumes fit was already called print ( "Successfully Predicted!" )
现在我们可以使用它了。在下面的代码片段中,我将其作为tmp
导入,因为我在一个名为tmp.py
的文件中。你不需要这样做,因为你所有的代码都在同一个文件中。
>>> import tmp>>> f = tmp.Foo()>>> f.predict("a")Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/generic/Downloads/tmp.py", line 12, in wrapper raise NotFittedError()tmp.NotFittedError>>> f.fit()>>> f.predict("a")Successfully Predicted!
一些进一步的评论:
- 如果你的唯一目标是抛出一个
NotFittedError
,也许你不需要做这些。我认为sklearn.NotFittedError无论如何都会被抛出。 - 如果你想区分不同类型的错误,在这种情况下,了解你可以有多个
except
子句也可能对你有用。