Naver Boostcamp

[Pytorch 기본] apply

HaneeOh 2023. 3. 15. 11:51

Apply

apply는 모델 내에 이미 구현되어 있는 method가 아닌

사용자가 커스텀한 함수를 모델의 submodule에 전체적으로 적용하고 싶을 때 사용하는 기능이다.

모델 내부의 가중치값을 임의로 변경하거나 전체 submodule을 출력하는 등 다양하게 활용할 수 있다.

pytorch 공식 문서에도 일반적으로 모델의 파라미터를 (사용자 임의대로) initializing할 때 사용한다고 적혀있다.

 

 

사용 방법은 예제를 보면 알 수 있다.

>>> @torch.no_grad()
>>> def init_weights(m):
>>>     print(m)
>>>     if type(m) == nn.Linear:
>>>         m.weight.fill_(1.0)
>>>         print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)

Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)

nn.Sequential로 묶어준 모듈에 apply를 적용했다.

apply는 인자로 모듈을 받아서 nn.Linear인 경우 모듈 내부의 가중치를 전부 1로 바꿔준다.

apply는 내부의 코드를 실행한 이후 실행 결과가 적용된 모듈을 반환한다.

 

 

 

Module 출력

모듈 내의 submodule을 순서대로 출력해보는 코드이다.

import torch
from torch import nn
from torch.nn.parameter import Parameter


# 아래 코드는 수정하실 필요가 없습니다!
# 하지만 아래 과제를 진행하기 전에 아래 코드를 보면서 최대한 이해해보세요!

# Function
class Function_A(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(1))

    def forward(self, x):
        return x + self.W

class Function_B(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(1))

    def forward(self, x):
        return x - self.W

class Function_C(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(1))

    def forward(self, x):
        return x * self.W

class Function_D(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(1))

    def forward(self, x):
        return x / self.W


# Layer
class Layer_AB(nn.Module):
    def __init__(self):
        super().__init__()

        self.a = Function_A('plus')
        self.b = Function_B('substract')

    def forward(self, x):
        x = self.a(x)
        x = self.b(x)

        return x

class Layer_CD(nn.Module):
    def __init__(self):
        super().__init__()

        self.c = Function_C('multiply')
        self.d = Function_D('divide')

    def forward(self, x):
        x = self.c(x)
        x = self.d(x)

        return x


# Model
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.ab = Layer_AB()
        self.cd = Layer_CD()

    def forward(self, x):
        x = self.ab(x)
        x = self.cd(x)

        return x


model = Model()

Layer_AB는 Function A와 Function B를,

Layer_CD는 Function C와 Function D를,

마지막으로 Model은 Layer_AB와 Layer_CD를 참조하고 있다.

 

def print_module(module):
    print(module)
    print("-" * 30)

# 🦆 apply는 apply가 적용된 module을 return 해줘요!
returned_module = model.apply(print_module)

이를 print_module 함수를 통해 apply해보면 출력은 다음과 같다.

Function_A()
------------------------------
Function_B()
------------------------------
Layer_AB(
  (a): Function_A()
  (b): Function_B()
)
------------------------------
Function_C()
------------------------------
Function_D()
------------------------------
Layer_CD(
  (c): Function_C()
  (d): Function_D()
)
------------------------------
Model(
  (ab): Layer_AB(
    (a): Function_A()
    (b): Function_B()
  )
  (cd): Layer_CD(
    (c): Function_C()
    (d): Function_D()
  )
)
------------------------------

apply는 Postorder Traversal의 방식으로 출력되는 것을 확인할 수 있다.

즉 모듈의 가장 밑단의 submodel부터 시작해서 그 부모 모델부터 차례대로 출력된다.

tree 구조로 생각하면 leaf node부터 출력되는 것과 동일하다.

 

 

 

가중치 초기화(Weight Initialization)

apply를 활용해 모듈 내부의 Parameter값을 변경할 수 있다.

 

import torch
from torch import nn
from torch.nn.parameter import Parameter


# 아래 코드는 수정하실 필요가 없습니다!
# 실행만 시켜주시고 다음 셀로 넘어가주세요!

# Function
class Function_A(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(1))

    def forward(self, x):
        return x + self.W

class Function_B(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(1))

    def forward(self, x):
        return x - self.W

class Function_C(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(1))

    def forward(self, x):
        return x * self.W

class Function_D(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(1))

    def forward(self, x):
        return x / self.W


# Layer
class Layer_AB(nn.Module):
    def __init__(self):
        super().__init__()

        self.a = Function_A('plus')
        self.b = Function_B('substract')

    def forward(self, x):
        x = self.a(x)
        x = self.b(x)

        return x

class Layer_CD(nn.Module):
    def __init__(self):
        super().__init__()

        self.c = Function_C('multiply')
        self.d = Function_D('divide')

    def forward(self, x):
        x = self.c(x)
        x = self.d(x)

        return x


# Model
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.ab = Layer_AB()
        self.cd = Layer_CD()

    def forward(self, x):
        x = self.ab(x)
        x = self.cd(x)

        return x

각 Function A, B, C, D는 하나짜리 텐서를 parameter로 가지고 있고, Layer_AB와 Layer_CD는 각 모듈을 참조하고 있으며,

최종적으로는 모듈이 Layer_AB와 Layer_CD를 참조하는 구조이다.

submodule을 가지는 module은 submodule의 parameter를 가지게 된다.

 

model = Model()

# TODO : apply를 이용해 모든 Parameter 값을 1로 만들어보세요!
def weight_initialization(module):
    module_name = module.__class__.__name__

    for param in module.parameters():
      param.data = torch.ones_like(param)


# 🦆 apply는 apply가 적용된 module을 return 해줘요!
returned_module = model.apply(weight_initialization)

# 아래 코드는 수정하실 필요가 없습니다!
x = torch.rand(1)

output = model(x)

weight_initialization은 모듈 내의 모든 submodule을 불러와서

각 submodule의 파라미터를 1로 바꿔주는 함수이다.

weight_initialization의 내부를 출력해 보면 다음과 같다.

 

Function_A
Parameter containing:
tensor([0.3533], requires_grad=True)

Function_B
Parameter containing:
tensor([0.1665], requires_grad=True)

Layer_AB
Parameter containing:
tensor([1.], requires_grad=True)
Parameter containing:
tensor([1.], requires_grad=True)

Function_C
Parameter containing:
tensor([0.6069], requires_grad=True)

Function_D
Parameter containing:
tensor([0.9736], requires_grad=True)

Layer_CD
Parameter containing:
tensor([1.], requires_grad=True)
Parameter containing:
tensor([1.], requires_grad=True)

Model
Parameter containing:
tensor([1.], requires_grad=True)
Parameter containing:
tensor([1.], requires_grad=True)
Parameter containing:
tensor([1.], requires_grad=True)
Parameter containing:
tensor([1.], requires_grad=True)

Function A, B의 파라미터가 1로 먼저 변경되어

상위 모듈인 Layer_AB는 1을 파라미터로 가지는 것을 확인할 수 있다.

 

이처럼 apply를 적용하면 모듈 내의 가중치값을 일괄적으로 변경할 수 있다.

(조건문을 걸어 특정 모듈의 가중치만 바꾸어 줄 수도 있음)

 

 

Function 수정하기

현재 4개의 Function A, B, C, D가 있다.

- A : x + W
- B : x - W
- C : x * W
- D : x / W

이걸 다음처럼 linear transformation처럼 동작하도록 바꿔보자.

- A : x @ W + b
- B : x @ W + b
- C : x @ W + b
- D : x @ W + b

 

#@title 부덕이 모델
import torch
from torch import nn
from torch.nn.parameter import Parameter


# 아래 코드는 수정하실 필요가 없습니다!
# 실행만 시켜주시고 다음 셀로 넘어가주세요!

# Function
class Function_A(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(2, 2))

    def forward(self, x):
        return x + self.W

class Function_B(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(2, 2))

    def forward(self, x):
        return x - self.W

class Function_C(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(2, 2))

    def forward(self, x):
        return x * self.W

class Function_D(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(2, 2))

    def forward(self, x):
        return x / self.W


# Layer
class Layer_AB(nn.Module):
    def __init__(self):
        super().__init__()

        self.a = Function_A('plus')
        self.b = Function_B('substract')

    def forward(self, x):
        x = self.a(x)
        x = self.b(x)

        return x

class Layer_CD(nn.Module):
    def __init__(self):
        super().__init__()

        self.c = Function_C('multiply')
        self.d = Function_D('divide')

    def forward(self, x):
        x = self.c(x)
        x = self.d(x)

        return x


# Model
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.ab = Layer_AB()
        self.cd = Layer_CD()

    def forward(self, x):
        x = self.ab(x)
        x = self.cd(x)

        return x

원형 모듈이다. 위의 제시문처럼 linear transformation으로 바꿔주려면 어떻게 해야할까?

 

 

 

답은 apply와 hook을 함께 적용하는 것이다.

model = Model()

from functools import partial

# Parameter b 추가
def add_bias(module):
    module_name = module.__class__.__name__
    if module_name.split('_')[0] == "Function":
      module.b = Parameter(torch.rand(2,1))

# 1로 초기화
def weight_initialization(module):
    module_name = module.__class__.__name__
    # add_bias(module)
    if module_name.split('_')[0] == "Function":
        module.W.data.fill_(1.0)
        module.b.data.fill_(1.0)


# apply를 이용해 모든 Function을 linear transformation으로 바꾸자 (X @ W + b)
def hook(module, input, output):
    module_name = module.__class__.__name__  
    output = input[0] @ module.W.T
    output = torch.add(output, module.b)
    return output


def linear_transformation(module):
    module_name = module.__class__.__name__
    if module_name.split('_')[0] == "Function":
        module.register_forward_hook(hook)

returned_module = model.apply(add_bias)
returned_module = model.apply(weight_initialization)
returned_module = model.apply(linear_transformation)

add_bias 함수를 apply하여 편향 b를 인스턴스 변수로 추가해준다.

그리고 weight_initialization 함수를 apply하여 인스턴스 변수로 저장되어 있는 가중치의 element를 전부 1로 바꿔준다.

 

마지막으로... hook을 적용해야 하는데

linear_transformation 함수에서 모듈의 이름을 확인해서 Function으로 시작할 경우(Function A, B, C, D) forward_hook을 적용한다.

 

이때 hook의 작동 순서가 헷갈려서 만약 원래 forward 함수가 실행이 된 뒤 hook이 적용이 되는 거라면

제시문에서 요구하는 x @ W + b의 식이 구현되는 것이 아니라 (x + W) @ W + b가 작동되는 것이 아닌지에 대해 궁금증이 들었다.

결론은 forward 실행 이후 forward_hook이 실행되긴 하지만, forward_hook이 있는 경우 순전파의 결과값이 forward_hook의 결과값으로 변경되어

최종적으로는  x @ W + b의 식을 구현할 수 있게 된다. 

 

forward_hook의 실행 과정에 대해서는 자세히 정리해 두었으니 참고!

https://ohsy0512.tistory.com/27