본문 바로가기

알아두면 쓸데없는 신비한 잡학사전

About GAN(Generative Adversarial Nets)

데이터를 생성하는데 사용하는 GAN은 무엇일까?

 

GAN은 생성 모델이다.

기계 학습은 크게 지도 학습과 비지도 학습으로 구분한다. 지도 학습과 비지도 학습을 구분하는 기준은 데이터의 특성이다. 정답 라벨이 있는 데이터를 가지고 예측이나 분류 과제를 수행하는 모델을 만드는 것을 지도 학습이라고 부른다. 대표적으로 손 글씨 숫자 데이터 MNIST에는 손 글씨 숫자와 그 숫자에 대한 라벨(0-9)이 있다. 지도 학습에 따르면 손 글씨 숫자 이미지를 보고 어떠한 숫자인지 분류하는 것은 지도 학습의 전형이라고 할 수 있다.

 

지도 학습을 활용하여서 데이터에 따른 정답을 맞히는 모델을 판별 모델(Discriminative Model)이라고 한다. 일반적으로 판별 모델을 구현하기 위해서 지도 학습의 방법을 사용한다. 생성 모델(Generative Model)은 판별 모델과 자주 비교하면서 소개된다. 생성 모델의 목적은 정답을 맞히는 것이 아니다. 생성 모델은 이름에서 알 수 있듯이 생성을 하는 것이 주된 목적이다. MNIST 손 글씨 숫자 데이터로 예를 들면 판별 모델은 숫자 이미지를 보고 어떠한 숫자인지 판별하는 모델이라면 생성 모델은 손 글씨 숫자 데이터를 생성하는 모델이다. 생성 모델은 데이터 세트에 있는 손 글씨 숫자 데이터 이미지와 유사한 이미지를 생성해내는 것이 주된 목적이다.

 

앞선 기계 학습은 지도 학습과 비지도 학습으로 분류된다고 하였으니 생성 모델은 비지도 학습으로 구현되는 것 아니냐고 추론할 수 있다. 물론 생성 모델에 있어서 데이터 세트의 정답 세트는 중요하지 않다. 손 글씨 이미지를 생성할 때, 해당 손 글씨 이미지의 정답 라벨 자체에 대한 중요도는 떨어진다. 하지만 그렇다고 해서 생성 모델이 전적으로 비지도 학습인 것은 아니다. 지도 학습의 아이디어를 사용해서 어떠한 정답에 가까워지도록 학습함으로서 생성 모델을 구현할 수 있다.

 

GAN이 등장하기 전까지 생성 모델의 대표적인 학습 방법은 Variational Auto Encoder(VAE)이었다. 물론 VAE는 생성 모델에 있어서 지금도 많이 연구되고 활용한다. VAE는 지도 학습처럼 동작한다. VAE는 입력 데이터로 들어온 이미지와 정답 레이블이 동일하다. 즉 3이라고 쓰인 손 글씨 숫자가 입력으로 VAE 모델의 입력으로 들어오면 모델은 입력으로 들어온 손 글씨 숫자 3과 동일한 결과를 내도록 학습한다. 이를 위해서 입력으로 들어온 이미지에서 특징을 추출하는 Encoder와 이를 통해 유사하지만 새로운 이미지를 생성하는 Decoder로 VAE을 구성한다.

 

GAN은 어떻게 데이터를 생성하는가.

Ian Goodfellow가 제안한 GAN 역시 지도 학습을 활용하여서 학습이 이루어진다. GAN은 Generator(생성기)와 Discriminator(판별기)로 구성한다. Generator의 목적은 이름에서 알 수 있듯이 랜덤한 데이터로부터 원하는 데이터(주로 이미지)를 생성한다. Discriminator는 Generator가 생성한 이미지와 원래 데이터 세트에 있는 이미지를 보고 어떠한 것이 원래 데이터에 있는 이미지인가를 판별하는 역할을 한다.

 

GAN의 동작을 가장 잘 설명하는 비유는 Ian Goodfellow의 논문에서도 언급된 것처럼 경찰과 지폐 위조범의 비유이다. 여기서 지폐 위조범은 Generator, 경찰은 Discriminator에 대응한다. 지폐 위조범(Generator)는 가짜 지폐를 만든다. 물론 진짜 지폐와 유사하게 만드는 것이 중요하다. 그리고 경찰(Discriminator) 지폐 위조범이 만든 가짜 지폐와 진짜 지폐를 구별해야한다. 물론 기술이 부족한 지폐 위조범은 처음에 생성한 지폐는 조악할 것이다. 하지만 점차 기술이 쌓여감에 따라 진짜와 비슷한 위조 지폐가 만들어진다. 성장한 위조 실력에 따라서 경찰도 판별하는 능력을 길러야한다. 이렇게 지폐 위조범과 경찰이 서로의 실력을 자극하게 되어서 성장하는 동안 위조 지폐는 진짜 지폐와 점차 비슷해진다.

 

그렇다면 어떻게 GAN은 학습하지?

GAN 학습을 위해 사용하는 Value Function(Cost Function)을 뜯어보면서 GAN의 학습이 어떻게 이루어지는지 확인해보자.

GAN의 수식

GAN의 학습은 Generator와 Discriminator가 다르다. Generator는 Value Function이 최소화되도록 학습이 진행되는 반면 Discriminator는 Value Function이 최대화되도록 학습이 이루어진다. 따라서 GAN의 학습을 이해하기 위해서는 Value Function을 이해해야한다.

 

Value Function은 크게 두 부분으로 이루어진다. 하나는 원래 데이터와 관련된 부분(전자)이고 다른 하나는 Generator에 의해서 생성된 데이터와 관련된 부분(후자)이다.

 

Pdata 는 원래 데이터에 대한 확률 분포로서 식의 앞 부분은 원래 데이터로부터 샘플링한 데이터와 관련된 식이다. 원래 데이터에 대해서는 Discriminator는 0이 아닌 1을 예측해야한다. 따라서 logDx 의 값은 최대가 되어야하며, 그 때의 로그 값은 Dx=1 인 경우로서 0(=log1) 이 된다. 반면 Generator는 원래의 데이터에 대해서 Discriminator가 잘 동작하는 것과는 큰 상관이 없으므로 Generator 학습에 있어서 식의 전반부는 활용되지 않는다.

 

Pz 는 데이터를 생성하기 위해서 필요한 확률 분포로서 일반적으로 정규 분포를 사용한다. Value Function의 후반부는 정규 분포로부터 샘플링한 z 에서 시작한다. z 는 새로운 이미지를 생성하기 위한 시드(seed)로서 노이즈라고 표현한다. 노이즈로부터 Generator는 이미지를 생성한다(Gz ). 여기서 학습이 잘된 Discriminator는 Generator가 생성한 이미지를 0으로 예측해야한다. 따라서 1-D(Gz) 가 최대가 되는 방향으로 학습이 이루어진다. 반면 Generator는 Discriminator를 속이고 D(Gz)  값이 1이 되어야한다. 따라서 Generator는 1-D(Gz) 가 최소가 되는 방향으로 학습을 진행한다.

 

GAN으로 생성한 MNIST

GAN의 아이디어를 코드로 옮겨서 MNIST 데이터를 생성해보자. 여기서 주의할 것은 Generator를 학습할 때에는 Discriminator의 가중치는 갱신해서는 안된다. 만약 Generator를 학습할 때 Discriminator의 가중치도 같이 갱신된다면 Discriminator는 제대로 학습되지 않는다. Generator를 학습할 때 Discriminator는 가짜 이미지에 대해서 1이라고 출력하도록 학습이 진행된다. 하지만 직전까지 Discriminator는 가짜 이미지는 0이고 진짜 이미지가 1이라고 학습되었다. 따라서 Generator를 학습하면서 Discriminator를 학습하면 이처럼 충돌이 발생한다. 혼란을 피하기 위해서 Discriminator는 Discriminator를 학습할 때(즉 진짜 이미지는 1이고 가짜 이미지는 0이다)에만 가중치를 변경하고 Generator를 학습할 때에는 가중치를 갱신해서는 안된다. 이것에만 주의하면 코드 작성은 상당히 용이하다.

GAN으로 생성한 손 글씨 숫자

약 300,000번의 학습 과정을 거쳐서 생성된 데이터이다. 물론 깊지 않은 신경망이었고, 학습 횟수도 많지 않았기 때문에 생성된 이미지라는 것이 다소 티가 나지만 그래도 상당히 알아볼 수 있는 숫자의 형태이다.

 

https://github.com/dhsong95/generative-adversarial-nets-mnist

 

dhsong95/generative-adversarial-nets-mnist

Generate MNIST using GAN. Contribute to dhsong95/generative-adversarial-nets-mnist development by creating an account on GitHub.

github.com