Naver Boostcamp

[Pytorch 기본]hook(pre forward hook, forward hook, backward hook)

HaneeOh 2023. 3. 15. 02:28

1. Hook

hook이란 패키지를 만드는 코드에서 중간에 원하는 코드를 삽입할 수 있는 기능이다. 순전파 이후에 모델의 가중치를 변경하거나, 파라미터 업데이트를 실시간으로 확인하는 등, 내 입맛대로 바꾸고 싶은 모델을 일부 변형하여 사용할 수 있다.

 

hook은 크게 Tensor에 적용하는 hookModule에 적용하는 hook으로 나눌 수 있다.

 

1-1. Tensor

Tensor는 forward hook이 없고 backward hook만 적용할 수 있다.

import torch

tensor = torch.rand(1, requires_grad=True)

def tensor_hook(grad):
    pass

tensor.register_hook(tensor_hook)

# 🦆 tensor는 backward hook만 있어요!
tensor._backward_hooks

 

1-2. Module

Module에 적용되는 hook은 forward_pre_hook, forward_hook, full_backward_hook가 있다. 그 밖에도 backward_hook과 state_dict_hook이 있지만 backward_hook은 사라진 기능이고 state_dict_hook은 사용자가 사용하는 것이 아닌 load_state_dict 함수가 사용하는 기능이다.

모델에 등록된 hook은 __dict__를 통해 확인할 수 있다.

 

전체적인 실행 순서가 잘 이해되지 않아 찾아봤는데

forward_pre_hook 👉forward👉forward_hook👉backward👉full_backward_hook

의 순서대로 진행이 된다.

 

hook의 실행 순서를 구체적으로 알고 싶다면 nn.Module의 소스코드를 읽어보면 된다.

https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

 

GitHub - pytorch/pytorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration

Tensors and Dynamic neural networks in Python with strong GPU acceleration - GitHub - pytorch/pytorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration

github.com

__call__을 통해 함수가 실행되기 때문에 _call_impl이라는 함수를 보면

코드 내에서 hook의 순서를 이해할 수 있다.

 

 

forward_pre_hook

forward 실행 전에 실행되는 hook이다.

import torch
from torch import nn


# Add 모델을 수정하지 마세요! 
class Add(nn.Module):
    def __init__(self):
        super().__init__() 

    def forward(self, x1, x2):
        output = torch.add(x1, x2)

        return output

# 모델 생성
add = Add()

# TODO: 답을 x1, x2, output 순서로 list에 차례차례 넣으세요! 
answer = []

# TODO : pre_hook를 이용해서 x1, x2 값을 알아내 answer에 저장하세요
def pre_hook(module, input):
    answer.append(input[0])
    answer.append(input[1])
    # return input[0], input[1]


# 아래 코드는 수정하실 필요가 없습니다!
x1 = torch.rand(1)
x2 = torch.rand(1)

add.register_forward_pre_hook(pre_hook)
output = add(x1, x2)

pre_hook(module, input)에서 input 값으로는 forward 함수로 받은 x1, x2가 튜플의 형태로 저장되어 있다. 여기서 반환값을 통해 모델의 input을 수정할 수 있다. pre_hook을 등록하는 방법은 모델 생성 이후 register 함수를 사용하면 된다.

 

 

forward_hook

forward 이후에 실행되는 hook이다.

import torch
from torch import nn

# Add 모델을 수정하지 마세요! 
class Add(nn.Module):
    def __init__(self):
        super().__init__() 

    def forward(self, x1, x2):
        output = torch.add(x1, x2)
        return output

# 모델 생성
add = Add()


# TODO : hook를 이용해서 전파되는 output 값에 5를 더해보세요!
def hook(module, input, output):
  return output + 5

add.register_forward_hook(hook)


# 아래 코드는 수정하실 필요가 없습니다!
x1 = torch.rand(1)
x2 = torch.rand(1)

output = add(x1, x2)

hook(module, input, output)에서 input으로는 forward에서 받은 인자(x1, x2)를 튜플 형태로 가지고 있고, output으로는 순전파의 결과값을 가지고 있다. 모델 내에서는 우선 forward를 실행해서 결과값을 저장한 뒤, forward_hook의 반환값이 있으면 forward_hook의 반환값으로 결과를 수정한다. 그래서 hook에서 input을 수정한다해도 forward에 적용되지 않는다. 반환값으로는 바꾸고자 하는 순전파의 결과값을 주면 된다.

해당 내용을 _call_impl에서 확인할 수 있다.

 

full_backward_hook

input에 대한 gradient가 계산될 때마다 호출되는 hook이다.

import torch
from torch import nn
from torch.nn.parameter import Parameter


# Model 모델을 수정하지 마세요! 
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = Parameter(torch.Tensor([5]))

    def forward(self, x1, x2):
        output = x1 * x2
        output = output * self.W

        return output

# 모델 생성
model = Model()


# TODO: 답을 x1.grad, x2.grad, output.grad 순서로 list에 차례차례 넣으세요! 
answer = []

# TODO : hook를 이용해서 x1.grad, x2.grad, output.grad 값을 알아내 answer에 저장하세요
def module_hook(module, grad_input, grad_output):
    answer.append(grad_input[0])
    answer.append(grad_input[1])
    answer.append(grad_output[0])

model.register_full_backward_hook(module_hook)
# 아래 코드는 수정하실 필요가 없습니다!
x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)

output = model(x1, x2)
output.retain_grad()
output.backward()

module_hook(module, grad_input, grad_output)에서 grad_input은 forward에서 받은 인자(x1, x2)에 대한 gradient을 튜플의 형태로 저장하고 있다. grad_output은 순전파 output에 대한 gradient를 담고 있으므로 1이 되며 튜플로 저장되어 있다. grad_output을 수정할 수 없으며, input과 output도 수정할 수 없다. 반환값으로는 바꾸고자 하는 grad_input을 줄 수 있다.

 

즉 새로운 gradient를 반환값으로 줌으로써 grad_input 대신 활용할 수 있다. 그러나 grad_input의 의미가 왜곡되기 때문에 디버깅 이외의 상황에서는 권장하지 않는다.

import torch
from torch import nn
from torch.nn.parameter import Parameter

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = Parameter(torch.Tensor([5]))

    def forward(self, x1, x2):
        output = x1 * x2
        output = output * self.W

        return output

# 모델 생성
model = Model()


# TODO : hook를 이용해서 module의 gradient 출력의 합이 1이 되도록 하세요!
#        ex) (1.5, 0.5) -> (0.75, 0.25)
def module_hook(module, grad_input, grad_output):
  grad_input = tuple(map(lambda x: x/sum(grad_input), grad_input))
  return grad_input

model.register_full_backward_hook(module_hook)


# 아래 코드는 수정하실 필요가 없습니다!
x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)

output = model(x1, x2)
output.backward()

# x1.grad + x2.grad == 1

full_backward_hook을 통해 합이 1로 표준화된 grad_input을 반환해줌으로써 x1과 x2의 gradient의 합이 1이 되었다.

 

full_backward_hook에서는 input(forward의 인자)의 gradient와 output의 gradient 값만 알 수 있어서 모델 내부 Parameter의 gradient 값은 알 수 없다. 만약 모델 내부 Parameter의 gradient를 알고 싶다면 텐서에 사용되는 hook을 사용하면 된다.

import torch
from torch import nn
from torch.nn.parameter import Parameter

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = Parameter(torch.Tensor([5]))

    def forward(self, x1, x2):
        output = x1 * x2
        output = output * self.W

        return output

# 모델 생성
model = Model()

# TODO : hook를 이용해서 W의 gradient 값을 알아내 answer에 저장하세요
def tensor_hook(grad):
    print(grad)
model.W.register_hook(tensor_hook)

# 아래 코드는 수정하실 필요가 없습니다!
x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)

output = model(x1, x2)
output.backward()

'Naver Boostcamp' 카테고리의 다른 글

[Pytorch 구조 학습하기]Dataset과 DataLoader  (0) 2023.03.21
[Pytorch 기본] apply  (1) 2023.03.15
[Pytorch 기본]nn.Module  (0) 2023.03.14
[AI Math] 경사하강법  (0) 2023.03.12
[AI Math]행렬  (0) 2023.03.12