R프로그램_의사결정나무
오늘은 예측 분석 중 의사결정 나무 기법을 통해 R 내에 있는 IRIS(붓꽃)을 분류하는 예제를 해 보겠다.
일단 IRIS는 매우 유명한 예제이다. 얼마 전 재미있게 읽었던 야사와 만화로 배우는 인공지능 책에서 그 유래가 나오는데, 자세한 것은 아래 브런치를 보면 좋을 것 같다.
https://brunch.co.kr/@hvnpoet/82
붓꽃의 꽃받침(Sepal)의 길이와 너비, 꽃잎(Petal)의 길이와 너비라는 네 개의 변수 조합에 따라 붓꽃은 Setosa, Versicolour, Virginica로 나뉜다.
만약 어떤 붓꽃을 발견하고, 위 네 개의 변수를 자로 잰 후, 의사결정 나무라는 방법으로 이 꽃의 이름을 알아내 보자. 먼저 data가 어떻게 생겼는지 보자.
> head(iris)
Sepal.Length Sepal.Width Petal.Length Petal.Width Species
1 5.1 3.5 1.4 0.2 setosa
2 4.9 3.0 1.4 0.2 setosa
3 4.7 3.2 1.3 0.2 setosa
4 4.6 3.1 1.5 0.2 setosa
5 5.0 3.6 1.4 0.2 setosa
6 5.4 3.9 1.7 0.4 setosa
그럼 이 150개의 데이터를 이용하여 의사 결정 나무를 그려보자
우선 R에서 이걸 하려면 rpart라는 패키지를 불러와야 한다.
install.packages("rpart")
library(rpart)
그리고 rpart 함수를 사용하여 분류해보자. rpart(종속변수~., data = )
놀랍게도 R은 자동으로 자신이 판단하여 네 개의 변수 중 하나를 선택해서 일정 이상과 이하 되는 것으로 나눈 후 그다음 변수를 이용하여 세 개의 품종이 나올 때까지 분류를 한다.
아래 결과를 보면
1) setosa로 먼저 분류 (맞을 확률은 0.333...)
2) 꽃잎의 길이를 2.45보다 작은 것과 크거나 같은 것으로 구별 , 작은 것은 모두 setosa
3) 꽃잎의 길이가 2.45보다 크거나 같은 것을 versicolor로 분류 (맞을 확률 0.5)
6) 다음 기준인 꽃잎의 너비가 1.75보다 작은 것을 versicolor로 분류 (맞을 확률 0.907...)
7) 꽃잎의 너비가 1.75보다 큰 것은 verginica (맞을 확률 0.978..)
iris.df=rpart(Species~.,data=iris)
iris.df
> iris.df
n= 150
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)
2) Petal.Length< 2.45 50 0 setosa (1.00000000 0.00000000 0.00000000) *
3) Petal.Length>=2.45 100 50 versicolor (0.00000000 0.50000000 0.50000000)
6) Petal.Width< 1.75 54 5 versicolor (0.00000000 0.90740741 0.09259259) *
7) Petal.Width>=1.75 46 1 virginica (0.00000000 0.02173913 0.97826087) *
그러면 R로 그래프를 그려보자
x11() #새로운 창을 띄우기
plot(iris.df,compress=T,) #가지 형태만 만들기
text(iris.df,use.n=T, cex=0.7) #분류 기준 쓰기, cex는 글자 크기
post(iris.df, file="") #150개 중 어떻게 분류했는지 표시
그런데 이상하다. setosa는 100% 꽃잎의 길이 <2.45로 분류가 잘 되었는데, versicolor와 virginica는 꽃잎 너비로 분류하니 분류가 100% 잘 되지는 않았다.
다음은 rpart.plot 패키지를 이용하여 그래프를 다시 그려보았다. 좀 더 예쁜 의사결정나무 그림...
install.packages("rpart.plot")
library(rpart.plot)
tree <- rpart(Species ~ Sepal.Length + Sepal.Width + Petal.Length +
Petal.Width, data=iris, method = "class")
rpart.plot(tree)