brunch

You can make anything
by writing

C.S.Lewis

by 첨물 May 24. 2021

빅분기실기연습(12)

R프로그램_교차유효성검증(K Fold)

오늘의 주제는 교차 유효성 검사 모델 평가

주어진 데이터 일부를 학습시켜 모델을 생성하고, 나머지는 모델을 검증을 해 본다는 것이다.


지난번에 데이터를 7:3으로 나눈 후 7로 훈련시키고, 3으로 검증을 했는데, 이게 가장 기본적인 교차검증(Cross validation) 이다. 


그럼 이것 말고 뭐가 더 있을까? LOOCV (Leave-One-Out Cross Validation) 이 있다. 아래와 같이 1개를 테스트용으로 두고, 나머지 전체를 훈련용으로 한다. 테스트 data를 다음 것으로 바꾼 후, 그 나머지로 훈련시킨다. 상당히 시간이 많이 걸릴 것이지만 꽤 정확한 모델이 될 수 있다. 

 



다음은 K-Fold Cross-Validation 이다. 아래 그림처럼 5번을 실시한다고 하면 전체 데이터를 5개로 나눈 후 첫 번째는 4/5 데이터로 훈련하고, 나머지 1/5 데이터를 검증한다. 다음엔 테스트 데이터를 바꾸어 진행한다.

이렇게 K 번을 모델 검증을 하게 되면 Cross Validation 보다는 모델 정확도가 높아지고, LOOCV 보다는 시간이 적게 소요된다. 


그럼 오늘은 IRIS 데이터를 이용하여 5-Fold Cross Validation을 통해 교차유효성 검증을 해 보도록 하겠다.

먼저 필요한 패키지는 의사결정나무를 쓸 수 있는 party와 K-fold 검증을 할 수 있는 cvTools이다. 



.libPaths("c:/myRproject/Library")

install.packages("party")

library(party)

install.packages("cvTools")

library(cvTools)


다음엔 IRIS 데이터를 위와 같이 5등분을 해 보자. 무작위로 행을 나누는데, 1,2,3,4,5로 150개 데이터를 5개의 그룹으로 나눈다.


cross<-cvFolds(nrow(iris), K=5)


5-fold CV:    

Fold   Index

   1     125

   2      97

   3     148

   4     101

   5     131

   1      22

   2      87

   3      73

   4       2

   5     121


cross는   5종의 데이터로 이루어진 리스트인데, which는 1,2,3,4,5 가 반복적으로 되어 있고, subsets은 1~150이 랜덤 하게 배치되어 있다. 


그럼 3개의 변수를 지정해보자. k는 1~5 그룹을 나타낸다. acc는 5번의 모델 검증을 했는데, 얼마나 잘 맞추었는지 정확도에 관한 변수이고, cnt는 5회 반복하는 count 변수이다.


k<-1:5

acc<-numeric()

cnt<-1


이걸 어떻게 하는지 살펴보도록 하겠다.

먼저 iris의 데이터를 5개의 그룹으로 랜덤하게 나눈다. test 할 첫 번째 데이터 셋은 iris의 데이터 중 30개를 선택하는데, cross의 subset이 행의 이름이 된다. 그런데 그걸 which==1 인 것만 취한다.

train은 나머지 120개의 값을 할당하는데, 이것은 train<-iris[-data_index,] 처럼 행 앞에 - 를 붙이기만 하면 된다. 그다음엔 ctree로 모델링을 하는 것이다. ctree(종속변수~., data=train)으로 하면 된다.

예측은 predict 함수를 쓴다. predict(model, test), 

다음엔 훈련 시킨 것으로 얼마나 잘 예측했는지 알아보는 것으로 table() 함수를 쓰고, 대각 방향이 제대로 맞춘 것이므로 전체 값으로 나누면 정확도이다.(acc),  이 작업을 5번 반복한다.  



for (i in k){

  data_index=cross$subsets[cross$which==i,1]

  test=iris[data_index,] #30개 데이터

  train<-iris[-data_index,] #120개의 데이터


  model<-ctree(Species~., data=train)

  pred<-predict(model, test)  


  t<-table(pred, test$Species)

  acc[cnt]<-(t[1,1]+t[2,2]+t[3,3])/sum(t)

  cnt=cnt+1

  }

acc

mean(acc)


pred         setosa versicolor virginica

  setosa         10          0         0

  versicolor      0          8         0

  virginica       0          1        11

            

pred         setosa versicolor virginica

  setosa         10          0         0

  versicolor      0         12         2

  virginica       0          0         6

            

pred         setosa versicolor virginica

  setosa          9          0         0

  versicolor      0          9         2

  virginica       0          1         9

            

pred         setosa versicolor virginica

  setosa         14          0         0

  versicolor      0          8         0

  virginica       0          0         8

            

pred         setosa versicolor virginica

  setosa          7          0         0

  versicolor      0          9         0

  virginica       0          2        12


> acc

[1] 0.9666667 0.9333333 1.0000000 0.9666667 0.9333333

> mean(acc)

[1] 0.96


네 번째 실행한 것은 100% 정확도가 나왔지만 나머지는 몇 개씩 틀린 값이 있었다. 따라서 이렇게 만들어진 모델은 96% 정확도를 가지고 있다고 말할 수 있다.


브런치는 최신 브라우저에 최적화 되어있습니다. IE chrome safari