본문 바로가기
Personal Projects/Model

[생성 신경망] GAN 모델 구현

by muns91 2024. 4. 17.
GAN 실습

 

 지난 시간에는 적대적 생성 신경망 (Generative Adversarial Network, GAN)에 대해 간단한 이론을 살펴보았습니다. 지난 글에서는 이미 이론을 살펴보았으니, 직접 모델을 구현해봐야겠지요? 아래를 통해 모델 구현 환경 및 코드를 살펴보면서 과정을 설명하도록 하겠습니다. 

 

 

GitHub - Muns91/Generative-Adversarial-Network

Contribute to Muns91/Generative-Adversarial-Network development by creating an account on GitHub.

github.com


 

라이브러리 & 모듈 import

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Flatten, Dropout, Reshape, LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt
import os
import matplotlib.image as mpimg

 

 해당 코드에서는 모델 구현, 배열 그리고 이미지를 위한 라이브러리와 모듈들을 import 하였습니다. 해당 코드를 통해 원하는 것들을 import하고 라이브러리와 모듈의 기능들을 사용하실 수 있습니다. 

 


데이터 불러오기

if not os.path.exists('gan_images'):
    os.makedirs('gan_images')
    
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
X_train = (X_train - 127.5) / 127.5  # 데이터 정규화

true = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

 

 해당 코드에서는 1) GAN을 통해 생성할 이미지를 저장할 폴더를 확인하고 생성하는 코드, 2) 데이터를 로드하고 저장할 변수 지정, reshape 및 정규화 그리고 3) 진짜 이미지와 가짜 이미지에 대한 레이블을 생성하는 코드로 이루어져있습니다. 아래를 통해 추가적인 설명을 진행하도록 하겠습니다. 

 

1) 폴더 확인 및 생성

if not os.path.exists('gan_images'):
    os.makedirs('gan_images')

 

 이 코드에는 'gan_images' 라는 이름의 디렉토리가 현재 작업 디렉토리에 존재하는지를 확인합니다. 만약 디렉토리가 존재하지 않는다면 'os.makedirs('gan_images')'를 호출하여 해당 폴더를 생성하게 됩니다. 이 폴더는 훈련 중에서 생성된 이미지를 저장하는 용도입니다. 

 

2) 데이터 로딩 및 전처리

(X_train, _), (_, _) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
X_train = (X_train - 127.5) / 127.5  # 데이터 정규화

 

 실습에서는 손글씨 데이터인 'MNIST' 데이터를 사용합니다.' 따라서 데이터를 로드하고 이를 X_train에 저장합니다. 그리고 저장받은 데이터를 사용하기 위해 추가적으로 차원을 '1' 하나 더 만들어 4차원 형태로 만들어냅니다. 여기서 추가된 1은 이미지의 채널 수를 의미합니다. 실습에서 사용된 MNIST 데이터는 흑백이기 때문에 채널 수는 1입니다. 

 

3) 레이블 생성 *

true = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

 

'true'와 'fake'는 각각 진짜 이미지와 가까 이미지에 대한 레이블을 생성합니다. true는 진짜 이미지에 해당하는 배치 크기만큼의 1로 채워진 배열을 생성하고, fake는 가짜 이미지에 해당하는 배치 크기만큼 0으로 채워진 배열을 생성합니다. 이 두 레이블은 훈련 과정에서 판별자가 이미지를 진짜로 판단하는지 혹은 가짜로 판단하는지에 대한 학습 목표로 사용됩니다. 

 


생성자 모델 *

# 생성자 모델을 만듭니다.
generator = Sequential([
    Dense(256, input_dim=100),
    LeakyReLU(alpha=0.2),
    Dense(512),
    LeakyReLU(alpha=0.2),
    Dense(28*28*1, activation='tanh'),  # 결과 이미지의 차원 맞춤
    Reshape((28, 28, 1))
])

generator.summary()

 

 실습에서 사용되는 생성자 모델입니다. 생성자 모델은 입력으로 랜덤 노이즈를 받아 가짜 이미지를 만들어 내는 역할을 합니다. 따라서 모델의 아웃풋은 이미지와 같은 형태로 Reshape를 해줘야 합니다. LeakyReLU에서 'alpha'는 음수 입력에 대해서 아주 작은 기울기를 허용하기 때문에 이로 인해 입력 값이 0 이하일 때도 완전히 0이 아닌 alpha 변수에 값을 곱한 값을 출력합니다.

 

생성자 모델 summary()

 

판별자 모델 *

# 판별자 모델을 만듭니다.
discriminator = Sequential([
    Flatten(input_shape=(28, 28, 1)),
    Dense(512),
    LeakyReLU(alpha=0.2),
    Dropout(0.3),
    Dense(256),
    LeakyReLU(alpha=0.2),
    Dropout(0.3),
    Dense(1, activation='sigmoid')
])
discriminator.summary()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
discriminator.trainable = False

 

 이번에는 판별자 모델입니다. 이 모델은 생성된 가짜 이미지가 실제 이미지와 얼마나 유사한지를 판별하는 기능을 수행합니다. 판별자의 목표는 진짜 이미지를 진짜(1)로 가짜 이미지를 가짜(0)으로 정확하게 구분하는 것입니다. 따라서 위에서 언급한 true와 fake 레이블이 1과 0으로 되어 있기 때문에 이를 최종적으로 답을 내놓기 위해 아웃풋의 출력은 1이고 이를 예측하기 위해 binary_crossentropy와 sigmoid를 사용하게 됩니다. 

 

discriminator.trainable = False

 

 마지막에 사용된 이 설정은 생성자와 판별자가 연결된 전체 GAN에서 훈련을 할때 판별자의 가중치가 업데이트되지 않도록 합니다. 이는 판별자를 고정시켜서 생성자만을 훈련하게 함으로써 전체적인 GAN 과정에서 꼭 필요합니다. 

 

판별자 모델 summary()

 

GAN 모델 만들기 *

# GAN 모델을 만듭니다.
g_input = Input(shape=(100,))
g_output = generator(g_input)
d_output = discriminator(g_output)
gan = Model(g_input, d_output)
gan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

 

 

 우선적으로 생성자의 입력 함수(g_input)는 100차원의 랜덤 노이즈가 '생성자'의 입력으로 사용됩니다. 차원이 100인 이유는 생성자가 다양한 패턴과 특징을 학습하고 이를 기반으로 복잡한 이미지를 생성할 수 있을 정도로 충분히 많은 정보를 포함할 수 있기 때문입니다. 따라서 차원이 높아질 수록 생성자가 더 세밀하고 다양한 출력을 생성할 수 있게 됩니다. 하지만 차원 수를 너무 높게 설정하면 모델 학습이 더 어려워지고 과적합의 위험이 증가할 수 있습니다. 

 

 다음 g_ouput은 생성자가 만들어낸 가짜 이미지 입니다. 이는 바로 다음의 판별자의 입력으로 사용되게 되며 이를 통해 판별자는 가짜와 진짜에 대한 출력을 d_output을 통해 내보냅니다. 그래서 g_input과 d_output이 Model로 들어가 생성자와 판별자가 연결된 하나의 큰 모델로 만들어지게 됩니다. 

 


 

학 습

epoch = 5000
batch_size = 128
saving_interval = 200


for i in range(epoch):
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    imgs = X_train[idx]
    d_loss_real = discriminator.train_on_batch(imgs, true)
    noise = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(noise)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)

    # 손실 값 추출과 스칼라 변환 확인
    d_loss_real = d_loss_real if isinstance(d_loss_real, float) else d_loss_real[0]
    d_loss_fake = d_loss_fake if isinstance(d_loss_fake, float) else d_loss_fake[0]

    # 손실 값의 평균 계산
    d_loss = 0.5 * (d_loss_real + d_loss_fake)
    g_loss = gan.train_on_batch(noise, true)

    # 결과 출력
    print(f'epoch:{i} d_loss:{d_loss:.4f} g_loss:{g_loss:.4f}')

    if i % saving_interval == 0:
        noise = np.random.normal(0, 1, (25, 100))
        gen_imgs = generator.predict(noise)
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(5, 5)
        count = 0
        for j in range(5):
            for k in range(5):
                axs[j, k].imshow(gen_imgs[count, :, :, 0], cmap='gray')
                axs[j, k].axis('off')
                count += 1
        fig.savefig(f"gan_images/gan_mnist_{i}.png")
        plt.close(fig)

 

 여기서부터는 학습 과정이 수행됩니다. 불러온 모델을 통해서 epoch 만큼 학습이 수행되고 saving_interval에 저장된 숫자만큼 epoch 당 생성한 이미지를 저장합니다. 나머지 자세한 사항은 아래를 통해 설명하도록 하겠습니다. 

 

1) 훈련 루프

for i in range(epoch):
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    imgs = X_train[idx]
    d_loss_real = discriminator.train_on_batch(imgs, true)
    noise = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(noise)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)

 

 해당 코드에서는 무작위로 이미지를 선택하여 판별자를 진짜 이미지로 훈련합니다. 그리고 노이즈로부터 가짜 이미지를 생성하게되고 생성된 가짜 이미지를 반별자에게 제공하여 가짜로 인식하도록 훈련합니다. 

 

2) 손실 계산 및 출력

    d_loss = 0.5 * (d_loss_real + d_loss_fake)
    g_loss = gan.train_on_batch(noise, true)
    print(f'epoch:{i} d_loss:{d_loss:.4f} g_loss:{g_loss:.4f}')

 

 진짜와 가짜 이미지에 대한 판별자의 손실을 평균내어 계산합니다. 여기서 판별자의 손실을 평균 내는 이유는 두 가지 주요 손실 값, 즉 진짜 이미지에 대한 손실과 가짜 이미지에 대한 손실을 종합적으로 고려하기 위해서 입니다. 이를 통해서 판별자의 전반적인 성능을 균형 있게 평가하고 모델의 학습을 더 안정적으로 관리할 수 있습니다. 

 

 

3) 이미지 저장

    if i % saving_interval == 0:
        noise = np.random.normal(0, 1, (25, 100))
        gen_imgs = generator.predict(noise)
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(5, 5)
        count = 0
        for j in range(5):
            for k in range(5):
                axs[j, k].imshow(gen_imgs[count, :, :, 0], cmap='gray')
                axs[j, k].axis('off')
                count += 1
        fig.savefig(f"gan_images/gan_mnist_{i}.png")
        plt.close(fig)

 

 정해진 epoch 간격마다 생성된 이미지를 저장하고 각 저장 시점에서 노이지를 새로 생성하여 새로운 이미지를 만들고 이를 5x5 그리드에 표시한 파일로 저장합니다. 

 


 

원본과 생성 이미지 비교

def load_and_compare_images(epoch, example_index=0):
    # MNIST 데이터 불러오기
    (X_train, _), (_, _) = mnist.load_data()
    X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
    X_train = (X_train - 127.5) / 127.5  # Normalize

    # 저장된 이미지 불러오기
    image_path = f"gan_images/gan_mnist_{epoch}.png"
    if not os.path.exists(image_path):
        print("해당 경로에 이미지 파일이 없습니다:", image_path)
        return

    generated_image = mpimg.imread(image_path)

    # 원본 이미지 선택
    original_image = X_train[example_index].reshape(28, 28)

    # 이미지 비교
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(original_image, cmap='gray')
    axs[0].set_title('Original MNIST Image')
    axs[0].axis('off')

    axs[1].imshow(generated_image)
    axs[1].set_title('Generated Image at Epoch ' + str(epoch))
    axs[1].axis('off')

    plt.show()

# 예를 들어, 3000번째 에포크의 생성된 이미지와 원본 이미지 비교
load_and_compare_images(3000, example_index=31)

 

 이 코드에서는 특정 epoch에서 생성된 GAN 이미지와 실제 데이터 셋의 이미지를 비교합니다. 목적은 생성자가 시간에 따라 얼마나 진짜 같은 이미지를 생성하는 지 시각적으로 확인하기 위해서 입니다. 아래 사진을 통해서 특정 Epoch 마다의 이미지를 비교보았습니다. 확실히 학습이 덜 되었을 때는 상대적으로 노이즈가 많이보이고 일정 구간이 지나면 글자의 모양이 나타나기 시작하는 것을 확인할 수 있습니다. 

 

epoch=600

 

epoch=3000

 

epoch 4800

 


 

마무리

 여기까지 기본적인 GAN에 대한 실습을 해보았습니다. 다음에는 Convolution layer를 적용한 DCGAN (Deep Convolutional GAN)을 활용한 실습을 진행하도록 하겠습니다.  늘 이론으로만 알고 있던 GAN을 직접 구현해보고 결과를 살펴보니, 빨리 어딘가에 써먹어 보고 싶은 생각이 드네요!. 그럼 이번 글은 여기까지로 마무리하도록 하겠습니다. 

 

참고 

모두의 딥러닝 : https://thebook.io/080324/

 

모두의 딥러닝 개정 3판

더북(TheBook): (주)도서출판 길벗에서 제공하는 IT 도서 열람 서비스입니다.

thebook.io

 

 

반응형