brunch

You can make anything
by writing

C.S.Lewis

by 유윤식 Aug 12. 2021

PyTorch 쓰세요(2).

#콜백 #callback #pytorch_lightning

콜백을 지정해서 checkpoint, earlystop, 등등 컨트롤하는 방법.


어제에 이어서 Mnist 예제 코드를 그대로 활용하는데,

바뀌는 부분에 대해서만 설명을 적어보자면,



def configure_optimizers(self):

        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.9, verbose=True)


        return {

            "optimizer": optimizer,

            "lr_scheduler": {

                "scheduler": scheduler,

                "monitor": "val_acc",

                "strict": True,

                "name": "mnist_lr"

            }

        }


스케쥴러를 정의하고 사용하는데,

뒤에 따라붙는 옵션이 중요하다.


어제의 성능이나 오늘의 성능이 딱히 중요하지는 않지만...

비교를 위해서 어제의 성능을 잠시 다시 확인해보면,


validation loss 가 0.147 정도를 기록했다.


어제와 오늘은 랜덤이기 때문에 이 부분부터 집고 넘어가자면,

아래와 같은 코드를 impot 구문 아래에서 선언해주면 된다.



def seed_everything(seed = 7):

    random.seed(seed)

    os.environ['PYTHONHASHSEED'] = str(seed)

    np.random.seed(seed)

    torch.manual_seed(seed)

    torch.cuda.manual_seed(seed)

    torch.backends.cudnn.deterministic = True


seed_everything()


이건 인터넷에 떠도는 옵션이다.


이어서 콜백을 선언해볼텐데,

3가지!



from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor


chk_callback = ModelCheckpoint(

    dirpath='./lightning_chks',

    filename='sample-mnist-epoch{epoch:02d}-val_loss{val-loss:.2f}',

    verbose=True,

    save_last=True,

    save_top_k=5,

    monitor='val_loss',

    mode='min'

)


earlystop_callback = EarlyStopping(

    monitor='val_loss',

    patience=5,

    verbose=True,

    mode='min'

)


lrmonitor_callback = LearningRateMonitor(logging_interval='step')


쉽다.

그냥 선언해서 사용하면 된다.


trainer 선언부에 callback 옵션을 넘겨준다.



trainer = Trainer(

    gpus=AVAIL_GPUS,

    accelerator="dp",

    max_epochs=50,

    progress_bar_refresh_rate=10,

    callbacks=[chk_callback, earlystop_callback, lrmonitor_callback]

)



이제 학습을 돌리는데,

earlystop 이 걸릴 수 있다.


epochs 를 50 까지 확 늘려서 실험한다.


33번째 학습에서 멈췄다.

잘 하고 있는듯 하다.


이제 test 를 돌려보면,


trainer.test()


결과는,

오늘의 나는 어제의 나를 넘어섰다.


이후 스텝은 메뉴얼하게 체크포인트를 저장하고 다시 불러와서 같은 값을 도출하는지 확인했다.


요즘은 BERT를 진행하고 있어서,

이후 예제는 Mnist 게임을 벗어던지고 새로운 게임을 포스팅한다.


브런치는 최신 브라우저에 최적화 되어있습니다. IE chrome safari