Deep Learning/PyTorch
Module의 apply() 메서드
green_ne
2022. 1. 28. 13:50
apply() 메서드는
해당 Module의 모든 sub-module에 인수받은 함수를 적용시켜준다.
예를 들면, 다음과 같이 Seqential 모듈의 가중치를 1로 초기화시킬 수 있다.
@torch.no_grad()
def init_weights(m):
print(m)
if type(m) == nn.Linear:
m.weight.fill_(1.0)
print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)
반응형