#VAE #CONV #MNIST
MNIST 데이터셋을 활용
CONV2D 를 결합한 VAE 작성 예제
28 X 28 사이즈의 데이터를 VAE를 활용해 복원
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
필요한 라이브러리 작성
class ConvVAE(pl.LightningModule):
def __init__(self, lr):
super(ConvVAE, self).__init__()
self.lr = lr
self.conv1 = nn.Conv2d(1, 32, 3, stride=2, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
self.conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
self.fc1 = nn.Linear(128 * 4 * 4, 128)
self.fc21 = nn.Linear(128, 20)
self.fc22 = nn.Linear(128, 20)
# Decoder
self.fc3 = nn.Linear(20, 128)
self.fc4 = nn.Linear(128, 128 * 4 * 4)
self.conv4 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=0)
self.conv5 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)
self.conv6 = nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1)
def encode(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(x.size(0), -1) # Flatten
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1) # Mean and Log Variance
def decode(self, z):
h3 = F.relu(self.fc3(z))
h4 = F.relu(self.fc4(h3))
h4 = h4.view(h4.size(0), 128, 4, 4)
h4 = F.relu(self.conv4(h4))
h4 = F.relu(self.conv5(h4))
h4 = torch.sigmoid(self.conv6(h4))
return h4
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.encode(x)
logvar = torch.clamp(logvar, min=-4, max=4)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
def training_step(self, batch, batch_idx):
x, _ = batch
recon_x, mu, logvar = self(x)
loss = self.loss_function(recon_x, x, mu, logvar)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, _ = batch
recon_x, mu, logvar = self(x)
loss = self.loss_function(recon_x, x, mu, logvar)
self.log('val_loss', loss)
return loss
def loss_function(self, recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return 500 * BCE + KLD
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
모델 정의
reparameterize 함수에서
잠재 변수 z를 평균과 로그 분산을 사용해 샘플링 시도
이 부분이 그냥 AE와 차이점
loss_function 같은 경우는 Chat GPT 에 물어봐도 잘 나오는데
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
>> Binary Cross Entropy loss for reconstruction
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
>> Kullback-Leibler Divergence loss
이제 학습에 사용할 데이터 준비
transform = transforms.Compose([
transforms.ToTensor(), # 0 - 1 range value set
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=11, persistent_workers=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=11, persistent_workers=True)
모델 INIT
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvVAE(0.001).to(device)
Trainer 정의
trainer = pl.Trainer(
max_epochs=10,
)
튜너를 사용하려면(Optional)
import matplotlib.pyplot as plt
from pytorch_lightning.tuner.tuning import Tuner
tuner = Tuner(trainer)
lr_finder = tuner.lr_find(model, train_dataloaders=train_loader, val_dataloaders=test_loader, num_training=300)
fig = lr_finder.plot(suggest=True)
fig.show()
실험을 통한 lr 최적값을 모델에 넣어주고
model.learning_rate = lr_finder.suggestion()
학습을 시작하는데
model.learning_rate = lr_finder.suggestion()
관련 복원이 잘 될까?
테스트 데이터를 활용해서 복원을 시키고 그림으로 나타내보면
** 여기서 loss, accuracy 는 그냥 단순 지표일 뿐 수치가 좋다고 복원이 잘 되는건 아니다!
images, label = next(iter(test_loader))
resize_transform = transforms.Compose([
transforms.Resize((28, 28)),
])
실제 test 데이터를 확인하고
s = resize_transform(images[:10]).squeeze(0).permute(0, 2, 3, 1).detach().numpy()
fig, axes = plt.subplots(1, 10, figsize=(15, 8))
for i, ax in enumerate(axes.flatten()):
ax.imshow(s[i])
ax.axis('off')
ax.set_title(f'Image {i + 1} : {label[i]}')
plt.tight_layout()
plt.show()
이걸 복원시키면
s, _, _ = model(images[:10])
s = resize_transform(s).squeeze(0).permute(0, 2, 3, 1).detach().numpy() # .permute(1, 2, 0).detach().numpy()
fig, axes = plt.subplots(1, 10, figsize=(15, 8))
for i, ax in enumerate(axes.flatten()):
ax.imshow(s[i])
ax.axis('off')
ax.set_title(f'Image {i + 1}')
plt.tight_layout()
plt.show()
비슷하다.
이걸 어디에 쓸까?
복원이 필요한 곳에서도 쓰겠지만
추천시스템, 이상치탐지시스템 등 에서도 활용!