티스토리 뷰
[FastCampus] 파이썬을 활용한 이커머스 데이터 분석하기¶
Ch03. 광고 반응률 예측¶
01. 분석의 목적¶
- Logistic Regression을 사용하여 고객별 반응율을 예측해보자.
- 연속된 값이 아닌 Yes or no 두 가지중 어디에 속하는지 이진분류를 예측하는 머신러닝 알고리즘이다.
02. 모듈, 데이터 로딩 및 데이터 확인¶
필요한 모듈과 이번에 사용할 데이터를 로딩한 후 간단하게 데이터를 확인해 보자.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
data = pd.read_csv("advertising.csv")
data
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Ad Topic Line | City | Male | Country | Timestamp | Clicked on Ad | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 68.95 | NaN | 61833.90 | 256.09 | Cloned 5thgeneration orchestration | Wrightburgh | 0 | Tunisia | 3/27/2016 0:53 | 0 |
1 | 80.23 | 31.0 | 68441.85 | 193.77 | Monitored national standardization | West Jodi | 1 | Nauru | 4/4/2016 1:39 | 0 |
2 | 69.47 | 26.0 | 59785.94 | 236.50 | Organic bottom-line service-desk | Davidton | 0 | San Marino | 3/13/2016 20:35 | 0 |
3 | 74.15 | 29.0 | 54806.18 | 245.89 | Triple-buffered reciprocal time-frame | West Terrifurt | 1 | Italy | 1/10/2016 2:31 | 0 |
4 | 68.37 | 35.0 | 73889.99 | 225.58 | Robust logistical utilization | South Manuel | 0 | Iceland | 6/3/2016 3:36 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
995 | 72.97 | 30.0 | 71384.57 | 208.58 | Fundamental modular algorithm | Duffystad | 1 | Lebanon | 2/11/2016 21:49 | 1 |
996 | 51.30 | 45.0 | 67782.17 | 134.42 | Grass-roots cohesive monitoring | New Darlene | 1 | Bosnia and Herzegovina | 4/22/2016 2:07 | 1 |
997 | 51.63 | 51.0 | 42415.72 | 120.37 | Expanded intangible solution | South Jessica | 1 | Mongolia | 2/1/2016 17:24 | 1 |
998 | 55.55 | 19.0 | 41920.79 | 187.95 | Proactive bandwidth-monitored policy | West Steven | 0 | Guatemala | 3/24/2016 2:35 | 0 |
999 | 45.01 | 26.0 | 29875.80 | 178.35 | Virtual 5thgeneration emulation | Ronniemouth | 0 | Brazil | 6/3/2016 21:43 | 1 |
1000 rows × 10 columns
data.head()
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Ad Topic Line | City | Male | Country | Timestamp | Clicked on Ad | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 68.95 | NaN | 61833.90 | 256.09 | Cloned 5thgeneration orchestration | Wrightburgh | 0 | Tunisia | 3/27/2016 0:53 | 0 |
1 | 80.23 | 31.0 | 68441.85 | 193.77 | Monitored national standardization | West Jodi | 1 | Nauru | 4/4/2016 1:39 | 0 |
2 | 69.47 | 26.0 | 59785.94 | 236.50 | Organic bottom-line service-desk | Davidton | 0 | San Marino | 3/13/2016 20:35 | 0 |
3 | 74.15 | 29.0 | 54806.18 | 245.89 | Triple-buffered reciprocal time-frame | West Terrifurt | 1 | Italy | 1/10/2016 2:31 | 0 |
4 | 68.37 | 35.0 | 73889.99 | 225.58 | Robust logistical utilization | South Manuel | 0 | Iceland | 6/3/2016 3:36 | 0 |
컬럼의 의미를 확인해 보자.
- Daily Time spent on site : 사이트에 머무른 시간
- Age : 나이
- Area Income : 지역에 대한 수입(개인의 수입을 알기엔 어려워서 지역의 대한 범위로 파악)
- Daily internet Usasge : 하루에 인터넷 사용량
- Ad Topic Line : 광고에 대한 설명 (텍스트 형식이라 쓸 수 있는 데이터인지는 판단해 보아야 한다.)
- City : 도시
- Male : 성별 (0은 여자, 1은 남자)
- Country : 나라
- Timestamp : 시간 (때에 따라 유의미할 수도 있고 아닐 수도 있다.)
- Clicked on Ad : 광고를 클릭 했는지 여부
Clicked on Ad가 이번에 예측할 컬럼이다. 이제 데이터의 정보들을 알아보자.
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 10 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Daily Time Spent on Site 1000 non-null float64
1 Age 916 non-null float64
2 Area Income 1000 non-null float64
3 Daily Internet Usage 1000 non-null float64
4 Ad Topic Line 1000 non-null object
5 City 1000 non-null object
6 Male 1000 non-null int64
7 Country 1000 non-null object
8 Timestamp 1000 non-null object
9 Clicked on Ad 1000 non-null int64
dtypes: float64(4), int64(2), object(4)
memory usage: 78.2+ KB
Age 컬럼은 916개로 결측치가 있다는 것을 알 수 있다. 이 부분은 조금 있다 처리하도록 하자.
data.describe()
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Male | Clicked on Ad | |
---|---|---|---|---|---|---|
count | 1000.000000 | 916.000000 | 1000.000000 | 1000.000000 | 1000.000000 | 1000.00000 |
mean | 65.000200 | 36.128821 | 55000.000080 | 180.000100 | 0.481000 | 0.50000 |
std | 15.853615 | 9.018548 | 13414.634022 | 43.902339 | 0.499889 | 0.50025 |
min | 32.600000 | 19.000000 | 13996.500000 | 104.780000 | 0.000000 | 0.00000 |
25% | 51.360000 | 29.000000 | 47031.802500 | 138.830000 | 0.000000 | 0.00000 |
50% | 68.215000 | 35.000000 | 57012.300000 | 183.130000 | 0.000000 | 0.50000 |
75% | 78.547500 | 42.000000 | 65470.635000 | 218.792500 | 1.000000 | 1.00000 |
max | 91.430000 | 61.000000 | 79484.800000 | 269.960000 | 1.000000 | 1.00000 |
Area Income에서 min 값이 25%,50%,75%,max와 비교해서 조금 큰 차이를 보인다. 하지만 크게 치우져 있지 않은것 같아보여 그냥 넘어가도록 하겠다.
male 에서 mean이 48%라는 것은 남자가 48이라는 것을 의미한다.
Age Income에 대한 시각화를 distplot으로 그려보자.
sns.distplot(data['Area Income'])
/home/jaeyoon89/.local/lib/python3.6/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
warnings.warn(msg, FutureWarning)
<AxesSubplot:xlabel='Area Income', ylabel='Density'>
그래프를 보면 min에서 25%,50%,75%,max 가 차이가 났던 이유를 볼 수 있다.
sns.distplot(data['Age'])
/home/jaeyoon89/.local/lib/python3.6/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
warnings.warn(msg, FutureWarning)
<AxesSubplot:xlabel='Age', ylabel='Density'>
data['Country'].nunique()
237
data['City'].nunique()
969
data['Ad Topic Line'].nunique()
1000
nunique() 메서드를 이용해서 Country 컬럼의 고유값 개수를 구해 보았더니 237개가 나왔다. 또 City 컬럼의 고유값은 969개이다. 마지막으로 Ad Topic Line 컬럼의 고유값은 1000개가 나왔다. 이런 텍스트들은 성별과 같이 숫자로 0 과 1로 바꾸기에는 고유값이 너무 많기 때문에 이번 단계에서는 빼주도록 하자.
03. Missing Value(결측치) 확인 및 처리¶
info() 메서드로 확인했던 Age 컬럼의 결측치를 isnull() 메서드로 확인해보자.
data.isnull()
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Ad Topic Line | City | Male | Country | Timestamp | Clicked on Ad | |
---|---|---|---|---|---|---|---|---|---|---|
0 | False | True | False | False | False | False | False | False | False | False |
1 | False | False | False | False | False | False | False | False | False | False |
2 | False | False | False | False | False | False | False | False | False | False |
3 | False | False | False | False | False | False | False | False | False | False |
4 | False | False | False | False | False | False | False | False | False | False |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
995 | False | False | False | False | False | False | False | False | False | False |
996 | False | False | False | False | False | False | False | False | False | False |
997 | False | False | False | False | False | False | False | False | False | False |
998 | False | False | False | False | False | False | False | False | False | False |
999 | False | False | False | False | False | False | False | False | False | False |
1000 rows × 10 columns
위처럼 불리안 형식으로 나타난다. Age 컬럼 첫 번째 인덱스에 True로 되있는데 이것이 바로 결측치이다.
결측치를 간단히 확인해 보기 위해 sum() 메서드를 이용해보자.
data.isnull().sum() / 1000
Daily Time Spent on Site 0.000
Age 0.084
Area Income 0.000
Daily Internet Usage 0.000
Ad Topic Line 0.000
City 0.000
Male 0.000
Country 0.000
Timestamp 0.000
Clicked on Ad 0.000
dtype: float64
파이썬에서는 불리안인 True는 1로 False는0 값으로 바꾸어 출력해준다.
라인의 갯수가 기억이 안나면 다음의 len() 메서드를 이용하여 확인해 줄 수 있다.
len(data)
1000
이번엔 결측치를 처리해 보자. 영어로 impute라고 한다. 결측치는 아예 제거하는 방법이 있고 그안에 추정치를 넣어주거나 결측치 자체를 보존하는 방법이 있다. 머신러닝에서는 결측치를 처리하지 않고 머신러닝 알고리즘이 동작할 수 없다.
먼저 결측치가 있는 인덱스를 모두 제거해보자.
data.dropna()
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Ad Topic Line | City | Male | Country | Timestamp | Clicked on Ad | |
---|---|---|---|---|---|---|---|---|---|---|
1 | 80.23 | 31.0 | 68441.85 | 193.77 | Monitored national standardization | West Jodi | 1 | Nauru | 4/4/2016 1:39 | 0 |
2 | 69.47 | 26.0 | 59785.94 | 236.50 | Organic bottom-line service-desk | Davidton | 0 | San Marino | 3/13/2016 20:35 | 0 |
3 | 74.15 | 29.0 | 54806.18 | 245.89 | Triple-buffered reciprocal time-frame | West Terrifurt | 1 | Italy | 1/10/2016 2:31 | 0 |
4 | 68.37 | 35.0 | 73889.99 | 225.58 | Robust logistical utilization | South Manuel | 0 | Iceland | 6/3/2016 3:36 | 0 |
5 | 59.99 | 23.0 | 59761.56 | 226.74 | Sharable client-driven software | Jamieberg | 1 | Norway | 5/19/2016 14:30 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
995 | 72.97 | 30.0 | 71384.57 | 208.58 | Fundamental modular algorithm | Duffystad | 1 | Lebanon | 2/11/2016 21:49 | 1 |
996 | 51.30 | 45.0 | 67782.17 | 134.42 | Grass-roots cohesive monitoring | New Darlene | 1 | Bosnia and Herzegovina | 4/22/2016 2:07 | 1 |
997 | 51.63 | 51.0 | 42415.72 | 120.37 | Expanded intangible solution | South Jessica | 1 | Mongolia | 2/1/2016 17:24 | 1 |
998 | 55.55 | 19.0 | 41920.79 | 187.95 | Proactive bandwidth-monitored policy | West Steven | 0 | Guatemala | 3/24/2016 2:35 | 0 |
999 | 45.01 | 26.0 | 29875.80 | 178.35 | Virtual 5thgeneration emulation | Ronniemouth | 0 | Brazil | 6/3/2016 21:43 | 1 |
916 rows × 10 columns
위 방법의 문제점은 인덱스를 없애버리면 결측치가 있는 컬럼 말고 나머지 컬럼의 데이터도 없어진다. 이 데이터에선 84개만 없어졌지만 이 양이 20 % ~ 30 % 가 넘어버리면 나중에 예측할 때 좋지 않다.
이번엔 결측치가 있는 컬럼을 제거해보자.
data.drop('Age', axis=1)
Daily Time Spent on Site | Area Income | Daily Internet Usage | Ad Topic Line | City | Male | Country | Timestamp | Clicked on Ad | |
---|---|---|---|---|---|---|---|---|---|
0 | 68.95 | 61833.90 | 256.09 | Cloned 5thgeneration orchestration | Wrightburgh | 0 | Tunisia | 3/27/2016 0:53 | 0 |
1 | 80.23 | 68441.85 | 193.77 | Monitored national standardization | West Jodi | 1 | Nauru | 4/4/2016 1:39 | 0 |
2 | 69.47 | 59785.94 | 236.50 | Organic bottom-line service-desk | Davidton | 0 | San Marino | 3/13/2016 20:35 | 0 |
3 | 74.15 | 54806.18 | 245.89 | Triple-buffered reciprocal time-frame | West Terrifurt | 1 | Italy | 1/10/2016 2:31 | 0 |
4 | 68.37 | 73889.99 | 225.58 | Robust logistical utilization | South Manuel | 0 | Iceland | 6/3/2016 3:36 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
995 | 72.97 | 71384.57 | 208.58 | Fundamental modular algorithm | Duffystad | 1 | Lebanon | 2/11/2016 21:49 | 1 |
996 | 51.30 | 67782.17 | 134.42 | Grass-roots cohesive monitoring | New Darlene | 1 | Bosnia and Herzegovina | 4/22/2016 2:07 | 1 |
997 | 51.63 | 42415.72 | 120.37 | Expanded intangible solution | South Jessica | 1 | Mongolia | 2/1/2016 17:24 | 1 |
998 | 55.55 | 41920.79 | 187.95 | Proactive bandwidth-monitored policy | West Steven | 0 | Guatemala | 3/24/2016 2:35 | 0 |
999 | 45.01 | 29875.80 | 178.35 | Virtual 5thgeneration emulation | Ronniemouth | 0 | Brazil | 6/3/2016 21:43 | 1 |
1000 rows × 9 columns
위처럼 Age 컬럼을 지워버리게 되면 이 컬럼이 머신러닝 알고리즘에서 중요한 컬럼이 될 수 있기 때문에 좋지 않은 방법이다. 무분별한 드랍은 조심해야 할 사항이다.
그렇기 때문에 결측치를 다른 값으로 대체해주는 방법이 가장 좋다. 제일 많이 쓰는 방법은 평균값을 넣어 주는 것이다.
data['Age'].mean()
36.12882096069869
outlier가 있을 경우에는 컬럼의 중간값을 출력해주는 median() 메서드를 이용하여 출력된 값을 넣어줄 수 있다.
data['Age'].median()
35.0
mean과 median의 차이가 크지 않으므로 mean 값을 결측치에 넣어주자.
data = data.fillna(round(data['Age'].mean()))
# 소수점을 없애려면 round() 메서드를 이용하여 소수점을 제거할 수 있다.
data
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Ad Topic Line | City | Male | Country | Timestamp | Clicked on Ad | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 68.95 | 36.0 | 61833.90 | 256.09 | Cloned 5thgeneration orchestration | Wrightburgh | 0 | Tunisia | 3/27/2016 0:53 | 0 |
1 | 80.23 | 31.0 | 68441.85 | 193.77 | Monitored national standardization | West Jodi | 1 | Nauru | 4/4/2016 1:39 | 0 |
2 | 69.47 | 26.0 | 59785.94 | 236.50 | Organic bottom-line service-desk | Davidton | 0 | San Marino | 3/13/2016 20:35 | 0 |
3 | 74.15 | 29.0 | 54806.18 | 245.89 | Triple-buffered reciprocal time-frame | West Terrifurt | 1 | Italy | 1/10/2016 2:31 | 0 |
4 | 68.37 | 35.0 | 73889.99 | 225.58 | Robust logistical utilization | South Manuel | 0 | Iceland | 6/3/2016 3:36 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
995 | 72.97 | 30.0 | 71384.57 | 208.58 | Fundamental modular algorithm | Duffystad | 1 | Lebanon | 2/11/2016 21:49 | 1 |
996 | 51.30 | 45.0 | 67782.17 | 134.42 | Grass-roots cohesive monitoring | New Darlene | 1 | Bosnia and Herzegovina | 4/22/2016 2:07 | 1 |
997 | 51.63 | 51.0 | 42415.72 | 120.37 | Expanded intangible solution | South Jessica | 1 | Mongolia | 2/1/2016 17:24 | 1 |
998 | 55.55 | 19.0 | 41920.79 | 187.95 | Proactive bandwidth-monitored policy | West Steven | 0 | Guatemala | 3/24/2016 2:35 | 0 |
999 | 45.01 | 26.0 | 29875.80 | 178.35 | Virtual 5thgeneration emulation | Ronniemouth | 0 | Brazil | 6/3/2016 21:43 | 1 |
1000 rows × 10 columns
결측치가 없어졌는지 확인해 보자.
data.isnull().sum()
Daily Time Spent on Site 0
Age 0
Area Income 0
Daily Internet Usage 0
Ad Topic Line 0
City 0
Male 0
Country 0
Timestamp 0
Clicked on Ad 0
dtype: int64
결측치가 평균값으로 잘 채워짐을 확인할 수 있다.
하지만 결측치가 70 % ~ 80 % 이라면 20%의 정보로 mean이나 median을 구하여 결측치를 채워 넣는 것은 의미가 없다. 왜냐하면 데이터가 너무 적기 때문이다. 결측치가 많을 경우 컬럼 자체를 제거해 주는 방법이 가장 좋다.
04. Train, Test Set 나누기¶
데이터를 나누기위한 모듈을 불러오자.
data
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Ad Topic Line | City | Male | Country | Timestamp | Clicked on Ad | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 68.95 | 36.0 | 61833.90 | 256.09 | Cloned 5thgeneration orchestration | Wrightburgh | 0 | Tunisia | 3/27/2016 0:53 | 0 |
1 | 80.23 | 31.0 | 68441.85 | 193.77 | Monitored national standardization | West Jodi | 1 | Nauru | 4/4/2016 1:39 | 0 |
2 | 69.47 | 26.0 | 59785.94 | 236.50 | Organic bottom-line service-desk | Davidton | 0 | San Marino | 3/13/2016 20:35 | 0 |
3 | 74.15 | 29.0 | 54806.18 | 245.89 | Triple-buffered reciprocal time-frame | West Terrifurt | 1 | Italy | 1/10/2016 2:31 | 0 |
4 | 68.37 | 35.0 | 73889.99 | 225.58 | Robust logistical utilization | South Manuel | 0 | Iceland | 6/3/2016 3:36 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
995 | 72.97 | 30.0 | 71384.57 | 208.58 | Fundamental modular algorithm | Duffystad | 1 | Lebanon | 2/11/2016 21:49 | 1 |
996 | 51.30 | 45.0 | 67782.17 | 134.42 | Grass-roots cohesive monitoring | New Darlene | 1 | Bosnia and Herzegovina | 4/22/2016 2:07 | 1 |
997 | 51.63 | 51.0 | 42415.72 | 120.37 | Expanded intangible solution | South Jessica | 1 | Mongolia | 2/1/2016 17:24 | 1 |
998 | 55.55 | 19.0 | 41920.79 | 187.95 | Proactive bandwidth-monitored policy | West Steven | 0 | Guatemala | 3/24/2016 2:35 | 0 |
999 | 45.01 | 26.0 | 29875.80 | 178.35 | Virtual 5thgeneration emulation | Ronniemouth | 0 | Brazil | 6/3/2016 21:43 | 1 |
1000 rows × 10 columns
from sklearn.model_selection import train_test_split
X = data[['Daily Time Spent on Site','Age','Area Income','Daily Internet Usage','Male']]
y = data['Clicked on Ad']
X
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Male | |
---|---|---|---|---|---|
0 | 68.95 | 36.0 | 61833.90 | 256.09 | 0 |
1 | 80.23 | 31.0 | 68441.85 | 193.77 | 1 |
2 | 69.47 | 26.0 | 59785.94 | 236.50 | 0 |
3 | 74.15 | 29.0 | 54806.18 | 245.89 | 1 |
4 | 68.37 | 35.0 | 73889.99 | 225.58 | 0 |
... | ... | ... | ... | ... | ... |
995 | 72.97 | 30.0 | 71384.57 | 208.58 | 1 |
996 | 51.30 | 45.0 | 67782.17 | 134.42 | 1 |
997 | 51.63 | 51.0 | 42415.72 | 120.37 | 1 |
998 | 55.55 | 19.0 | 41920.79 | 187.95 | 0 |
999 | 45.01 | 26.0 | 29875.80 | 178.35 | 0 |
1000 rows × 5 columns
y
0 0
1 0
2 0
3 0
4 0
..
995 1
996 1
997 1
998 0
999 1
Name: Clicked on Ad, Length: 1000, dtype: int64
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2 , random_state = 100)
X_train
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Male | |
---|---|---|---|---|---|
675 | 82.58 | 38.0 | 65496.78 | 225.23 | 1 |
358 | 51.38 | 59.0 | 42362.49 | 158.56 | 0 |
159 | 75.55 | 36.0 | 73234.87 | 159.24 | 0 |
533 | 91.43 | 36.0 | 46964.11 | 209.91 | 1 |
678 | 87.85 | 34.0 | 51816.27 | 153.01 | 0 |
... | ... | ... | ... | ... | ... |
855 | 50.87 | 24.0 | 62939.50 | 190.41 | 0 |
871 | 76.79 | 27.0 | 55677.12 | 235.94 | 0 |
835 | 63.11 | 34.0 | 63107.88 | 254.94 | 1 |
792 | 56.56 | 26.0 | 68783.45 | 204.47 | 1 |
520 | 46.61 | 42.0 | 65856.74 | 136.18 | 0 |
800 rows × 5 columns
X_test
Daily Time Spent on Site | Age | Area Income | Daily Internet Usage | Male | |
---|---|---|---|---|---|
249 | 62.20 | 25.0 | 25408.21 | 161.16 | 0 |
353 | 79.54 | 44.0 | 70492.60 | 217.68 | 1 |
537 | 61.72 | 26.0 | 67279.06 | 218.49 | 0 |
424 | 43.59 | 36.0 | 58849.77 | 132.31 | 1 |
564 | 64.75 | 36.0 | 63001.03 | 117.66 | 0 |
... | ... | ... | ... | ... | ... |
684 | 42.06 | 34.0 | 43241.19 | 131.55 | 0 |
644 | 78.35 | 46.0 | 53185.34 | 253.48 | 0 |
110 | 66.63 | 60.0 | 60333.38 | 176.98 | 0 |
28 | 70.20 | 34.0 | 32708.94 | 119.20 | 0 |
804 | 53.92 | 41.0 | 25739.09 | 125.46 | 1 |
200 rows × 5 columns
05. 로지스틱 리그레션 모델 만들고 평가하기¶
이번엔 사이킷런을 이용해서 리그레션 모델을 만들고 평가해보자.
from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model.fit(X_train, y_train)
LogisticRegression()
model.coef_
array([[-6.64737762e-02, 2.66015818e-01, -1.15501902e-05,
-2.44285539e-02, 2.00758165e-03]])
coef를 구하고 그다음으로 predict() 메서드를 이용해서 X_test()를 출력하자.
pred = model.predict(X_test)
y_test
249 1
353 0
537 0
424 1
564 1
..
684 1
644 0
110 1
28 1
804 1
Name: Clicked on Ad, Length: 200, dtype: int64
from sklearn.metrics import accuracy_score, confusion_matrix
accuracy_score(y_test, pred)
0.9
confusion_matrix(y_test, pred)
array([[92, 8],
[12, 88]])
confusion_matrix의 출력을 보면 92라는 숫자는 0인데 0으로 예측한 것이고, 88은 1인데 1이라고 예측한 것의 갯수를 구분해서 나타내준다.
출처 : 패스트캠퍼스
'패스트캠퍼스 스터디' 카테고리의 다른 글
패스트캠퍼스study(ch.04 고객 이탈 예측/KNN) (0) | 2021.06.16 |
---|---|
패스트캠퍼스study(ch.02 고객별 연간 지출액 예측) (0) | 2021.06.09 |