ViT

1
2
3
4
5
6
7
8
9
10
11
12
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

Data size

batch size = 8

image = (3,224,224)

1
2
x = torch.randn(8,3,224,224)
x.shape
torch.Size([8, 3, 224, 224])

Patch Embedding

Batch * C * H * W -> Batch * N * (P * P * C)로 임베딩

1
2
3
4
5
6
7
8
patch_size = 16
print('x : ',x.shape)
patches_shape = x.reshape((8,-1,16*16*3))
patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
print('patch : ',patches.shape)

print(torch.eq(patches_shape,patches))

x :  torch.Size([8, 3, 224, 224])
patch :  torch.Size([8, 196, 768])
tensor([[[ True, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False,  True]],

        [[ True, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False,  True]],
    
        [[ True, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False,  True]],
    
        ...,
    
        [[ True, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False,  True]],
    
        [[ True, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False,  True]],
    
        [[ True, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False,  True]]])

embedding 과정을 rearrange 대신 reshape로 진행하려 했는데, 두 함수의 결과 tensor를 비교해보니 결과가 다르다.

1
2
3
4
5
6
7
8
9
shape_test = torch.tensor([[1,2,3],
                           [4,5,6]])
print('original_shape : ',shape_test.shape)

tensor_reshape = shape_test.reshape(3,2)
tensor_rearrange = rearrange(shape_test,'a b -> b a ')

print('reshape_tensor : ',tensor_reshape)
print('rearrange_tensor : ',tensor_rearrange)
original_shape :  torch.Size([2, 3])
reshape_tensor :  tensor([[1, 2],
        [3, 4],
        [5, 6]])
rearrange_tensor :  tensor([[1, 4],
        [2, 5],
        [3, 6]])

이 2,3의 tensor를 각각 reshape와 rearrange로 transpose한 결과를 보면, 두 tensor는 사이즈는 같지만 tensor의 정렬이 다르다

reshape는 원하는 사이즈에 맞춰 데이터를 원래의 형태에 관계 없이 순서대로 채우지만, rearrange는 데이터의 정렬 방법이 다르다.

ViT에서는 einops 대신 kernel_size와 stride_size를 patch_size로 갖는 Convolutional 2D layer를 사용한 후 flatten 해준다.

이 방법을 사용할 시 성능이 향상되었다고 한다.

1
2
3
4
5
6
7
8
9
patch_size = 16
emb_size = 768
in_channels = 3

projection = nn.Sequential(
    nn.Conv2d(in_channels,emb_size,kernel_size = patch_size,stride = patch_size),
    Rearrange('b e (h) (w) -> b (h w) e'))

projection(x).shape
torch.Size([8, 196, 768])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
img_size = 224

projected_x = projection(x)
print('Projected shape : ',projected_x.shape)

cls_token = nn.Parameter(torch.randn(1,1,emb_size))
positions = nn.Parameter(torch.randn((img_size//patch_size) ** 2 + 1,emb_size))
print('Cls Shape : ',cls_token.shape,', Pos shape : ',positions.shape)

batch_size = 8
cls_tokens = repeat(cls_token,'() n e -> b n e',b = batch_size)
print('Repeated Cls shape : ',cls_tokens.shape)

cat_x = torch.cat([cls_tokens,projected_x],dim = 1)

cat_x += positions
print('output : ',cat_x.shape)

Projected shape :  torch.Size([8, 196, 768])
Cls Shape :  torch.Size([1, 1, 768]) , Pos shape :  torch.Size([197, 768])
Repeated Cls shape :  torch.Size([8, 1, 768])
output :  torch.Size([8, 197, 768])

torch.cat 으로 두 tensor를 concat

1
2
a = torch.zeros(2,3,3)
a
tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])
1
2
b = torch.ones(2,1,3)
b
tensor([[[1., 1., 1.]],

        [[1., 1., 1.]]])
1
2
3
cat = torch.cat([b,a],dim = 1)
print(cat)
print('shape : ',cat.shape)
tensor([[[1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])
shape :  torch.Size([2, 4, 3])

하나의 클래스로 구현한 결과

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class PatchEmbedding(nn.Module): 
    def __init__(self,in_channels: int = 3,patch_size :int = 16,emb_size : int = 768,img_size : int = 224):
        super().__init__()
        self.patch_size = patch_size
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels,emb_size,kernel_size = patch_size,stride = patch_size),
            Rearrange('b e (h) (w) -> b (h w) e')
        )
        self.cls_token = nn.Parameter(torch.randn(1,1,emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1,emb_size))
        

    def forward(self,x: Tensor) -> Tensor:
        b,_,_,_ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token,'() n e -> b n e',b = b)
        x = torch.cat([cls_tokens,x],dim = 1)
        x += self.positions

        return x

Multi-head Attention

ViT 에서는 q,k,v가 같은 tensor로 입력된다.

3개의 linear projection을 통해 embedding 된 후 각각 scaled dot-product attention 진행

Linear Projection

1
2
3
4
5
6
emb_size = 768
num_heads = 8
keys = nn.Linear(emb_size,emb_size)
queries = nn.Linear(emb_size,emb_size)
values = nn.Linear(emb_size,emb_size)
print(keys, queries, values)
Linear(in_features=768, out_features=768, bias=True) Linear(in_features=768, out_features=768, bias=True) Linear(in_features=768, out_features=768, bias=True)

Multi-Head

1
2
3
4
5
query = rearrange(queries(cat_x),'b n (h d) -> b h n d',h = num_heads)
key = rearrange(keys(cat_x),'b n (h d) -> b h n d',h = num_heads)
value = rearrange(values(cat_x),'b n (h d) -> b h n d',h = num_heads)

print('shape : ',query.shape,key.shape,value.shape)
shape :  torch.Size([8, 8, 197, 96]) torch.Size([8, 8, 197, 96]) torch.Size([8, 8, 197, 96])

Scaled Dot Product Attention

1
2
3
4
5
6
7
8
9
10
11
12
energy = torch.einsum('bhqd, bhkd -> bhqk',query,key)
print('energy : ',energy.shape)

scaling = emb_size ** (1/2)
att = F.softmax(energy / scaling,dim = -1)
print('att : ',att.shape)

out = torch.einsum('bhal, bhlv -> bhav',att,value)
print('out: ',out.shape)

out = rearrange(out,'b h n d -> b n (h d)')
print('out_2 : ',out.shape)
energy :  torch.Size([8, 8, 197, 197])
att :  torch.Size([8, 8, 197, 197])
out:  torch.Size([8, 8, 197, 96])
out_2 :  torch.Size([8, 197, 768])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class MultiHeadAttention(nn.Module):
    def __init__(self,emb_size: int = 768,num_heads : int = 8, dropout : float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads

        self.qkv = nn.Linear(emb_size,emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size,emb_size)

    def forward(self,x : Tensor,mask : Tensor = None) -> Tensor:
        qkv = rearrange(self.qkv(x),'b n (h d qkv) -> (qkv) b h n d',h = self.num_heads,qkv = 3)
        queries,keys,values = qkv[0], qkv[1], qkv[2]
        energy = torch.einsum('bhqd, bhkd -> bhqk',queries,keys)

        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask,fill_value)

        scaling= self.emb_size ** (1/2)
        att = F.softmax(energy/ scaling,dim = -1)
        att = self.att_drop(att)

        out = torch.einsum('bhal,bhlv -> bhav',att,values)
        out = rearrange(out,'b h n d -> b n (h d)')
        out = self.projection(out)
        return out


        

Residual Block

1
2
3
4
5
6
7
8
9
10
class Residualblock(nn.Module):
    def __init__(self,fn):
        super().__init__()
        self.fn = fn

    def forward(self,x):
        res = x
        x = self.fn(x)
        output = x + res
        return output

Feed Forward MLP

  • expansion 후에 GELU 와 Dropout 진행 후 다시 원래의 emb_size로 축소
1
2
3
4
5
6
7
8
class FeedForwardBlock(nn.Sequential):
    def __init__(self,emb_size : int,expansion : int = 4,drop_p : float = 0.):
        super().__init__(
            nn.Linear(emb_size,expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size,emb_size),
        )

Transformer Encoder Block

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,emb_size :int = 768,
                 drop_p:float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p : float = 0.,
                 **kwargs):
        super().__init__(
                Residualblock(
                    nn.Sequential(
                        nn.LayerNorm(emb_size),
                        MultiHeadAttention(emb_size,**kwargs),
                        nn.Dropout(drop_p),
                                )
                            ),
                Residualblock(
                    nn.Sequential(
                        nn.LayerNorm(emb_size),
                        FeedForwardBlock(emb_size,expansion = forward_expansion,drop_p = forward_drop_p),
                        nn.Dropout(drop_p)
                                )
                            )
            )   

test

1
2
3
x = torch.randn(8,3,224,224)
patches_embedded = PatchEmbedding()(x)
TransformerEncoderBlock()(patches_embedded).shape
cls: torch.Size([1, 1, 768])
x: torch.Size([8, 196, 768])
torch.Size([8, 197, 768])

Architecture

1
2
3
4
class TransformerEncoder(nn.Sequential):
    def __init__(self,depth = 12,**kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])
        

Head

1
2
3
4
5
6
7
class ClassificationHead(nn.Sequential):
    def __init__(self,emb_size:int = 768,n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e',reduction = 'mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size,n_classes)
        )

Summary

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class ViT(nn.Sequential):
    def __init__(self,
                 in_channels : int = 3,
                 patch_size : int = 16,
                 emb_size : int = 768,
                 img_size : int = 224,
                 depth : int = 12,
                 n_classes = 1000,
                 **kwargs):
        super().__init__(
            PatchEmbedding(in_channels,patch_size,emb_size,img_size),
            TransformerEncoder(depth,emb_size = emb_size,**kwargs),
            ClassificationHead(emb_size, n_classes)
        )
1
summary(ViT(),(3,224,224),device = 'cpu')
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
    PatchEmbedding-3             [-1, 197, 768]               0
         LayerNorm-4             [-1, 197, 768]           1,536
            Linear-5            [-1, 197, 2304]       1,771,776
           Dropout-6          [-1, 8, 197, 197]               0
            Linear-7             [-1, 197, 768]         590,592
MultiHeadAttention-8             [-1, 197, 768]               0
           Dropout-9             [-1, 197, 768]               0
    Residualblock-10             [-1, 197, 768]               0
        LayerNorm-11             [-1, 197, 768]           1,536
           Linear-12            [-1, 197, 3072]       2,362,368
             GELU-13            [-1, 197, 3072]               0
          Dropout-14            [-1, 197, 3072]               0
           Linear-15             [-1, 197, 768]       2,360,064
          Dropout-16             [-1, 197, 768]               0
    Residualblock-17             [-1, 197, 768]               0
        LayerNorm-18             [-1, 197, 768]           1,536
           Linear-19            [-1, 197, 2304]       1,771,776
          Dropout-20          [-1, 8, 197, 197]               0
           Linear-21             [-1, 197, 768]         590,592
MultiHeadAttention-22             [-1, 197, 768]               0
          Dropout-23             [-1, 197, 768]               0
    Residualblock-24             [-1, 197, 768]               0
        LayerNorm-25             [-1, 197, 768]           1,536
           Linear-26            [-1, 197, 3072]       2,362,368
             GELU-27            [-1, 197, 3072]               0
          Dropout-28            [-1, 197, 3072]               0
           Linear-29             [-1, 197, 768]       2,360,064
          Dropout-30             [-1, 197, 768]               0
    Residualblock-31             [-1, 197, 768]               0
        LayerNorm-32             [-1, 197, 768]           1,536
           Linear-33            [-1, 197, 2304]       1,771,776
          Dropout-34          [-1, 8, 197, 197]               0
           Linear-35             [-1, 197, 768]         590,592
MultiHeadAttention-36             [-1, 197, 768]               0
          Dropout-37             [-1, 197, 768]               0
    Residualblock-38             [-1, 197, 768]               0
        LayerNorm-39             [-1, 197, 768]           1,536
           Linear-40            [-1, 197, 3072]       2,362,368
             GELU-41            [-1, 197, 3072]               0
          Dropout-42            [-1, 197, 3072]               0
           Linear-43             [-1, 197, 768]       2,360,064
          Dropout-44             [-1, 197, 768]               0
    Residualblock-45             [-1, 197, 768]               0
        LayerNorm-46             [-1, 197, 768]           1,536
           Linear-47            [-1, 197, 2304]       1,771,776
          Dropout-48          [-1, 8, 197, 197]               0
           Linear-49             [-1, 197, 768]         590,592
MultiHeadAttention-50             [-1, 197, 768]               0
          Dropout-51             [-1, 197, 768]               0
    Residualblock-52             [-1, 197, 768]               0
        LayerNorm-53             [-1, 197, 768]           1,536
           Linear-54            [-1, 197, 3072]       2,362,368
             GELU-55            [-1, 197, 3072]               0
          Dropout-56            [-1, 197, 3072]               0
           Linear-57             [-1, 197, 768]       2,360,064
          Dropout-58             [-1, 197, 768]               0
    Residualblock-59             [-1, 197, 768]               0
        LayerNorm-60             [-1, 197, 768]           1,536
           Linear-61            [-1, 197, 2304]       1,771,776
          Dropout-62          [-1, 8, 197, 197]               0
           Linear-63             [-1, 197, 768]         590,592
MultiHeadAttention-64             [-1, 197, 768]               0
          Dropout-65             [-1, 197, 768]               0
    Residualblock-66             [-1, 197, 768]               0
        LayerNorm-67             [-1, 197, 768]           1,536
           Linear-68            [-1, 197, 3072]       2,362,368
             GELU-69            [-1, 197, 3072]               0
          Dropout-70            [-1, 197, 3072]               0
           Linear-71             [-1, 197, 768]       2,360,064
          Dropout-72             [-1, 197, 768]               0
    Residualblock-73             [-1, 197, 768]               0
        LayerNorm-74             [-1, 197, 768]           1,536
           Linear-75            [-1, 197, 2304]       1,771,776
          Dropout-76          [-1, 8, 197, 197]               0
           Linear-77             [-1, 197, 768]         590,592
MultiHeadAttention-78             [-1, 197, 768]               0
          Dropout-79             [-1, 197, 768]               0
    Residualblock-80             [-1, 197, 768]               0
        LayerNorm-81             [-1, 197, 768]           1,536
           Linear-82            [-1, 197, 3072]       2,362,368
             GELU-83            [-1, 197, 3072]               0
          Dropout-84            [-1, 197, 3072]               0
           Linear-85             [-1, 197, 768]       2,360,064
          Dropout-86             [-1, 197, 768]               0
    Residualblock-87             [-1, 197, 768]               0
        LayerNorm-88             [-1, 197, 768]           1,536
           Linear-89            [-1, 197, 2304]       1,771,776
          Dropout-90          [-1, 8, 197, 197]               0
           Linear-91             [-1, 197, 768]         590,592
MultiHeadAttention-92             [-1, 197, 768]               0
          Dropout-93             [-1, 197, 768]               0
    Residualblock-94             [-1, 197, 768]               0
        LayerNorm-95             [-1, 197, 768]           1,536
           Linear-96            [-1, 197, 3072]       2,362,368
             GELU-97            [-1, 197, 3072]               0
          Dropout-98            [-1, 197, 3072]               0
           Linear-99             [-1, 197, 768]       2,360,064
         Dropout-100             [-1, 197, 768]               0
   Residualblock-101             [-1, 197, 768]               0
       LayerNorm-102             [-1, 197, 768]           1,536
          Linear-103            [-1, 197, 2304]       1,771,776
         Dropout-104          [-1, 8, 197, 197]               0
          Linear-105             [-1, 197, 768]         590,592
MultiHeadAttention-106             [-1, 197, 768]               0
         Dropout-107             [-1, 197, 768]               0
   Residualblock-108             [-1, 197, 768]               0
       LayerNorm-109             [-1, 197, 768]           1,536
          Linear-110            [-1, 197, 3072]       2,362,368
            GELU-111            [-1, 197, 3072]               0
         Dropout-112            [-1, 197, 3072]               0
          Linear-113             [-1, 197, 768]       2,360,064
         Dropout-114             [-1, 197, 768]               0
   Residualblock-115             [-1, 197, 768]               0
       LayerNorm-116             [-1, 197, 768]           1,536
          Linear-117            [-1, 197, 2304]       1,771,776
         Dropout-118          [-1, 8, 197, 197]               0
          Linear-119             [-1, 197, 768]         590,592
MultiHeadAttention-120             [-1, 197, 768]               0
         Dropout-121             [-1, 197, 768]               0
   Residualblock-122             [-1, 197, 768]               0
       LayerNorm-123             [-1, 197, 768]           1,536
          Linear-124            [-1, 197, 3072]       2,362,368
            GELU-125            [-1, 197, 3072]               0
         Dropout-126            [-1, 197, 3072]               0
          Linear-127             [-1, 197, 768]       2,360,064
         Dropout-128             [-1, 197, 768]               0
   Residualblock-129             [-1, 197, 768]               0
       LayerNorm-130             [-1, 197, 768]           1,536
          Linear-131            [-1, 197, 2304]       1,771,776
         Dropout-132          [-1, 8, 197, 197]               0
          Linear-133             [-1, 197, 768]         590,592
MultiHeadAttention-134             [-1, 197, 768]               0
         Dropout-135             [-1, 197, 768]               0
   Residualblock-136             [-1, 197, 768]               0
       LayerNorm-137             [-1, 197, 768]           1,536
          Linear-138            [-1, 197, 3072]       2,362,368
            GELU-139            [-1, 197, 3072]               0
         Dropout-140            [-1, 197, 3072]               0
          Linear-141             [-1, 197, 768]       2,360,064
         Dropout-142             [-1, 197, 768]               0
   Residualblock-143             [-1, 197, 768]               0
       LayerNorm-144             [-1, 197, 768]           1,536
          Linear-145            [-1, 197, 2304]       1,771,776
         Dropout-146          [-1, 8, 197, 197]               0
          Linear-147             [-1, 197, 768]         590,592
MultiHeadAttention-148             [-1, 197, 768]               0
         Dropout-149             [-1, 197, 768]               0
   Residualblock-150             [-1, 197, 768]               0
       LayerNorm-151             [-1, 197, 768]           1,536
          Linear-152            [-1, 197, 3072]       2,362,368
            GELU-153            [-1, 197, 3072]               0
         Dropout-154            [-1, 197, 3072]               0
          Linear-155             [-1, 197, 768]       2,360,064
         Dropout-156             [-1, 197, 768]               0
   Residualblock-157             [-1, 197, 768]               0
       LayerNorm-158             [-1, 197, 768]           1,536
          Linear-159            [-1, 197, 2304]       1,771,776
         Dropout-160          [-1, 8, 197, 197]               0
          Linear-161             [-1, 197, 768]         590,592
MultiHeadAttention-162             [-1, 197, 768]               0
         Dropout-163             [-1, 197, 768]               0
   Residualblock-164             [-1, 197, 768]               0
       LayerNorm-165             [-1, 197, 768]           1,536
          Linear-166            [-1, 197, 3072]       2,362,368
            GELU-167            [-1, 197, 3072]               0
         Dropout-168            [-1, 197, 3072]               0
          Linear-169             [-1, 197, 768]       2,360,064
         Dropout-170             [-1, 197, 768]               0
   Residualblock-171             [-1, 197, 768]               0
          Reduce-172                  [-1, 768]               0
       LayerNorm-173                  [-1, 768]           1,536
          Linear-174                 [-1, 1000]         769,000
================================================================
Total params: 86,415,592
Trainable params: 86,415,592
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 364.33
Params size (MB): 329.65
Estimated Total Size (MB): 694.56
----------------------------------------------------------------
1