在PyTorch中为多个损失计算`loss.backward`的性能如何?

我想为我的模型计算多个损失函数的梯度。

我想了解在使用retain_graph=True的情况下,进行连续的backwards调用是廉价的还是昂贵的。

理论上,我预期第一次调用应该比后续的调用慢,因为计算图不需要重新评估,只需要进行一些矩阵乘法运算。

在实践中,我发现很难进行基准测试。

我的代码如下:

# Code in file nn/two_layer_net_nn.pyimport torchD_in = 40model = torch.load('model.pytorch')device = torch.device('cpu')def loss1(y_pred,x):    return (y_pred*(0.5-x.clamp(0,1))).sum()def loss2(y_pred,x):    return (y_pred*(1-x.clamp(0,1))).sum()# Predict random input    x = torch.rand(1,D_in, device=device,requires_grad=True)y_pred = model(x)# Is this %%timeitloss = loss1(y_pred,x)loss.backward(retain_graph=True)202 µs ± 4.34 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)# Slower than this?%%timeitloss = loss2(y_pred,x)loss.backward(retain_graph=True)216 µs ± 27.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)# Are successive backwards calls cheap?loss = lossX(y_pred,x)loss.backward(retain_graph=True)

我认为%%timeit不起作用是因为它会运行多次迭代然后进行平均计算。

我该如何测量连续调用backward是否会很快?retain_graph=True对性能实际上意味着什么?


回答:

我认为你只问了

  1. 第一次反向传播是否比后续的慢?

还有两个问题也可以在这里一起讨论:

  1. 合并损失是否更有效?
  2. 保持图(如果不是其值的话)是否需要更多内存?

然而,在此之前,让我们强调一下retain_graph的实际用途:如果你在计算上碰巧在多个时间点有多个输出,那么它用于多次通过你的图。例如,考虑联合多任务学习(参见这个问题及其答案对此的讨论)。

回到问题:一般来说,我预期保留图并不真的重要。毕竟,它只是将部分计算保存在内存中以供将来使用,而不“做”任何事情。

尽管如此 – 第一次反向传播会花费更长时间,因为PyTorch会在计算梯度时缓存一些需要的计算。

所以这里是证明:

import numpy as npimport torchimport torch.nn as nnimport timeimport osimport psutilD_in = 1024model = nn.Sequential(nn.Linear(1024, 4096), nn.ReLU(), nn.Linear(4096, 4096), nn.ReLU(), nn.Linear(4096, 1024))device = torch.device('cpu')def loss1(y_pred,x):    return (y_pred*(0.5-x.clamp(0,1))).sum()def loss2(y_pred,x):    return (y_pred*(1-x.clamp(0,1))).sum()def timeit(func, repetitions):    time_taken = []    mem_used = []    for _ in range(repetitions):        time_start = time.time()        mem_used.append(func())        time_taken.append(time.time() - time_start)    return np.round([np.mean(time_taken), np.min(time_taken), np.max(time_taken), \           np.mean(mem_used), np.min(mem_used), np.max(mem_used)], 4).tolist()# Predict random inputx = torch.rand(1,D_in, device=device,requires_grad=True)def init():    out = model(x)    loss = loss1(out, x)    loss.backward()def func1():    x = torch.rand(1, D_in, device=device, requires_grad=True)    loss = loss1(model(x),x)    loss.backward()    loss = loss2(model(x),x)    loss.backward()    del x    process = psutil.Process(os.getpid())    return process.memory_info().rssdef func2():    x = torch.rand(1, D_in, device=device, requires_grad=True)    loss = loss1(model(x),x) + loss2(model(x),x)    loss.backward()    del x    process = psutil.Process(os.getpid())    return process.memory_info().rssdef func3():    x = torch.rand(1, D_in, device=device, requires_grad=True)    loss = loss1(model(x),x)    loss.backward(retain_graph=True)    loss = loss2(model(x),x)    loss.backward(retain_graph=True)    del x    process = psutil.Process(os.getpid())    return process.memory_info().rssdef func4():    x = torch.rand(1, D_in, device=device, requires_grad=True)    loss = loss1(model(x),x) + loss2(model(x),x)    loss.backward(retain_graph=True)    del x    process = psutil.Process(os.getpid())    return process.memory_info().rssinit()print(timeit(func1, 100))print(timeit(func2, 100))print(timeit(func3, 100))print(timeit(func4, 100))

结果如下(抱歉我的格式有点懒):

# time mean, time min, time max, memory mean, memory min, memory max[0.1165, 0.1138, 0.1297, 383456419.84, 365731840.0, 384438272.0][0.127, 0.1233, 0.1376, 400914759.68, 399638528.0, 434044928.0][0.1167, 0.1136, 0.1272, 400424468.48, 399577088.0, 401223680.0][0.1263, 0.1226, 0.134, 400815964.16, 399556608.0, 434307072.0]

然而,如果你跳过第一次反向传播(注释掉对init()函数的调用),那么func1中的第一次反向运行确实会花费更长时间:

# time mean, time min, time max, memory mean, memory min, memory max[0.1208, 0.1136, **0.1579**, 350157455.36, 349331456.0, 350978048.0][0.1297, 0.1232, 0.1499, 393928540.16, 350052352.0, 401854464.0][0.1197, 0.1152, 0.1547, 350787338.24, 349982720.0, 351629312.0][0.1335, 0.1229, 0.1793, 382819123.2, 349929472.0, 401776640.0]

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注