요약
register_buffer는 모델의 상태(state)로서 관리하고 싶은 텐서를 등록하는 데 사용됩니다. 즉, 이 메서드는 state_dict에 포함되어서, torch.nn.Module.state_dict()에 함께 저장되어, torch.save을 할 때, 함께 저장됩니다. 또한, register_buffer으로 등록된 텐서는 기본적으로 기울기를 계산하지 않습니다.
기능 1. state_dict을 통해 모델을 저장/로드 할 때, 함께 포함되도록
torch.nn.Module로 딥러닝 네트워크를 구성하고, 필요한 텐서(non-trainable)도 함께 저장이 가능합니다.
아래의 예시를 살펴보겠습니다.
- 9번줄: self.register_buffer("running_mean", torch.zeros(10))으로 텐서를 하나 저장합니다. 이렇게되면 self. running_mean에 속성으로도 저장됩니다.
- 21번줄: torch.save(model.state_dict(), "model_with_buffer.pth")에서 model.state_dict()을 이용해서 state_dict을 저장합니다. 이 때, register_buffer으로 등록한 running_mean= torch.zeros(10)도 함께 state_dict에 저장됩니다.
- 24번줄: register_buffer을 사용하지않고 저장하려면, state_dict의 딕셔너리에 key-value을 별도로 이렇게 저장해줘야합니다.
기능 2. non-trainable parameter을 저장하는 경우
배치 정규화에서 배치 단위의 평균(mean)과 분산(var)은 통계량값만 저장하고, gradient로는 사용되지 않습니다. 이 때도 사용이 가능합니다.
반응형
'Data science > Python' 카테고리의 다른 글
Coroutine: python 정리 (1) | 2024.07.17 |
---|---|
numpy array의 stride란? (0) | 2024.07.10 |
python with 구문 & context manager (1) | 2024.06.19 |
python 매직메서드 (__repr__, __str__, __slots__) (0) | 2024.06.03 |
파이썬 바이트 표현 (bytes, bytearray) (0) | 2024.05.07 |