연세대 인공지능학회 YAI

[Continual Learning] What is continual learning? 본문

Learning/Continual

[Continual Learning] What is continual learning?

_YAI_ 2022. 8. 13. 14:08

What is continual learning?

 

* YAI 9기 조현우 님이 작성한 글입니다.


Introduction

사람과 다르게 인공지능은 순차적으로 들어오는 데이터를 전부 기억하지 못합니다.

동물을 분류하는 인공지능 모델을 예로 들어보겠습니다. Day 1에는 개의 이미지만 있어서, 이것으로만 학습을 하면 모델은 개를 인식하고 분류할 수 있게 됩니다. 그 다음에 Day 2에 고양이 이미지가 새로 들어와서, 이것으로 새로 학습을 시키면 모델은 고양이를 분류할 수 있게 됩니다. 하지만 고양이로만 학습을 시킨다면 이전에 학습했던 개에 대한 정보는 전부 잃어버리게 될 것입니다.

 

 

이처럼 인공지능 모델이 새로운 데이터를 학습할때 기존의 데이터에 대한 정보를 잃어버리는 현상을 Catastrophic forgetting이라고 하고, catastrophic forgetting을 해결하기 위한 연구가 continual learning입니다. 말 그대로 순차적으로 들어오는 데이터를 처리하기 위한 연구라고 할 수 있고 lifelong learning, incremental learning으로도 불립니다.


Naive ideas for continual learning

어떤 태스크를 처리하기 위해 데이터를 한번 학습한 인공지능 모델이 이후에 새로운 데이터가 들어온다면 어떻게 학습해야 할까요? 다음과 같은 경우를 생각해 볼 수 있습니다.

 

  1. 이전 데이터로 학습한 모델을 새로운 데이터로만 학습한다.
  2. 이전의 데이터와 새로 들어온 데이터를 합친 데이터셋을 만들고, 새로운 데이터셋으로 모델을 다시 학습한다.
  3. 모델의 구조를 변형시켜 기존 데이터와 새로운 데이터를 잘 학습시킬 수 있게 한다.

1은 위에서 언급했던 문제가 그대로 발생합니다. 기존 데이터를 고려하지 않고 새로운 데이터로만 모델을 학습하면 기존 데이터에 대한 정보를 잃어버리고, 이전의 태스크를 해결하지 못하게 됩니다.

 

2는 자원과 시간이 충분하다라는 가정이 있을 경우에만 이상적인 상황입니다. 이전의 데이터도 전부 보존되면 새로운 방법을 생각할 것 없이 모든 데이터로 모델을 다시 학습시키면 이전/현재 데이터에 대한 정보도 전부 학습하면서 모든 태스크를 처리할 수 있게 될 것입니다. 하지만 그런 가정은 컴퓨팅 자원이 유한하고 학습 시간이 제한되는 현실적인 상황과는 괴리가 있습니다. 다만 이전에 있던 데이터를 전부 가져오기 보다는 제한적인 메모리를 사용하여 기존 데이터의 일부와 새로운 데이터를 합친 데이터셋으로 모델을 학습시키는 것은 어느정도 말이 되는 가정이겠죠.

 

3은 모델의 설계관점에서 바라본 아이디어입니다. 1,2와는 별개로 기존 데이터에 대한 정보를 유지하고 새로운 데이터를 학습할 수 있는 구조로 모델을 continual learning에 적합하게 바꿔본다는 생각입니다.

 

실제로 continaul learning을 해결하기 위해 제시된 논문들은 어떤 아이디어를 바탕으로 했는지, 뒤에서 자세히 살펴보겠습니다.


Class-incremental / Task-incremental

continual learning을 해결하기 위해 제시된 방법들을 살펴보기 앞서, continual learning의 유형에 대해 알아보겠습니다. 논문에서 제시하는 유형은 class-incremental learning, task-incremental learning, incremental domain learning으로 제시하지만 앞의 두개에 대한 정의와 “task”에 대한 정의만 간단히 살펴보겠습니다.

 

  • Task: an isolated training phase with a new batch of data, belonging to a new group of classes, a new domain, or a different output space.
  • Class-incremental learning: an output space for for all observed class labels ${\mathcal{Y}^{(t)}} \subset {\mathcal{Y}^{(t+1)}}$ with $P(\mathcal{Y}^{(t)}) \neq P(\mathcal{Y}^{(t+1)})$
  • Task-incremental leraning: ${\mathcal{Y}^{(t)}} \neq {\mathcal{Y}^{(t+1)}}$ which additionaly requires task label $t$ to indicate the isolated output nodes $\mathcal{Y}^{(t)}$ for current task $t$.

즉 batch마다 제시되는 하나의 데이터 셋이 하나의 task라 할 수 있고, class-incremental learning은 이전의 태스크를 포함하여 현재까지 제시된 모든 태스크에 대한 예측을 하는 것이며 task-incremental은 현재의 태스크에 대한 예측만 수행하는 것입니다. Class-incremental learning이 task-incremental learning에 포함되고, 특수한 경우라고 생각할 수 있습니다.


Continual Learning Approach

아래 그림은 Survey 논문에서 저자들이 제시하는 continual learning의 taxonomy입니다. 크게 3가지로 분류할 수 있습니다.

 

  • Replay methods
  • Regularization-based methods
  • Parameter isolation methods

 

Replay methods는 위에서 제시한 2번 방법에 가깝습니다. 기존의 데이터를 저장해놓은 다음 새로 들어오는 데이터와 함께 모델을 재학습합니다. Catastrophic forgetting을 해결할 수 있는 심플하고 효율적인 방법이지만, 새로운 데이터를 기존 데이터와 함께 학습하는 만큼 새로운 태스크에 대한 성능이 떨어집니다. 이를 해결하기 위해 기존의 태스크는 유지하면서 새로운 태스크에 대한 정보만 업데이트 할 수 있도록 하는 연구들이 진행됐습니다.  

 

Regularizationi-based method는 loss 함수에 새로운 regularization term을 추가해서 catastrophic forgetting을 막는 방법입니다. 크게 data-focused methods와 prior-focused methods로 나눌 수 있는데, 전자는 기존 데이터에 대해선 그대로 classification loss를 사용하고 새로운 데이터에 대해서만 knowledge distillation loss를 사용한다는 접근 방식입니다. Prior-focused method는 기존 모델의 파라미터에 대한 분포에서, 각 파라미터가 독립적이라는 가정을 하고 파라미터의 변화가 매우 작게 이루어지도록 하여 기존 태스크와 새로운 태스크를 전부 수행할 수 있도록 파라미터를 수정합니다.  

 

Parameter isolation method는 위에서 3번 접근과 유사하게, 각 태스크에 대해 모델을 따로 두고 각각에 대한 파라미터를 추정한다는 아이디어 입니다. 기존 태스크에 대해서는 파라미터를 동결하고, 새로운 태스크에 대해서만 학습을 진행해서 파라미터를 얻게됨으로써 catastrophic forgetting을 방지합니다.   

 


Reference

Matthias De Lange et al., “A Continual Learning Survey: Defying Forgetting in Classification Tasks”, TPAMI 2021

Comments