SciKit Logistic Regression Demo¶

This notebook shows the logistic regression demo with SciKit-Learn’s LogisticRegression instead of StatsModels.

Setup¶

Let’s import some packages:

import pandas as pd
import seaborn as sns
import numpy as np
from scipy.special import logit
import matplotlib.pyplot as plt
import sklearn
from sklearn.linear_model import LogisticRegression
rng = np.random.RandomState(20201024)

Read Data¶

Load the UCLA grad admissions data.

students = pd.read_csv("https://stats.idre.ucla.edu/stat/data/binary.csv")
students.head()
admit gre gpa rank
0 0 380 3.61 3
1 1 660 3.67 3
2 1 800 4.00 1
3 1 640 3.19 4
4 0 520 2.93 4

Let’s train and test:

test = students.sample(frac=0.2, random_state=rng)
train_mask = pd.Series(True, index=students.index)
train_mask[test.index] = False
train = students[train_mask].copy()
train.head()
admit gre gpa rank
0 0 380 3.61 3
1 1 660 3.67 3
3 1 640 3.19 4
7 0 400 3.08 2
8 1 540 3.39 3

Train the Model¶

Let’s create our logistic regression model. penalty='none' will turn off regularization:

lg_mod = LogisticRegression(penalty='none')

Now we need to train (or fit) the model. SciKit-Learn doesn’t know about Pandas; it works on NumPy arrays and matrices. Fortunately, a Pandas data frame is a matrix, but we need to separate out the predictors and outcome:

feat_cols = ['gre', 'gpa', 'rank']
out_col = 'admit'
train_x = train[feat_cols]
train_y = train[out_col]

Now we can call fit. Unlike Statsmodels, SciKit-Learn doesn’t have separate model and results objects. We just call fit, which trains the model in-place:

lg_mod.fit(train_x, train_y)
LogisticRegression(penalty='none')

We can get the coefficients from it:

lg_mod.coef_
array([[ 0.00201959,  0.88632881, -0.57805603]])

And the intercept:

lg_mod.intercept_
array([-3.60332028])

The SciKit-Learn convention is that parameters estimated by the fit process are stored in fields ending in _.

Let’s quick train one with statsmodels just to see that it’s training the same thing:

import statsmodels.api as sm
import statsmodels.formula.api as smf
smf.glm('admit ~ gre + gpa + rank', train, family=sm.families.Binomial()).fit().summary()
Generalized Linear Model Regression Results
Dep. Variable: admit No. Observations: 320
Model: GLM Df Residuals: 316
Model Family: Binomial Df Model: 3
Link Function: logit Scale: 1.0000
Method: IRLS Log-Likelihood: -184.48
Date: Sat, 24 Oct 2020 Deviance: 368.97
Time: 21:11:55 Pearson chi2: 318.
No. Iterations: 4
Covariance Type: nonrobust
coef std err z P>|z| [0.025 0.975]
Intercept -3.6033 1.286 -2.802 0.005 -6.124 -1.083
gre 0.0020 0.001 1.620 0.105 -0.000 0.004
gpa 0.8863 0.372 2.384 0.017 0.158 1.615
rank -0.5781 0.144 -4.022 0.000 -0.860 -0.296

Let’s compute our training set accuracy. SciKit-Learn’s predict method outputs the class, not the score, so it is directly 1/0. Let’s compute that, and compare to train_y:

train_d = lg_mod.predict(train_x)
np.mean(train_d == train_y)
0.7125

It’s using a probability of 0.5 - more likely than not - as the decision boundary. More formally, it is using the log odds: \(\operatorname{log} \operatorname{Odds}(A) > 0\)

It returns the positive class (1 in our case) if that quantity is greater than 0.

If we want the probabilities, we can call predict_proba:

lg_mod.predict_proba(train_x)
array([[0.79744403, 0.20255597],
       [0.67955536, 0.32044464],
       [0.85761431, 0.14238569],
       [0.77236289, 0.22763711],
       [0.77595587, 0.22404413],
       [0.46787112, 0.53212888],
       [0.68017007, 0.31982993],
       [0.60796732, 0.39203268],
       [0.28929082, 0.71070918],
       [0.64926804, 0.35073196],
       [0.7890339 , 0.2109661 ],
       [0.71303456, 0.28696544],
       [0.91225083, 0.08774917],
       [0.45512696, 0.54487304],
       [0.42895462, 0.57104538],
       [0.82024802, 0.17975198],
       [0.55209187, 0.44790813],
       [0.9006386 , 0.0993614 ],
       [0.33664913, 0.66335087],
       [0.43280414, 0.56719586],
       [0.55358476, 0.44641524],
       [0.69707229, 0.30292771],
       [0.75254157, 0.24745843],
       [0.54401017, 0.45598983],
       [0.66178295, 0.33821705],
       [0.77700399, 0.22299601],
       [0.53231559, 0.46768441],
       [0.84773499, 0.15226501],
       [0.72619966, 0.27380034],
       [0.87123671, 0.12876329],
       [0.81510723, 0.18489277],
       [0.65598571, 0.34401429],
       [0.68042809, 0.31957191],
       [0.80122093, 0.19877907],
       [0.67697604, 0.32302396],
       [0.79423431, 0.20576569],
       [0.62746801, 0.37253199],
       [0.90664576, 0.09335424],
       [0.94421386, 0.05578614],
       [0.65106381, 0.34893619],
       [0.90488435, 0.09511565],
       [0.80754187, 0.19245813],
       [0.61959104, 0.38040896],
       [0.73966299, 0.26033701],
       [0.57387359, 0.42612641],
       [0.79885067, 0.20114933],
       [0.87699062, 0.12300938],
       [0.67183458, 0.32816542],
       [0.9006386 , 0.0993614 ],
       [0.66569059, 0.33430941],
       [0.86318032, 0.13681968],
       [0.68828659, 0.31171341],
       [0.63455267, 0.36544733],
       [0.59043287, 0.40956713],
       [0.50108564, 0.49891436],
       [0.43523051, 0.56476949],
       [0.32293726, 0.67706274],
       [0.62237298, 0.37762702],
       [0.93828843, 0.06171157],
       [0.87451512, 0.12548488],
       [0.51068643, 0.48931357],
       [0.80275483, 0.19724517],
       [0.77354727, 0.22645273],
       [0.54401017, 0.45598983],
       [0.58065172, 0.41934828],
       [0.87341255, 0.12658745],
       [0.79375072, 0.20624928],
       [0.9288318 , 0.0711682 ],
       [0.75711137, 0.24288863],
       [0.7442132 , 0.2557868 ],
       [0.64681399, 0.35318601],
       [0.61378519, 0.38621481],
       [0.48777163, 0.51222837],
       [0.42240076, 0.57759924],
       [0.59327694, 0.40672306],
       [0.6165724 , 0.3834276 ],
       [0.81804031, 0.18195969],
       [0.65154402, 0.34845598],
       [0.68849424, 0.31150576],
       [0.86516757, 0.13483243],
       [0.7314378 , 0.2685622 ],
       [0.89994589, 0.10005411],
       [0.68021755, 0.31978245],
       [0.48138388, 0.51861612],
       [0.65306421, 0.34693579],
       [0.40427821, 0.59572179],
       [0.87354875, 0.12645125],
       [0.67528622, 0.32471378],
       [0.85964292, 0.14035708],
       [0.88543467, 0.11456533],
       [0.52887802, 0.47112198],
       [0.6338775 , 0.3661225 ],
       [0.69274338, 0.30725662],
       [0.59733506, 0.40266494],
       [0.80589106, 0.19410894],
       [0.85336765, 0.14663235],
       [0.84375425, 0.15624575],
       [0.60931278, 0.39068722],
       [0.86166305, 0.13833695],
       [0.45810237, 0.54189763],
       [0.75141526, 0.24858474],
       [0.69877311, 0.30122689],
       [0.8997631 , 0.1002369 ],
       [0.66766018, 0.33233982],
       [0.72988405, 0.27011595],
       [0.63981229, 0.36018771],
       [0.83171045, 0.16828955],
       [0.73376108, 0.26623892],
       [0.7619676 , 0.2380324 ],
       [0.86213019, 0.13786981],
       [0.59350087, 0.40649913],
       [0.62099314, 0.37900686],
       [0.44931528, 0.55068472],
       [0.49590498, 0.50409502],
       [0.79930412, 0.20069588],
       [0.76337273, 0.23662727],
       [0.75998651, 0.24001349],
       [0.82199675, 0.17800325],
       [0.68492723, 0.31507277],
       [0.85870364, 0.14129636],
       [0.43859354, 0.56140646],
       [0.32100237, 0.67899763],
       [0.4936894 , 0.5063106 ],
       [0.7501525 , 0.2498475 ],
       [0.80137506, 0.19862494],
       [0.58254078, 0.41745922],
       [0.63134613, 0.36865387],
       [0.68427027, 0.31572973],
       [0.5901876 , 0.4098124 ],
       [0.54745314, 0.45254686],
       [0.78407363, 0.21592637],
       [0.31482547, 0.68517453],
       [0.63225909, 0.36774091],
       [0.68619193, 0.31380807],
       [0.71447559, 0.28552441],
       [0.80512456, 0.19487544],
       [0.85274858, 0.14725142],
       [0.70676829, 0.29323171],
       [0.52588903, 0.47411097],
       [0.85532682, 0.14467318],
       [0.62723091, 0.37276909],
       [0.80012274, 0.19987726],
       [0.77250279, 0.22749721],
       [0.75655142, 0.24344858],
       [0.88733065, 0.11266935],
       [0.67590521, 0.32409479],
       [0.81256919, 0.18743081],
       [0.45026522, 0.54973478],
       [0.56595064, 0.43404936],
       [0.91930117, 0.08069883],
       [0.4882454 , 0.5117546 ],
       [0.73796929, 0.26203071],
       [0.71247395, 0.28752605],
       [0.70288731, 0.29711269],
       [0.67092054, 0.32907946],
       [0.8772    , 0.1228    ],
       [0.60981602, 0.39018398],
       [0.78304946, 0.21695054],
       [0.71266404, 0.28733596],
       [0.80261514, 0.19738486],
       [0.31482547, 0.68517453],
       [0.38267737, 0.61732263],
       [0.38918776, 0.61081224],
       [0.41859643, 0.58140357],
       [0.61914402, 0.38085598],
       [0.81164066, 0.18835934],
       [0.71327725, 0.28672275],
       [0.74061288, 0.25938712],
       [0.59684708, 0.40315292],
       [0.80617405, 0.19382595],
       [0.50653111, 0.49346889],
       [0.64819235, 0.35180765],
       [0.60981602, 0.39018398],
       [0.79244863, 0.20755137],
       [0.63070859, 0.36929141],
       [0.65616529, 0.34383471],
       [0.63954805, 0.36045195],
       [0.68594502, 0.31405498],
       [0.89550516, 0.10449484],
       [0.67916107, 0.32083893],
       [0.56809418, 0.43190582],
       [0.86154211, 0.13845789],
       [0.77181234, 0.22818766],
       [0.73030981, 0.26969019],
       [0.36284433, 0.63715567],
       [0.69082537, 0.30917463],
       [0.58254078, 0.41745922],
       [0.91457703, 0.08542297],
       [0.70389972, 0.29610028],
       [0.43887584, 0.56112416],
       [0.79259365, 0.20740635],
       [0.65775483, 0.34224517],
       [0.59781175, 0.40218825],
       [0.56371842, 0.43628158],
       [0.74290994, 0.25709006],
       [0.67654242, 0.32345758],
       [0.8387374 , 0.1612626 ],
       [0.86035964, 0.13964036],
       [0.54089045, 0.45910955],
       [0.84270528, 0.15729472],
       [0.78603132, 0.21396868],
       [0.74568817, 0.25431183],
       [0.68661862, 0.31338138],
       [0.64206564, 0.35793436],
       [0.64525103, 0.35474897],
       [0.52247959, 0.47752041],
       [0.65240452, 0.34759548],
       [0.74628102, 0.25371898],
       [0.81151895, 0.18848105],
       [0.69648534, 0.30351466],
       [0.83337031, 0.16662969],
       [0.87134079, 0.12865921],
       [0.57918053, 0.42081947],
       [0.49147407, 0.50852593],
       [0.78506525, 0.21493475],
       [0.53327108, 0.46672892],
       [0.47647543, 0.52352457],
       [0.71988621, 0.28011379],
       [0.71226615, 0.28773385],
       [0.45932863, 0.54067137],
       [0.86798211, 0.13201789],
       [0.65552782, 0.34447218],
       [0.48359698, 0.51640302],
       [0.84711353, 0.15288647],
       [0.82124092, 0.17875908],
       [0.8926151 , 0.1073849 ],
       [0.70394522, 0.29605478],
       [0.85446255, 0.14553745],
       [0.42842545, 0.57157455],
       [0.78637226, 0.21362774],
       [0.81875467, 0.18124533],
       [0.61050829, 0.38949171],
       [0.6636857 , 0.3363143 ],
       [0.5620791 , 0.4379209 ],
       [0.75201525, 0.24798475],
       [0.63800146, 0.36199854],
       [0.70471704, 0.29528296],
       [0.63227911, 0.36772089],
       [0.69291275, 0.30708725],
       [0.7612707 , 0.2387293 ],
       [0.46468871, 0.53531129],
       [0.8425708 , 0.1574292 ],
       [0.56982394, 0.43017606],
       [0.61705182, 0.38294818],
       [0.7098472 , 0.2901528 ],
       [0.8590724 , 0.1409276 ],
       [0.65952993, 0.34047007],
       [0.54330948, 0.45669052],
       [0.65995559, 0.34004441],
       [0.83620457, 0.16379543],
       [0.83704178, 0.16295822],
       [0.80436205, 0.19563795],
       [0.75452383, 0.24547617],
       [0.8399504 , 0.1600496 ],
       [0.5457806 , 0.4542194 ],
       [0.83553884, 0.16446116],
       [0.65866748, 0.34133252],
       [0.82156374, 0.17843626],
       [0.34329824, 0.65670176],
       [0.61120011, 0.38879989],
       [0.63275058, 0.36724942],
       [0.65197439, 0.34802561],
       [0.90960354, 0.09039646],
       [0.57387359, 0.42612641],
       [0.6947955 , 0.3052045 ],
       [0.67031571, 0.32968429],
       [0.76534996, 0.23465004],
       [0.41118901, 0.58881099],
       [0.81688193, 0.18311807],
       [0.87265493, 0.12734507],
       [0.56329437, 0.43670563],
       [0.71687134, 0.28312866],
       [0.88522875, 0.11477125],
       [0.86503378, 0.13496622],
       [0.80293629, 0.19706371],
       [0.70597215, 0.29402785],
       [0.78886503, 0.21113497],
       [0.83782249, 0.16217751],
       [0.64071621, 0.35928379],
       [0.55478368, 0.44521632],
       [0.72582248, 0.27417752],
       [0.72735549, 0.27264451],
       [0.41067717, 0.58932283],
       [0.73985824, 0.26014176],
       [0.80230725, 0.19769275],
       [0.55622092, 0.44377908],
       [0.57896556, 0.42103444],
       [0.53889636, 0.46110364],
       [0.59658231, 0.40341769],
       [0.39792269, 0.60207731],
       [0.49937692, 0.50062308],
       [0.64635048, 0.35364952],
       [0.68599204, 0.31400796],
       [0.51807675, 0.48192325],
       [0.8703536 , 0.1296464 ],
       [0.85954686, 0.14045314],
       [0.65284485, 0.34715515],
       [0.65996455, 0.34003545],
       [0.48557889, 0.51442111],
       [0.56908929, 0.43091071],
       [0.57197412, 0.42802588],
       [0.40093793, 0.59906207],
       [0.78238201, 0.21761799],
       [0.78172776, 0.21827224],
       [0.52762582, 0.47237418],
       [0.64004598, 0.35995402],
       [0.63595194, 0.36404806],
       [0.81272968, 0.18727032],
       [0.65845878, 0.34154122],
       [0.46100513, 0.53899487],
       [0.64794141, 0.35205859],
       [0.65863775, 0.34136225],
       [0.58804217, 0.41195783],
       [0.54575926, 0.45424074],
       [0.70516645, 0.29483355],
       [0.49049325, 0.50950675],
       [0.81936864, 0.18063136],
       [0.8174811 , 0.1825189 ],
       [0.52762582, 0.47237418],
       [0.66327264, 0.33672736]])

It returns the probability of each class. Since this is binary, each row will sum to 1.

Let’s make sure the classes are in the order we expect:

lg_mod.classes_
array([0, 1], dtype=int64)

Now we can just get the probability of 1 and save it:

lg_mod.decision_function(train_x)
array([-1.37039542e+00, -7.51729186e-01, -1.79561494e+00, -1.22170176e+00,
       -1.24225261e+00,  1.28692834e-01, -7.54553462e-01, -4.38775908e-01,
        8.98830850e-01, -6.15823364e-01, -1.31911185e+00, -9.10168099e-01,
       -2.34143256e+00,  1.79976401e-01,  2.86117553e-01, -1.51802873e+00,
       -2.09126306e-01, -2.20434038e+00,  6.78262841e-01,  2.70419364e-01,
       -2.15165319e-01, -8.33394940e-01, -1.11221364e+00, -1.76497429e-01,
       -6.71249784e-01, -1.24829162e+00, -1.29442793e-01, -1.71694562e+00,
       -9.75425854e-01, -1.91193796e+00, -1.48354363e+00, -6.45455808e-01,
       -7.55739813e-01, -1.39394270e+00, -7.39909397e-01, -1.35064045e+00,
       -5.21369775e-01, -2.27335053e+00, -2.82882732e+00, -6.23718606e-01,
       -2.25271360e+00, -1.43411629e+00, -4.87812784e-01, -1.04421769e+00,
       -2.97673108e-01, -1.37912648e+00, -1.96423572e+00, -7.16494339e-01,
       -2.20434038e+00, -6.88758055e-01, -1.84195977e+00, -7.92121079e-01,
       -5.51798108e-01, -3.65755138e-01, -4.34256625e-03,  2.60541882e-01,
        7.40305858e-01, -4.99632573e-01, -2.72158590e+00, -1.94148432e+00,
       -4.27522197e-02, -1.40360188e+00, -1.22845058e+00, -1.76497429e-01,
       -3.25449323e-01, -1.93147461e+00, -1.34768395e+00, -2.56888161e+00,
       -1.13690735e+00, -1.06798328e+00, -6.05063915e-01, -4.63251306e-01,
        4.89232410e-02,  3.12925722e-01, -3.77528780e-01, -4.75024948e-01,
       -1.50312643e+00, -6.25833072e-01, -7.93089126e-01, -1.85889045e+00,
       -1.00192964e+00, -2.19662351e+00, -7.54771766e-01,  7.44989119e-02,
       -6.32535747e-01,  3.87670495e-01, -1.93270711e+00, -7.32192528e-01,
       -1.81232732e+00, -2.04493342e+00, -1.15640764e-01, -5.48887754e-01,
       -8.12976314e-01, -3.94373388e-01, -1.42352900e+00, -1.76126206e+00,
       -1.68643121e+00, -4.44424458e-01, -1.82917193e+00,  1.67984454e-01,
       -1.10617463e+00, -8.41462339e-01, -2.19459513e+00, -6.97621343e-01,
       -9.94034397e-01, -5.74549504e-01, -1.59779833e+00, -1.01378936e+00,
       -1.16349721e+00, -1.83309648e+00, -3.78456894e-01, -4.93765718e-01,
        2.03437607e-01,  1.63804428e-02, -1.38195075e+00, -1.17126023e+00,
       -1.15260554e+00, -1.52993460e+00, -7.76508969e-01, -1.80456431e+00,
        2.46872080e-01,  7.49169146e-01,  2.52437308e-02, -1.09942581e+00,
       -1.39491075e+00, -3.33212339e-01, -5.37996079e-01, -7.73466389e-01,
       -3.64740944e-01, -1.90385537e-01, -1.28956548e+00,  7.77655171e-01,
       -5.41920627e-01, -7.82375824e-01, -9.17221306e-01, -1.41863640e+00,
       -1.75632332e+00, -8.79739767e-01, -1.03648817e-01, -1.77700640e+00,
       -5.20355582e-01, -1.38706165e+00, -1.22249765e+00, -1.13386477e+00,
       -2.06376027e+00, -7.35016804e-01, -1.46679131e+00,  1.99599138e-01,
       -2.65348615e-01, -2.43288971e+00,  4.70270804e-02, -1.03544048e+00,
       -9.07429903e-01, -8.61085076e-01, -7.12351487e-01, -1.96617803e+00,
       -4.46538924e-01, -1.28352647e+00, -9.08358018e-01, -1.40271991e+00,
        7.77655171e-01,  4.78199537e-01,  4.50727705e-01,  3.28537833e-01,
       -4.85916623e-01, -1.46070615e+00, -9.11354451e-01, -1.04915643e+00,
       -3.92345002e-01, -1.42533908e+00, -2.61259158e-02, -6.11102928e-01,
       -4.46538924e-01, -1.33974878e+00, -5.35257883e-01, -6.46251697e-01,
       -5.73403084e-01, -7.81229405e-01, -2.14825030e+00, -7.49919105e-01,
       -2.74079677e-01, -1.82815774e+00, -1.21857310e+00, -9.96195010e-01,
        5.63040096e-01, -8.03980800e-01, -3.33212339e-01, -2.37084670e+00,
       -8.65937738e-01,  2.45725660e-01, -1.34063075e+00, -6.53304903e-01,
       -3.96355628e-01, -2.56267022e-01, -1.06114838e+00, -7.37927158e-01,
       -1.64886359e+00, -1.81828026e+00, -1.63927899e-01, -1.67849603e+00,
       -1.30116697e+00, -1.07574630e+00, -7.84358064e-01, -5.84340906e-01,
       -5.98229014e-01, -8.99790141e-02, -6.29625393e-01, -1.07887496e+00,
       -1.45991026e+00, -8.30616812e-01, -1.60970419e+00, -1.91286607e+00,
       -3.19410311e-01,  3.41070189e-02, -1.29543234e+00, -1.33281261e-01,
        9.41677959e-02, -9.43897249e-01, -9.06415710e-01,  1.63045714e-01,
       -1.88323363e+00, -6.43427422e-01,  6.56356238e-02, -1.71213910e+00,
       -1.52477756e+00, -2.11773589e+00, -8.66156043e-01, -1.77003927e+00,
        2.88278166e-01, -1.30319535e+00, -1.50793295e+00, -4.49449278e-01,
       -6.79762541e-01, -2.49604278e-01, -1.10938937e+00, -5.66700409e-01,
       -8.69862285e-01, -5.42006706e-01, -8.13772203e-01, -1.15965874e+00,
        1.41480669e-01, -1.67748184e+00, -2.81132883e-01, -4.77053335e-01,
       -8.94642068e-01, -1.80760689e+00, -6.61200145e-01, -1.73673154e-01,
       -6.63096306e-01, -1.63025505e+00, -1.63638014e+00, -1.41378374e+00,
       -1.12288701e+00, -1.65785910e+00, -1.83636714e-01, -1.62540238e+00,
       -6.57361677e-01, -1.52697810e+00,  6.48630397e-01, -4.52359632e-01,
       -5.44035092e-01, -6.27729232e-01, -2.30880368e+00, -2.97673108e-01,
       -8.22635491e-01, -7.09613290e-01, -1.18223798e+00,  3.59052245e-01,
       -1.49536342e+00, -1.92463971e+00, -2.54543019e-01, -9.28994948e-01,
       -2.04290504e+00, -1.85774403e+00, -1.40474830e+00, -8.75901298e-01,
       -1.31809766e+00, -1.64211477e+00, -5.78474051e-01, -2.20017981e-01,
       -9.73529693e-01, -9.81246562e-01,  3.61166710e-01, -1.04523189e+00,
       -1.40077760e+00, -2.25838689e-01, -3.18528343e-01, -1.55900431e-01,
       -3.91244730e-01,  4.14128134e-01,  2.49233513e-03, -6.03035529e-01,
       -7.81447709e-01, -7.23385168e-02, -1.90408886e+00, -1.81153143e+00,
       -6.31567701e-01, -6.63136237e-01,  5.77004502e-02, -2.78136450e-01,
       -2.89910092e-01,  4.01558603e-01, -1.27960192e+00, -1.27576345e+00,
       -1.10615944e-01, -5.75563697e-01, -5.57837121e-01, -1.46784544e+00,
       -6.56433562e-01,  1.56296891e-01, -6.10002656e-01, -6.57229451e-01,
       -3.55877656e-01, -1.83550635e-01, -8.72022898e-01,  3.80315663e-02,
       -1.51207580e+00, -1.49937404e+00, -1.10615944e-01, -6.77912528e-01])
train['score'] = lg_mod.predict_proba(train_x)[:, 1]
train
admit gre gpa rank score
0 0 380 3.61 3 0.202556
1 1 660 3.67 3 0.320445
3 1 640 3.19 4 0.142386
7 0 400 3.08 2 0.227637
8 1 540 3.39 3 0.224044
... ... ... ... ... ...
395 0 620 4.00 2 0.509507
396 0 560 3.04 3 0.180631
397 0 460 2.63 2 0.182519
398 0 700 3.65 2 0.472374
399 0 600 3.89 3 0.336727

320 rows × 5 columns

We can get the log odds ratio itself with decision_function:

lg_mod.decision_function(train_x)
array([-1.37039542e+00, -7.51729186e-01, -1.79561494e+00, -1.22170176e+00,
       -1.24225261e+00,  1.28692834e-01, -7.54553462e-01, -4.38775908e-01,
        8.98830850e-01, -6.15823364e-01, -1.31911185e+00, -9.10168099e-01,
       -2.34143256e+00,  1.79976401e-01,  2.86117553e-01, -1.51802873e+00,
       -2.09126306e-01, -2.20434038e+00,  6.78262841e-01,  2.70419364e-01,
       -2.15165319e-01, -8.33394940e-01, -1.11221364e+00, -1.76497429e-01,
       -6.71249784e-01, -1.24829162e+00, -1.29442793e-01, -1.71694562e+00,
       -9.75425854e-01, -1.91193796e+00, -1.48354363e+00, -6.45455808e-01,
       -7.55739813e-01, -1.39394270e+00, -7.39909397e-01, -1.35064045e+00,
       -5.21369775e-01, -2.27335053e+00, -2.82882732e+00, -6.23718606e-01,
       -2.25271360e+00, -1.43411629e+00, -4.87812784e-01, -1.04421769e+00,
       -2.97673108e-01, -1.37912648e+00, -1.96423572e+00, -7.16494339e-01,
       -2.20434038e+00, -6.88758055e-01, -1.84195977e+00, -7.92121079e-01,
       -5.51798108e-01, -3.65755138e-01, -4.34256625e-03,  2.60541882e-01,
        7.40305858e-01, -4.99632573e-01, -2.72158590e+00, -1.94148432e+00,
       -4.27522197e-02, -1.40360188e+00, -1.22845058e+00, -1.76497429e-01,
       -3.25449323e-01, -1.93147461e+00, -1.34768395e+00, -2.56888161e+00,
       -1.13690735e+00, -1.06798328e+00, -6.05063915e-01, -4.63251306e-01,
        4.89232410e-02,  3.12925722e-01, -3.77528780e-01, -4.75024948e-01,
       -1.50312643e+00, -6.25833072e-01, -7.93089126e-01, -1.85889045e+00,
       -1.00192964e+00, -2.19662351e+00, -7.54771766e-01,  7.44989119e-02,
       -6.32535747e-01,  3.87670495e-01, -1.93270711e+00, -7.32192528e-01,
       -1.81232732e+00, -2.04493342e+00, -1.15640764e-01, -5.48887754e-01,
       -8.12976314e-01, -3.94373388e-01, -1.42352900e+00, -1.76126206e+00,
       -1.68643121e+00, -4.44424458e-01, -1.82917193e+00,  1.67984454e-01,
       -1.10617463e+00, -8.41462339e-01, -2.19459513e+00, -6.97621343e-01,
       -9.94034397e-01, -5.74549504e-01, -1.59779833e+00, -1.01378936e+00,
       -1.16349721e+00, -1.83309648e+00, -3.78456894e-01, -4.93765718e-01,
        2.03437607e-01,  1.63804428e-02, -1.38195075e+00, -1.17126023e+00,
       -1.15260554e+00, -1.52993460e+00, -7.76508969e-01, -1.80456431e+00,
        2.46872080e-01,  7.49169146e-01,  2.52437308e-02, -1.09942581e+00,
       -1.39491075e+00, -3.33212339e-01, -5.37996079e-01, -7.73466389e-01,
       -3.64740944e-01, -1.90385537e-01, -1.28956548e+00,  7.77655171e-01,
       -5.41920627e-01, -7.82375824e-01, -9.17221306e-01, -1.41863640e+00,
       -1.75632332e+00, -8.79739767e-01, -1.03648817e-01, -1.77700640e+00,
       -5.20355582e-01, -1.38706165e+00, -1.22249765e+00, -1.13386477e+00,
       -2.06376027e+00, -7.35016804e-01, -1.46679131e+00,  1.99599138e-01,
       -2.65348615e-01, -2.43288971e+00,  4.70270804e-02, -1.03544048e+00,
       -9.07429903e-01, -8.61085076e-01, -7.12351487e-01, -1.96617803e+00,
       -4.46538924e-01, -1.28352647e+00, -9.08358018e-01, -1.40271991e+00,
        7.77655171e-01,  4.78199537e-01,  4.50727705e-01,  3.28537833e-01,
       -4.85916623e-01, -1.46070615e+00, -9.11354451e-01, -1.04915643e+00,
       -3.92345002e-01, -1.42533908e+00, -2.61259158e-02, -6.11102928e-01,
       -4.46538924e-01, -1.33974878e+00, -5.35257883e-01, -6.46251697e-01,
       -5.73403084e-01, -7.81229405e-01, -2.14825030e+00, -7.49919105e-01,
       -2.74079677e-01, -1.82815774e+00, -1.21857310e+00, -9.96195010e-01,
        5.63040096e-01, -8.03980800e-01, -3.33212339e-01, -2.37084670e+00,
       -8.65937738e-01,  2.45725660e-01, -1.34063075e+00, -6.53304903e-01,
       -3.96355628e-01, -2.56267022e-01, -1.06114838e+00, -7.37927158e-01,
       -1.64886359e+00, -1.81828026e+00, -1.63927899e-01, -1.67849603e+00,
       -1.30116697e+00, -1.07574630e+00, -7.84358064e-01, -5.84340906e-01,
       -5.98229014e-01, -8.99790141e-02, -6.29625393e-01, -1.07887496e+00,
       -1.45991026e+00, -8.30616812e-01, -1.60970419e+00, -1.91286607e+00,
       -3.19410311e-01,  3.41070189e-02, -1.29543234e+00, -1.33281261e-01,
        9.41677959e-02, -9.43897249e-01, -9.06415710e-01,  1.63045714e-01,
       -1.88323363e+00, -6.43427422e-01,  6.56356238e-02, -1.71213910e+00,
       -1.52477756e+00, -2.11773589e+00, -8.66156043e-01, -1.77003927e+00,
        2.88278166e-01, -1.30319535e+00, -1.50793295e+00, -4.49449278e-01,
       -6.79762541e-01, -2.49604278e-01, -1.10938937e+00, -5.66700409e-01,
       -8.69862285e-01, -5.42006706e-01, -8.13772203e-01, -1.15965874e+00,
        1.41480669e-01, -1.67748184e+00, -2.81132883e-01, -4.77053335e-01,
       -8.94642068e-01, -1.80760689e+00, -6.61200145e-01, -1.73673154e-01,
       -6.63096306e-01, -1.63025505e+00, -1.63638014e+00, -1.41378374e+00,
       -1.12288701e+00, -1.65785910e+00, -1.83636714e-01, -1.62540238e+00,
       -6.57361677e-01, -1.52697810e+00,  6.48630397e-01, -4.52359632e-01,
       -5.44035092e-01, -6.27729232e-01, -2.30880368e+00, -2.97673108e-01,
       -8.22635491e-01, -7.09613290e-01, -1.18223798e+00,  3.59052245e-01,
       -1.49536342e+00, -1.92463971e+00, -2.54543019e-01, -9.28994948e-01,
       -2.04290504e+00, -1.85774403e+00, -1.40474830e+00, -8.75901298e-01,
       -1.31809766e+00, -1.64211477e+00, -5.78474051e-01, -2.20017981e-01,
       -9.73529693e-01, -9.81246562e-01,  3.61166710e-01, -1.04523189e+00,
       -1.40077760e+00, -2.25838689e-01, -3.18528343e-01, -1.55900431e-01,
       -3.91244730e-01,  4.14128134e-01,  2.49233513e-03, -6.03035529e-01,
       -7.81447709e-01, -7.23385168e-02, -1.90408886e+00, -1.81153143e+00,
       -6.31567701e-01, -6.63136237e-01,  5.77004502e-02, -2.78136450e-01,
       -2.89910092e-01,  4.01558603e-01, -1.27960192e+00, -1.27576345e+00,
       -1.10615944e-01, -5.75563697e-01, -5.57837121e-01, -1.46784544e+00,
       -6.56433562e-01,  1.56296891e-01, -6.10002656e-01, -6.57229451e-01,
       -3.55877656e-01, -1.83550635e-01, -8.72022898e-01,  3.80315663e-02,
       -1.51207580e+00, -1.49937404e+00, -1.10615944e-01, -6.77912528e-01])

Analysis with Test Data¶

Now we’ll predict our test data:

test_x = test[feat_cols]
test_y = test[out_col]
test_d = lg_mod.predict(test_x)

What’s the accuracy?

np.mean(test_y == test_d)
0.7

Practice: compute the precision and recall.

ROC Curve¶

We can plot the receiver operating characteristic by using the SciKit roc_curve function to compare outcomes and the decision function. It returns the tpr and fpr:

test_lo = lg_mod.decision_function(test_x)
fpr, tpr, thresh = sklearn.metrics.roc_curve(test_y, test_lo)
plt.plot(np.linspace(0, 1), np.linspace(0, 1), color='grey', linestyle=':')
plt.plot(fpr, tpr)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC')
plt.show()
../../../_images/SciKitLogistic_42_0.png

What’s the area under that curve?

sklearn.metrics.roc_auc_score(test_y, test_lo)
0.7170909090909092

I’m going to fill it now for the slide demo:

plt.fill_between(fpr, tpr, 0, color='lightgrey')
plt.plot(np.linspace(0, 1), np.linspace(0, 1), color='red', linestyle=':')
plt.plot(fpr, tpr)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC')
plt.show()
../../../_images/SciKitLogistic_46_0.png

Let’s see a precision-recall curve:

precision, recall, thresholds = sklearn.metrics.precision_recall_curve(test_y, test_lo)
plt.plot(recall, precision)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.show()
../../../_images/SciKitLogistic_48_0.png

Precision vs. Threshold¶

What if we decrease the threshold?

def precision_at_threshold(thresh):
    mask = test_lo >= thresh
    # mask will select for positive class
    # mean is fraction
    return np.mean(test_y[mask])
thresh = np.linspace(-2, 1, 20)
precs = [precision_at_threshold(p) for p in thresh]
plt.plot(thresh, precs)
plt.xlim(1, -2)
plt.xlabel('Threshold (Log Odds)')
plt.ylabel('Precision')
Text(0, 0.5, 'Precision')
../../../_images/SciKitLogistic_52_1.png