[논문리뷰] ViT

ViT-AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

Paper:

Paper

Github:

https://github.com/google-research/vision_transformer

https://github.com/lucidrains/vit-pytorch

Abstract

2010년 후반, NLP task에서 self-attention을 이용한 Transformer model 이 좋은 성능을 보여었습니다.

Transformer model은 기존의 RNN 계열에 비해 효율적인 연산과 좋은 성능으로 인해 큰 규모의 모델도 쉽게 구성할수 있었습니다.

Computer Vision task에서도 self-attention을 도입하려는 시도는 꾸준히 있어왔습니다.

Self-attention을 기존의 CNN architecture에 도입하거나, Convolution 자체를 대체하는 등의 시도가 있었지만 성능이나 연산량에 한계가 존재했기 때문에, 여전히 ResNet 기반의 architecture 들이 vision task에서는 SOTA 를 차지했습니다.

하지만 Vision Transformer (ViT)는 Transformer의 architecture 를 최소한의 수정으로 image에 그대로 적용 하였으며 기존의 CNN 계열 SOTA 모델의 성능을 뛰어넘었습니다.

Vision Transformer - Embedding

우선 ViT 의 Input embedding 에 대해서 알아보겠습니다.

NLP 에서는 문장을 tokenizer를 통해 tokenize 한 1차원 token vector 를 input으로 사용합니다.

하지만 Image는 H×WH\times{W} 의 2차원으로 이루어져있기 때문에 그대로 사용할 수 없습니다.

따라서 전체 Image를 Patch로 나누고 1차원으로 flatten 하는 일종의 Token화 과정이 필요합니다.

Untitled

Vision Transformer의 입력은 다음과 같은 과정을 거쳐 구성됩니다.

  1. Patch Embedding
  2. Class Token
  3. Position Embedding
  4. Hybrid Model (Optional)

1. Patch Embedding

Image 를 token으로 만드는 가장 간단한 방법은 Image를 여러개의 patch 로 나누는 것입니다.

xRH×W×C{x\in\mathbb{R}^{H\times{W}\times{C}}} 의 Image를 P×PP\times{P} 의 patch로 나눈 후 flatten 을 해주면 (N=HWP2N=\frac {HW} {P^2}) 개 의 Patch 를 얻을 수 있습니다.

Untitled

위의 과정을 통해 얻은 xRN×(P2C){x\in\mathbb{R}^{N\times{(P^2C)}}} 를 일정한 DD dimension 의 latent vector 로 embedding 해주는데, 이를 Patch embedding 이라고 부릅니다.

해당 과정은 PyTorch 에서 아래와 같이 구현됩니다.

from einops.layers.torch import Rearrange

# Patching and Flatten for embedding
to_patch_embedding = nn.Sequential(
    Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
    nn.LayerNorm(patch_dim),
    nn.Linear(patch_dim, dim),
    nn.LayerNorm(dim),
)

2. Class Token

Untitled

BERT 의 Class Token 과 유사하게, Image 전체의 정보를 함축하여 저장할수 있게하는 Learnable embedding 을 입력 patch들의 제일 앞에 추가해줍니다. (XclassX_{class} 에 해당 )

이후 z0z_0 이 L 개의 Transformer encoder를 통과하고 나온 최종 output z0Lz_0^L 에 MLP layer 로 이루어진 classification head 를 통과시켜 예측된 class yy 를 얻습니다.

아래는 Pytorch에서 cls token을 정의하고 추가해준 뒤 classification head 로 결과를 예측하는 일부분을 나타냅니다.

해당 예시에서는 기존 논문대로 cls token의 결과에 대해서만 classification head를 통과시키는 것 외에, 전체 output을 평균내어 class를 예측하는 옵션도 구현하였습니다.

# Init learnable class token
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
...
def forward(self, img):
	# Add class token to embedding feature
	x = torch.cat((cls_tokens, x), dim=1)
	...
	# Mean whole state or only cls state
	x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 
	# Predict class
	return self.mlp_head(x)

3. Position Embedding

NLP task와 유사하게, 각 Image의 patch에 position embedding 을 더해줍니다.

Image에서 pixel의 위치정보는 pixel간의 상관관계 (correlation) 를 이해하는데에도 매우 중요한 역할을 합니다. CNN architecture가 vision task에서 좋은 성능을 발휘하는 것도 Convolution layer 에서 local한 정보를 잘 캐치할수 있기 때문입니다.

반면 self-attention은 global하게 일어나기 때문에 이러한 local 정보로부터 얻어지는 inductive bias를 학습하기 어렵습니다.

따라서, 위치 정보를 학습할 수 있는 Position embedding을 더해주게 됩니다. position embedding은 단순한 1차원 vector로 이루어져있으며, 더 고차원의 position embedding은 성능에 큰 영향을 주지는 않았다고 합니다.

Untitled

position embedding은 처음 initialize 될 때 위치에 대한 아무런 정보도 가지지 않은 random distribution으로 생성되며, spatial relation 은 온전히 학습을 통해 처음부터 학습됩니다.

이미 학습된 position embedding은 train 시의 input size에 대해서 fitting이 되어있기 때문에, train data와 다른 resolution 의 input 에 대해서는 의미가 없어집니다.

이를 해결하기 위해 다른 size의 input이 들어오게 된다면, position embedding을 그 사이즈에 맞게 interpolate 해주는 방식으로 기존의 위치 정보를 이용할 수 있게 합니다.

Pytorch 에서 position embedding 을 수행하는 코드는 아래와 같습니다.

# Init position embedding
# Path embedding 된 feature 에 더해줄것이므로 dim channel의 n+1 개 만큼 생성
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

def forward(self, img):
...
	x += self.pos_embedding[:, :(n + 1)]

4. Hybrid Architecture (optional)

1-3 의 과정과 같이, Image를 patch화 한 뒤 그대로 embedding을 진행할수도 있지만 Convolution layer를 통과시킨 feature에 대해서 embedding 을 진행할수도 있습니다.

이 과정을 거치게 되면 Convolution layer로 인해 순수 transformer layer 에서 부족한 inductive bias를 보다 수월하게 학습하는 효과를 기대할 수 있습니다.

Untitled

실험 결과 작은 사이즈의 ViT 에서는 Hybrid model이 더 뛰어난 성능을 보였지만, parameter 수가 일정 이상 커지게 되면 차이가 거의 없었습니다.

Overall

최종적으로 앞의 과정을 Pytorch 로 구현하면 아래와 같습니다.

forward 함수를 살펴보면,

  • Patch embedding
  • Add cls token
  • Add position embedding
  • LL x Transformer encdoer
  • Classification head

의 과정을 거쳐 Predicted class YY 가 나오는것을 확인할 수 있습니다.

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
				# Path embedding
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape
				
				# Add cls token
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
				# Add position embedding
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
				
				# Transformer encoder layers
        x = **self.transformer(x)**
	
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
			
        x = self.to_latent(x)
				# Predict
        return self.mlp_head(x)

Vision Transformer - Encoder

다음으로는 Vision Transformer 에서 Transformer encoder 에 대해서 알아보겠습니다.

Untitled

Untitled

앞선 과정을 통해 embedding 된 feature 는 위 그림과 같은 Transformer Encoder를 LL 번 통과하게 됩니다.

Transformer encoder는 Attention is all you need 에서 소개된것과 동일하게 구성되어 있으며, 다음과 같은 특징을 같습니다.

  1. Multi head self -attention
  2. MLP blocks
  3. Layer Normalization
  4. Residual Connection

1. Multi head self-attention

Untitled

Multi head self-attention은 위의 과정을 num_heads (hh) 개로 나누어 수행합니다.

각각의 head 끼리는 별개의 tensor로 취급되지만, (B,h,N,d/h)(B, h, N, d/h) 의 형태로 묶으면 한번에 multi head self-attention을 수행할 수 있습니다.

Pytorch code로 구현하면 아래와 같습니다.

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)
			  
				# Generate qkv via linear layer
        qkv = self.to_qkv(x).chunk(3, dim = -1)
				# Reshape for multi head self-attention
        **q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)**
				
				# (b h n d) x (b h d n) -> (b h n n) * scale
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
				
				# softmax
        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out) 

2. MLP Blocks

기존 논문과 동일하게, MLP block은 LayerNorm, Linear, GeLU 로 구성되어 있습니다.

PyTorch 로 구현하면 아래와 같습니다.

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

3. Layer Normalization

LayerNorm 은 NLP Transformer 에서 BatchNorm에 비해 좋은 성능을 보였습니다. ViT 에서는 Vision task에도 LayerNorm 을 그대로 차용했지만, 그에 대한 부가적인 설명이나 실험은 없었습니다.

4. Residual Connection

Transformer Block 에서, multi head self-attention과 MLP block을 통과한 직후 LayerNorm 이전의 feature 를 residual connection으로 더해줍니다.

이를 통해 Gradient vanishing으로 인한 정보 손실을 방지하여 더 좋은 성능을 이끌어낼 수 있습니다.

Overall

위의 요소들로 구성된 Transformer Block은 아래와 같이 구현할 수 있습니다.

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

Experiments

Untitled

논문에서는 이렇게 구성한 다양한 크기의 Vision Transformer model을 약 300M(3억) 개의 train data 로 학습을 진행했습니다.

Untitled

그 결과, SOTA를 차지하고 있던 CNN-like model (ResNet) 보다 뛰어난 성능을 보여주었습Con니다.

Conclusions

ViT 에서는 기존에 self-attention을 Vision task에 적용하려했던 시도와 달리, Image를 path들의 sequence 그 자체로 사용하여 NLP의 Transformer encoder를 그대로 사용하였습니다.

그 결과, 매우 큰 Dataset에서의 학습을 성공적으로 수행하였고 여러가지 Image classification dataset에 대해서 기존의 CNN model들의 성능을 뛰어넘고 SOTA를 달성하였습니다.

뿐만 아니라, Large dataset 에서 학습된 pre-trained model을 다른 down stream task로 fine tuning 할때에도 매우 좋은 성능을 보여주었습니다.

그 결과 현재 수많은 Vision model들에서 Transformer architecture를 적용하고 있고, 발전도 이어지고 있습니다.

하지만 한계점도 존재하는데, 대표적인 예로 Image의 크기에 따라 quadratic 하게 늘어나는 연산량이 있습니다.