[딥러닝] 생성형 모델 : VAE / GAN / Diffusion (+ 이미지 생성 및 복원해보기)
이번 포스팅에서는 생성형 모델에 대해서 알아보려고 한다.
생성형 모델은 주어진 데이터 분포를 학습하여 새로운 데이터를 생성할 수 있는 머신러닝 모델이다. 생성형 모델의 등장은 딥러닝 분야에서 매우 혁신적이었으며, 지금까지도 많은 분야에서 응용되어 사용되고 있다.
생성 모델의 개념과, VAE, GAN, Diffusion Model과 같은 생성형 모델의 대표적인 유형들의 개념에 대해서 살펴보고, 이를 Tensorflow에서 활용하여 이미지를 생성하거나 복원해보는 작업들을 해보자.
생성 모델
1. 정의 : 주어진 데이터를 바탕으로 새로운 데이터 샘플을 자동으로 생성할 수 있는 알고리즘을 설계
2. 목적 : 데이터 세트의 기본 확률 분포를 이해하고 포착하여 유사한 통계를 가진 새로운 데이터 포인트를 생성하는 것
3. 모델의 학습 방법 : 비지도 학습
- 레이블이 없는 데이터로부터 학습
- 주어진 데이터 속 잠재 구조와 숨겨진 패턴을 학습하는 것이 목표이다.
4. 주요 구성 요소
- 입력 데이터 (Input Data): 모델이 학습할 실제 데이터
- 잠재 공간 (Latent Space): 데이터의 압축된 표현을 포함하는 저차원 공간
- 생성 데이터 (Generated Data): 학습된 분포를 기반으로 생성된 새로운 데이터
5. 판별 모델 vs 생성 모델
- 판별 모델은 특징이 주어졌을 때 레이블(클래스)를 예측하는 것.
- 생성 모델은 특징이 주어졌을 때 새로운 데이터 인스턴스를 생성하는 것.
대표적인 생성 모델의 유형
- PixelRNN / PixelCNN
- VAE (변분 오토인코더)
- GAN (생성적 적대 신경망)
- Diffusion Model
PixelRNN / PixelCNN
-
순차적인 방식으로 이미지를 생성하는 생성 모델
-
이미지를 픽셀 단위로 스캔하면서 각 픽셀 값을 예측
- 이전 픽셀들의 값을 기반으로 다음 픽셀의 확률 분포를 계산
- PixelCNN : CNN 구조를 사용하여 공간적 종속성 학습
- PixelRNN : LSTM 등 RNN 구조를 사용하여 순차적 종속성 학습
-
이미지 생성, 보완, 복원 등 다양한 작업에 사용
Autoencoder
입력 데이터를 효율적으로 인코딩하기 위해 비지도 학습 방식으로 학습되는 인공 신경망
오토인코더의 목표는 입력 데이터를 압축된 형태로 인코딩한 후 압축된 표현에서 원래의 입력 데이터를 재구성하는 것.
즉 원본 데이터의 특징을 압축했다가 이를 원본으로 최대한 정확하게 복원할 수 있도록 설계된 모델이다.
기본 구조
- 인코더(Encoder): 입력 데이터를 저차원 잠재 공간으로 매핑하여 중요한 특징만 추출 (입력층 → 은닉층 →잠재 공간 표현)
- 잠재 공간(Latent Space) : 인코더가 입력 데이터를 압축한 후 얻는 저차원 공간, 입력 데이터의 중요한 특징을 포함, 노이즈 제거
- 디코더(Decoder): 잠재 공간 표현에서 원래의 입력 데이터를 재구성하여 출력한다. (잠재 공간 표현 → 은닉층 → 출력층)
- 손실 함수(Loss Function) : 입력 데이터와 재구성된 데이터 간의 차이를 측정 (MSE 등 사용)
Autoencoder의 주요 목적
- 데이터 압축: 고차원 데이터를 저차원으로 압축
- 노이즈 제거: 노이즈가 있는 데이터를 깨끗하게 복원
- 이상 탐지: 정상 패턴을 학습하여 이상 데이터를 탐지
VAE (Variational Autoencoder, 변분 오토인코더)
기본적인 오토인코더는 데이터의 중요한 특징을 압축하고 이를 원래대로 복원하는 데 중점을 두고 있어 새로운 데이터를 생성하는 데 적합하지 않다. 또한 잠재 공간에서 샘플링된 벡터가 훈련 데이터에서 본 적 없는 값이라면 디코더가 이를 제대로 복원할 수 없다.
이를 극복하기 나온 생성 모델이 VAE(변분 오토인코더)이다.
Encoder : 입력 데이터를 저차원 Latent Space(잠재 공간)의 확률 분포(평균과 분산)으로 매핑한다.
- 입력 데이터 \(x\)를 받아 잠재 공간의 확률 분포 q(z|x)로 변환. 이 분포는 평균 벡터 μ와 표준편차(혹은 분산) σ로 표현된다.
- 입력 데이터에 대한 잠재 변수 \(z\)의 확률 분포를 나타낸다.
Decoder : Latent Space로부터 샘플링된 벡터(특징)를 입력으로 받아 재구성한 데이터를 출력한다.
샘플링 : 입력 데이터를 잠재 공간의 확률 분포로 변환한 후, 이 분포에서 샘플링된 벡터(\(z\))를 디코더가 원본 데이터로 복원하는 방법을 학습한다.
- (ϵ는 정규 분포에서 샘플링된 노이즈)
- Reparameterization trick : 위의 샘플링 과정은 연속적인 학습을 가능하게 한다. 즉 Backpropagation시 샘플링 과정이 미분 가능하다.
- 샘플링된 잠재 백터 (Latent Vector) : 잠재 벡터는 데이터의 특징을 나타내며, 확률 분포로부터 다양하게 샘플링된다는 것은 해당 특징들을 미세하게 조정해가면서 변화나 형태를 보여줄 수 있다는 것이다.
잠재 공간의 확률 분포를 학습 : 각 입력 데이터에 대해 평균과 분산을 계산하고 잠재 공간에서 샘플링할 때 약간의 변화를 주어 다양한 데이터를 생성할 수 있도록 한다.
- 일부가 변형된 샘플링된 데이터를 사용하기 때문에 원본을 기반으로 하는 다양한 데이터를 생성하도록 학습시킬 수 있다.
목표 : VAE는 생성모델이기 때문에 Decoder를 학습시키는 것이 목표이다.
VAE의 학습
VAE 손실 (VAE Loss)
- VAE의 손실 함수는 재구성 손실과 쿨백 라이블러 발산 손실의 합으로 이루어진다.
- 해당 손실을 최소화하여 입력 데이터의 중요한 특징을 잠재 공간에 효과적으로 압축하고, 해당 잠재 공간이 의미 있는 분포를 가지도록 한다.
- 잠재 공간이 의미 있는 분포를 가지도록 한다는 것은 샘플링된 벡터로부터 유의미한 데이터가 나올 확률이 높아진다는 뜻이다.
재구성 손실 (Reconstruction Loss)
- 디코더의 출력과 원본 데이터 간의 차이를 최소화하는 것이 목표 → 데이터가 얼마나 잘 복원되었는지를 측정
- 일반적으로 MSE 혹은 Cross-enrtopy loss를 사용
쿨백 라이블러 발산 손실 (KL Divergence Loss)
- 인코더가 생성한 잠재 분포와 사전 정의된 분포의 차이를 최소화
- 해당 손실을 최소화하는 것은 곧 잠재 공간이 의미 있는 분포를 가지도록 하는 것이다.
GAN (Generative Adversarial Network, 생성적 적대 신경망)
생성자(Generator)와 판별자(Discriminator)의 대립과 경쟁을 통해 모델을 훈련, 사용자가 만족할만한 수준의 결과를 생성하는 모델
- 두 개의 신경망(생성자와 판별자)이 서로 경쟁하는 방식으로 학습한다.
- 생성자는 새로운 데이터 인스턴스(가짜 데이터)를 생성한다.
- 판별자는 생성자의 데이터가 진짜인지 가짜인지 구분하려고 한다.
- 생성 모델이 진짜같은 가짜를 생성하게 하는 것이 목표 -> 판별 모델이 진짜와 가짜를 구분할 수 없으면 성공적인 모델
- 생성자와 판별자의 경쟁 과정을 통해 생성자는 점점 더 진짜 같은 데이터를 생성하게 된다.
-
목적 : 생성자가 생성한 데이터가 판별자에 의해 진짜 데이터와 구별되지 않도록 하는 것
생성자 (Generator)
Generator(G) : 입력(Latent Vector)를 받아서 결과(합성된, 가짜 데이터)를 생성. 학습 과정을 통해 점점 더 진짜같은 데이터를 생성하게 된다.
- Latent Space : 데이터셋에서 중요한 특징을 저차원으로 압축한 공간
- Latent Vector(z) : 데이터의 표준 정규 분포 등에서 무작위로 샘플링된 벡터로서 Latent Space에서 각 Point를 무작위로 추출하여 사용한다. 다양한 유형의 합성된 데이터가 나오도록 하는 역할을 한다.
Generator의 구조
- N개의 Generator Block + Full Connected Layer + Sigmoid 함수로 구성
- Generator Block : Linear Layer + Batch nomalization + ReLU 함수로 구성
판별자 (Discriminator)
Discriminator(D) : 진짜 데이터(x)와, 생성자로부터 생성된 가짜 데이터(G(z))를 입력을 받아서 Real / Fake를 가려내는 판별기(Classifier)
Discriminator의 구조
- N개의 Discriminator Block + Full Connected Laye로 구성
- Discriminator Block : Linear Layer + ReLU 함수로 구성
훈련 과정
1. 생성기가 랜덤한 노이즈 벡터를 입력으로 받아 가짜 데이터를 생성
2. 판별기가 진짜 데이터와 가짜 데이터를 입력받아 각각에 대한 진위 여부를 판단
3. 판별기의 예측 결과를 바탕으로 생성기와 판별기가 각각의 Loss를 계산하고, 이를 통해 모델을 업데이트
4. 이상적으로는 생성기가 진짜 같은 데이터를 만들어내어 판별기가 더 이상 진짜와 가짜를 구별할 수 없을 때까지 훈련 진행
GAN의 훈련 목적과 Loss 함수
GAN의 훈련 목적은 다음과 같다.
- G(Generator)의 손실 최소화 : 실제 데이터와의 차이가 거의 없어야 한다.
- D(Discriminator)의 손실 최대화 : 가짜 데이터와 진짜 데이터를 오판별하는 빈도가 높아져야 한다.
GAN은 Discriminator와 Generator의 각각의 성능을 끌어올려서 둘 간의 대립을 통해 min(G), max(D)를 추구하는 것이 목적이다.
그렇기에 결국 각각의 성능을 끌어올리기 위해서 손실함수가 역할에 따라서 다른 의미를 추구하게 된다.
- Discriminator: D(x)는 1을, D(G(z))는 0을 출력할 것 -> logD(x) & log(1 – D(G(z)))가 max
- Generator: D(G(z))가 1을 출력할 것 -> log(1 – D(G(z)))가 min
Discriminator와 Generator 각각의 목표에 따라 손실 함수의 해석과 최적화 방향이 달라지며, 이것이 GAN 훈련의 어려움이자 중요한 포인트가 된다.
GAN의 훈련 전략
1. D와 G를 동시에 훈련해야 함. D는 max, G는 min을 추구
- 두 모델이 서로의 성능을 극대화하려는 상반된 목표를 가지고 있기 때문에 수렴하기 어렵고, 훈련이 불안정하다.
- 하지만 GAN에서 최적의 경우에 수렴되는 값이 있다는 것을 수학적으로 증명하였다.
2. 훈련 초기에 생성되는 G(z)는 품질이 좋지 않다.
- 매번 D(G(z))이 0에 가까우면, log(1-D(G(z)))이 0에 가까운 값을 갖게 된다.
- Gradient를 적용하기 힘들다.
- 따라서 휴리스틱한 훈련 전략을 사용하게 된다.
- 초기에는 log (1 – D(G(z)))를 사용하지 말고 D(G(z))를 최대화시키는 방향으로 훈련
- D(G(z))는 log (1 – D(G(z)))보다 값이 더 크기 때문에 gradient가 소실되는 경우를 피할 수 있다.
Diffusion Model
Diffusion Model은 점진적 노이즈 추가 및 제거 과정의 효과적 학습을 위한 다양한 방법론을 제안했으며 안정적인 학습 과정과 고품질 데이터 생성 능력으로 주목받고 있는 생성형 모델이다. GAN, VAE와 함께 주요 축으로 부상하였다.
- 데이터 생성 및 변환을 위한 확률적 모델, 노이즈를 추가하고 제거하는 과정을 통해 데이터의 변화를 모사하면서 학습.
- 학습 과정
- 순방향 과정 (Forward Diffusion) : 데이터에 노이즈를 추가하여 데이터 분포를 점점 더 단순한 분포로 변형
- 역방향 과정 (Reverse Diffusion) : 변형된 분포에서 노이즈를 제거하여 원본 데이터를 복원하는 방법을 학습
- 복잡한 데이터 분포를 단순한 분포 변환, 이를 다시 복원하는 과정을 모사하여 새로운 데이터 샘플을 생성하거나 기존 데이터를 변환하는 것이 주요 목표이다.
생성형 모델의 활용
- Dataset의 증가
- 예술적 스타일 이전 방법
- 동영상의 다음 프레임 예측 방법
- 슈퍼 해상도 이미지
- 이미지 다른 이미지로 변환
- 텍스트로 이미지 생성
- 사진으로부터 3D 모델 생성
Tensorflow로 이미지 생성 및 복원해보기
이제 직접 생성형 모델을 만들고 학습시켜서 이미지를 생성 혹은 복원해보는 작업을 해보자.
가장 간단한 데이터로 MNIST 숫자 이미지 데이터셋을 학습시킬 것이다.
- 1. Autoencoder와 VAE
- 2. GAN과 Diffusion Model
1. Autoencoder와 VAE
from IPython.core.display import display, HTML
display(HTML("<style>.container {width:90% !important;}</style>"))
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
데이터 시각화를 위한 코드¶
import numpy as np
import matplotlib.pyplot as plt
# Visualization function
def plot_latent_space(vae, n=15, figsize=15):
digit_size = 28
scale = 1.0
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = np.linspace(-scale, scale, n)
grid_y = np.linspace(-scale, scale, n)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = vae.get_layer("decoder").predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit
plt.figure(figsize=(figsize, figsize))
start_range = digit_size // 2
end_range = n * digit_size + start_range
pixel_range = np.arange(start_range, end_range, digit_size)
plt.xticks(pixel_range, np.round(grid_x, 1))
plt.yticks(pixel_range, np.round(grid_y, 1))
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap="Greys_r")
plt.show()
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape, Input, Lambda, Conv2D, MaxPooling2D, UpSampling2D
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.models import Model
데이터 로드 및 전처리¶
MNIST 숫자 데이터를 사용하여 이를 인코딩한 후 복원하는 과정을 살펴본다.
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data() # y를 사용하지 않음
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11490434/11490434 [==============================] - 0s 0us/step
x_train.shape, x_test.shape
((60000, 28, 28), (10000, 28, 28))
1. Autoencoder 모델 정의¶
# 입력 데이터 차원
input_dim = 28*28 # 28x28 픽셀 이미지가 1차원 벡터로 변환된 형태
encoding_dim = 32 # 잠재 공간의 차원 -> 늘릴 수록 표현이 잘 됨
# 인코더 정의
encoder = Sequential([
Flatten(input_shape=(28, 28)), # 입력 이미지를 1차원 벡터로 변환
Dense(encoding_dim, activation='relu') # 인코딩 레이어
])
# 디코더 정의
decoder = Sequential([
Dense(input_dim, activation='sigmoid', input_shape=(encoding_dim,)), # 디코딩 레이어
Reshape((28, 28)) # 1차원 벡터를 다시 28x28 이미지로 변환
])
# 오토인코더 모델 정의
autoencoder = Sequential([encoder, decoder])
# 손실 함수로 binary_crossentropy 사용
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
모델 훈련¶
autoencoder.fit(
x_train, x_train, # X값으로만 학습을 진행
epochs=50,
batch_size=256,
shuffle=True,
verbose=0,
validation_data=(
x_test, x_test
))
<keras.src.callbacks.History at 0x7b3f3bc9ba60>
재구성된 데이터 확인¶
- Autoencoder에 의해서 Latent Space로부터 다시 복원된 데이터를 확인해보자.
decoded_imgs = autoencoder.predict(x_test)
313/313 [==============================] - 1s 3ms/step
for i in range(5): #5개 확인
# 원본 이미지
plt.figure(figsize=(3,3))
plt.title("Original")
plt.imshow(x_test[i], cmap='gray')
plt.show()
# 재구성된 이미지
plt.figure(figsize=(3,3))
plt.title("Generated")
plt.imshow(decoded_imgs[i], cmap='gray')
plt.show()
2. VAE 모델 정의¶
latent_dim = 2
# 인코더 정의
encoder = Sequential([
Input(shape=(28, 28)), # 입력 이미지 크기
Flatten(), # 입력 이미지를 1차원 벡터로 변환
Dense(128, activation='relu'), # 중간 레이어
Dense(latent_dim + latent_dim) # z_mean과 z_log_var를 함께 출력
], name='encoder')
encoder.summary()
# 샘플링 레이어를 정의
class Sampling(layers.Layer):
def call(self, inputs):
z_mean, z_log_var = tf.split(inputs, num_or_size_splits=2, axis=1)
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
# 디코더 정의
decoder = Sequential([
Input(shape=(latent_dim,)), # 잠재 공간 입력
Dense(128, activation='relu'), # 중간 레이어
Dense(28 * 28, activation='sigmoid'), # 출력 레이어
Reshape((28, 28)) # 1차원 벡터를 다시 28x28 이미지로 변환
], name='decoder')
decoder.summary()
# VAE 모델 정의
vae = Sequential([encoder, Sampling(), decoder], name='vae')
# 손실 함수 정의 (reconstruction_loss + kl_loss)
def vae_loss(inputs, outputs, z_mean_log_var):
z_mean, z_log_var = tf.split(z_mean_log_var, num_or_size_splits=2, axis=1)
reconstruction_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(inputs, outputs)) * 28 * 28
kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)
return reconstruction_loss + kl_loss
# 모델 컴파일
optimizer = tf.keras.optimizers.Adam()
vae.compile(optimizer, loss=lambda inputs, outputs: vae_loss(inputs, outputs, encoder(inputs)))
Model: "encoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= flatten_4 (Flatten) (None, 784) 0 dense_10 (Dense) (None, 128) 100480 dense_11 (Dense) (None, 4) 516 ================================================================= Total params: 100996 (394.52 KB) Trainable params: 100996 (394.52 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________ Model: "decoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_12 (Dense) (None, 128) 384 dense_13 (Dense) (None, 784) 101136 reshape_3 (Reshape) (None, 28, 28) 0 ================================================================= Total params: 101520 (396.56 KB) Trainable params: 101520 (396.56 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________
모델 훈련¶
# 모델 학습
vae.fit(
x_train, x_train, #모델 입력/출력 (입력 데이터가 얼마나 유지된 채로 나오는 지 학습)
epochs=30, batch_size=128,
verbose=0,
validation_data=(
x_test, x_test
))
<keras.src.callbacks.History at 0x7b3ea144d8d0>
결과 확인¶
- Latent Space의 확률 분포에 따라 어떻게 이미지가 변화하는지 확인
plot_latent_space(vae)
1/1 [==============================] - 0s 62ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 22ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 21ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 20ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 24ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 22ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 36ms/step 1/1 [==============================] - 0s 28ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 27ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 39ms/step 1/1 [==============================] - 0s 34ms/step 1/1 [==============================] - 0s 29ms/step 1/1 [==============================] - 0s 45ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 35ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 28ms/step 1/1 [==============================] - 0s 25ms/step 1/1 [==============================] - 0s 25ms/step 1/1 [==============================] - 0s 25ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 44ms/step 1/1 [==============================] - 0s 28ms/step 1/1 [==============================] - 0s 25ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 27ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 32ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 27ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 29ms/step 1/1 [==============================] - 0s 28ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 28ms/step 1/1 [==============================] - 0s 33ms/step 1/1 [==============================] - 0s 41ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 28ms/step 1/1 [==============================] - 0s 21ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 119ms/step 1/1 [==============================] - 0s 122ms/step 1/1 [==============================] - 0s 25ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 23ms/step 1/1 [==============================] - 0s 24ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 22ms/step 1/1 [==============================] - 0s 21ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 24ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 21ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 23ms/step 1/1 [==============================] - 0s 22ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 22ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 23ms/step 1/1 [==============================] - 0s 20ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 24ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 20ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 22ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 20ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 20ms/step 1/1 [==============================] - 0s 20ms/step 1/1 [==============================] - 0s 21ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 24ms/step 1/1 [==============================] - 0s 22ms/step 1/1 [==============================] - 0s 20ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 25ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 29ms/step 1/1 [==============================] - 0s 22ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 25ms/step 1/1 [==============================] - 0s 20ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 21ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 21ms/step 1/1 [==============================] - 0s 21ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 20ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 20ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 19ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 20ms/step 1/1 [==============================] - 0s 17ms/step 1/1 [==============================] - 0s 23ms/step 1/1 [==============================] - 0s 18ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 29ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 28ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 25ms/step 1/1 [==============================] - 0s 25ms/step 1/1 [==============================] - 0s 32ms/step 1/1 [==============================] - 0s 32ms/step 1/1 [==============================] - 0s 28ms/step 1/1 [==============================] - 0s 29ms/step 1/1 [==============================] - 0s 26ms/step 1/1 [==============================] - 0s 30ms/step 1/1 [==============================] - 0s 29ms/step 1/1 [==============================] - 0s 37ms/step 1/1 [==============================] - 0s 31ms/step 1/1 [==============================] - 0s 29ms/step 1/1 [==============================] - 0s 25ms/step
2. GAN과 Diffusion Model
from IPython.core.display import display, HTML
display(HTML("<style>.container {width:90% !important;}</style>"))
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
데이터 시각화를 위한 코드¶
# 생성된 이미지 시각화
def gan_images(generator, epoch, noise_dim, rows=5, cols=5):
noise = np.random.normal(0, 1, (rows * cols, noise_dim))
gen_imgs = generator.predict(noise)
gen_imgs = gen_imgs.reshape(rows * cols, 28, 28)
fig, axs = plt.subplots(rows, cols, figsize=(10, 10))
cnt = 0
for i in range(rows):
for j in range(cols):
axs[i, j].imshow(gen_imgs[cnt], cmap='gray')
axs[i, j].axis('off')
cnt += 1
plt.show()
# 결과 시각화
def plot_denoising_results(model, noisy_data, clean_data):
decoded_imgs = model.predict(noisy_data)
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
ax = plt.subplot(3, n, i + 1)
plt.imshow(noisy_data[i].reshape(28, 28), cmap='gray')
plt.title("Noisy")
plt.axis('off')
ax = plt.subplot(3, n, i + 1 + n)
plt.imshow(decoded_imgs[i].reshape(28, 28), cmap='gray')
plt.title("Denoised")
plt.axis('off')
ax = plt.subplot(3, n, i + 1 + 2*n)
plt.imshow(clean_data[i].reshape(28, 28), cmap='gray')
plt.title("Original")
plt.axis('off')
plt.show()
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape, Input, Lambda, Conv2D, MaxPooling2D, UpSampling2D
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.models import Model
데이터 로드 및 전처리¶
MNIST 숫자 데이터를 사용하여 이를 기반으로 이미지를 생성해보자.
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data() # y를 사용하지 않음
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11490434/11490434 [==============================] - 0s 0us/step
x_train.shape, x_test.shape
((60000, 28, 28), (10000, 28, 28))
3. GAN 모델 정의¶
x_train = x_train.reshape(-1, 28, 28, 1)
# 생성기 모델 정의
def build_generator():
model = Sequential([
# '''생성기 입력은 노이즈'''
Dense(128, activation='relu', input_dim=100), #노이즈 크기 100
# '''생성기 출력은 이미지 크기 만큼'''
Dense(784, activation='sigmoid'),
Reshape((28, 28, 1)) # 784차원 벡터를 28x28 이미지로 변환
])
return model
# 판별기 모델 정의
def build_discriminator():
model = Sequential([
# 판별기 입력은 이미지
layers.Flatten(input_shape=(28, 28)),
Dense(128, activation='relu', input_dim=784),
# 판별기 출력은 출력이 1개
Dense(1, activation='sigmoid')
])
return model
# GAN 모델 구축 및 학습
def build_gan(generator, discriminator):
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
discriminator.trainable = False
gan_input = tf.keras.Input(shape=(100,))
gan_output = discriminator(generator(gan_input))
gan = tf.keras.Model(gan_input, gan_output)
gan.compile(optimizer='adam', loss='binary_crossentropy')
return gan
# 모델 생성
generator = build_generator()
discriminator = build_discriminator()
gan = build_gan(generator, discriminator)
# 하이퍼파라미터 설정
batch_size = 128
epochs = 10000
sample_interval = 1000
noise_dim = 100
모델 훈련¶
# 학습 과정
def train(generator, discriminator, gan, x_train, batch_size, epochs, noise_dim, sample_interval):
half_batch = batch_size // 2
for epoch in range(epochs):
# 판별기 훈련
idx = np.random.randint(0, x_train.shape[0], half_batch)
real_imgs = x_train[idx]
noise = np.random.normal(0, 1, (half_batch, noise_dim))
fake_imgs = generator.predict(noise)
real_labels = np.ones((half_batch, 1))
fake_labels = np.zeros((half_batch, 1))
discriminator.train_on_batch(real_imgs, real_labels)
discriminator.train_on_batch(fake_imgs, fake_labels)
# 생성기 훈련
noise = np.random.normal(0, 1, (batch_size, noise_dim))
valid_y = np.ones((batch_size, 1))
gan.train_on_batch(noise, valid_y)
# 주기적으로 이미지 샘플링
if epoch % sample_interval == 0:
gan_images(generator, epoch, noise_dim)
모델 훈련 과정¶
- 생성기가 얼마나 점점 더 새로운 데이터를 잘 생성하는지 확인해보자.
train(generator, discriminator, gan, x_train, batch_size, epochs, noise_dim, sample_interval)
4. Diffusion 모델 정의¶
# 노이즈 추가 함수
def add_noise(data, noise_factor=0.5):
noisy_data = data + noise_factor * np.random.normal(size=data.shape)
noisy_data = np.clip(noisy_data, 0., 1.)
return noisy_data
# 노이즈 제거 모델 정의
def build_denoising_model():
model = tf.keras.Sequential([
Input(shape=(28, 28, 1)),
Conv2D(32, (3, 3), activation='relu', padding='same'),
MaxPooling2D((2, 2), padding='same'),
Conv2D(32, (3, 3), activation='relu', padding='same'),
UpSampling2D((2, 2)),
Conv2D(1, (3, 3), activation='sigmoid', padding='same')
])
model.compile(optimizer='adam', loss='binary_crossentropy')
return model
# 노이즈 데이터 생성
noise_factor = 0.5
x_train_noisy = add_noise(x_train, noise_factor) # '''학습 데이터 노이즈 입히기'''
x_test_noisy = add_noise(x_test, noise_factor) # '''테스트 데이터 노이즈 입히기'''
denoising_model = build_denoising_model()
모델 학습¶
# 모델 학습
denoising_model.fit(
x_train_noisy, x_train, # 입력과 출력은 노이즈 추가된 입력, 노이즈 없는 출력
epochs=10, batch_size=128,
verbose=0,
validation_data=(
x_test_noisy, x_test
))
<keras.src.callbacks.History at 0x78e64e30f2e0>
결과 확인 및 시각화¶
# 노이즈 제거 결과 시각화
plot_denoising_results(denoising_model, x_test_noisy, x_test)
313/313 [==============================] - 1s 2ms/step
학습된 Diffusion Model 사용¶
- 학습이 완료된 Diffusion Model을 사용해보자.
!pip install diffusers transformers scipy
Requirement already satisfied: diffusers in /usr/local/lib/python3.10/dist-packages (0.29.0) Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.41.2) Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (1.11.4) Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.10/dist-packages (from diffusers) (7.1.0) Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from diffusers) (3.14.0) Requirement already satisfied: huggingface-hub>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from diffusers) (0.23.3) Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from diffusers) (1.25.2) Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from diffusers) (2024.5.15) Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from diffusers) (2.31.0) Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from diffusers) (0.4.3) Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from diffusers) (9.4.0) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1) Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1) Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.4) Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.2->diffusers) (2023.6.0) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.2->diffusers) (4.12.2) Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.10/dist-packages (from importlib-metadata->diffusers) (3.19.2) Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers) (3.7) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers) (2.0.7) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->diffusers) (2024.6.2)
import torch
from diffusers import StableDiffusionPipeline
import matplotlib.pyplot as plt
# 모델 및 토크나이저 로드
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = StableDiffusionPipeline.from_pretrained(model_id)
pipeline = pipeline.to(device)
/usr/local/lib/python3.10/dist-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead. deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: ``` pip install accelerate ``` . /usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: The secret `HF_TOKEN` does not exist in your Colab secrets. To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session. You will be able to reuse this secret in all of your notebooks. Please note that authentication is recommended but still optional to access public models or datasets. warnings.warn(
Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]
# 텍스트 프롬프트를 사용하여 이미지 생성
prompt = "bamboo and panda"
with torch.autocast("cuda"):
image = pipeline(prompt)[0]
# 이미지 시각화
plt.imshow(image[0])
plt.axis("off")
plt.show()
0%| | 0/50 [00:00<?, ?it/s]
prompt = "yellow monkey and lake"
with torch.autocast("cuda"):
image = pipeline(prompt)[0]
# 이미지 시각화
plt.imshow(image[0])
plt.axis("off")
plt.show()
0%| | 0/50 [00:00<?, ?it/s]
해당 포스팅의 내용은 "상명대학교 민경하 교수님 "인공지능" 수업, 상명대학교 김승현 교수님 "딥러닝"수업을 기반으로 작성하였으며, 포스팅 자료는 해당 내용을 기반으로 재구성하여 만들어 사용하였습니다.
'Data Science > 머신러닝 & 딥러닝' 카테고리의 다른 글
[딥러닝] Transformer : 소개와 동작 원리 (+ 간단한 챗봇 만들기) (0) | 2024.06.11 |
---|---|
[딥러닝] 기계 번역 : Seq2Seq와 Attention (+ 모델 학습시켜 다국어 번역해보기) (0) | 2024.06.11 |
[딥러닝] NLP : 자연어 처리 기본 (+ 영화 리뷰글 긍정/부정 판단해보기) (0) | 2024.06.09 |
[딥러닝] 기억하는 신경망 : RNN, 그리고 개선 모델 (LSTM, GRU) (0) | 2024.06.08 |
[딥러닝] CNN : ResNet 모델로 동물 이미지 분류하기(CIFAR 이미지셋) (0) | 2024.06.08 |