Deep Learning/PyTorch
Dataset의 기본 구성 요소
green_ne
2022. 1. 28. 11:32
Dataset은 DataLoader의 대상이 되는 data 인수에 해당한다.
PyTorch에서는 2가지 dataset들을 지원한다.
- Map-style datasets = Dataset 클래스
- Iterable-style datasets = IterableDataset 클래스
# Custome Dataset 기본 뼈대
Custome Dataset을 만들기 위해서는 torch.utils.data.Dataset 클래스를 상속해서 만든다.
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, text, labels):
# 데이터 위치, 파일명 저장
# 데이터 load
# 데이터 처리할 transforms들 정의
self.data = text
def __len__(self):
# 최대 elements 수 반환
return len(self.labels)
def __getitem__(self, idx):
# dataset에서 idx번째 데이터 반환
# 데이터 전처리 등도 여기서 처리
return self.data[idx]
모든 작업을 생성시점에 처리할 필요는 없으며, getitem이 가장 중요한 부분으로 최적화도 고려해야 한다.
이렇듯 직접 만든 Dataset에 대한 표준화된 처리방법이 필요한데, 이는 HuggingFace, FastAPI 등을 참고하면 된다.
참고 : https://pytorch.org/docs/stable/data.html#dataset-types
반응형