Deep Learning/PyTorch
Module 클래스
green_ne
2022. 1. 28. 12:16
모든 신경말 모델들의 base 클래스로서, 여러 기능들을 한 곳에 모아놓는 상자 역할을 한다.
또한 이 클래스는 다른 nn.Module도 포함하여 트리 구조로 중첩할 수도 있다.
따라서 nn.Module 클래스는 빈 상자일 뿐 어떻게 설계할지는 사용자의 몫이다.
예를 들면, 기본적으로 다음과 같이 구성할 수 있다.
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
여기서 __init__에서 super를 호출하는 것을 볼 수 있는데, 이를 수행하는 이유는 다음과 같다.
기본적으로 nn.Module은 Parameter나 Buffer, Hook과 같은 여러 기능을 제공한다.
작성한 모델에서도 위 기능을 지원하기 위해 init에서 nn.Module 클래스 자체를 초기화해야 한다.
Python에서는 super클래스 생성자 및 초기자는 자동으로 호출되지 않는다. 따라서 명시적으로 호출되어야 하는데, 이를 super가 알아내어 해준다. 만약 Python3을 사용하고 있다면, super 호출에서 인수를 적지 않아도 알아서 호출할 수 있다.
반응형