当我们再训练网络的时候可能希望保持一部分的网络参数不变,只对其中一部分的参数进行调整;或者值训练部分分支网络,并不让其梯度对主网络的梯度造成影响,这时候我们就需要使用detach()函数来切断一些分支的反向传播
1 detach()[source]
返回一个新的Variable,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个Variable永远不需要计算其梯度,不具有grad。
即使之后重新将它的requires_grad置为true,它也不会具有梯度grad
这样我们就会继续使用这个新的Variable进行计算,后面当我们进行反向传播时,到该调用detach()的Variable就会停止,不能再继续向前进行传播
源码为:
def detach(self): """Returns a new Variable, detached from the current graph. Result will never require gradient. If the input is volatile, the output will be volatile too. .. note:: Returned Variable uses the same data tensor, as the original one, and in-place modifications on either of them will be seen, and may trigger errors in correctness checks. """ result = NoGrad()(self) # this is needed, because it merges version counters result._grad_fn = None return result
可见函数进行的操作有:
将grad_fn设置为None 将Variable的requires_grad设置为False如果输入 volatile=True(即不需要保存记录,当只需要结果而不需要更新参数时这么设置来加快运算速度),那么返回的Variable volatile=True。(volatile已经弃用)
注意:
返回的Variable和原始的Variable公用同一个data tensor。in-place函数修改会在两个Variable上同时体现(因为它们共享data tensor),当要对其调用backward()时可能会导致错误。
举例:
比如正常的例子是:
import torch a = torch.tensor([1, 2, 3.], requires_grad=True) print(a.grad) out = a.sigmoid() out.sum().backward() print(a.grad)
返回:
(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.1966, 0.1050, 0.0452])
当使用detach()但是没有进行更改时,并不会影响backward():
import torch a = torch.tensor([1, 2, 3.], requires_grad=True) print(a.grad) out = a.sigmoid() print(out) #添加detach(),c的requires_grad为False c = out.detach() print(c) #这时候没有对c进行更改,所以并不会影响backward() out.sum().backward() print(a.grad)
返回:
(deeplearning) userdeMBP:pytorch user$ python test.py
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0.1966, 0.1050, 0.0452])
可见c,out之间的区别是c是没有梯度的,out是有梯度的
如果这里使用的是c进行sum()操作并进行backward(),则会报错:
import torch a = torch.tensor([1, 2, 3.], requires_grad=True) print(a.grad) out = a.sigmoid() print(out) #添加detach(),c的requires_grad为False c = out.detach() print(c) #使用新生成的Variable进行反向传播 c.sum().backward() print(a.grad)
返回:
(deeplearning) userdeMBP:pytorch user$ python test.py