Deep Learning/PyTorch

Module 클래스

green_ne 2022. 1. 28. 12:16

모든 신경말 모델들의 base 클래스로서, 여러 기능들을 한 곳에 모아놓는 상자 역할을 한다.

또한 이 클래스는 다른 nn.Module도 포함하여 트리 구조로 중첩할 수도 있다.

 

따라서 nn.Module 클래스는 빈 상자일 뿐 어떻게 설계할지는 사용자의 몫이다.

예를 들면, 기본적으로 다음과 같이 구성할 수 있다.

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

 

 

여기서 __init__에서 super를 호출하는 것을 볼 수 있는데, 이를 수행하는 이유는 다음과 같다.

 

기본적으로 nn.Module은 Parameter나 Buffer, Hook과 같은 여러 기능을 제공한다.

작성한 모델에서도 위 기능을 지원하기 위해 init에서 nn.Module 클래스 자체를 초기화해야 한다.

Python에서는 super클래스 생성자 및 초기자는 자동으로 호출되지 않는다. 따라서 명시적으로 호출되어야 하는데, 이를 super가 알아내어 해준다. 만약 Python3을 사용하고 있다면, super 호출에서 인수를 적지 않아도 알아서 호출할 수 있다.

 

 

참고 : https://stackoverflow.com/questions/63058355/why-is-the-super-constructor-necessary-in-pytorch-custom-modules

 

 

 

반응형