Backpropagation(역전파)
딥러닝을 공부하다보면 필수적으로 공부해야 하는 이론 중 하나가 backpropagation 입니다. Backpropagation 에 의해 모델이 학습이 된다는 것은 잘 알고 있을 것입니다. 이번 포스팅에서는 Backpropagation 을 다뤄보겠습니다.
1. 개요
딥 러닝의 뉴럴 네트워크를 훈련은 두 단계로 구성됩니다.
- Loss Funtion(손실함수) 를 계산하기 위한 순방향 Pass
- 학습 가능한 Parameter(매개변수) 를 계산하기 위한 역방향 Pass
여기서, Loss Funtion 을 최소화 하기 위해 역방향으로 손실 함수의 값을 전달하여 가중치 $w$ 를 갱신하는 것을Backpropagation 이라 합니다.
간단한 네트워크를 통해 Backpropagation 을 이해해 봅시다.

위 네트워크는 다음과 같은 식을 가집니다.
$$ b=w_{1}*a $$
$$ c=w_{2}*a $$
$$ d=w_{3}*b+w_{4}*c $$
$$ L= 10-d $$
정리하면,
$$ L = w_{3} * w_{1} * a + w_{4} * w_{2} * a $$
가 됩니다.
실제로도 이런 형태의 네트워크를 가진 모델이 많은데, 비슷한 예로 ResNet 을 들 수 있습니다. a -> c -> d 로 가는 부분이 Skip Connection 으로 볼 수 있습니다.
그럼 본격적으로 살펴봅시다.
2. 간단한 Backpropagation
2.1 Bias 가 없는 Network
우선 a -> b 부분에서 forward pass 와 backpropagation 이 어떻게 계산되는지 살펴봅시다.

이 부분을 식으로 표기하면 $$ b = w_{1} * a $$ 입니다. 만약 $w_{1}$ 이 2 라면, $ b = 2 * a $ 로 나타낼 수 있습니다.
만약 a == 1 이라면, b == 2 일 것 입니다. 특정 input value 를 위 식에 집어넣어서 output 을 얻는 과정이 Forwad Pass 라고 할 수 있습니다. 그런데 정답이 b == 3 이어야 된다고 합시다. 어떻게 $w_{1}$ 을 갱신해야 할까요?
2.1 Gradient Descent (경사하강법)
지금처럼 변경해야할 $w$ 가 위 네트워크 처럼 4개 밖에 없다면 간단한 것 같습니다. 그런데 BERT 등 LLM 등 초 거대 모델 같은 경우에는 수백만개의 $w$ 가 존재합니다. 이 경우에는 gradient-descent(경사 하강법)을 통해 $w$ 를 갱신합니다.
$$ w \leftarrow w - \eta \frac{\partial\,\textrm{loss}}{\partial\,w} $$
여기서 $\eta$ 는 0.01 등의 아주 작은 값을 가지고, 왜 작은 값을 다루는지는 추후에 다루겠습니다.
2.3 Backpropagation 계산 후 $w_{1}$ 갱신 계산
Gradient Desecnt 방법을 통해 $w_{1}$ 을 갱신해 봅시다. $Loss$ 를 간단하게 $Loss = (b(계산값) - \hat{b}(정답))^2$ 라고 해 봅시다. 이는 Mean Squared Error(MSE), '평균제곱오차' 라고 하는 Loss Function(손실함수) 중 하나입니다.
$ \frac{\partial\,\textrm{loss}}{\partial\,w} $ 은 다음과 같이 계산할 수 있습니다.
$$\frac{\partial\,\textrm{loss}}{\partial\,w} = \frac{\partial}{\partial\,w} (b - \hat{b})^{2}$$
여기서 $\hat{b} = w_{1} * a$ 이므로
$$\frac{\partial\,\textrm{loss}}{\partial\,w} = \frac{\partial}{\partial\,w} (b - w_{1} * a)^{2}$$
$$ = -2 * (b - w_{1} * a) * a $$
가 됩니다. 이를 gradient descnet 식에 집어넣으면
$$w_{1}(new) = w_{1} - 0.01 * -2 * (b - w_{1} * a) * a $$
a = 1, b = 3, $w_{1}$ = 2 를 집어넣으면
$$ = 2 - 0.01 * -2 * (3 - 2 * 1) * 1 = 2.02 $$
로 $w_{1}$ 값이 2.02 로 갱신시킬 수 있는 것을 확인할 수 있습니다. 이 값을 이용해 b 를 다시 계산한다면 b = 2.02 가 되어 처음 계산한 b = 2 보다 정답값 b = 3 에 더 가까워 진 것을 알 수 있습니다.
2.3 Bias 가 있는 Network
방금 전의 네트워크에서 bias 를 추가하여 $ b = w_{1} * a + bias_{1} $ 인 경우를 생각해보겠습니다.
이 경우, 경사하강법을 사용한다면 위와 마찬가지로 weight $w_{1}$ 과 bias $bias_{1}$ 을 갱신합니다.
$$ w \leftarrow w - \eta \frac{\partial\,\textrm{loss}}{\partial\,w} $$
$$ bias \leftarrow bias - \eta \frac{\partial\,\textrm{loss}}{\partial\,bias} $$
$w_{1}$ 을 갱신하는 것은 위에서 계산하였으니, $bias_{1}$ 을 갱신하는 것을 계산해보겠습니다.
$$\frac{\partial\,\textrm{loss}}{\partial\,bias} = \frac{\partial}{\partial\,bias} (b - \hat{b})^{2}$$
$$= \frac{\partial}{\partial\,bias} (b - w_{1} * a - bias)^{2}$$
$$= -2 * (b - w_{1} * a - bias)$$
가 됩니다. 이 식을 통하여 bias 를 갱신할 수 있습니다.
3. 2개 이상의 Layer 를 가질 시 Backpropagation 계산
Backpropagation 이 여러개의 Layer 를 거칠 경우, 미분의 연쇄법칙을 이용하여 해당 weight $w_{i}$ 를 갱신할 수 있습니다. a -> b -> d 로 가는 네트워크를 예로 들겠습니다.

위 네트워크가 가진 weight 를 식으로 표현하면 다음과 같을 것입니다.
$$d = w_{3} * b + bias_{3}$$
$$b = w_{1} * a + bias_{1}$$
우선, $w_{3}$ 은 간단하게 계산을 하면 다음과 같이 계산할 수 있습니다.
$$\frac{\partial\,\textrm{loss}}{\partial\,w_{3}}=\frac{\partial}{\partial\,w_{3}}(d-\hat{d})^2=-2*(d-w_{3}*b-bias_{3})*b$$
$w_{1}$ 은 미분의 연쇄법칙을 통하여 다음 식과 같이 계산을 할 수 있습니다.
$$\frac{\partial\,\textrm{loss}}{\partial\,w_{1}} = \frac{\partial\,\textrm{loss}}{\partial\,b} * \frac{\partial\,b}{\partial\,w_{1}}$$
우선 $\frac{\partial\,\textrm{loss}}{\partial\,b}$를 계산하면
$$\frac{\partial\,\textrm{loss}}{\partial\,b}=\frac{\partial}{\partial\,b}(d-\hat{d})^{2}=(d-(w_{3}*b+bias_{3}))^{2}$$
$$=-2*(d-w_{3}*b-bias_{3})*w_{3}$$
$\frac{\partial\,b}{\partial\,w_{1}}$를 계산하면
$$\frac{\partial\,b}{\partial\,w_{1}}=\frac{\partial}{\partial\,w_{1}}(w_{1}*a+bias_{1})=a$$
종합하면 다음과 같습니다.
$$\frac{\partial\,\textrm{loss}}{\partial\,b} * \frac{\partial\,b}{\partial\,w_{1}}=-2*(d-w_{3}*b-bias_{3})*w_{3}*a$$
이를 통해 레이어가 3개 이상인 네트워크에서도 Backpropagation 을 통해 weight 를 갱신할 수 있습니다.
4. 전체 네트워크
다시 전체 네트워크 이미지를 봅시다.

전체 네트워크를 보면, 경로가 합쳐지거나 나눠지는 경우가 있습니다. 예를 들어 d 를 보면
$$ d=w_{3}*b+w_{4}*c $$
으로 계산이 됩니다. 그림으로 표현하면 다음과 같습니다.

이 경우 그래디언트 계산은 다음과 같습니다.

전체 네트워크에 대한 gradient 계산하면 다음과 같이 나타낼 수 있습니다.

각 w 에 대한 gradient 값을 계산해보면 다음과 같을 것입니다.
$$\frac{\partial\,\textrm{loss}}{\partial\,w_{4}}=\frac{\partial\,\textrm{loss}}{\partial\,d}*\frac{\partial\,d}{\partial\,w_{4}}$$
$$\frac{\partial\,\textrm{loss}}{\partial\,w_{3}}=\frac{\partial\,\textrm{loss}}{\partial\,b}*\frac{\partial\,d}{\partial\,w_{3}}$$
$$\frac{\partial\,\textrm{loss}}{\partial\,w_{2}}=\frac{\partial\,\textrm{loss}}{\partial\,d}*\frac{\partial\,d}{\partial\,c}*\frac{\partial\,c}{\partial\,w_{2}}$$
$$\frac{\partial\,\textrm{loss}}{\partial\,w_{1}}=\frac{\partial\,\textrm{loss}}{\partial\,d}*\frac{\partial\,d}{\partial\,b}*\frac{\partial\,b}{\partial\,w_{1}}$$
'DeepLeaning > 기초' 카테고리의 다른 글
| [정보이론] 정보와 Entorpy (0) | 2023.06.14 |
|---|



