Deep Learning/PyTorch

Dataset의 기본 구성 요소

green_ne 2022. 1. 28. 11:32

Dataset은 DataLoader의 대상이 되는 data 인수에 해당한다.

 

PyTorch에서는 2가지 dataset들을 지원한다.

- Map-style datasets = Dataset 클래스

- Iterable-style datasets = IterableDataset 클래스

 

# Custome Dataset 기본 뼈대

Custome Dataset을 만들기 위해서는 torch.utils.data.Dataset 클래스를 상속해서 만든다.

from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, text, labels):
            # 데이터 위치, 파일명 저장
            # 데이터 load
            # 데이터 처리할 transforms들 정의
            self.data = text

    def __len__(self):
            # 최대 elements 수 반환
            return len(self.labels)
            
    def __getitem__(self, idx):
            # dataset에서 idx번째 데이터 반환
            # 데이터 전처리 등도 여기서 처리
            return self.data[idx]

 

모든 작업을 생성시점에 처리할 필요는 없으며, getitem이 가장 중요한 부분으로 최적화도 고려해야 한다.

이렇듯 직접 만든 Dataset에 대한 표준화된 처리방법이 필요한데, 이는 HuggingFace, FastAPI 등을 참고하면 된다.

 

 

 

참고 : https://pytorch.org/docs/stable/data.html#dataset-types

반응형