SciKit Logistic Regression Demo¶
This notebook shows the logistic regression demo with SciKit-Learn's LogisticRegression instead of StatsModels.
Setup¶
Let's import some packages:
import pandas as pd
import seaborn as sns
import numpy as np
from scipy.special import logit
import matplotlib.pyplot as plt
import sklearn
from sklearn.linear_model import LogisticRegression
rng = np.random.RandomState(20201024)
Read Data¶
Load the UCLA grad admissions data.
students = pd.read_csv("https://stats.idre.ucla.edu/stat/data/binary.csv")
students.head()
admit | gre | gpa | rank | |
---|---|---|---|---|
0 | 0 | 380 | 3.61 | 3 |
1 | 1 | 660 | 3.67 | 3 |
2 | 1 | 800 | 4.00 | 1 |
3 | 1 | 640 | 3.19 | 4 |
4 | 0 | 520 | 2.93 | 4 |
Let's train and test:
test = students.sample(frac=0.2, random_state=rng)
train_mask = pd.Series(True, index=students.index)
train_mask[test.index] = False
train = students[train_mask].copy()
train.head()
admit | gre | gpa | rank | |
---|---|---|---|---|
0 | 0 | 380 | 3.61 | 3 |
1 | 1 | 660 | 3.67 | 3 |
3 | 1 | 640 | 3.19 | 4 |
7 | 0 | 400 | 3.08 | 2 |
8 | 1 | 540 | 3.39 | 3 |
Train the Model¶
Let's create our logistic regression model. penalty='none'
will turn off regularization:
lg_mod = LogisticRegression(penalty='none')
Now we need to train (or fit
) the model. SciKit-Learn doesn't know about Pandas; it works on NumPy arrays and matrices. Fortunately, a Pandas data frame is a matrix, but we need to separate out the predictors and outcome:
feat_cols = ['gre', 'gpa', 'rank']
out_col = 'admit'
train_x = train[feat_cols]
train_y = train[out_col]
Now we can call fit
. Unlike Statsmodels, SciKit-Learn doesn't have separate model and results objects. We just call fit
, which trains the model in-place:
lg_mod.fit(train_x, train_y)
LogisticRegression(penalty='none')
We can get the coefficients from it:
lg_mod.coef_
array([[ 0.00201959, 0.88632881, -0.57805603]])
And the intercept:
lg_mod.intercept_
array([-3.60332028])
The SciKit-Learn convention is that parameters estimated by the fit
process are stored in fields ending in _
.
Let's quick train one with statsmodels just to see that it's training the same thing:
import statsmodels.api as sm
import statsmodels.formula.api as smf
smf.glm('admit ~ gre + gpa + rank', train, family=sm.families.Binomial()).fit().summary()
Dep. Variable: | admit | No. Observations: | 320 |
---|---|---|---|
Model: | GLM | Df Residuals: | 316 |
Model Family: | Binomial | Df Model: | 3 |
Link Function: | logit | Scale: | 1.0000 |
Method: | IRLS | Log-Likelihood: | -184.48 |
Date: | Sat, 24 Oct 2020 | Deviance: | 368.97 |
Time: | 21:11:55 | Pearson chi2: | 318. |
No. Iterations: | 4 | ||
Covariance Type: | nonrobust |
coef | std err | z | P>|z| | [0.025 | 0.975] | |
---|---|---|---|---|---|---|
Intercept | -3.6033 | 1.286 | -2.802 | 0.005 | -6.124 | -1.083 |
gre | 0.0020 | 0.001 | 1.620 | 0.105 | -0.000 | 0.004 |
gpa | 0.8863 | 0.372 | 2.384 | 0.017 | 0.158 | 1.615 |
rank | -0.5781 | 0.144 | -4.022 | 0.000 | -0.860 | -0.296 |
Let's compute our training set accuracy. SciKit-Learn's predict
method outputs the class, not the score, so it is directly 1/0. Let's compute that, and compare to train_y
:
train_d = lg_mod.predict(train_x)
np.mean(train_d == train_y)
0.7125
It's using a probability of 0.5 - more likely than not - as the decision boundary. More formally, it is using the log odds: $\operatorname{log} \operatorname{Odds}(A) > 0$
It returns the positive class (1 in our case) if that quantity is greater than 0.
If we want the probabilities, we can call predict_proba
:
lg_mod.predict_proba(train_x)
array([[0.79744403, 0.20255597], [0.67955536, 0.32044464], [0.85761431, 0.14238569], [0.77236289, 0.22763711], [0.77595587, 0.22404413], [0.46787112, 0.53212888], [0.68017007, 0.31982993], [0.60796732, 0.39203268], [0.28929082, 0.71070918], [0.64926804, 0.35073196], [0.7890339 , 0.2109661 ], [0.71303456, 0.28696544], [0.91225083, 0.08774917], [0.45512696, 0.54487304], [0.42895462, 0.57104538], [0.82024802, 0.17975198], [0.55209187, 0.44790813], [0.9006386 , 0.0993614 ], [0.33664913, 0.66335087], [0.43280414, 0.56719586], [0.55358476, 0.44641524], [0.69707229, 0.30292771], [0.75254157, 0.24745843], [0.54401017, 0.45598983], [0.66178295, 0.33821705], [0.77700399, 0.22299601], [0.53231559, 0.46768441], [0.84773499, 0.15226501], [0.72619966, 0.27380034], [0.87123671, 0.12876329], [0.81510723, 0.18489277], [0.65598571, 0.34401429], [0.68042809, 0.31957191], [0.80122093, 0.19877907], [0.67697604, 0.32302396], [0.79423431, 0.20576569], [0.62746801, 0.37253199], [0.90664576, 0.09335424], [0.94421386, 0.05578614], [0.65106381, 0.34893619], [0.90488435, 0.09511565], [0.80754187, 0.19245813], [0.61959104, 0.38040896], [0.73966299, 0.26033701], [0.57387359, 0.42612641], [0.79885067, 0.20114933], [0.87699062, 0.12300938], [0.67183458, 0.32816542], [0.9006386 , 0.0993614 ], [0.66569059, 0.33430941], [0.86318032, 0.13681968], [0.68828659, 0.31171341], [0.63455267, 0.36544733], [0.59043287, 0.40956713], [0.50108564, 0.49891436], [0.43523051, 0.56476949], [0.32293726, 0.67706274], [0.62237298, 0.37762702], [0.93828843, 0.06171157], [0.87451512, 0.12548488], [0.51068643, 0.48931357], [0.80275483, 0.19724517], [0.77354727, 0.22645273], [0.54401017, 0.45598983], [0.58065172, 0.41934828], [0.87341255, 0.12658745], [0.79375072, 0.20624928], [0.9288318 , 0.0711682 ], [0.75711137, 0.24288863], [0.7442132 , 0.2557868 ], [0.64681399, 0.35318601], [0.61378519, 0.38621481], [0.48777163, 0.51222837], [0.42240076, 0.57759924], [0.59327694, 0.40672306], [0.6165724 , 0.3834276 ], [0.81804031, 0.18195969], [0.65154402, 0.34845598], [0.68849424, 0.31150576], [0.86516757, 0.13483243], [0.7314378 , 0.2685622 ], [0.89994589, 0.10005411], [0.68021755, 0.31978245], [0.48138388, 0.51861612], [0.65306421, 0.34693579], [0.40427821, 0.59572179], [0.87354875, 0.12645125], [0.67528622, 0.32471378], [0.85964292, 0.14035708], [0.88543467, 0.11456533], [0.52887802, 0.47112198], [0.6338775 , 0.3661225 ], [0.69274338, 0.30725662], [0.59733506, 0.40266494], [0.80589106, 0.19410894], [0.85336765, 0.14663235], [0.84375425, 0.15624575], [0.60931278, 0.39068722], [0.86166305, 0.13833695], [0.45810237, 0.54189763], [0.75141526, 0.24858474], [0.69877311, 0.30122689], [0.8997631 , 0.1002369 ], [0.66766018, 0.33233982], [0.72988405, 0.27011595], [0.63981229, 0.36018771], [0.83171045, 0.16828955], [0.73376108, 0.26623892], [0.7619676 , 0.2380324 ], [0.86213019, 0.13786981], [0.59350087, 0.40649913], [0.62099314, 0.37900686], [0.44931528, 0.55068472], [0.49590498, 0.50409502], [0.79930412, 0.20069588], [0.76337273, 0.23662727], [0.75998651, 0.24001349], [0.82199675, 0.17800325], [0.68492723, 0.31507277], [0.85870364, 0.14129636], [0.43859354, 0.56140646], [0.32100237, 0.67899763], [0.4936894 , 0.5063106 ], [0.7501525 , 0.2498475 ], [0.80137506, 0.19862494], [0.58254078, 0.41745922], [0.63134613, 0.36865387], [0.68427027, 0.31572973], [0.5901876 , 0.4098124 ], [0.54745314, 0.45254686], [0.78407363, 0.21592637], [0.31482547, 0.68517453], [0.63225909, 0.36774091], [0.68619193, 0.31380807], [0.71447559, 0.28552441], [0.80512456, 0.19487544], [0.85274858, 0.14725142], [0.70676829, 0.29323171], [0.52588903, 0.47411097], [0.85532682, 0.14467318], [0.62723091, 0.37276909], [0.80012274, 0.19987726], [0.77250279, 0.22749721], [0.75655142, 0.24344858], [0.88733065, 0.11266935], [0.67590521, 0.32409479], [0.81256919, 0.18743081], [0.45026522, 0.54973478], [0.56595064, 0.43404936], [0.91930117, 0.08069883], [0.4882454 , 0.5117546 ], [0.73796929, 0.26203071], [0.71247395, 0.28752605], [0.70288731, 0.29711269], [0.67092054, 0.32907946], [0.8772 , 0.1228 ], [0.60981602, 0.39018398], [0.78304946, 0.21695054], [0.71266404, 0.28733596], [0.80261514, 0.19738486], [0.31482547, 0.68517453], [0.38267737, 0.61732263], [0.38918776, 0.61081224], [0.41859643, 0.58140357], [0.61914402, 0.38085598], [0.81164066, 0.18835934], [0.71327725, 0.28672275], [0.74061288, 0.25938712], [0.59684708, 0.40315292], [0.80617405, 0.19382595], [0.50653111, 0.49346889], [0.64819235, 0.35180765], [0.60981602, 0.39018398], [0.79244863, 0.20755137], [0.63070859, 0.36929141], [0.65616529, 0.34383471], [0.63954805, 0.36045195], [0.68594502, 0.31405498], [0.89550516, 0.10449484], [0.67916107, 0.32083893], [0.56809418, 0.43190582], [0.86154211, 0.13845789], [0.77181234, 0.22818766], [0.73030981, 0.26969019], [0.36284433, 0.63715567], [0.69082537, 0.30917463], [0.58254078, 0.41745922], [0.91457703, 0.08542297], [0.70389972, 0.29610028], [0.43887584, 0.56112416], [0.79259365, 0.20740635], [0.65775483, 0.34224517], [0.59781175, 0.40218825], [0.56371842, 0.43628158], [0.74290994, 0.25709006], [0.67654242, 0.32345758], [0.8387374 , 0.1612626 ], [0.86035964, 0.13964036], [0.54089045, 0.45910955], [0.84270528, 0.15729472], [0.78603132, 0.21396868], [0.74568817, 0.25431183], [0.68661862, 0.31338138], [0.64206564, 0.35793436], [0.64525103, 0.35474897], [0.52247959, 0.47752041], [0.65240452, 0.34759548], [0.74628102, 0.25371898], [0.81151895, 0.18848105], [0.69648534, 0.30351466], [0.83337031, 0.16662969], [0.87134079, 0.12865921], [0.57918053, 0.42081947], [0.49147407, 0.50852593], [0.78506525, 0.21493475], [0.53327108, 0.46672892], [0.47647543, 0.52352457], [0.71988621, 0.28011379], [0.71226615, 0.28773385], [0.45932863, 0.54067137], [0.86798211, 0.13201789], [0.65552782, 0.34447218], [0.48359698, 0.51640302], [0.84711353, 0.15288647], [0.82124092, 0.17875908], [0.8926151 , 0.1073849 ], [0.70394522, 0.29605478], [0.85446255, 0.14553745], [0.42842545, 0.57157455], [0.78637226, 0.21362774], [0.81875467, 0.18124533], [0.61050829, 0.38949171], [0.6636857 , 0.3363143 ], [0.5620791 , 0.4379209 ], [0.75201525, 0.24798475], [0.63800146, 0.36199854], [0.70471704, 0.29528296], [0.63227911, 0.36772089], [0.69291275, 0.30708725], [0.7612707 , 0.2387293 ], [0.46468871, 0.53531129], [0.8425708 , 0.1574292 ], [0.56982394, 0.43017606], [0.61705182, 0.38294818], [0.7098472 , 0.2901528 ], [0.8590724 , 0.1409276 ], [0.65952993, 0.34047007], [0.54330948, 0.45669052], [0.65995559, 0.34004441], [0.83620457, 0.16379543], [0.83704178, 0.16295822], [0.80436205, 0.19563795], [0.75452383, 0.24547617], [0.8399504 , 0.1600496 ], [0.5457806 , 0.4542194 ], [0.83553884, 0.16446116], [0.65866748, 0.34133252], [0.82156374, 0.17843626], [0.34329824, 0.65670176], [0.61120011, 0.38879989], [0.63275058, 0.36724942], [0.65197439, 0.34802561], [0.90960354, 0.09039646], [0.57387359, 0.42612641], [0.6947955 , 0.3052045 ], [0.67031571, 0.32968429], [0.76534996, 0.23465004], [0.41118901, 0.58881099], [0.81688193, 0.18311807], [0.87265493, 0.12734507], [0.56329437, 0.43670563], [0.71687134, 0.28312866], [0.88522875, 0.11477125], [0.86503378, 0.13496622], [0.80293629, 0.19706371], [0.70597215, 0.29402785], [0.78886503, 0.21113497], [0.83782249, 0.16217751], [0.64071621, 0.35928379], [0.55478368, 0.44521632], [0.72582248, 0.27417752], [0.72735549, 0.27264451], [0.41067717, 0.58932283], [0.73985824, 0.26014176], [0.80230725, 0.19769275], [0.55622092, 0.44377908], [0.57896556, 0.42103444], [0.53889636, 0.46110364], [0.59658231, 0.40341769], [0.39792269, 0.60207731], [0.49937692, 0.50062308], [0.64635048, 0.35364952], [0.68599204, 0.31400796], [0.51807675, 0.48192325], [0.8703536 , 0.1296464 ], [0.85954686, 0.14045314], [0.65284485, 0.34715515], [0.65996455, 0.34003545], [0.48557889, 0.51442111], [0.56908929, 0.43091071], [0.57197412, 0.42802588], [0.40093793, 0.59906207], [0.78238201, 0.21761799], [0.78172776, 0.21827224], [0.52762582, 0.47237418], [0.64004598, 0.35995402], [0.63595194, 0.36404806], [0.81272968, 0.18727032], [0.65845878, 0.34154122], [0.46100513, 0.53899487], [0.64794141, 0.35205859], [0.65863775, 0.34136225], [0.58804217, 0.41195783], [0.54575926, 0.45424074], [0.70516645, 0.29483355], [0.49049325, 0.50950675], [0.81936864, 0.18063136], [0.8174811 , 0.1825189 ], [0.52762582, 0.47237418], [0.66327264, 0.33672736]])
It returns the probability of each class. Since this is binary, each row will sum to 1.
Let's make sure the classes are in the order we expect:
lg_mod.classes_
array([0, 1], dtype=int64)
Now we can just get the probability of 1 and save it:
lg_mod.decision_function(train_x)
array([-1.37039542e+00, -7.51729186e-01, -1.79561494e+00, -1.22170176e+00, -1.24225261e+00, 1.28692834e-01, -7.54553462e-01, -4.38775908e-01, 8.98830850e-01, -6.15823364e-01, -1.31911185e+00, -9.10168099e-01, -2.34143256e+00, 1.79976401e-01, 2.86117553e-01, -1.51802873e+00, -2.09126306e-01, -2.20434038e+00, 6.78262841e-01, 2.70419364e-01, -2.15165319e-01, -8.33394940e-01, -1.11221364e+00, -1.76497429e-01, -6.71249784e-01, -1.24829162e+00, -1.29442793e-01, -1.71694562e+00, -9.75425854e-01, -1.91193796e+00, -1.48354363e+00, -6.45455808e-01, -7.55739813e-01, -1.39394270e+00, -7.39909397e-01, -1.35064045e+00, -5.21369775e-01, -2.27335053e+00, -2.82882732e+00, -6.23718606e-01, -2.25271360e+00, -1.43411629e+00, -4.87812784e-01, -1.04421769e+00, -2.97673108e-01, -1.37912648e+00, -1.96423572e+00, -7.16494339e-01, -2.20434038e+00, -6.88758055e-01, -1.84195977e+00, -7.92121079e-01, -5.51798108e-01, -3.65755138e-01, -4.34256625e-03, 2.60541882e-01, 7.40305858e-01, -4.99632573e-01, -2.72158590e+00, -1.94148432e+00, -4.27522197e-02, -1.40360188e+00, -1.22845058e+00, -1.76497429e-01, -3.25449323e-01, -1.93147461e+00, -1.34768395e+00, -2.56888161e+00, -1.13690735e+00, -1.06798328e+00, -6.05063915e-01, -4.63251306e-01, 4.89232410e-02, 3.12925722e-01, -3.77528780e-01, -4.75024948e-01, -1.50312643e+00, -6.25833072e-01, -7.93089126e-01, -1.85889045e+00, -1.00192964e+00, -2.19662351e+00, -7.54771766e-01, 7.44989119e-02, -6.32535747e-01, 3.87670495e-01, -1.93270711e+00, -7.32192528e-01, -1.81232732e+00, -2.04493342e+00, -1.15640764e-01, -5.48887754e-01, -8.12976314e-01, -3.94373388e-01, -1.42352900e+00, -1.76126206e+00, -1.68643121e+00, -4.44424458e-01, -1.82917193e+00, 1.67984454e-01, -1.10617463e+00, -8.41462339e-01, -2.19459513e+00, -6.97621343e-01, -9.94034397e-01, -5.74549504e-01, -1.59779833e+00, -1.01378936e+00, -1.16349721e+00, -1.83309648e+00, -3.78456894e-01, -4.93765718e-01, 2.03437607e-01, 1.63804428e-02, -1.38195075e+00, -1.17126023e+00, -1.15260554e+00, -1.52993460e+00, -7.76508969e-01, -1.80456431e+00, 2.46872080e-01, 7.49169146e-01, 2.52437308e-02, -1.09942581e+00, -1.39491075e+00, -3.33212339e-01, -5.37996079e-01, -7.73466389e-01, -3.64740944e-01, -1.90385537e-01, -1.28956548e+00, 7.77655171e-01, -5.41920627e-01, -7.82375824e-01, -9.17221306e-01, -1.41863640e+00, -1.75632332e+00, -8.79739767e-01, -1.03648817e-01, -1.77700640e+00, -5.20355582e-01, -1.38706165e+00, -1.22249765e+00, -1.13386477e+00, -2.06376027e+00, -7.35016804e-01, -1.46679131e+00, 1.99599138e-01, -2.65348615e-01, -2.43288971e+00, 4.70270804e-02, -1.03544048e+00, -9.07429903e-01, -8.61085076e-01, -7.12351487e-01, -1.96617803e+00, -4.46538924e-01, -1.28352647e+00, -9.08358018e-01, -1.40271991e+00, 7.77655171e-01, 4.78199537e-01, 4.50727705e-01, 3.28537833e-01, -4.85916623e-01, -1.46070615e+00, -9.11354451e-01, -1.04915643e+00, -3.92345002e-01, -1.42533908e+00, -2.61259158e-02, -6.11102928e-01, -4.46538924e-01, -1.33974878e+00, -5.35257883e-01, -6.46251697e-01, -5.73403084e-01, -7.81229405e-01, -2.14825030e+00, -7.49919105e-01, -2.74079677e-01, -1.82815774e+00, -1.21857310e+00, -9.96195010e-01, 5.63040096e-01, -8.03980800e-01, -3.33212339e-01, -2.37084670e+00, -8.65937738e-01, 2.45725660e-01, -1.34063075e+00, -6.53304903e-01, -3.96355628e-01, -2.56267022e-01, -1.06114838e+00, -7.37927158e-01, -1.64886359e+00, -1.81828026e+00, -1.63927899e-01, -1.67849603e+00, -1.30116697e+00, -1.07574630e+00, -7.84358064e-01, -5.84340906e-01, -5.98229014e-01, -8.99790141e-02, -6.29625393e-01, -1.07887496e+00, -1.45991026e+00, -8.30616812e-01, -1.60970419e+00, -1.91286607e+00, -3.19410311e-01, 3.41070189e-02, -1.29543234e+00, -1.33281261e-01, 9.41677959e-02, -9.43897249e-01, -9.06415710e-01, 1.63045714e-01, -1.88323363e+00, -6.43427422e-01, 6.56356238e-02, -1.71213910e+00, -1.52477756e+00, -2.11773589e+00, -8.66156043e-01, -1.77003927e+00, 2.88278166e-01, -1.30319535e+00, -1.50793295e+00, -4.49449278e-01, -6.79762541e-01, -2.49604278e-01, -1.10938937e+00, -5.66700409e-01, -8.69862285e-01, -5.42006706e-01, -8.13772203e-01, -1.15965874e+00, 1.41480669e-01, -1.67748184e+00, -2.81132883e-01, -4.77053335e-01, -8.94642068e-01, -1.80760689e+00, -6.61200145e-01, -1.73673154e-01, -6.63096306e-01, -1.63025505e+00, -1.63638014e+00, -1.41378374e+00, -1.12288701e+00, -1.65785910e+00, -1.83636714e-01, -1.62540238e+00, -6.57361677e-01, -1.52697810e+00, 6.48630397e-01, -4.52359632e-01, -5.44035092e-01, -6.27729232e-01, -2.30880368e+00, -2.97673108e-01, -8.22635491e-01, -7.09613290e-01, -1.18223798e+00, 3.59052245e-01, -1.49536342e+00, -1.92463971e+00, -2.54543019e-01, -9.28994948e-01, -2.04290504e+00, -1.85774403e+00, -1.40474830e+00, -8.75901298e-01, -1.31809766e+00, -1.64211477e+00, -5.78474051e-01, -2.20017981e-01, -9.73529693e-01, -9.81246562e-01, 3.61166710e-01, -1.04523189e+00, -1.40077760e+00, -2.25838689e-01, -3.18528343e-01, -1.55900431e-01, -3.91244730e-01, 4.14128134e-01, 2.49233513e-03, -6.03035529e-01, -7.81447709e-01, -7.23385168e-02, -1.90408886e+00, -1.81153143e+00, -6.31567701e-01, -6.63136237e-01, 5.77004502e-02, -2.78136450e-01, -2.89910092e-01, 4.01558603e-01, -1.27960192e+00, -1.27576345e+00, -1.10615944e-01, -5.75563697e-01, -5.57837121e-01, -1.46784544e+00, -6.56433562e-01, 1.56296891e-01, -6.10002656e-01, -6.57229451e-01, -3.55877656e-01, -1.83550635e-01, -8.72022898e-01, 3.80315663e-02, -1.51207580e+00, -1.49937404e+00, -1.10615944e-01, -6.77912528e-01])
train['score'] = lg_mod.predict_proba(train_x)[:, 1]
train
admit | gre | gpa | rank | score | |
---|---|---|---|---|---|
0 | 0 | 380 | 3.61 | 3 | 0.202556 |
1 | 1 | 660 | 3.67 | 3 | 0.320445 |
3 | 1 | 640 | 3.19 | 4 | 0.142386 |
7 | 0 | 400 | 3.08 | 2 | 0.227637 |
8 | 1 | 540 | 3.39 | 3 | 0.224044 |
... | ... | ... | ... | ... | ... |
395 | 0 | 620 | 4.00 | 2 | 0.509507 |
396 | 0 | 560 | 3.04 | 3 | 0.180631 |
397 | 0 | 460 | 2.63 | 2 | 0.182519 |
398 | 0 | 700 | 3.65 | 2 | 0.472374 |
399 | 0 | 600 | 3.89 | 3 | 0.336727 |
320 rows × 5 columns
We can get the log odds ratio itself with decision_function
:
lg_mod.decision_function(train_x)
array([-1.37039542e+00, -7.51729186e-01, -1.79561494e+00, -1.22170176e+00, -1.24225261e+00, 1.28692834e-01, -7.54553462e-01, -4.38775908e-01, 8.98830850e-01, -6.15823364e-01, -1.31911185e+00, -9.10168099e-01, -2.34143256e+00, 1.79976401e-01, 2.86117553e-01, -1.51802873e+00, -2.09126306e-01, -2.20434038e+00, 6.78262841e-01, 2.70419364e-01, -2.15165319e-01, -8.33394940e-01, -1.11221364e+00, -1.76497429e-01, -6.71249784e-01, -1.24829162e+00, -1.29442793e-01, -1.71694562e+00, -9.75425854e-01, -1.91193796e+00, -1.48354363e+00, -6.45455808e-01, -7.55739813e-01, -1.39394270e+00, -7.39909397e-01, -1.35064045e+00, -5.21369775e-01, -2.27335053e+00, -2.82882732e+00, -6.23718606e-01, -2.25271360e+00, -1.43411629e+00, -4.87812784e-01, -1.04421769e+00, -2.97673108e-01, -1.37912648e+00, -1.96423572e+00, -7.16494339e-01, -2.20434038e+00, -6.88758055e-01, -1.84195977e+00, -7.92121079e-01, -5.51798108e-01, -3.65755138e-01, -4.34256625e-03, 2.60541882e-01, 7.40305858e-01, -4.99632573e-01, -2.72158590e+00, -1.94148432e+00, -4.27522197e-02, -1.40360188e+00, -1.22845058e+00, -1.76497429e-01, -3.25449323e-01, -1.93147461e+00, -1.34768395e+00, -2.56888161e+00, -1.13690735e+00, -1.06798328e+00, -6.05063915e-01, -4.63251306e-01, 4.89232410e-02, 3.12925722e-01, -3.77528780e-01, -4.75024948e-01, -1.50312643e+00, -6.25833072e-01, -7.93089126e-01, -1.85889045e+00, -1.00192964e+00, -2.19662351e+00, -7.54771766e-01, 7.44989119e-02, -6.32535747e-01, 3.87670495e-01, -1.93270711e+00, -7.32192528e-01, -1.81232732e+00, -2.04493342e+00, -1.15640764e-01, -5.48887754e-01, -8.12976314e-01, -3.94373388e-01, -1.42352900e+00, -1.76126206e+00, -1.68643121e+00, -4.44424458e-01, -1.82917193e+00, 1.67984454e-01, -1.10617463e+00, -8.41462339e-01, -2.19459513e+00, -6.97621343e-01, -9.94034397e-01, -5.74549504e-01, -1.59779833e+00, -1.01378936e+00, -1.16349721e+00, -1.83309648e+00, -3.78456894e-01, -4.93765718e-01, 2.03437607e-01, 1.63804428e-02, -1.38195075e+00, -1.17126023e+00, -1.15260554e+00, -1.52993460e+00, -7.76508969e-01, -1.80456431e+00, 2.46872080e-01, 7.49169146e-01, 2.52437308e-02, -1.09942581e+00, -1.39491075e+00, -3.33212339e-01, -5.37996079e-01, -7.73466389e-01, -3.64740944e-01, -1.90385537e-01, -1.28956548e+00, 7.77655171e-01, -5.41920627e-01, -7.82375824e-01, -9.17221306e-01, -1.41863640e+00, -1.75632332e+00, -8.79739767e-01, -1.03648817e-01, -1.77700640e+00, -5.20355582e-01, -1.38706165e+00, -1.22249765e+00, -1.13386477e+00, -2.06376027e+00, -7.35016804e-01, -1.46679131e+00, 1.99599138e-01, -2.65348615e-01, -2.43288971e+00, 4.70270804e-02, -1.03544048e+00, -9.07429903e-01, -8.61085076e-01, -7.12351487e-01, -1.96617803e+00, -4.46538924e-01, -1.28352647e+00, -9.08358018e-01, -1.40271991e+00, 7.77655171e-01, 4.78199537e-01, 4.50727705e-01, 3.28537833e-01, -4.85916623e-01, -1.46070615e+00, -9.11354451e-01, -1.04915643e+00, -3.92345002e-01, -1.42533908e+00, -2.61259158e-02, -6.11102928e-01, -4.46538924e-01, -1.33974878e+00, -5.35257883e-01, -6.46251697e-01, -5.73403084e-01, -7.81229405e-01, -2.14825030e+00, -7.49919105e-01, -2.74079677e-01, -1.82815774e+00, -1.21857310e+00, -9.96195010e-01, 5.63040096e-01, -8.03980800e-01, -3.33212339e-01, -2.37084670e+00, -8.65937738e-01, 2.45725660e-01, -1.34063075e+00, -6.53304903e-01, -3.96355628e-01, -2.56267022e-01, -1.06114838e+00, -7.37927158e-01, -1.64886359e+00, -1.81828026e+00, -1.63927899e-01, -1.67849603e+00, -1.30116697e+00, -1.07574630e+00, -7.84358064e-01, -5.84340906e-01, -5.98229014e-01, -8.99790141e-02, -6.29625393e-01, -1.07887496e+00, -1.45991026e+00, -8.30616812e-01, -1.60970419e+00, -1.91286607e+00, -3.19410311e-01, 3.41070189e-02, -1.29543234e+00, -1.33281261e-01, 9.41677959e-02, -9.43897249e-01, -9.06415710e-01, 1.63045714e-01, -1.88323363e+00, -6.43427422e-01, 6.56356238e-02, -1.71213910e+00, -1.52477756e+00, -2.11773589e+00, -8.66156043e-01, -1.77003927e+00, 2.88278166e-01, -1.30319535e+00, -1.50793295e+00, -4.49449278e-01, -6.79762541e-01, -2.49604278e-01, -1.10938937e+00, -5.66700409e-01, -8.69862285e-01, -5.42006706e-01, -8.13772203e-01, -1.15965874e+00, 1.41480669e-01, -1.67748184e+00, -2.81132883e-01, -4.77053335e-01, -8.94642068e-01, -1.80760689e+00, -6.61200145e-01, -1.73673154e-01, -6.63096306e-01, -1.63025505e+00, -1.63638014e+00, -1.41378374e+00, -1.12288701e+00, -1.65785910e+00, -1.83636714e-01, -1.62540238e+00, -6.57361677e-01, -1.52697810e+00, 6.48630397e-01, -4.52359632e-01, -5.44035092e-01, -6.27729232e-01, -2.30880368e+00, -2.97673108e-01, -8.22635491e-01, -7.09613290e-01, -1.18223798e+00, 3.59052245e-01, -1.49536342e+00, -1.92463971e+00, -2.54543019e-01, -9.28994948e-01, -2.04290504e+00, -1.85774403e+00, -1.40474830e+00, -8.75901298e-01, -1.31809766e+00, -1.64211477e+00, -5.78474051e-01, -2.20017981e-01, -9.73529693e-01, -9.81246562e-01, 3.61166710e-01, -1.04523189e+00, -1.40077760e+00, -2.25838689e-01, -3.18528343e-01, -1.55900431e-01, -3.91244730e-01, 4.14128134e-01, 2.49233513e-03, -6.03035529e-01, -7.81447709e-01, -7.23385168e-02, -1.90408886e+00, -1.81153143e+00, -6.31567701e-01, -6.63136237e-01, 5.77004502e-02, -2.78136450e-01, -2.89910092e-01, 4.01558603e-01, -1.27960192e+00, -1.27576345e+00, -1.10615944e-01, -5.75563697e-01, -5.57837121e-01, -1.46784544e+00, -6.56433562e-01, 1.56296891e-01, -6.10002656e-01, -6.57229451e-01, -3.55877656e-01, -1.83550635e-01, -8.72022898e-01, 3.80315663e-02, -1.51207580e+00, -1.49937404e+00, -1.10615944e-01, -6.77912528e-01])
Analysis with Test Data¶
Now we'll predict our test data:
test_x = test[feat_cols]
test_y = test[out_col]
test_d = lg_mod.predict(test_x)
What's the accuracy?
np.mean(test_y == test_d)
0.7
Practice: compute the precision and recall.
ROC Curve¶
We can plot the receiver operating characteristic by using the SciKit roc_curve
function to compare outcomes and the decision function. It returns the tpr
and fpr
:
test_lo = lg_mod.decision_function(test_x)
fpr, tpr, thresh = sklearn.metrics.roc_curve(test_y, test_lo)
plt.plot(np.linspace(0, 1), np.linspace(0, 1), color='grey', linestyle=':')
plt.plot(fpr, tpr)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC')
plt.show()
What's the area under that curve?
sklearn.metrics.roc_auc_score(test_y, test_lo)
0.7170909090909092
I'm going to fill it now for the slide demo:
plt.fill_between(fpr, tpr, 0, color='lightgrey')
plt.plot(np.linspace(0, 1), np.linspace(0, 1), color='red', linestyle=':')
plt.plot(fpr, tpr)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC')
plt.show()
Let's see a precision-recall curve:
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(test_y, test_lo)
plt.plot(recall, precision)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.show()
Precision vs. Threshold¶
What if we decrease the threshold?
def precision_at_threshold(thresh):
mask = test_lo >= thresh
# mask will select for positive class
# mean is fraction
return np.mean(test_y[mask])
thresh = np.linspace(-2, 1, 20)
precs = [precision_at_threshold(p) for p in thresh]
plt.plot(thresh, precs)
plt.xlim(1, -2)
plt.xlabel('Threshold (Log Odds)')
plt.ylabel('Precision')
Text(0, 0.5, 'Precision')