我想在不使用循环的情况下计算SVM的损失,但总是做不对。需要一些启发。
和
def svm_loss_vectorized(W, X, y, reg): loss = 0.0 scores = np.dot(X, W) correct_scores = scores[y] deltas = np.ones(scores.shape) margins = scores - correct_scores + deltas margins[margins < 0] = 0 # max -> Boolean array indexing margins[np.arange(scores.shape[0]), y] = 0 # Don't count j = yi loss = np.sum(margins) # Average num_train = X.shape[0] loss /= num_train # Regularization loss += 0.5 * reg * np.sum(W * W) return loss
它应该输出与以下函数相同的损失值。
def svm_loss_naive(W, X, y, reg): num_classes = W.shape[1] num_train = X.shape[0] loss = 0.0 for i in range(num_train): scores = X[i].dot(W) correct_class_score = scores[y[i]] for j in range(num_classes): if j == y[i]: continue margin = scores[j] - correct_class_score + 1 # note delta = 1 if margin > 0: loss += margin loss /= num_train # mean loss += 0.5 * reg * np.sum(W * W) # l2 regularization return loss
回答:
这里是一个向量化的方法 –
delta = 1N = X.shape[0]M = W.shape[1]scoresv = X.dot(W)marginv = scoresv - scoresv[np.arange(N), y][:,None] + deltamask0 = np.zeros((N,M),dtype=bool)mask0[np.arange(N),y] = 1mask = (marginv<0) | mask0marginv[mask] = 0loss_out = marginv.sum()/num_train # meanloss_out += 0.5 * reg * np.sum(W * W) # l2 regularization
此外,我们可以用np.tensordot
来优化np.sum(W * W)
,如下所示 –
float(np.tensordot(W,W,axes=((0,1),(0,1))))
运行时间测试
提出的方法作为函数 –
def svm_loss_vectorized_v2(W, X, y, reg): delta = 1 N = X.shape[0] M = W.shape[1] scoresv = X.dot(W) marginv = scoresv - scoresv[np.arange(N), y][:,None] + delta mask0 = np.zeros((N,M),dtype=bool) mask0[np.arange(N),y] = 1 mask = (marginv<=0) | mask0 marginv[mask] = 0 loss_out = marginv.sum()/num_train # mean loss_out += 0.5 * reg * float(np.tensordot(W,W,axes=((0,1),(0,1)))) return loss_out
时间测量 –
In [86]: W= np.random.randn(3073,10) ...: X= np.random.randn(500,3073) ...: y= np.random.randint(0,10,(500)) ...: reg = 4.56 ...: In [87]: svm_loss_naive(W, X, y, reg)Out[87]: 70380.938069371899In [88]: svm_loss_vectorized_v2(W, X, y, reg)Out[88]: 70380.938069371914In [89]: %timeit svm_loss_naive(W, X, y, reg)100 loops, best of 3: 10.2 ms per loopIn [90]: %timeit svm_loss_vectorized_v2(W, X, y, reg)100 loops, best of 3: 2.94 ms per loop