본문 바로가기
딥러닝 모델 글로만 이해하기

VAE (variational auto encoder) 쉬운 설명 (글로만 이해하기)

by bigpicture 2023. 11. 1.
반응형

VAE는 인코더와 디코더로 구성됩니다. 모델 학습이 완료된 뒤에는 디코더만 사용됩니다. 임의의 z값들을 생성해서 디코더에 넣으면 이미지가 생성되는 VAE를 만들 것입니다. 모델이 만들어지는 순서대로 설명하겠습니다. 

 

1. 인코더 만들기

입력 이미지 해상도가 24x24 라면 입력 이미지는 576 차원의 벡터입니다. 

576차원을 입력받아 512차원을 내보내는 뉴런들을 구성합니다. 은닉 1층입니다. 출력값에 Relu 함수를 적용합니다. 

512차원을 입력받아 256차원을 내보내는 뉴런들을 구성합니다. 은닉 2층입니다. 출력값에 Relu 함수를 적용합니다. 

256차원을 입력받아 latent 변수들의 평균과 분산을 내보내는 뉴런들을 구성할 겁니다. 은닉 3층입니다. 각 latent 변수들은 정규분포 형태로 구현할 것입니다. latent 변수는 32개로 정하겠습니다. 따라서 은닉 3층의 출력값은 32개의 평균과 32개의 표준편차여야 합니다. 표준편차에 해당하는 출력값은 ln(varience) 로 해석하고, 이를 표준편차로 변형합니다. 출력값을 exp(0.5*ln(varience)) 하면 표준편차로 변합니다. log(var)을 사용하는 이유는 학습시 안정성이 높기 때문입니다. 은닉 3층은 256차원을 입력받아 64차원을 내보내는 뉴런들로 구성합니다. 

인코더의 최종 출력은 32개의 평균과 32개의 표준편차입니다. 

 

2. latent 변수의 reparemetrize

각 latent 변수가 정규분포로 표현되었으니, 각 정규분포에서 값을 랜덤추출해서 디코더의 입력값으로 사용해야 합니다. 

back propagation 알고리즘 적용을 위해 약간의 트릭을 적용해주겠습니다. 어떤 latent 변수를 z라고 놓고 평균을 $\mu$, 표준편차를 $s$라고 합시다. 이때 z를 평균과 표준편차에 대한 변수 형태로 표현해야 역전파를 적용할 수 있습니다. 평균과 표준편차는 신경망 weight 들로 계산된 것이므로 평균과 표준편차를 신경망 weight 으로 편미분을 해야하기 때문입니다. 평균이 $\mu$이고 표준편차가 $s$ 인 정규분포에서 추출된 표본 z를 수식으로 나타내려고 시도해보시면 어렵다는 것을 알 수 있습니다. 이때 사용하는 트릭이 reparametrize 트릭입니다. 표준정규분포에서 추출된 표본을 $\varepsilon$으로 두면 아래와 같이 나타낼 수 있습니다. 

$z=\mu+\sigma \times \varepsilon$

각 latent 변수에서 값을 랜덤추출하면 32개의 값이 생깁니다. 

 

3. 디코더

latent 변수들에서 생성한 32개의 값을 디코더 첫 은닉층의 입력값으로 사용합니다. 32차원을 입력받아 256차원을 출력하는 뉴런들을 구성합니다. 은닉층 4라고 두겠습니다. Relu 함수를 적용합니다. 

256차원을 입력받아 512차원을 출력하는 뉴런들을 구성합니다. 은닉 5층이라고 두겠습니다. Relu 함수를 적용합니다. 

512차원을 입력받아 입력 이미지의 차원인 576 차원을 출력하는 신경망을 구성합니다. 출력된 값을 24x24 형태로 나열하면 이미지가 됩니다. 

모델 구성은 끝났습니다. 이제 모델을 어떻게 최적화할지 설명하겠습니다.

 

4. 손실함수

최적화하려면 목적이 필요합니다. 목적을 정의하기 위해 손실함수를 알아봅시다. 손실함수는 아래와 같은 코드로 정의되어 있습니다. 

 

def loss_fn(x, x_hat, logvar):
    # Reconstruction loss (could use also MSE, but BCE is good for inputs bounded in [0, 1])
    bce = F.binary_cross_entropy(x_hat, x, reduction='mean')

    # KLD between Gaussian rvs
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return bce + kld

 

손실함수는 binary cross entropy(bce)와 KL divergence 의 합입니다. 위 코드에서 bce 는 입력이미지 x와 출력 이미지 y의 차이를 표현합니다. 이미지 픽셀 값을 [0,1]로 스케일링 했기 때문에 위와 같은 계산이 가능해집니다. KL divergence 는 우리가 예측한 z의 분포와, z의 실제 분포라고 가정한 표준정규분포 사이의 차이입니다. 

 

손실함수를 최소화하는 방향으로 최적화가 진행됩니다. 손실함수는 우리가 만들려는 모델의 목적을 보여줍니다. 우리가 만들려는 모델의 목적은 아래 두가지입니다. 

 

1) 출력 이미지 x' 와 입력이미지 x가 같아지도록 함

2) latent 변수가 표준정규분포를 따르도록 함

 

이렇게 프로그래밍된 코드를 이해하는 것은 오히려 쉬운데, 이 내용을 확률론의 관점으로 이해하는 것은 어렵습니다. 추상화되어 있어 그런 것 같습니다. 

반응형

댓글