model.save()
- 학습의 결과를 저장하기 위한 함수
- 모델 architecture와 파라미터를 저장
- 모델 학습 중간 과정을 저장하여 최선의 결과모델을 선택할 수 있음
- 만들어진 모델을 외부 연구자와 공유하여 학습 재연성 향상
모델의 parameter만 save & load
torch.save(model.state_dict(), os.path.join(MODEL_PATH, "model.pt"))
new_model = TheModelClass()
new_model.load_state_dict(torch.load(os.path.join(MODEL_PATH, "model.pt")))
모델 architecture와 parameter를 함께 save&load
torch.save(model, os.path.join(MODEL_PATH, "model.pt"))
model = torch.load(os.path.join(MODEL_PATH, "model.pt"))
Checkpoints
- 학습 중간의 결과를 저장하여 최선의 결과를 선택
- earlystopping 기법 사용시 이전 학습의 결과물을 저장
- 일반적으로 epoch, loss, metric을 함께 저장하여 확인
torch.save({
# 모델의 정보를 epoch과 함께 저장
'epoch': e,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': epoch_loss,
}, f"saved/checkpoint_model_{e}_{epoch_loss/len(dataloader)}_{epoch_acc/len(dataloader)}.pt")
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
Transfer Learning
- 다른 데이터셋으로 pretrained된 모델을 현재 데이터에 적용
- 현재의 DL에서는 가장 일반적인 학습 기법
- backbone architecture가 잘 학습된 모델에서 일부분만 변경하여 학습을 수행함
- Freezing: pretrained model을 활용시 모델의 일부 parameter를 frozen시킴
v99 = models.vgg16(pretrained=True).to(device)
class MyNewNet(nn.Module):
def __init__(self):
super(MyNewNet, self).__init__()
self.vgg19 = models.vgg19(pretrained=True) # vgg19의 backbone architecture 활용
self.linear_layers = nn.Linear(1000, 1) # 모델에 마지막 linear layer 추가
# Defining the forward pass
def forward(self, x):
x = self.vgg19(x)
return self.linear_layers(x)
for param in my_model.parameters():
param.requires_grad = False # frozen
for param in my_model.linear_layers.parameters():
param.requires_grad = True # 마지막 레이어는 autograd 활성화
'Naver Boostcamp' 카테고리의 다른 글
마스크 착용 상태 분류 대회 (Public 5위, Private 2위) (0) | 2023.04.24 |
---|---|
[Pytorch 활용하기] Multi-GPU 학습 (0) | 2023.04.09 |
[Pytorch 구조 학습하기]Dataset과 DataLoader (0) | 2023.03.21 |
[Pytorch 기본] apply (1) | 2023.03.15 |
[Pytorch 기본]hook(pre forward hook, forward hook, backward hook) (1) | 2023.03.15 |