Naver Boostcamp

[Pytorch 활용하기] Multi-GPU 학습

HaneeOh 2023. 4. 9. 23:15

오늘날 딥러닝은 엄청난 크기의 데이터를 다루고 있다.

만약 데이터를 학습시킬 때 하나의 GPU가 아닌 여러 개의 GPU를 활용한다면, 큰 성능적인 이점을 얻을 수 있을 것이다.

 

멀티 GPU에 학습을 분산하는 방법에는 두 가지가 있다.

  • 모델을 나누는 방법
  • 데이터를 나누는 방법

 

모델을 나누는 방법은 의외로 꽤 예전부터 활용해왔던 기법이다. CNN의 초기 모델인 AlexNet에서 사용되기도 했다.

하지만 모델의 병목, 파이프라인의 어려움 등으로 인해 모델 병렬화는 고난이도 과제에 속한다.

 

Model Parallel

class ModelParallelResNet50(ResNet):
    def __init__(self, *args, **kwargs):
        super(ModelParallelResNet50, self).__init__(
            Bottleneck, [3, 4, 6, 3], num_classes=num_classes, *args, **kwargs)
        
        # 첫번째 모델을 cuda 0에 할당
        self.seq1 = nn.Sequential(
            self.conv1, self.bn1, self.relu, self.maxpool, self.layer1, self.layer2
        ).to('cuda:0')
        
        # 두번째 모델을 cuda 1에 할당
        self.seq2 = nn.Sequential(
            self.layer3, self.layer4, self.avgpool
        ).to('cuda:1')
        
        self.fc.to('cuda:1')

# 두 모델을 연결하기
    def forward(self, x):
        x = self.seq2(self.seq1(x).to('cuda:1'))
        return self.fc(x.view(x.size(0), -1))

 


Data Parallel

  • 데이터를 나눠 GPU에 할당 후 결과의 평균을 취하는 방법
  • minibatch 수식과 유사한데 한번에 여러 GPU에서 수행

 

Pytorch에서는 아래 두 가지 방식을 제공한다.

Data Parallel

  • 단순히 데이터를 분배한 후 평균을 취한다.
  • → GPU 사용 불균형 문제 발생, Batch 사이즈 감소 (한 GPU가 병목), GIL
# 모델을 DataParallel로 래핑합니다.
parallel_model = torch.nn.DataParallel(model)

# 멀티 GPU에서 순전파를 수행합니다.
predictions = parallel_model(inputs)

# 손실 함수를 계산합니다.
loss = loss_function(predictions, labels)

# 역전파를 수행합니다.
loss.mean().backward()

# GPU 손실 평균값을 계산하고 역전파를 수행합니다.
optimizer.step()

# 새로운 매개변수로 순전파를 수행합니다.
predictions = parallel_model(inputs)

 

DistributedDataParallel

  • 각 CPU마다 process 생성하여 개별 GPU에 할당
  • → 기본적으로 DataParallel로 하나 개별적으로 연산의 평균을 냄
# DistributedSampler를 사용하여 데이터를 샘플링합니다.
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)

# DataLoader를 생성합니다.
trainloader = torch.utils.data.DataLoader(train_data, batch_size=20, shuffle=True,
                                          num_workers=3, pin_memory=True,
                                          sampler=train_sampler)
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from multiprocessing import Pool

def main():
    n_gpus = torch.cuda.device_count()
    torch.multiprocessing.spawn(main_worker, nprocs=n_gpus, args=(n_gpus,))

def main_worker(gpu, n_gpus):
    image_size = 224
    batch_size = 512
    num_worker = 8
    epochs = ...
    batch_size = int(batch_size / n_gpus)
    num_worker = int(num_worker / n_gpus)
    dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:2568',
                            world_size=n_gpus, rank=gpu) # 멀티프로세싱 통신 규약 정의
    model = MODEL
    torch.cuda.set_device(gpu)
    model = model.cuda(gpu)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])

def f(x):
    return x*x

if __name__ == '__main__':
    with Pool(5) as p:
        print(p.map(f, [1, 2, 3]))