Module의 apply() 메서드

2022. 1. 28. 13:50Deep Learning/PyTorch

apply() 메서드는

해당 Module의 모든 sub-module에 인수받은 함수를 적용시켜준다.

 

_apply 함수 prototype

 

 

예를 들면, 다음과 같이 Seqential 모듈의 가중치를 1로 초기화시킬 수 있다.

@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.fill_(1.0)
        print(m.weight)
        
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)
반응형

'Deep Learning > PyTorch' 카테고리의 다른 글

AutoGrad 구조 및 이해  (0) 2022.02.02
PyTorch Containers  (0) 2022.01.28
Module 클래스  (0) 2022.01.28
PyTorch 란?  (0) 2022.01.28
Dataset의 기본 구성 요소  (0) 2022.01.28