R프로그램_교차유효성검증(K Fold)
오늘의 주제는 교차 유효성 검사 모델 평가
주어진 데이터 일부를 학습시켜 모델을 생성하고, 나머지는 모델을 검증을 해 본다는 것이다.
지난번에 데이터를 7:3으로 나눈 후 7로 훈련시키고, 3으로 검증을 했는데, 이게 가장 기본적인 교차검증(Cross validation) 이다.
그럼 이것 말고 뭐가 더 있을까? LOOCV (Leave-One-Out Cross Validation) 이 있다. 아래와 같이 1개를 테스트용으로 두고, 나머지 전체를 훈련용으로 한다. 테스트 data를 다음 것으로 바꾼 후, 그 나머지로 훈련시킨다. 상당히 시간이 많이 걸릴 것이지만 꽤 정확한 모델이 될 수 있다.
다음은 K-Fold Cross-Validation 이다. 아래 그림처럼 5번을 실시한다고 하면 전체 데이터를 5개로 나눈 후 첫 번째는 4/5 데이터로 훈련하고, 나머지 1/5 데이터를 검증한다. 다음엔 테스트 데이터를 바꾸어 진행한다.
이렇게 K 번을 모델 검증을 하게 되면 Cross Validation 보다는 모델 정확도가 높아지고, LOOCV 보다는 시간이 적게 소요된다.
그럼 오늘은 IRIS 데이터를 이용하여 5-Fold Cross Validation을 통해 교차유효성 검증을 해 보도록 하겠다.
먼저 필요한 패키지는 의사결정나무를 쓸 수 있는 party와 K-fold 검증을 할 수 있는 cvTools이다.
.libPaths("c:/myRproject/Library")
install.packages("party")
library(party)
install.packages("cvTools")
library(cvTools)
다음엔 IRIS 데이터를 위와 같이 5등분을 해 보자. 무작위로 행을 나누는데, 1,2,3,4,5로 150개 데이터를 5개의 그룹으로 나눈다.
cross<-cvFolds(nrow(iris), K=5)
5-fold CV:
Fold Index
1 125
2 97
3 148
4 101
5 131
1 22
2 87
3 73
4 2
5 121
cross는 5종의 데이터로 이루어진 리스트인데, which는 1,2,3,4,5 가 반복적으로 되어 있고, subsets은 1~150이 랜덤 하게 배치되어 있다.
그럼 3개의 변수를 지정해보자. k는 1~5 그룹을 나타낸다. acc는 5번의 모델 검증을 했는데, 얼마나 잘 맞추었는지 정확도에 관한 변수이고, cnt는 5회 반복하는 count 변수이다.
k<-1:5
acc<-numeric()
cnt<-1
이걸 어떻게 하는지 살펴보도록 하겠다.
먼저 iris의 데이터를 5개의 그룹으로 랜덤하게 나눈다. test 할 첫 번째 데이터 셋은 iris의 데이터 중 30개를 선택하는데, cross의 subset이 행의 이름이 된다. 그런데 그걸 which==1 인 것만 취한다.
train은 나머지 120개의 값을 할당하는데, 이것은 train<-iris[-data_index,] 처럼 행 앞에 - 를 붙이기만 하면 된다. 그다음엔 ctree로 모델링을 하는 것이다. ctree(종속변수~., data=train)으로 하면 된다.
예측은 predict 함수를 쓴다. predict(model, test),
다음엔 훈련 시킨 것으로 얼마나 잘 예측했는지 알아보는 것으로 table() 함수를 쓰고, 대각 방향이 제대로 맞춘 것이므로 전체 값으로 나누면 정확도이다.(acc), 이 작업을 5번 반복한다.
for (i in k){
data_index=cross$subsets[cross$which==i,1]
test=iris[data_index,] #30개 데이터
train<-iris[-data_index,] #120개의 데이터
model<-ctree(Species~., data=train)
pred<-predict(model, test)
t<-table(pred, test$Species)
acc[cnt]<-(t[1,1]+t[2,2]+t[3,3])/sum(t)
cnt=cnt+1
}
acc
mean(acc)
pred setosa versicolor virginica
setosa 10 0 0
versicolor 0 8 0
virginica 0 1 11
pred setosa versicolor virginica
setosa 10 0 0
versicolor 0 12 2
virginica 0 0 6
pred setosa versicolor virginica
setosa 9 0 0
versicolor 0 9 2
virginica 0 1 9
pred setosa versicolor virginica
setosa 14 0 0
versicolor 0 8 0
virginica 0 0 8
pred setosa versicolor virginica
setosa 7 0 0
versicolor 0 9 0
virginica 0 2 12
> acc
[1] 0.9666667 0.9333333 1.0000000 0.9666667 0.9333333
> mean(acc)
[1] 0.96
네 번째 실행한 것은 100% 정확도가 나왔지만 나머지는 몇 개씩 틀린 값이 있었다. 따라서 이렇게 만들어진 모델은 96% 정확도를 가지고 있다고 말할 수 있다.