The website uses cookies. By using this site, you agree to our use of cookies as described in the Privacy Policy.
I Agree

pytorch 中register_buffer()

wanghua609 2020-03-17 11:58:09 7075

今天在看DSSINet代码的ssim.py时,遇到了一个用法

  1. class NORMMSSSIM(torch.nn.Module):
  2. def __init__(self, sigma=1.0, levels=5, size_average=True, channel=1):
  3. super(NORMMSSSIM, self).__init__()
  4. self.sigma = sigma
  5. self.window_size = 5
  6. self.levels = levels
  7. self.size_average = size_average
  8. self.channel = channel
  9. self.register_buffer('window', create_window(self.window_size, self.channel, self.sigma))
  10. self.register_buffer('weights', torch.Tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]))

那么这个register_buffer()是干什么用呢?官方解释如下

  1. nn.modules.module.py
  2. Adds a persistent buffer to the module.向模块添加持久缓冲区。
  3. This is typically used to register a buffer that should not to be
  4. considered a model parameter. For example, BatchNorm's ``running_mean``
  5. is not a parameter, but is part of the persistent state.这通常用于注册不应被视为模型参数的缓冲区。例如,BatchNorm的“running_mean”不是参数,而是持久状态的一部分。
  6. Buffers can be accessed as attributes using given names.
  7. 缓冲区可以使用给定的名称作为属性访问。
  8. Args:
  9. name (string): name of the buffer. The buffer can be accessed
  10. from this module using the given name 名称(字符串):缓冲区的名称。可以使用给定的名称从该模块访问缓冲区
  11. tensor (Tensor): buffer to be registered.
  12. Example::
  13. >>> self.register_buffer('running_mean', torch.zeros(num_features))

应该就是在内存中定一个常量,同时,模型保存和加载的时候可以写入和读出。

pytorch一般情况下,是将网络中的参数保存成orderedDict形式的,这里的参数其实包含两种,一种是模型中各种module含的参数,即nn.Parameter,我们当然可以在网络中定义其他的nn.Parameter参数,另一种就是buffer,前者每次optim.step会得到更新,而不会更新后者。

  1. class myModel(nn.Module):
  2. def __init__(self, kernel_size=3):
  3. super(Depth_guided1, self).__init__()
  4. self.kernel_size = kernel_size
  5. self.back_end = torch.nn.Sequential(
  6. torch.nn.Conv2d(3, 32, 3, padding=1),
  7. torch.nn.ReLU(True),
  8. torch.nn.Conv2d(3, 64, 3, padding=1),
  9. torch.nn.ReLU(True),
  10. torch.nn.Conv2d(64, 3, 3, padding=1),
  11. torch.nn.ReLU(True),
  12. )
  13. mybuffer = np.arange(1,10,1)
  14. self.mybuffer_tmp = np.randn((len(mybuffer), 1, 1, 10), dtype='float32')
  15. self.mybuffer_tmp = torch.from_numpy(self.mybuffer_tmp)
  16. # register preset variables as buffer
  17. # So that, in testing , we can use buffer variables.
  18. self.register_buffer('mybuffer', self.mybuffer_tmp)
  19. # Learnable weights
  20. self.conv_weights = nn.Parameter(torch.FloatTensor(64, 10).normal_(mean=0, std=0.01))
  21. # Other code
  22. def forward(self):
  23. ...
  24. # 这里使用 self.mybuffer!

注记:

1.定义parameter和buffer都只需要传入Tensor即可。也不需要将其转成gpu,这是因为,当网络进行.cuda时候,会自动将里面的层的参数,buffer等转换成相应的GPU上。

2. self.register_buffer可以将tensor注册成buffer,在forward中使用self.mybuffer,而不是self.mybuffer_tmp

3.网络存储时也会将buffer存下,当网络load模型时,会将存储的模型的buffer也进行赋值。

4.buffer的更新在forward中,optim.step只能更新nn.parameter类型的参数。

Measure
Measure
Summary | 3 Annotations
一种是模型中各种module含的参数
2021/02/05 07:53
另一种就是buffer
2021/02/05 07:53
前者每次optim.step会得到更新,而不会更新后者
2021/02/05 07:53