ViT-AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE
Paper:
Github:
https://github.com/google-research/vision_transformer
https://github.com/lucidrains/vit-pytorch
Abstract
2010년 후반, NLP task에서 self-attention을 이용한 Transformer model 이 좋은 성능을 보여었습니다.
Transformer model은 기존의 RNN 계열에 비해 효율적인 연산과 좋은 성능으로 인해 큰 규모의 모델도 쉽게 구성할수 있었습니다.
Computer Vision task에서도 self-attention을 도입하려는 시도는 꾸준히 있어왔습니다.
Self-attention을 기존의 CNN architecture에 도입하거나, Convolution 자체를 대체하는 등의 시도가 있었지만 성능이나 연산량에 한계가 존재했기 때문에, 여전히 ResNet 기반의 architecture 들이 vision task에서는 SOTA 를 차지했습니다.
하지만 Vision Transformer (ViT)는 Transformer의 architecture 를 최소한의 수정으로 image에 그대로 적용 하였으며 기존의 CNN 계열 SOTA 모델의 성능을 뛰어넘었습니다.
Vision Transformer - Embedding
우선 ViT 의 Input embedding 에 대해서 알아보겠습니다.
NLP 에서는 문장을 tokenizer를 통해 tokenize 한 1차원 token vector 를 input으로 사용합니다.
하지만 Image는 의 2차원으로 이루어져있기 때문에 그대로 사용할 수 없습니다.
따라서 전체 Image를 Patch로 나누고 1차원으로 flatten 하는 일종의 Token화 과정이 필요합니다.
Vision Transformer의 입력은 다음과 같은 과정을 거쳐 구성됩니다.
- Patch Embedding
- Class Token
- Position Embedding
- Hybrid Model (Optional)
1. Patch Embedding
Image 를 token으로 만드는 가장 간단한 방법은 Image를 여러개의 patch 로 나누는 것입니다.
의 Image를 의 patch로 나눈 후 flatten 을 해주면 () 개 의 Patch 를 얻을 수 있습니다.
위의 과정을 통해 얻은 를 일정한 dimension 의 latent vector 로 embedding 해주는데, 이를 Patch embedding 이라고 부릅니다.
해당 과정은 PyTorch 에서 아래와 같이 구현됩니다.
from einops.layers.torch import Rearrange
# Patching and Flatten for embedding
to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
2. Class Token
BERT 의 Class Token 과 유사하게, Image 전체의 정보를 함축하여 저장할수 있게하는 Learnable embedding 을 입력 patch들의 제일 앞에 추가해줍니다. ( 에 해당 )
이후 이 L 개의 Transformer encoder를 통과하고 나온 최종 output 에 MLP layer 로 이루어진 classification head 를 통과시켜 예측된 class 를 얻습니다.
아래는 Pytorch에서 cls token을 정의하고 추가해준 뒤 classification head 로 결과를 예측하는 일부분을 나타냅니다.
해당 예시에서는 기존 논문대로 cls token의 결과에 대해서만 classification head를 통과시키는 것 외에, 전체 output을 평균내어 class를 예측하는 옵션도 구현하였습니다.
# Init learnable class token
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
...
def forward(self, img):
# Add class token to embedding feature
x = torch.cat((cls_tokens, x), dim=1)
...
# Mean whole state or only cls state
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
# Predict class
return self.mlp_head(x)
3. Position Embedding
NLP task와 유사하게, 각 Image의 patch에 position embedding 을 더해줍니다.
Image에서 pixel의 위치정보는 pixel간의 상관관계 (correlation) 를 이해하는데에도 매우 중요한 역할을 합니다. CNN architecture가 vision task에서 좋은 성능을 발휘하는 것도 Convolution layer 에서 local한 정보를 잘 캐치할수 있기 때문입니다.
반면 self-attention은 global하게 일어나기 때문에 이러한 local 정보로부터 얻어지는 inductive bias를 학습하기 어렵습니다.
따라서, 위치 정보를 학습할 수 있는 Position embedding을 더해주게 됩니다. position embedding은 단순한 1차원 vector로 이루어져있으며, 더 고차원의 position embedding은 성능에 큰 영향을 주지는 않았다고 합니다.
position embedding은 처음 initialize 될 때 위치에 대한 아무런 정보도 가지지 않은 random distribution으로 생성되며, spatial relation 은 온전히 학습을 통해 처음부터 학습됩니다.
이미 학습된 position embedding은 train 시의 input size에 대해서 fitting이 되어있기 때문에, train data와 다른 resolution 의 input 에 대해서는 의미가 없어집니다.
이를 해결하기 위해 다른 size의 input이 들어오게 된다면, position embedding을 그 사이즈에 맞게 interpolate 해주는 방식으로 기존의 위치 정보를 이용할 수 있게 합니다.
Pytorch 에서 position embedding 을 수행하는 코드는 아래와 같습니다.
# Init position embedding
# Path embedding 된 feature 에 더해줄것이므로 dim channel의 n+1 개 만큼 생성
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
def forward(self, img):
...
x += self.pos_embedding[:, :(n + 1)]
4. Hybrid Architecture (optional)
1-3 의 과정과 같이, Image를 patch화 한 뒤 그대로 embedding을 진행할수도 있지만 Convolution layer를 통과시킨 feature에 대해서 embedding 을 진행할수도 있습니다.
이 과정을 거치게 되면 Convolution layer로 인해 순수 transformer layer 에서 부족한 inductive bias를 보다 수월하게 학습하는 효과를 기대할 수 있습니다.
실험 결과 작은 사이즈의 ViT 에서는 Hybrid model이 더 뛰어난 성능을 보였지만, parameter 수가 일정 이상 커지게 되면 차이가 거의 없었습니다.
Overall
최종적으로 앞의 과정을 Pytorch 로 구현하면 아래와 같습니다.
forward
함수를 살펴보면,
- Patch embedding
- Add cls token
- Add position embedding
- x Transformer encdoer
- Classification head
의 과정을 거쳐 Predicted class 가 나오는것을 확인할 수 있습니다.
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
# Path embedding
x = self.to_patch_embedding(img)
b, n, _ = x.shape
# Add cls token
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
# Add position embedding
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
# Transformer encoder layers
x = **self.transformer(x)**
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
# Predict
return self.mlp_head(x)
Vision Transformer - Encoder
다음으로는 Vision Transformer 에서 Transformer encoder 에 대해서 알아보겠습니다.
앞선 과정을 통해 embedding 된 feature 는 위 그림과 같은 Transformer Encoder를 번 통과하게 됩니다.
Transformer encoder는 Attention is all you need 에서 소개된것과 동일하게 구성되어 있으며, 다음과 같은 특징을 같습니다.
- Multi head self -attention
- MLP blocks
- Layer Normalization
- Residual Connection
1. Multi head self-attention
Multi head self-attention은 위의 과정을 num_heads () 개로 나누어 수행합니다.
각각의 head 끼리는 별개의 tensor로 취급되지만, 의 형태로 묶으면 한번에 multi head self-attention을 수행할 수 있습니다.
Pytorch code로 구현하면 아래와 같습니다.
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
# Generate qkv via linear layer
qkv = self.to_qkv(x).chunk(3, dim = -1)
# Reshape for multi head self-attention
**q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)**
# (b h n d) x (b h d n) -> (b h n n) * scale
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# softmax
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
2. MLP Blocks
기존 논문과 동일하게, MLP block은 LayerNorm, Linear, GeLU 로 구성되어 있습니다.
PyTorch 로 구현하면 아래와 같습니다.
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
3. Layer Normalization
LayerNorm 은 NLP Transformer 에서 BatchNorm에 비해 좋은 성능을 보였습니다. ViT 에서는 Vision task에도 LayerNorm 을 그대로 차용했지만, 그에 대한 부가적인 설명이나 실험은 없었습니다.
4. Residual Connection
Transformer Block 에서, multi head self-attention과 MLP block을 통과한 직후 LayerNorm 이전의 feature 를 residual connection으로 더해줍니다.
이를 통해 Gradient vanishing으로 인한 정보 손실을 방지하여 더 좋은 성능을 이끌어낼 수 있습니다.
Overall
위의 요소들로 구성된 Transformer Block은 아래와 같이 구현할 수 있습니다.
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
Experiments
논문에서는 이렇게 구성한 다양한 크기의 Vision Transformer model을 약 300M(3억) 개의 train data 로 학습을 진행했습니다.
그 결과, SOTA를 차지하고 있던 CNN-like model (ResNet) 보다 뛰어난 성능을 보여주었습Con니다.
Conclusions
ViT 에서는 기존에 self-attention을 Vision task에 적용하려했던 시도와 달리, Image를 path들의 sequence 그 자체로 사용하여 NLP의 Transformer encoder를 그대로 사용하였습니다.
그 결과, 매우 큰 Dataset에서의 학습을 성공적으로 수행하였고 여러가지 Image classification dataset에 대해서 기존의 CNN model들의 성능을 뛰어넘고 SOTA를 달성하였습니다.
뿐만 아니라, Large dataset 에서 학습된 pre-trained model을 다른 down stream task로 fine tuning 할때에도 매우 좋은 성능을 보여주었습니다.
그 결과 현재 수많은 Vision model들에서 Transformer architecture를 적용하고 있고, 발전도 이어지고 있습니다.
하지만 한계점도 존재하는데, 대표적인 예로 Image의 크기에 따라 quadratic 하게 늘어나는 연산량이 있습니다.