연세대 인공지능학회 YAI

[논문 리뷰] Focal Self-attention for Local-Global Interactions in Vision Transformers 본문

컴퓨터비전 : CV/CV 논문 리뷰

[논문 리뷰] Focal Self-attention for Local-Global Interactions in Vision Transformers

_YAI_ 2023. 1. 14. 18:09

YAI 9기 김석님이 비전논문팀에서 작성한 글입니다.

Focal Self-attention for Local-Global Interactions in Vision Transformers


0. Abstract

  • 목적
    • Self attention을 통한 짧은 것에서부터 긴 단위까지 visual dependency를 모두 capture할 수 있도록 설계하면서도 quadratic computational overhead로 인한 resolution이 높은 task에 관해서 어려운 상황도 극복할 수 있어야 함
  • Method
    • SoTA model의 경우 coarse-grain이나 fine-grained local attention을 적용하여 computational & memory cost와 성능을 개선하는 방식을 채택함
    • 하지만 이와 같은 방법들은 Multi-Layer transformer의 self-attention mechanism의 modeling에 문제를 가져오기 때문에 sub-optimal한 방법에 그치지 않음
    • Focal self-attention mechanism에서 각 token은 인접한 token을 정밀하게 구성을 구체화(fine granularity)하고 멀리 있는 경우에는 coarse granularity를 보이는 특징을 가짐
  • Focal Transformer Proposal
    • Focal self-attention이라는 fine-grained local과 coarse-grained global interaction을 모두 포함하는 새로운 mechanism을 제시함
    • 위 mechanism을 사용하여 Vision Transformer model을 변형한 Focal Transformer를 선보여 SoTA의 ViT보다 image classification과 object detection 면에서 모두 월등히 뛰어난 성능을 보임

1. Backgrounds

Transformers

  • NLP에서 transformer가 널리 알려진 model인 만큼 CV에서도 이를 사용하려는 수요가 늘어남
  • ViT에서 처음으로 transformer를 CV에 적용하였는데 이와 같은 full-Transformer model이 image classification, object detection, semantic segmentation 등에 모두 좋은 성능을 보일 뿐만아니라 action recognition, object tracking, scene flow estimation 등에도 좋은 결과를 보임
  • Self-attention
    • 대개 사용하는 CNN과는 달리 Transformer의 가장 큰 feature이자 핵심
    • 각 transformer layer마다 global content에 따라 model에 필요한 region 별 image 자료와 short-range와 long-range 모두 각각 상호작용을 진행
    • 이러한 self attention을 사용하게 되면 기존 CNN같이 local surrounding은 물론 global context까지 동시에 포착이 가능함
    • 이런 점에서 object detection과 segmentation과 같이 image의 resolution이 높은 task의 경우, global하고 fine-grained한 self-attention 방법을 사용하는 것은 grid에 따른 quadratic computational cost를 절감한다는 점에서, self-attention 사용은 거의 자명하다.

Proposal

  • Focal transformer는 새로운 방식의 self attention mechanism을 제안하여 resolution이 높은 input에 대하여 Transformer layer에서 local, global 상호작용이 모두 진행되도록 제안

→ 이때, 근접한 region이 멀리 떨어진 region보다 visual dependency가 더 크다는 점에서 local region에는 fine-grained self attention을 , global region에는 coarse-grained attention을 적용한다는 것이 차이가 있음

  • 이때, feature map에 있는 query token은 인접한 곳 중에서 가장 granularity가 큰 곳에 attend하고 멀리 떨어진 곳의 경우 이를 summarize한 token을 attend하여 coarse-grained한 visual dependency를 포착하도록 한다는 점에서 query로부터 멀리 떨어질수록 granularity가 coarse하도록 설계함

→ 이런 점에서 이와 같은 구조는 full self-attention mechanism과는 달리 self-attention 연산에 사용되는 token을 최대한 적게 사용하면서 resolution이 높은 feature map을 완전히 cover할 수 있다는 점에서 매우 효율적이고 이를 각 token이 focal manner에 따라 attend한다는 점에서 focal self-attention이라 칭함

  • 이러한 focal self-attention을 기반으로 설계한 mechanism이 focal transformer이고 아래와 같이 두 가지 과정을 거침
    1. Resolution이 높은 image에 관하여 합리적인 연산 비용을 사용할 수 있도록 multi-scale 구조로 설계
    2. Feature map을 multiple window로 나누어 동일한 주변 공간을 공유하도록 하여 각 token마다 focal self-attention을 적용하는 불필요한 행동을 없앰
  • 이와 같은 focal self-attention의 효율성을 입증하기 위해 image classification, object detection & segmentation 등 종합적으로 연구를 진행하였고 그 결과 기존 transformer의 model size와 complexity가 유사함에도 SoTA와 비교하였을 때 월등히 좋은 결과를 보임

2. Architecture

Focal Transformer

  • Resolution이 높은 vision task에도 적용하기 위해 위와 같이 early stage에서 high resolution feature map을 얻는 구조로 설계하여 아래와 같은 과정을 거침

    1. Input image($I \in R^{H\times W \times 3}$)의 경우 4x4 patch로 partitioning 진행하여 $4\times 4 \times 3$ 차원의 $\frac H 4 \times \frac W 4$개의 visual token을 output으로 정함
    2. Patch embedding layer에 convolution layer를 filter size를 4x4, stride도 4로 지정하여 $d$차원의 hidden feature로 projection 진행
  • 이로부터 얻은 feature map은 4개의 stage의 focal Transformer block을 거침

    • 이때, 각 stage마다 focal Transformer block은 $N_i$($i \in {1,2,3,4}$)개의 focal Transformer layer로 구성됨
    • 각 stage마다 patch embedding layer를 추가적으로 사용하여 factor 2만큼의 공간을 축소하고 feature dimension은 2만큼 증가함
  • Image classification의 경우 마지막 stage의 output의 평균을 받아 classification layer에 전달하는 방식을 사용함

  • Object detection의 경우 feature map의 detection 방식에 따라 최근 3개나 4개 전부의 stage를 detector head에 feed를 진행함

  • 여기서 input feature의 dimension을 $d$, model capacity의 경우 각 stage의 focal Transformer layer의 개수 ${ N_1,N_2,N_3,N_4}$에 따라 customizing이 진행됨

  • Self-attention

    • 기존의 방법의 경우 fine-grain의 short과 long-range interaction을 모두 다룰 수 있지만 resolution이 높은 feature map의 attention의 경우 computational cost가 상당히 많이 들게 됨

      → Feature map의 size가 $\frac H 4 \times \frac W 4 \times d$의 경우, self-attention의 complexity가 $O((\frac H 4 \times \frac W 4 )^2d)$로 object detection에서는 $min(H,W)$가 적어도 800은 요구된다는 점에서 시간과 메모리를 상당히 많이 소모하는 것을 알 수 있음

    • 이를 해결하기 위해 focal self-attention 방법을 사용함


3. Approach

Focal self-attention

  • Transformer layer를 resolution이 높은 input에서도 scale이 가능하게 하기 위해 설계
  • 이는 fine-grain 단위로 모든 token을 attend하는 대신 이러한 fine-grained token을 local하게만 적용하고 global하게는 이를 summarize한 결과를 사용함
  • 이와 같은 방법을 사용하면 많은 region을 standard self-attention만큼 cover가 가능한 상황에서도 computational cost를 줄일 수 있음

  • Query position에서 주변에 coarse-grain을 사용하면 focal self-attention은 동일한 양의 visual token을 attend하면서 더 넓은 단위의 receptive field를 얻을 수 있음
  • 이와 같은 focal mechanism을 사용하면 long-range의 self attention을 시간과 메모리를 상대적으로 덜 사용하면서도 가능하게 하였는데 이는 token이 summarize되었다는 점에서 애초에 인접한 장소에 요구되는 token의 개수가 적기 때문에 가능하게 된 일이다.
  • 하지만, 현실적으로 각 query position마다 인접한 token에 대하여 extraction을 진행하는 것은 access 가능한 query에 있는 token마다 duplication을 진행해야 한다는 점에서 시간과 메모리 비용이 상당히 많이 들게 됨
    • 이를 해결하기 위해 대체로 사용하는 방법은 input feature map을 window 형태로 분할을 하는 방식을 사용
    • 이에 착안하여 focal Transformer에서는 window level의 focal self-attention을 적용함
      • Feature map이 $x \in R^{M\times N \times d}$이고 size가 $M\times N$인 경우 아래와 같은 과정 진행
      1. Feature map을 size가 $s_p \times s_p$인 grid window로 분할 진행
      2. 이때 token을 개별적으로 탐색하지 않고 각 window의 인접 지역을 탐색하는 방식을 진행하여 목표한 elaboration 진행

Window-wise attention

  • 명료함을 위해서 아래와 같은 term을 미리 정의함

  • Focal level($L$) : Focal self attention을 통해 extract한 token의 granularity level

  • Focal window size($s_w^l$) : Level $l \in { 1,...,L}$에서의 summarized token으로부터 받는 sub-window의 size

  • Focal region size($s_r^l$) : Level $l$에서 horizontal하고 vertical하게 attend된 sub-window의 개수

  • 위와 같은 3개의 term을 정의한 다음 self-attention module을 아래와 같은 두 가지 과정으로 세분화 할 수 있음

    1. Sub-window pooling

      • Feature map $x \in R^{M \times N \times d}$에 대하여 spatial dimension이 $M\times N$이고 feature dimension이 $d$일때, 모든 $L$ level에 관하여 sub-window pooling을 모두 진행

      • Focal level $l$에서, $x$를 $s_w^l \times s_w^l$ size의 grid의 sub-window로 분할함

      • 이후에는 단순한 linear layer $f_p^l$을 추가하여 sub-window에 pooling을 진행하여 아래와 같은 식을 만족함

        $$
        x^l = f_p^l(\hat x) \in R^{\frac M {s_w^l} \times \frac N {s_w^l} \times d} \ \hat x = Reshape(x) \in R^{(\frac M {s_w^l} \times \frac N {s_w^l} \times d) \times ({s_w^l} \times {s_w^l})}
        $$

      • Pooling을 진행한 feature map ${x^l}_1^L$은 fine-grain coarse-grain 양측 면에서 모두 풍부한 정보를 포함하고 있음

      • 이때, input feature map의 첫 번째 local level에 $s_w^l=1$로 setting하였다는 점에서 sub-window pooling을 진행할 필요는 없음

      • 기본적으로 focal window의 size가 대체로 매우 작은 것을 감안하면 sub-window pooling에 필요한 추가적인 parameter는 무시할 만하다.

    2. Attention computation

      • $L$ level에 관하여 pooling을 완료한 feature map ${x^l}^L_1$를 얻게 되면 3개의 linear projection layer $f_q,f_k,f_v$를 사용하여 첫 번째 level에서부터 query를 연산하고 key와 value를 모든 level에서 연산을 진행함

        $$
        Q = f_q(x^1), \ K={K^l}_1^L = f_k({x^1,...,x^L}), \ V={V^l}_1^L = f_v({x^1,...,x^L})
        $$

      • Focal self-attention을 수행하기 위해서는 feature map의 각 query token에 관하여 인접한 token의 extraction 과정이 수행되어야 하고 이때 size $s_p \times s_p$의 window partition 내의 token의 경우 인접한 요소들을 서로 공유함

      • $i$번째 window의 query $Q_i$의 경우, $s_r^l \times s_r^l$ key와 value를 $K^l$과 $V^l$로부터 추출을 진행하고 모든 $L$에 관하여 key와 value를 모두 모아서 $K_i = {K_i^l,...K_i^L} \in R^{s\times d}, \ V_i = {V_i^l,...V_i^L} \in R^{s \times d}$를 획득함

      • 이때, $s$는 $s=\sum_{t=1}^L(s_r^l)^2$로 모든 level에 관한 focal region의 총합을 의미함

      • Overlap region을 완전히 배제하는 엄밀한 focal self-attention 방법과는 달리 focal Transformer에서는 overlap된 region의 pyramid information을 얻기 위해 이를 두기로 함

      • Relative position bias까지 포함하여 $Q_i$에 대한 focal self-attention은 아래와 같이 구할 수 있음

        $$
        Attention(Q_i,K_i,V_i) = Softmax(\frac {Q_iK_i^T} {\sqrt d} + B)V_i
        $$

      • 여기에서 $B={B^l}_1^L$가 learnable relative position bias를 의미하고 $L$개의 focal level에 관하여 $L$개의 subset을 구성함

      • Sub-window pooling 때와 마찬가지로 input feature map의 첫 번째 local level에서는 $B^1 \in R^{(2s_p-1) \times (2s_p - 1)}$로 parameterize하여 horizontal과 vertical position이 모두 $[-s_p + 1, s_p -1]$ 범위에 오도록 함

      • 다른 focal level의 경우 query 마다 granularity가 모두 다르다는 점에서 모든 query를 동일하게 window 안에 집어넣어 $B^l \in R^{s_r^l \times s_r^l}$로 나타내어 query window와 각각의 $s_r^l \times s_r^l$ pooling token간의 relative position bias를 나타냄

      • 기본적으로 focal self-attention은 서로 다른 window끼리는 영향을 주지 않기 때문에 위의 $Attention$ 식의 경우 따로 따로 계산이 가능함

Complexity analysis

  • Window wise attention에서 언급한 과정을 통해 computational complexity를 분석할 수 있음
  • Input feature map $x \in R^{M\times N \times d}$에 대하여 focal level 1에서 $\frac M {s_w^l} \times \frac N {s_w^l}$개의 sub-window를 생성할 수 있음
  • 각 sub-window마다 pooling operation은 $O((s_w^l)^2d)$의 complexity를 가지고 있고 모든 sub-window를 가지고 오면 $O((MN)d)$, 여기에 모든 focal level을 가지고 오면 $O(L(MN)d)$만큼의 complexity를 가지며 이는 각 focal level의 sub-window size와는 별개이다.
  • Window wise attention에서 정의한 $Attention$ computation을 고려하면 size가 $s_p \times s_p$인 특정 query window에서의 computational cost는 $O((s_p)^2\sum_l(s_r^l)^2d)$이고 전체 input feature map에 관한 computational cost는 $O(\sum_l(s_r^l)^2(MN)d)$이다.
  • 위와 같은 사항을 모두 정리하여 도출한 focal self-attention에 대한 전체 computational cost는 $O((L+\sum_l(s_r^l)^2)(MN)d)$이다.
  • 이때 모든 query에 관하여 global receptive field를 강하게 보이게 하는 극단적인 case의 경우 $s_r^L = \frac {2max(M,n)} {s_w^L}$로 setting 가능

4. Trials

Image Classification

  • ImageNet-1K를 사용하여 서로 다른 method를 비교하였는데 공정성을 위해 모든 model은 epoch 300, batch-size를 1024로 두었음
  • 이때, learning rate의 경우 $10^{-3}$으로 초기화하였고 $10^{-5}$부터 시작하여 20개의 epoch마다 warming up 과정을 진행하며 Adam W를 optimizer로 사용하였고 weight decay를 0.05로 설정하여 maximum gradient norm은 5.0까지 올라감
  • Data augmentation의 경우 모두 동일 하게 적용하였고 random erasing을 진행한 후 augmentation 반복 후 EMA 값을 구하게 됨

  • 시행 결과 baseline model에 대한 summarize를 진행한 결과 Focal Transformer가 parameter 개수와 complexity가 비슷한 기존의 SoTA model보다 좋은 성능을 보임
  • 이는 window size를 늘렸을 때에도 마찬가지로 SoTA보다 좋은 성능을 보였고 51.1M개의 parameter만으로 83.5%의 Top-1을 달성함
Comments