요약


`register_buffer`는 모델의 상태(state)로서 관리하고 싶은 텐서를 등록하는 데 사용됩니다. 즉, 이 메서드는 `state_dict`에 포함되어서, `torch.nn.Module.state_dict()`에 함께 저장되어, torch.save을 할 때, 함께 저장됩니다. 또한, ` register_buffer`으로 등록된 텐서는 기본적으로 기울기를 계산하지 않습니다.

 

기능 1. state_dict을 통해 모델을 저장/로드 할 때, 함께 포함되도록

`torch.nn.Module`로 딥러닝 네트워크를 구성하고, 필요한 텐서(non-trainable)도 함께 저장이 가능합니다.

아래의 예시를 살펴보겠습니다.

  • 9번줄: ` self.register_buffer("running_mean", torch.zeros(10))`으로 텐서를 하나 저장합니다. 이렇게되면 `self. running_mean`에 속성으로도 저장됩니다.
  • 21번줄: `torch.save(model.state_dict(), "model_with_buffer.pth")`에서 ` model.state_dict()`을 이용해서 state_dict을 저장합니다. 이 때, `. register_buffer`으로 등록한 `running_mean= torch.zeros(10)`도 함께 state_dict에 저장됩니다.
  • 24번줄: register_buffer을 사용하지않고 저장하려면, state_dict의 딕셔너리에 key-value을 별도로 이렇게 저장해줘야합니다.

 

기능 2. non-trainable parameter을 저장하는 경우

배치 정규화에서 배치 단위의 평균(mean)과 분산(var)은 통계량값만 저장하고, gradient로는 사용되지 않습니다. 이 때도 사용이 가능합니다.

 

반응형

+ Recent posts