SciKit Logistic Regression Demo
Contents
SciKit Logistic Regression Demo#
This notebook shows the logistic regression demo with SciKit-Learn’s LogisticRegression
instead of StatsModels.
Setup#
Let’s import some packages:
import pandas as pd
import seaborn as sns
import numpy as np
from scipy.special import logit
import matplotlib.pyplot as plt
import sklearn
from sklearn.linear_model import LogisticRegression
rng = np.random.RandomState(20201024)
Read Data#
Load the UCLA grad admissions data.
students = pd.read_csv("https://stats.idre.ucla.edu/stat/data/binary.csv")
students.head()
admit | gre | gpa | rank | |
---|---|---|---|---|
0 | 0 | 380 | 3.61 | 3 |
1 | 1 | 660 | 3.67 | 3 |
2 | 1 | 800 | 4.00 | 1 |
3 | 1 | 640 | 3.19 | 4 |
4 | 0 | 520 | 2.93 | 4 |
Let’s train and test:
test = students.sample(frac=0.2, random_state=rng)
train_mask = pd.Series(True, index=students.index)
train_mask[test.index] = False
train = students[train_mask].copy()
train.head()
admit | gre | gpa | rank | |
---|---|---|---|---|
0 | 0 | 380 | 3.61 | 3 |
1 | 1 | 660 | 3.67 | 3 |
3 | 1 | 640 | 3.19 | 4 |
7 | 0 | 400 | 3.08 | 2 |
8 | 1 | 540 | 3.39 | 3 |
Train the Model#
Let’s create our logistic regression model. penalty='none'
will turn off regularization:
lg_mod = LogisticRegression(penalty='none')
Now we need to train (or fit
) the model. SciKit-Learn doesn’t know about Pandas; it works on NumPy arrays and matrices. Fortunately, a Pandas data frame is a matrix, but we need to separate out the predictors and outcome:
feat_cols = ['gre', 'gpa', 'rank']
out_col = 'admit'
train_x = train[feat_cols]
train_y = train[out_col]
Now we can call fit
. Unlike Statsmodels, SciKit-Learn doesn’t have separate model and results objects. We just call fit
, which trains the model in-place:
lg_mod.fit(train_x, train_y)
LogisticRegression(penalty='none')
We can get the coefficients from it:
lg_mod.coef_
array([[ 0.00201959, 0.88632881, -0.57805603]])
And the intercept:
lg_mod.intercept_
array([-3.60332028])
The SciKit-Learn convention is that parameters estimated by the fit
process are stored in fields ending in _
.
Let’s quick train one with statsmodels just to see that it’s training the same thing:
import statsmodels.api as sm
import statsmodels.formula.api as smf
smf.glm('admit ~ gre + gpa + rank', train, family=sm.families.Binomial()).fit().summary()
Dep. Variable: | admit | No. Observations: | 320 |
---|---|---|---|
Model: | GLM | Df Residuals: | 316 |
Model Family: | Binomial | Df Model: | 3 |
Link Function: | logit | Scale: | 1.0000 |
Method: | IRLS | Log-Likelihood: | -184.48 |
Date: | Sat, 24 Oct 2020 | Deviance: | 368.97 |
Time: | 21:11:55 | Pearson chi2: | 318. |
No. Iterations: | 4 | ||
Covariance Type: | nonrobust |
coef | std err | z | P>|z| | [0.025 | 0.975] | |
---|---|---|---|---|---|---|
Intercept | -3.6033 | 1.286 | -2.802 | 0.005 | -6.124 | -1.083 |
gre | 0.0020 | 0.001 | 1.620 | 0.105 | -0.000 | 0.004 |
gpa | 0.8863 | 0.372 | 2.384 | 0.017 | 0.158 | 1.615 |
rank | -0.5781 | 0.144 | -4.022 | 0.000 | -0.860 | -0.296 |
Let’s compute our training set accuracy. SciKit-Learn’s predict
method outputs the class, not the score, so it is directly 1/0. Let’s compute that, and compare to train_y
:
train_d = lg_mod.predict(train_x)
np.mean(train_d == train_y)
0.7125
It’s using a probability of 0.5 - more likely than not - as the decision boundary. More formally, it is using the log odds: \(\operatorname{log} \operatorname{Odds}(A) > 0\)
It returns the positive class (1 in our case) if that quantity is greater than 0.
If we want the probabilities, we can call predict_proba
:
lg_mod.predict_proba(train_x)
array([[0.79744403, 0.20255597],
[0.67955536, 0.32044464],
[0.85761431, 0.14238569],
[0.77236289, 0.22763711],
[0.77595587, 0.22404413],
[0.46787112, 0.53212888],
[0.68017007, 0.31982993],
[0.60796732, 0.39203268],
[0.28929082, 0.71070918],
[0.64926804, 0.35073196],
[0.7890339 , 0.2109661 ],
[0.71303456, 0.28696544],
[0.91225083, 0.08774917],
[0.45512696, 0.54487304],
[0.42895462, 0.57104538],
[0.82024802, 0.17975198],
[0.55209187, 0.44790813],
[0.9006386 , 0.0993614 ],
[0.33664913, 0.66335087],
[0.43280414, 0.56719586],
[0.55358476, 0.44641524],
[0.69707229, 0.30292771],
[0.75254157, 0.24745843],
[0.54401017, 0.45598983],
[0.66178295, 0.33821705],
[0.77700399, 0.22299601],
[0.53231559, 0.46768441],
[0.84773499, 0.15226501],
[0.72619966, 0.27380034],
[0.87123671, 0.12876329],
[0.81510723, 0.18489277],
[0.65598571, 0.34401429],
[0.68042809, 0.31957191],
[0.80122093, 0.19877907],
[0.67697604, 0.32302396],
[0.79423431, 0.20576569],
[0.62746801, 0.37253199],
[0.90664576, 0.09335424],
[0.94421386, 0.05578614],
[0.65106381, 0.34893619],
[0.90488435, 0.09511565],
[0.80754187, 0.19245813],
[0.61959104, 0.38040896],
[0.73966299, 0.26033701],
[0.57387359, 0.42612641],
[0.79885067, 0.20114933],
[0.87699062, 0.12300938],
[0.67183458, 0.32816542],
[0.9006386 , 0.0993614 ],
[0.66569059, 0.33430941],
[0.86318032, 0.13681968],
[0.68828659, 0.31171341],
[0.63455267, 0.36544733],
[0.59043287, 0.40956713],
[0.50108564, 0.49891436],
[0.43523051, 0.56476949],
[0.32293726, 0.67706274],
[0.62237298, 0.37762702],
[0.93828843, 0.06171157],
[0.87451512, 0.12548488],
[0.51068643, 0.48931357],
[0.80275483, 0.19724517],
[0.77354727, 0.22645273],
[0.54401017, 0.45598983],
[0.58065172, 0.41934828],
[0.87341255, 0.12658745],
[0.79375072, 0.20624928],
[0.9288318 , 0.0711682 ],
[0.75711137, 0.24288863],
[0.7442132 , 0.2557868 ],
[0.64681399, 0.35318601],
[0.61378519, 0.38621481],
[0.48777163, 0.51222837],
[0.42240076, 0.57759924],
[0.59327694, 0.40672306],
[0.6165724 , 0.3834276 ],
[0.81804031, 0.18195969],
[0.65154402, 0.34845598],
[0.68849424, 0.31150576],
[0.86516757, 0.13483243],
[0.7314378 , 0.2685622 ],
[0.89994589, 0.10005411],
[0.68021755, 0.31978245],
[0.48138388, 0.51861612],
[0.65306421, 0.34693579],
[0.40427821, 0.59572179],
[0.87354875, 0.12645125],
[0.67528622, 0.32471378],
[0.85964292, 0.14035708],
[0.88543467, 0.11456533],
[0.52887802, 0.47112198],
[0.6338775 , 0.3661225 ],
[0.69274338, 0.30725662],
[0.59733506, 0.40266494],
[0.80589106, 0.19410894],
[0.85336765, 0.14663235],
[0.84375425, 0.15624575],
[0.60931278, 0.39068722],
[0.86166305, 0.13833695],
[0.45810237, 0.54189763],
[0.75141526, 0.24858474],
[0.69877311, 0.30122689],
[0.8997631 , 0.1002369 ],
[0.66766018, 0.33233982],
[0.72988405, 0.27011595],
[0.63981229, 0.36018771],
[0.83171045, 0.16828955],
[0.73376108, 0.26623892],
[0.7619676 , 0.2380324 ],
[0.86213019, 0.13786981],
[0.59350087, 0.40649913],
[0.62099314, 0.37900686],
[0.44931528, 0.55068472],
[0.49590498, 0.50409502],
[0.79930412, 0.20069588],
[0.76337273, 0.23662727],
[0.75998651, 0.24001349],
[0.82199675, 0.17800325],
[0.68492723, 0.31507277],
[0.85870364, 0.14129636],
[0.43859354, 0.56140646],
[0.32100237, 0.67899763],
[0.4936894 , 0.5063106 ],
[0.7501525 , 0.2498475 ],
[0.80137506, 0.19862494],
[0.58254078, 0.41745922],
[0.63134613, 0.36865387],
[0.68427027, 0.31572973],
[0.5901876 , 0.4098124 ],
[0.54745314, 0.45254686],
[0.78407363, 0.21592637],
[0.31482547, 0.68517453],
[0.63225909, 0.36774091],
[0.68619193, 0.31380807],
[0.71447559, 0.28552441],
[0.80512456, 0.19487544],
[0.85274858, 0.14725142],
[0.70676829, 0.29323171],
[0.52588903, 0.47411097],
[0.85532682, 0.14467318],
[0.62723091, 0.37276909],
[0.80012274, 0.19987726],
[0.77250279, 0.22749721],
[0.75655142, 0.24344858],
[0.88733065, 0.11266935],
[0.67590521, 0.32409479],
[0.81256919, 0.18743081],
[0.45026522, 0.54973478],
[0.56595064, 0.43404936],
[0.91930117, 0.08069883],
[0.4882454 , 0.5117546 ],
[0.73796929, 0.26203071],
[0.71247395, 0.28752605],
[0.70288731, 0.29711269],
[0.67092054, 0.32907946],
[0.8772 , 0.1228 ],
[0.60981602, 0.39018398],
[0.78304946, 0.21695054],
[0.71266404, 0.28733596],
[0.80261514, 0.19738486],
[0.31482547, 0.68517453],
[0.38267737, 0.61732263],
[0.38918776, 0.61081224],
[0.41859643, 0.58140357],
[0.61914402, 0.38085598],
[0.81164066, 0.18835934],
[0.71327725, 0.28672275],
[0.74061288, 0.25938712],
[0.59684708, 0.40315292],
[0.80617405, 0.19382595],
[0.50653111, 0.49346889],
[0.64819235, 0.35180765],
[0.60981602, 0.39018398],
[0.79244863, 0.20755137],
[0.63070859, 0.36929141],
[0.65616529, 0.34383471],
[0.63954805, 0.36045195],
[0.68594502, 0.31405498],
[0.89550516, 0.10449484],
[0.67916107, 0.32083893],
[0.56809418, 0.43190582],
[0.86154211, 0.13845789],
[0.77181234, 0.22818766],
[0.73030981, 0.26969019],
[0.36284433, 0.63715567],
[0.69082537, 0.30917463],
[0.58254078, 0.41745922],
[0.91457703, 0.08542297],
[0.70389972, 0.29610028],
[0.43887584, 0.56112416],
[0.79259365, 0.20740635],
[0.65775483, 0.34224517],
[0.59781175, 0.40218825],
[0.56371842, 0.43628158],
[0.74290994, 0.25709006],
[0.67654242, 0.32345758],
[0.8387374 , 0.1612626 ],
[0.86035964, 0.13964036],
[0.54089045, 0.45910955],
[0.84270528, 0.15729472],
[0.78603132, 0.21396868],
[0.74568817, 0.25431183],
[0.68661862, 0.31338138],
[0.64206564, 0.35793436],
[0.64525103, 0.35474897],
[0.52247959, 0.47752041],
[0.65240452, 0.34759548],
[0.74628102, 0.25371898],
[0.81151895, 0.18848105],
[0.69648534, 0.30351466],
[0.83337031, 0.16662969],
[0.87134079, 0.12865921],
[0.57918053, 0.42081947],
[0.49147407, 0.50852593],
[0.78506525, 0.21493475],
[0.53327108, 0.46672892],
[0.47647543, 0.52352457],
[0.71988621, 0.28011379],
[0.71226615, 0.28773385],
[0.45932863, 0.54067137],
[0.86798211, 0.13201789],
[0.65552782, 0.34447218],
[0.48359698, 0.51640302],
[0.84711353, 0.15288647],
[0.82124092, 0.17875908],
[0.8926151 , 0.1073849 ],
[0.70394522, 0.29605478],
[0.85446255, 0.14553745],
[0.42842545, 0.57157455],
[0.78637226, 0.21362774],
[0.81875467, 0.18124533],
[0.61050829, 0.38949171],
[0.6636857 , 0.3363143 ],
[0.5620791 , 0.4379209 ],
[0.75201525, 0.24798475],
[0.63800146, 0.36199854],
[0.70471704, 0.29528296],
[0.63227911, 0.36772089],
[0.69291275, 0.30708725],
[0.7612707 , 0.2387293 ],
[0.46468871, 0.53531129],
[0.8425708 , 0.1574292 ],
[0.56982394, 0.43017606],
[0.61705182, 0.38294818],
[0.7098472 , 0.2901528 ],
[0.8590724 , 0.1409276 ],
[0.65952993, 0.34047007],
[0.54330948, 0.45669052],
[0.65995559, 0.34004441],
[0.83620457, 0.16379543],
[0.83704178, 0.16295822],
[0.80436205, 0.19563795],
[0.75452383, 0.24547617],
[0.8399504 , 0.1600496 ],
[0.5457806 , 0.4542194 ],
[0.83553884, 0.16446116],
[0.65866748, 0.34133252],
[0.82156374, 0.17843626],
[0.34329824, 0.65670176],
[0.61120011, 0.38879989],
[0.63275058, 0.36724942],
[0.65197439, 0.34802561],
[0.90960354, 0.09039646],
[0.57387359, 0.42612641],
[0.6947955 , 0.3052045 ],
[0.67031571, 0.32968429],
[0.76534996, 0.23465004],
[0.41118901, 0.58881099],
[0.81688193, 0.18311807],
[0.87265493, 0.12734507],
[0.56329437, 0.43670563],
[0.71687134, 0.28312866],
[0.88522875, 0.11477125],
[0.86503378, 0.13496622],
[0.80293629, 0.19706371],
[0.70597215, 0.29402785],
[0.78886503, 0.21113497],
[0.83782249, 0.16217751],
[0.64071621, 0.35928379],
[0.55478368, 0.44521632],
[0.72582248, 0.27417752],
[0.72735549, 0.27264451],
[0.41067717, 0.58932283],
[0.73985824, 0.26014176],
[0.80230725, 0.19769275],
[0.55622092, 0.44377908],
[0.57896556, 0.42103444],
[0.53889636, 0.46110364],
[0.59658231, 0.40341769],
[0.39792269, 0.60207731],
[0.49937692, 0.50062308],
[0.64635048, 0.35364952],
[0.68599204, 0.31400796],
[0.51807675, 0.48192325],
[0.8703536 , 0.1296464 ],
[0.85954686, 0.14045314],
[0.65284485, 0.34715515],
[0.65996455, 0.34003545],
[0.48557889, 0.51442111],
[0.56908929, 0.43091071],
[0.57197412, 0.42802588],
[0.40093793, 0.59906207],
[0.78238201, 0.21761799],
[0.78172776, 0.21827224],
[0.52762582, 0.47237418],
[0.64004598, 0.35995402],
[0.63595194, 0.36404806],
[0.81272968, 0.18727032],
[0.65845878, 0.34154122],
[0.46100513, 0.53899487],
[0.64794141, 0.35205859],
[0.65863775, 0.34136225],
[0.58804217, 0.41195783],
[0.54575926, 0.45424074],
[0.70516645, 0.29483355],
[0.49049325, 0.50950675],
[0.81936864, 0.18063136],
[0.8174811 , 0.1825189 ],
[0.52762582, 0.47237418],
[0.66327264, 0.33672736]])
It returns the probability of each class. Since this is binary, each row will sum to 1.
Let’s make sure the classes are in the order we expect:
lg_mod.classes_
array([0, 1], dtype=int64)
Now we can just get the probability of 1 and save it:
lg_mod.decision_function(train_x)
array([-1.37039542e+00, -7.51729186e-01, -1.79561494e+00, -1.22170176e+00,
-1.24225261e+00, 1.28692834e-01, -7.54553462e-01, -4.38775908e-01,
8.98830850e-01, -6.15823364e-01, -1.31911185e+00, -9.10168099e-01,
-2.34143256e+00, 1.79976401e-01, 2.86117553e-01, -1.51802873e+00,
-2.09126306e-01, -2.20434038e+00, 6.78262841e-01, 2.70419364e-01,
-2.15165319e-01, -8.33394940e-01, -1.11221364e+00, -1.76497429e-01,
-6.71249784e-01, -1.24829162e+00, -1.29442793e-01, -1.71694562e+00,
-9.75425854e-01, -1.91193796e+00, -1.48354363e+00, -6.45455808e-01,
-7.55739813e-01, -1.39394270e+00, -7.39909397e-01, -1.35064045e+00,
-5.21369775e-01, -2.27335053e+00, -2.82882732e+00, -6.23718606e-01,
-2.25271360e+00, -1.43411629e+00, -4.87812784e-01, -1.04421769e+00,
-2.97673108e-01, -1.37912648e+00, -1.96423572e+00, -7.16494339e-01,
-2.20434038e+00, -6.88758055e-01, -1.84195977e+00, -7.92121079e-01,
-5.51798108e-01, -3.65755138e-01, -4.34256625e-03, 2.60541882e-01,
7.40305858e-01, -4.99632573e-01, -2.72158590e+00, -1.94148432e+00,
-4.27522197e-02, -1.40360188e+00, -1.22845058e+00, -1.76497429e-01,
-3.25449323e-01, -1.93147461e+00, -1.34768395e+00, -2.56888161e+00,
-1.13690735e+00, -1.06798328e+00, -6.05063915e-01, -4.63251306e-01,
4.89232410e-02, 3.12925722e-01, -3.77528780e-01, -4.75024948e-01,
-1.50312643e+00, -6.25833072e-01, -7.93089126e-01, -1.85889045e+00,
-1.00192964e+00, -2.19662351e+00, -7.54771766e-01, 7.44989119e-02,
-6.32535747e-01, 3.87670495e-01, -1.93270711e+00, -7.32192528e-01,
-1.81232732e+00, -2.04493342e+00, -1.15640764e-01, -5.48887754e-01,
-8.12976314e-01, -3.94373388e-01, -1.42352900e+00, -1.76126206e+00,
-1.68643121e+00, -4.44424458e-01, -1.82917193e+00, 1.67984454e-01,
-1.10617463e+00, -8.41462339e-01, -2.19459513e+00, -6.97621343e-01,
-9.94034397e-01, -5.74549504e-01, -1.59779833e+00, -1.01378936e+00,
-1.16349721e+00, -1.83309648e+00, -3.78456894e-01, -4.93765718e-01,
2.03437607e-01, 1.63804428e-02, -1.38195075e+00, -1.17126023e+00,
-1.15260554e+00, -1.52993460e+00, -7.76508969e-01, -1.80456431e+00,
2.46872080e-01, 7.49169146e-01, 2.52437308e-02, -1.09942581e+00,
-1.39491075e+00, -3.33212339e-01, -5.37996079e-01, -7.73466389e-01,
-3.64740944e-01, -1.90385537e-01, -1.28956548e+00, 7.77655171e-01,
-5.41920627e-01, -7.82375824e-01, -9.17221306e-01, -1.41863640e+00,
-1.75632332e+00, -8.79739767e-01, -1.03648817e-01, -1.77700640e+00,
-5.20355582e-01, -1.38706165e+00, -1.22249765e+00, -1.13386477e+00,
-2.06376027e+00, -7.35016804e-01, -1.46679131e+00, 1.99599138e-01,
-2.65348615e-01, -2.43288971e+00, 4.70270804e-02, -1.03544048e+00,
-9.07429903e-01, -8.61085076e-01, -7.12351487e-01, -1.96617803e+00,
-4.46538924e-01, -1.28352647e+00, -9.08358018e-01, -1.40271991e+00,
7.77655171e-01, 4.78199537e-01, 4.50727705e-01, 3.28537833e-01,
-4.85916623e-01, -1.46070615e+00, -9.11354451e-01, -1.04915643e+00,
-3.92345002e-01, -1.42533908e+00, -2.61259158e-02, -6.11102928e-01,
-4.46538924e-01, -1.33974878e+00, -5.35257883e-01, -6.46251697e-01,
-5.73403084e-01, -7.81229405e-01, -2.14825030e+00, -7.49919105e-01,
-2.74079677e-01, -1.82815774e+00, -1.21857310e+00, -9.96195010e-01,
5.63040096e-01, -8.03980800e-01, -3.33212339e-01, -2.37084670e+00,
-8.65937738e-01, 2.45725660e-01, -1.34063075e+00, -6.53304903e-01,
-3.96355628e-01, -2.56267022e-01, -1.06114838e+00, -7.37927158e-01,
-1.64886359e+00, -1.81828026e+00, -1.63927899e-01, -1.67849603e+00,
-1.30116697e+00, -1.07574630e+00, -7.84358064e-01, -5.84340906e-01,
-5.98229014e-01, -8.99790141e-02, -6.29625393e-01, -1.07887496e+00,
-1.45991026e+00, -8.30616812e-01, -1.60970419e+00, -1.91286607e+00,
-3.19410311e-01, 3.41070189e-02, -1.29543234e+00, -1.33281261e-01,
9.41677959e-02, -9.43897249e-01, -9.06415710e-01, 1.63045714e-01,
-1.88323363e+00, -6.43427422e-01, 6.56356238e-02, -1.71213910e+00,
-1.52477756e+00, -2.11773589e+00, -8.66156043e-01, -1.77003927e+00,
2.88278166e-01, -1.30319535e+00, -1.50793295e+00, -4.49449278e-01,
-6.79762541e-01, -2.49604278e-01, -1.10938937e+00, -5.66700409e-01,
-8.69862285e-01, -5.42006706e-01, -8.13772203e-01, -1.15965874e+00,
1.41480669e-01, -1.67748184e+00, -2.81132883e-01, -4.77053335e-01,
-8.94642068e-01, -1.80760689e+00, -6.61200145e-01, -1.73673154e-01,
-6.63096306e-01, -1.63025505e+00, -1.63638014e+00, -1.41378374e+00,
-1.12288701e+00, -1.65785910e+00, -1.83636714e-01, -1.62540238e+00,
-6.57361677e-01, -1.52697810e+00, 6.48630397e-01, -4.52359632e-01,
-5.44035092e-01, -6.27729232e-01, -2.30880368e+00, -2.97673108e-01,
-8.22635491e-01, -7.09613290e-01, -1.18223798e+00, 3.59052245e-01,
-1.49536342e+00, -1.92463971e+00, -2.54543019e-01, -9.28994948e-01,
-2.04290504e+00, -1.85774403e+00, -1.40474830e+00, -8.75901298e-01,
-1.31809766e+00, -1.64211477e+00, -5.78474051e-01, -2.20017981e-01,
-9.73529693e-01, -9.81246562e-01, 3.61166710e-01, -1.04523189e+00,
-1.40077760e+00, -2.25838689e-01, -3.18528343e-01, -1.55900431e-01,
-3.91244730e-01, 4.14128134e-01, 2.49233513e-03, -6.03035529e-01,
-7.81447709e-01, -7.23385168e-02, -1.90408886e+00, -1.81153143e+00,
-6.31567701e-01, -6.63136237e-01, 5.77004502e-02, -2.78136450e-01,
-2.89910092e-01, 4.01558603e-01, -1.27960192e+00, -1.27576345e+00,
-1.10615944e-01, -5.75563697e-01, -5.57837121e-01, -1.46784544e+00,
-6.56433562e-01, 1.56296891e-01, -6.10002656e-01, -6.57229451e-01,
-3.55877656e-01, -1.83550635e-01, -8.72022898e-01, 3.80315663e-02,
-1.51207580e+00, -1.49937404e+00, -1.10615944e-01, -6.77912528e-01])
train['score'] = lg_mod.predict_proba(train_x)[:, 1]
train
admit | gre | gpa | rank | score | |
---|---|---|---|---|---|
0 | 0 | 380 | 3.61 | 3 | 0.202556 |
1 | 1 | 660 | 3.67 | 3 | 0.320445 |
3 | 1 | 640 | 3.19 | 4 | 0.142386 |
7 | 0 | 400 | 3.08 | 2 | 0.227637 |
8 | 1 | 540 | 3.39 | 3 | 0.224044 |
... | ... | ... | ... | ... | ... |
395 | 0 | 620 | 4.00 | 2 | 0.509507 |
396 | 0 | 560 | 3.04 | 3 | 0.180631 |
397 | 0 | 460 | 2.63 | 2 | 0.182519 |
398 | 0 | 700 | 3.65 | 2 | 0.472374 |
399 | 0 | 600 | 3.89 | 3 | 0.336727 |
320 rows × 5 columns
We can get the log odds ratio itself with decision_function
:
lg_mod.decision_function(train_x)
array([-1.37039542e+00, -7.51729186e-01, -1.79561494e+00, -1.22170176e+00,
-1.24225261e+00, 1.28692834e-01, -7.54553462e-01, -4.38775908e-01,
8.98830850e-01, -6.15823364e-01, -1.31911185e+00, -9.10168099e-01,
-2.34143256e+00, 1.79976401e-01, 2.86117553e-01, -1.51802873e+00,
-2.09126306e-01, -2.20434038e+00, 6.78262841e-01, 2.70419364e-01,
-2.15165319e-01, -8.33394940e-01, -1.11221364e+00, -1.76497429e-01,
-6.71249784e-01, -1.24829162e+00, -1.29442793e-01, -1.71694562e+00,
-9.75425854e-01, -1.91193796e+00, -1.48354363e+00, -6.45455808e-01,
-7.55739813e-01, -1.39394270e+00, -7.39909397e-01, -1.35064045e+00,
-5.21369775e-01, -2.27335053e+00, -2.82882732e+00, -6.23718606e-01,
-2.25271360e+00, -1.43411629e+00, -4.87812784e-01, -1.04421769e+00,
-2.97673108e-01, -1.37912648e+00, -1.96423572e+00, -7.16494339e-01,
-2.20434038e+00, -6.88758055e-01, -1.84195977e+00, -7.92121079e-01,
-5.51798108e-01, -3.65755138e-01, -4.34256625e-03, 2.60541882e-01,
7.40305858e-01, -4.99632573e-01, -2.72158590e+00, -1.94148432e+00,
-4.27522197e-02, -1.40360188e+00, -1.22845058e+00, -1.76497429e-01,
-3.25449323e-01, -1.93147461e+00, -1.34768395e+00, -2.56888161e+00,
-1.13690735e+00, -1.06798328e+00, -6.05063915e-01, -4.63251306e-01,
4.89232410e-02, 3.12925722e-01, -3.77528780e-01, -4.75024948e-01,
-1.50312643e+00, -6.25833072e-01, -7.93089126e-01, -1.85889045e+00,
-1.00192964e+00, -2.19662351e+00, -7.54771766e-01, 7.44989119e-02,
-6.32535747e-01, 3.87670495e-01, -1.93270711e+00, -7.32192528e-01,
-1.81232732e+00, -2.04493342e+00, -1.15640764e-01, -5.48887754e-01,
-8.12976314e-01, -3.94373388e-01, -1.42352900e+00, -1.76126206e+00,
-1.68643121e+00, -4.44424458e-01, -1.82917193e+00, 1.67984454e-01,
-1.10617463e+00, -8.41462339e-01, -2.19459513e+00, -6.97621343e-01,
-9.94034397e-01, -5.74549504e-01, -1.59779833e+00, -1.01378936e+00,
-1.16349721e+00, -1.83309648e+00, -3.78456894e-01, -4.93765718e-01,
2.03437607e-01, 1.63804428e-02, -1.38195075e+00, -1.17126023e+00,
-1.15260554e+00, -1.52993460e+00, -7.76508969e-01, -1.80456431e+00,
2.46872080e-01, 7.49169146e-01, 2.52437308e-02, -1.09942581e+00,
-1.39491075e+00, -3.33212339e-01, -5.37996079e-01, -7.73466389e-01,
-3.64740944e-01, -1.90385537e-01, -1.28956548e+00, 7.77655171e-01,
-5.41920627e-01, -7.82375824e-01, -9.17221306e-01, -1.41863640e+00,
-1.75632332e+00, -8.79739767e-01, -1.03648817e-01, -1.77700640e+00,
-5.20355582e-01, -1.38706165e+00, -1.22249765e+00, -1.13386477e+00,
-2.06376027e+00, -7.35016804e-01, -1.46679131e+00, 1.99599138e-01,
-2.65348615e-01, -2.43288971e+00, 4.70270804e-02, -1.03544048e+00,
-9.07429903e-01, -8.61085076e-01, -7.12351487e-01, -1.96617803e+00,
-4.46538924e-01, -1.28352647e+00, -9.08358018e-01, -1.40271991e+00,
7.77655171e-01, 4.78199537e-01, 4.50727705e-01, 3.28537833e-01,
-4.85916623e-01, -1.46070615e+00, -9.11354451e-01, -1.04915643e+00,
-3.92345002e-01, -1.42533908e+00, -2.61259158e-02, -6.11102928e-01,
-4.46538924e-01, -1.33974878e+00, -5.35257883e-01, -6.46251697e-01,
-5.73403084e-01, -7.81229405e-01, -2.14825030e+00, -7.49919105e-01,
-2.74079677e-01, -1.82815774e+00, -1.21857310e+00, -9.96195010e-01,
5.63040096e-01, -8.03980800e-01, -3.33212339e-01, -2.37084670e+00,
-8.65937738e-01, 2.45725660e-01, -1.34063075e+00, -6.53304903e-01,
-3.96355628e-01, -2.56267022e-01, -1.06114838e+00, -7.37927158e-01,
-1.64886359e+00, -1.81828026e+00, -1.63927899e-01, -1.67849603e+00,
-1.30116697e+00, -1.07574630e+00, -7.84358064e-01, -5.84340906e-01,
-5.98229014e-01, -8.99790141e-02, -6.29625393e-01, -1.07887496e+00,
-1.45991026e+00, -8.30616812e-01, -1.60970419e+00, -1.91286607e+00,
-3.19410311e-01, 3.41070189e-02, -1.29543234e+00, -1.33281261e-01,
9.41677959e-02, -9.43897249e-01, -9.06415710e-01, 1.63045714e-01,
-1.88323363e+00, -6.43427422e-01, 6.56356238e-02, -1.71213910e+00,
-1.52477756e+00, -2.11773589e+00, -8.66156043e-01, -1.77003927e+00,
2.88278166e-01, -1.30319535e+00, -1.50793295e+00, -4.49449278e-01,
-6.79762541e-01, -2.49604278e-01, -1.10938937e+00, -5.66700409e-01,
-8.69862285e-01, -5.42006706e-01, -8.13772203e-01, -1.15965874e+00,
1.41480669e-01, -1.67748184e+00, -2.81132883e-01, -4.77053335e-01,
-8.94642068e-01, -1.80760689e+00, -6.61200145e-01, -1.73673154e-01,
-6.63096306e-01, -1.63025505e+00, -1.63638014e+00, -1.41378374e+00,
-1.12288701e+00, -1.65785910e+00, -1.83636714e-01, -1.62540238e+00,
-6.57361677e-01, -1.52697810e+00, 6.48630397e-01, -4.52359632e-01,
-5.44035092e-01, -6.27729232e-01, -2.30880368e+00, -2.97673108e-01,
-8.22635491e-01, -7.09613290e-01, -1.18223798e+00, 3.59052245e-01,
-1.49536342e+00, -1.92463971e+00, -2.54543019e-01, -9.28994948e-01,
-2.04290504e+00, -1.85774403e+00, -1.40474830e+00, -8.75901298e-01,
-1.31809766e+00, -1.64211477e+00, -5.78474051e-01, -2.20017981e-01,
-9.73529693e-01, -9.81246562e-01, 3.61166710e-01, -1.04523189e+00,
-1.40077760e+00, -2.25838689e-01, -3.18528343e-01, -1.55900431e-01,
-3.91244730e-01, 4.14128134e-01, 2.49233513e-03, -6.03035529e-01,
-7.81447709e-01, -7.23385168e-02, -1.90408886e+00, -1.81153143e+00,
-6.31567701e-01, -6.63136237e-01, 5.77004502e-02, -2.78136450e-01,
-2.89910092e-01, 4.01558603e-01, -1.27960192e+00, -1.27576345e+00,
-1.10615944e-01, -5.75563697e-01, -5.57837121e-01, -1.46784544e+00,
-6.56433562e-01, 1.56296891e-01, -6.10002656e-01, -6.57229451e-01,
-3.55877656e-01, -1.83550635e-01, -8.72022898e-01, 3.80315663e-02,
-1.51207580e+00, -1.49937404e+00, -1.10615944e-01, -6.77912528e-01])
Analysis with Test Data#
Now we’ll predict our test data:
test_x = test[feat_cols]
test_y = test[out_col]
test_d = lg_mod.predict(test_x)
What’s the accuracy?
np.mean(test_y == test_d)
0.7
Practice: compute the precision and recall.
ROC Curve#
We can plot the receiver operating characteristic by using the SciKit roc_curve
function to compare outcomes and the decision function. It returns the tpr
and fpr
:
test_lo = lg_mod.decision_function(test_x)
fpr, tpr, thresh = sklearn.metrics.roc_curve(test_y, test_lo)
plt.plot(np.linspace(0, 1), np.linspace(0, 1), color='grey', linestyle=':')
plt.plot(fpr, tpr)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC')
plt.show()
What’s the area under that curve?
sklearn.metrics.roc_auc_score(test_y, test_lo)
0.7170909090909092
I’m going to fill it now for the slide demo:
plt.fill_between(fpr, tpr, 0, color='lightgrey')
plt.plot(np.linspace(0, 1), np.linspace(0, 1), color='red', linestyle=':')
plt.plot(fpr, tpr)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC')
plt.show()
Let’s see a precision-recall curve:
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(test_y, test_lo)
plt.plot(recall, precision)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.show()
Precision vs. Threshold#
What if we decrease the threshold?
def precision_at_threshold(thresh):
mask = test_lo >= thresh
# mask will select for positive class
# mean is fraction
return np.mean(test_y[mask])
thresh = np.linspace(-2, 1, 20)
precs = [precision_at_threshold(p) for p in thresh]
plt.plot(thresh, precs)
plt.xlim(1, -2)
plt.xlabel('Threshold (Log Odds)')
plt.ylabel('Precision')
Text(0, 0.5, 'Precision')