brunch

You can make anything
by writing

C.S.Lewis

by 박경아 Oct 21. 2022

10. 로지스틱 회귀 적용해 보기

앞서 글에서는 분류를 위한 알고리즘으로 로지스틱 회귀에 대해 알아보았다. 로지스틱 회귀는 선형회귀와 비슷하지만 연산의 마지막 단계에 시그모이드 함수를 사용해 결과값을 0과 1사이의 확률로 변환해 데이터를 분류했다. 


이번 글에서는 사이킷런에서 제공하는 붓꽃 데이터 세트로 로지스틱 알고리즘을 적용해 보고 결과를 다양하게 분석해 보고자 한다. 참고로 붓꽃 데이터 세트는 꽃잎의 길이(lengh)와 너비(width), 꽃받침의 길이와 너비의 4개 피처를 제공하고 타깃 변수로는 붓꽃의 종류를 나타내는 0, 1, 2 라벨이 있다.

아이리스 데이터셋 불러오기



이진 분류 (Binary Classification)


먼저 데이터 클래스를 두 가지 종류로 나누는 이진분류부터 연습해 보자. 이진분류는 앞서 공부한 시그모이드 함수를 사용해 선형회귀식 wx + b가 0보다 같거나 크면 시그모이드 값이 0.5보다 같거나 커져 1로 분류하고, wx + b가 0보다 작으면 시그모이드 값이 0.5보다 작아져 0으로 분류한다. 


이진 분류를 위해 붓꽃 데이터 세트에서 붓꽃 종류 클래스가 0과 1인 데이터만 선택하고, 또한 결과를 시각화하기 위해 독립변수로는 꽃잎 길이(sepal length) 하나만 선택해 보자. 다른 피처없이 요 sepal length로 두 붓꽃을 구별해야 하는데 sepal length의 분포를 살펴보면 sepal length가 짧을수록 클래스 0이고, 길수록 클래스 1로 보이지만 겹치는 부분도 있다.


sepal length의 분포


데이터를 훈련 데이터와 테스트 데이터 셋으로 나누고 사이킷런의 로지스틱 회귀를 불러와 학습시킨다. 결과를 살펴보면 훈련 정확도는 0.87이고 테스트 정확도는 0.93으로 학습이 덜 된 언더피팅 상태이긴 하다. 선형회귀와 마찬가지로 회귀식의 계수와 절편값을 알 수 있는데 바로 이 선형회귀식이 0이 될 때 X값이 시그모이드가 0.5가 될 값으로 X가 이보다 커지면 1로 분류되고, 작으면 0으로 분류될 것이다. 


모델링


선형회귀식


이 결과를 원래 데이터셋과 함께 그래프로 나타내보자. 원래 클래스는 1이지만 sepal length가 X 기준값(5.450602851599672)보다 작은 경우 1로 분류되지 못했다. 마찬가지로 원래 0 클래스지만 sepal length가 경계값보다 커서 0으로 분류되지 못한 경우도 있다. 

로지스틱 회귀 결과 분석



다중 분류 (Multinomial Classification)

 

이번에는 클래스를 3개 이상으로 분류하는 다중분류를 연습해 보자. 이진분류는 하나의 선형방정식과 하나의 확률값이 나와 이를 기준으로 데이터를 분류하는 것이다. 다중 분류는 각 클래스별로 선형방정식이 나오고 그에 따른 확률값이 계산된다. 바꿔 말하면 3개 클래스로 예측하는 경우 각 데이터 샘플에 대해 3개의 선형방정식과 그에 따른 3개의 확률값이 나오고 그 중에 가장 큰 확률값을 가진 클래스로 예측되는 것이다. 


다중분류 이미지


앗! 그렇다면 하나의 확률값을 만들어 나머지 클래스는 자동으로 (1-확률값)이 되는 시그모이드가 아니라 각각의 선형방정식에서 나온 결과값들을 종합해 확률로 변환하는 함수가 필요하다. 바로 소프트맥스 함수인데 소프트맥스는 선형방정식에서 나온 결과값(z)을 자연상수를 밑으로 하는 지수함수의 지수로 사용하고 그 값들을 모두 더해 분모로 하고 각 결과값의 지수함수값을 분자로 그 비중에 따라 0-1 사이의 확률로 변환해준다. 

소프트맥스 함수 (출처 : 위키피디아)


그럼 붓꽃 데이터 세트의 3개 클래스를 모두 사용하고 시각화를 위해 피처는 2개만 골라 로지스틱 회귀 알고리즘 학습을 시켜보자. 역시 사이킷런의 로지스틱 회귀 알고리즘을 불러와 학습시키면 훈련, 테스트 모두 정확도가 0.82 정도로 앞서보다 언더피팅은 해소되었지만 성능은 다소 떨어졌다. 



모델링 결과 3개의 선형방정식을 위한 계수들과 절편들을 확인할 수 있는데 독립변수를 2개를 넣었기에 계수는 2개씩 3쌍이 나오고, 각각의 방정식을 위한 절편은 3개가 출력된다. Classification Report로 클래스 별 결과를 살펴보면 클래스 0은 f1 score가 1.00로 모두 제대로 분류되었지만 클래스 1과 2는 각각 0.73과 0.75로 잘못 분류된 경우가 발생했다. 



사이킷런 알고리즘에 predict() 함수를 붙이면 예측 클래스를 알 수 있지만 predict_proba() 함수를 붙이면 각 샘플에 대해 어느 클래스에 속하는 지 확률을 출력해 준다. 확률이 가장 높은 값으로 클래스가 예측되며 당연히 이들의 합은 항상 1이다. 


클래스별 예측 확률값


이제 마지막으로 결과를 시각화해보자. 피처 가운데 sepal length를 x 축으로, sepal width를 y 축으로 원본 데이터를 스캐터 플롯으로 표시하고 학습 결과로 나온 선형방정식의 계수와 절편을 이용해 x 값에 대응하는 y값을 가진 직선을 3개 그려보자. 



그래프로 결과를 살펴보아도 보라색으로 표시된 클래스 0 데이터들은 빨간 선을 경계로 나머지 클래스들과 잘 분류가 되지만, 초록색과 노란색으로 표시된 2번과 3번 클래스 데이터들은 잘 분류되지 못하는 것을 볼 수 있다. 


초록색 선은 좀 더 오른쪽으로 이동하고 오렌지 선은 반대 방향으로 기울어져야 할까?? 정확도를 올리기 위해서는 머신러닝 알고리즘이 각 클래스의 특징을 좀 더 구별해서 학습할 수 있도록 피처를 추가하거나 새로운 피처를 만들어 내야 할 것 같다. 다행히 원래 데이터 셋이 가지고 있는 4개를 모두 학습시켜보면 훈련 정확도가 0.98, 테스트 데이터 정확도가 0.97이 나온다. 


4개 피처를 모두 사용해 3캐 클래스로 예측한 결과


위에서 확인해 본 내용들은 아래 캐글 노트북에서 확인할 수 있다. 

https://www.kaggle.com/code/kyungapark/logistic-regression


브런치는 최신 브라우저에 최적화 되어있습니다. IE chrome safari