[5분 컷 이해] Swin transformer 쉬운 이해와 설명
요약
Swin Transformer은 ViT(Vision Transformer)와 유사하게 이미지를 패치로 나누어, 각 패치를 토큰처럼 취급하여Transfomer에 전달하는 모델이다. ViT와의 차이점은 ViT은 패치를 레이어를 통과시키면서, 고정크기로만 연산하는 것에 반해, Swin Transformer은 레이어를 지나면서 각 패치들을 합쳐서, 패치의 크기를 크게 만들고, Self-attention의 범위를 확장시키는데 그 차이가 있다(Figure 1). 추가적인 주요 특징으로는 윈도우 내 영역내 여러 패치가 존재할 때, Attention을 줄 영역을 윈도우 내로만 주는 것이 아니라, 이전 레이어의 Window영역을 가로지를 수 있게 다음 레이어에서 Attention할 영역의 Window을 매번 달리하는 것이 있다.
모델 설명: 1) 패치의 분리, 2) 패치의 임베딩, 3) 트렌스포머(어텐션 + Shifted window 파티셔닝), 4) 패치 병합과정의 연속
1. 이미지를 작은 단위의 이미지인 패치(Patch)형태로 분리한다: 각각의 패치는 겹치지않도록 분리하며 ViT(Vision Transformer)의 방식과 동일하게 이미지를 분리한다. 즉, 각 패치는 정사각형의 p x p의 형태를 갖는다. 예를들어, MNIST의 28 x 28의 이미지 크기라면, 4x4패치를 만든다고 가정하면, 49개의 패치가 생긴다(49=28/4 * 28/ 4) (Figure 2. 패치 파티셔닝의 예시).
2. 각 패치의 벡터 형태로 변환한다. 1. 단계에서의 패치의 크기는 $P \times P $이다. 이 이미지가 3차원이었다면, 이를 모두 concat함과 동시에 Flatten하여 $ P \cdot P \cdot 3 $의 형태로 변경한다. 즉, 채널에 해당하는 벡터까지 옆으로 펼쳐서(Flatten) 벡터화 한다.
3. 선형결합을 하여 C차원을 가진 벡터로 만든다. 패치의 가로X세로가 PXP인 것을 3채널을 펼쳐(Flatten)하여 PxPx3의 벡터를 만들었다. 이 벡터를 압축하여, 임의의 크기가 C인 벡터로 만들어준다(=쉽게 말해 Dense(C)을 한번 거쳐준다). 이 연산의 결과의 벡터는 $\mathbb{R}^{C}$이다.
4. Stage 1에서는 Swin transformer block을 통과시킨다: 이 때의 각 패치의 차원은 동일하게 유지한다. 즉 C에서 C로 나온다. 다만 벡터내 원소의 숫자는 다르다.
5. Stage 2. 인접패치를 합쳐준다: 첫 번째로 임베딩된 패치들은 근처의 2x2패치들을 합친다(Concat). 3의 단계에서 각 패치들을 C차원의 벡터로 만들었기 때문에, 4C 차원의 벡터가 만들어진다. 이 합치는 연산은 토큰의 숫자를 많이 줄여든다. 일종의 Pooling과 비슷하다(=본문에 ”down sampling of resolution”라고 기술됨.).
6. Stage 2단계를 2번, 위와 비슷한 Stage 3단계를 6번, Stage 4단계를 2번진행한다.
Window 영역 내에서만 attention을 적용하는 이유? 계산 복잡도 때문에...
Transformer에서는 토큰과 토큰사이의 관계를 모두 Attention을 적용하며 학습한다. 토큰의 수가 N이라면, NxN의 Attention weight matrix가 나오게 되는 것이다. Vision문제에서 이를 쉽게 적용할 수 있을까? 그렇지 않다. 이미지 자체가 128 x 128이면 엄청난 크기의 attention weight matrix가 생성되어야하기 때문이다. 저자들은 이를 해결하기 위한 방법을 고안했다.
1) Self-attention in non overlapped window: Window 영역을 정하고, 이 내에서만 self attention을 적용하는 것. 아래와 같이 작은 사각형은 각 패치를 의미하고, 두꺼운 줄을 가진 큰 사각형은 Window으로 정의해보자. 윈도우 영역내에는 4x4 (본문내 M x M)개의 패치가 있다. 이를 global self attention을 한다고생각하면, ....다음과 같은 계산이 나온다.
한 윈도우에, 패치가 M x M이 있다고 가정하고, 일반적인 global self-attention을 하는 경우(1)와 widnow내 h x w개의 패치가 있는 경우를 계산복잡도를 구하면 아래와 같다. 중요한건 총 패치수인 h x w 수가 증가할 때, (1)의 식은 제곱으로 복잡도가 증가하지만, (2)의 식은 선형적으로 증하한다는 것이다. 윈도우 내에만 self attention하는 것이 컴퓨터비전에서 계산 복잡도를 낮출 수 있기 때문에, 이 방법을 저자들은 생각했다.
Shift window 을 구현하는 방법과 그 이유: 윈도우 내에서만 attention을 계산하면, 윈도우 넘어 연결성이 부족하기에... 윈도우 영역을 돌려가며(cyclic) 윈도우를 정의
Shift attention을 적용하기위해서 각 attention을 적용하는 단계를 본 논문에서는 module이라고 부른다. 논문에서는 2개의 모듈을 사용하는데, 첫번째는 왼쪽 상단에서 규칙적인 모양을 가진 window로 attention을 하고, 두 번째 모듈에서는 규칙적으로 분할된 윈도우에서 $(\lfloor\frac{M}{2}\rfloor,\lfloor\frac{M}{2}\rfloor)$ 만큼 떼어냏어 옮긴 후에 attention을 적용하는 것이다.