Background
$W \in \mathbb{R}^{D_2 \times D_1}$, $X \in \mathbb{R}^{D_1 \times N}$ 에 대해 bias가 생략된 linear layer의 출력값 $Z=WX, Z \in \mathbb{R}^{D_2 \times N}$가 있다고 하자.
또한 $L : \mathbb{R}^{D_2 \times N} \rightarrow \mathbb{R}$은 최적화 하려는 object function이라고 하자.
Chain Rule 에 따라 gradient를 다음과 같이 구할수 있다.
$$
\frac{\partial{L}}{\partial{W}} = \frac{\partial{L}}{\partial{Z}} \frac{\partial{Z}}{\partial{W}}
$$
하만 이때 $ \frac{\partial{L}}{\partial{Z}} $는 행렬이지만 $ \frac{\partial{Z}}{\partial{W}} $은 4차원 텐서가 된다.
4차원 텐서를 직접적으로 계산할경우 공간적, 시간적 측면에서 매우 비효율적이다.
Methodology
$\frac{\partial{L}}{\partial{W_{i,j}}}$에 대해 생각해보자. 이는 $ \frac{\partial{L}}{\partial{Z}} $와 $ \frac{\partial{Z}}{\partial{W_{i,j}}} $의 내적(원소곱을 수행하고 각 원소를 다 더한것)으로 구할수 있다.
(두 행렬의 차원은 $Z$와 같다고 하자.)
$$
\frac{\partial{L}}{\partial{W_{i,j}}} = \frac{\partial{L}}{\partial{Z}} \cdot \frac{\partial{Z}}{\partial{W_{i,j}}}
$$
그럼 이제 $Z = WX$임을 고려하며 $\frac{\partial{Z}}{\partial{W_{i,j}}}$를 구해보자. (편의를 위해 이제부터 $ \frac{\partial{L}}{\partial{Z}} =T$로 치환하자.)
정의에 따르면 $Z$는 다음과 같이 구할수 있다.
$Z_{a,b} = \sum_{k=1}^{D_1} W_{a,k} X_{k,b}$
$Z_{a,b} \; \text{where} \; a \neq i$는 $W_{i,j}$를 포함하지 않으므로 ${\left ( \frac{\partial{Z}}{\partial{W_{i,j}}} \right ) }_{a,b} = 0 \; \text{where} \; a \neq i$이다.
$Z_{a,b} \; \text{where} \; a = i$의 경우에는 $W_{i,j}$가 $X_{j,b}$와 곱해지므로 ${\left ( \frac{\partial{Z}}{\partial{W_{i,j}}} \right ) }_{a,b} = X_{j,b} \; \text{where} \; a = i$이다.
즉, 다음과 같은 꼴이 나온다.
$$
\frac{\partial{Z}}{\partial{W_{i,j}}} =
\begin{bmatrix}
0 & 0 & ... & 0 \\
0 & 0 & ... & 0 \\
X_{j,1} & X_{j,2} & ... & X_{j,N} \\
0 & 0 & ... & 0 \\
\end{bmatrix}
$$
($X$가 포함된 행은 $i$번째 행이다.)
이제 $ \frac{\partial{L}}{\partial{W_{i,j}}} $를 구해보자.
$$
\begin{align}
\frac{\partial{L}}{\partial{W_{i,j}}}
&= T \cdot \frac{\partial{Z}}{\partial{W_{i,j}}} \\
&= \sum_{k=1}^{N} T_{i,k} X_{j, k} \\
&= \sum_{k=1}^{N} T_{i,k} {X^T}_{k, j} \\
\end{align}
$$
이걸 행렬곱으로 나타내면 다음과 같다.
$$
\frac{\partial{L}}{\partial{W}}
= T X^T
= \frac{\partial{L}}{\partial{Z}} X^T
$$
참고자료
http://cs231n.stanford.edu/handouts/derivatives.pdf
'인공지능' 카테고리의 다른 글
ML-Agents Crawler 환경을 stable-baselines3로 학습하기 (1) | 2024.09.10 |
---|---|
Cross Entropy Derivation (0) | 2024.03.11 |
CS234 Notes - Lecture 2 번역본 (0) | 2023.11.17 |
RNN - Recurrent neural network (0) | 2023.09.11 |
머신러닝에 쓰이는 정보이론 (0) | 2023.09.04 |