brunch

You can make anything
by writing

C.S.Lewis

by 유윤식 Aug 12. 2021

PyTorch 쓰세요(3).

#CNN #Conv2D #컨볼루션

컨볼루션 == 에볼루션.

Mnist 예제에 CNN 적용 예제가 수두룩한데,

마지막으로 한 번 더!

Mnist 예제로 CNN 적용해보기 포스팅!!


callback 적용까지  상태에서 그냥 Full-Source  다시   적어보자면,


class LitMNIST(LightningModule):

    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=0.001):

        super().__init__()

        self.data_dir = data_dir

        self.hidden_size = hidden_size

        self.learning_rate = learning_rate


        self.num_classes = 10

        self.dims = (1, 28, 28)

        channels, width, height = self.dims

        self.transform = transforms.Compose([

            transforms.ToTensor(),

            transforms.RandomCrop(28, padding=4),

            transforms.Normalize((0.1307, ), (0.3081, )),

        ])

        self.transform_test = transforms.Compose([

            transforms.ToTensor(),

            transforms.Normalize((0.1307, ), (0.3081, )),

        ])


        # Define PyTorch model

        self.model = nn.Sequential(

            nn.Conv2d(channels, 32, kernel_size=3, padding=1, padding_mode='zeros'),

            nn.MaxPool2d(2),

            nn.GELU(),

            nn.Conv2d(32, 64, kernel_size=3, padding=1, padding_mode='zeros'),

            nn.MaxPool2d(2),

            nn.GELU(),

            nn.Flatten(),

            nn.Linear(3136, hidden_size*100),

            nn.ReLU(),

            nn.Dropout(0.2),

            nn.Linear(hidden_size*100, hidden_size*50),

            nn.ReLU(),

            nn.Dropout(0.2),

            nn.Linear(hidden_size*50, hidden_size*10),

            nn.ReLU(),

            nn.Dropout(0.2),

            nn.Linear(hidden_size*10, self.num_classes)

        )


    def forward(self, x):

        x = self.model(x)

        return F.log_softmax(x, dim=1)


    def training_step(self, batch, batch_idx):

        x, y = batch

        logits = self(x) # 이 부분이 forward() 를 부른다.

        loss = F.nll_loss(logits, y)        

        return {'loss': loss}


    def validation_step(self, batch, batch_idx):

        x, y = batch

        logits = self(x)

        loss = F.nll_loss(logits, y)

        preds = torch.argmax(logits, dim=1)

        acc = accuracy(preds, y)


        self.log('val_loss', loss, prog_bar=True)

        self.log('val_acc', acc, prog_bar=True)

        return {'val_loss': loss, 'val_acc': acc}


    def test_step(self, batch, batch_idx):

        return self.validation_step(batch, batch_idx)


    def configure_optimizers(self):

        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.89, verbose=True)


        return {

            "optimizer": optimizer,

            "lr_scheduler": {

                "scheduler": scheduler,

                "monitor": "val_acc",

                "strict": True,

                "name": "mnist_lr",

            }

        }


    def prepare_data(self):

        MNIST(self.data_dir, train=True, download=True)

        MNIST(self.data_dir, train=False, download=True)


    def setup(self, stage=None):

        if stage == 'fit' or stage is None:

            mnist_full = MNIST(self.data_dir, train=True, download=False, transform=self.transform)

            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])


        if stage == 'test' or stage is None:

            self.mnist_test = MNIST(self.data_dir, train=False, download=False, transform=self.transform_test)


    def train_dataloader(self):

        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE, num_workers=4)


    def val_dataloader(self):

        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE, num_workers=4)


    def test_dataloader(self):

        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE, num_workers=4)



바뀐부분 / 강조하고 싶은 부분을 좀 강조해봤다.

CNN - Conv2D 는 이미 많은 사람들이 알고 있는 이미지 분석에서는 이제 Hello World! 라는 생각이 든다.


callback은 변하지 않았지만,


chk_callback = ModelCheckpoint(

    dirpath='./lightning_chks',

    filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',

    verbose=True,

    save_last=True,

    save_top_k=5,

    monitor='val_loss',

    mode='min'

)


earlystop_callback = EarlyStopping(

    monitor='val_loss',

    patience=5,

    verbose=True,

    mode='min'

)


lrmonitor_callback = LearningRateMonitor(logging_interval='step')



이제 학습을 진행해보면,


model = LitMNIST()

trainer = Trainer(

    gpus=AVAIL_GPUS,

    accelerator="dp",

    max_epochs=50,

    progress_bar_refresh_rate=10,

    callbacks=[chk_callback, earlystop_callback, lrmonitor_callback]

)

trainer.fit(model)


전체 50 Epochs 를 주었는데,

19번째에서 중지되었다.


성능은 이전 FCL 보다 더 좋아졌다.


왜 그럴까?



https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html


CNN은 이미지의 부분 부분을 본다.

일련의 시퀀스 데이터처럼 이미지를 바라보던 FCL 보다 이미지 특정 부분의 주변까지도 참고하면서,

배워나간다.


자세한 사항은 아무 책이나 읽어보면 자세히 설명해준다.

수식도 설명해준다.


이를 활용해서 Auto-Encoder 방식의 이미지 분류 모델이 존재하고,

더 나아가면 여러 대회나 논문에서 소개되어서 이슈를 만들어낸 모델 구조를 직접 적용해서 나만의 모델을 따로 만들어 낼 수 있다.


저 정도의 정확도라면 사람의 눈보다 더 빠르고 정확하게 숫자를 인식 할 수 있다고 생각한다.


BERT에 들어가기 전에 간단하게

자율주행 자동차에 쓰이는 로직도 한 번 소개해보면 좋을 것 같다.



작가의 이전글 PyTorch 쓰세요(2).
작품 선택
키워드 선택 0 / 3 0
댓글여부
afliean
브런치는 최신 브라우저에 최적화 되어있습니다. IE chrome safari