brunch

You can make anything
by writing

C.S.Lewis

by 유윤식 Jul 04. 2024

PyTorch 쓰세요(5).

#VAE #CONV #MNIST

MNIST 데이터셋을 활용

CONV2D 를 결합한 VAE 작성 예제


28 X 28 사이즈의 데이터를 VAE를 활용해 복원



import torch

import torch.nn as nn

import torch.optim as optim

import pytorch_lightning as pl

import torch.nn.functional as F

from torchvision import datasets, transforms

from torch.utils.data import DataLoader


필요한 라이브러리 작성



class ConvVAE(pl.LightningModule):

    def __init__(self, lr):

        super(ConvVAE, self).__init__()


        self.lr = lr


        self.conv1 = nn.Conv2d(1, 32, 3, stride=2, padding=1)

        self.conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)

        self.conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)


        self.fc1 = nn.Linear(128 * 4 * 4, 128)

        self.fc21 = nn.Linear(128, 20)

        self.fc22 = nn.Linear(128, 20)


        # Decoder

        self.fc3 = nn.Linear(20, 128)

        self.fc4 = nn.Linear(128, 128 * 4 * 4)


        self.conv4 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=0)

        self.conv5 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)

        self.conv6 = nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1)


    def encode(self, x):

        x = F.relu(self.conv1(x))

        x = F.relu(self.conv2(x))

        x = F.relu(self.conv3(x))

        x = x.view(x.size(0), -1) # Flatten

        h1 = F.relu(self.fc1(x))

        return self.fc21(h1), self.fc22(h1)  # Mean and Log Variance


    def decode(self, z):

        h3 = F.relu(self.fc3(z))

        h4 = F.relu(self.fc4(h3))

        h4 = h4.view(h4.size(0), 128, 4, 4)

        h4 = F.relu(self.conv4(h4))

        h4 = F.relu(self.conv5(h4))

        h4 = torch.sigmoid(self.conv6(h4))

        return h4


    def reparameterize(self, mu, logvar):

        std = torch.exp(0.5 * logvar)

        eps = torch.randn_like(std)

        return mu + eps * std


    def forward(self, x):

        mu, logvar = self.encode(x)

        logvar = torch.clamp(logvar, min=-4, max=4)

        z = self.reparameterize(mu, logvar)

        return self.decode(z), mu, logvar


    def training_step(self, batch, batch_idx):

        x, _ = batch

        recon_x, mu, logvar = self(x)

        loss = self.loss_function(recon_x, x, mu, logvar)

        self.log('train_loss', loss)

        return loss


    def validation_step(self, batch, batch_idx):

        x, _ = batch

        recon_x, mu, logvar = self(x)

        loss = self.loss_function(recon_x, x, mu, logvar)

        self.log('val_loss', loss)

        return loss


    def loss_function(self, recon_x, x, mu, logvar):

        BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')

        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return 500 * BCE + KLD


    def configure_optimizers(self):

        return torch.optim.Adam(self.parameters(), lr=self.lr)


모델 정의


reparameterize 함수에서

잠재 변수 z를 평균과 로그 분산을 사용해 샘플링  시도


이 부분이 그냥 AE와 차이점


loss_function 같은 경우는 Chat GPT 에 물어봐도 잘 나오는데


BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')

    >> Binary Cross Entropy loss for reconstruction

KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    >> Kullback-Leibler Divergence loss



이제 학습에 사용할 데이터 준비


transform = transforms.Compose([

    transforms.ToTensor(), # 0 - 1 range value set

])


train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=11, persistent_workers=True)


test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=11, persistent_workers=True)



모델 INIT


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ConvVAE(0.001).to(device)



Trainer 정의


trainer = pl.Trainer(

    max_epochs=10,

)



튜너를 사용하려면(Optional)


import matplotlib.pyplot as plt

from pytorch_lightning.tuner.tuning import Tuner


tuner = Tuner(trainer)

lr_finder = tuner.lr_find(model, train_dataloaders=train_loader, val_dataloaders=test_loader, num_training=300)


fig = lr_finder.plot(suggest=True)

fig.show()



실험을 통한 lr 최적값을 모델에 넣어주고


model.learning_rate = lr_finder.suggestion()



학습을 시작하는데


model.learning_rate = lr_finder.suggestion()



관련 복원이 잘 될까?

테스트 데이터를 활용해서 복원을 시키고 그림으로 나타내보면

** 여기서 loss, accuracy 는 그냥 단순 지표일 뿐 수치가 좋다고 복원이 잘 되는건 아니다!



images, label = next(iter(test_loader))


resize_transform = transforms.Compose([

    transforms.Resize((28, 28)),

])



실제 test 데이터를 확인하고


s = resize_transform(images[:10]).squeeze(0).permute(0, 2, 3, 1).detach().numpy()


fig, axes = plt.subplots(1, 10, figsize=(15, 8))

for i, ax in enumerate(axes.flatten()):

    ax.imshow(s[i])

    ax.axis('off')

    ax.set_title(f'Image {i + 1} : {label[i]}')   

plt.tight_layout()

plt.show()



이걸 복원시키면


s, _, _ = model(images[:10])

s = resize_transform(s).squeeze(0).permute(0, 2, 3, 1).detach().numpy() # .permute(1, 2, 0).detach().numpy()


fig, axes = plt.subplots(1, 10, figsize=(15, 8))

for i, ax in enumerate(axes.flatten()):

    ax.imshow(s[i])

    ax.axis('off')

    ax.set_title(f'Image {i + 1}')  

plt.tight_layout()

plt.show()



비슷하다.


이걸 어디에 쓸까?

복원이 필요한 곳에서도 쓰겠지만


추천시스템, 이상치탐지시스템 등 에서도 활용!

작가의 이전글 Python: DuckDB(5)
브런치는 최신 브라우저에 최적화 되어있습니다. IE chrome safari