6. SHAP

SHAP’s goal is to explain machine learning output using a game theoretic approach. A primary use of SHAP is to understand how variables and values influence predictions visually and quantitatively. The API of SHAP is built along the explainers. These explainers are appropriate only for certain types or classes of algorithms. For example, you should use the TreeExplainer for tree-based models. Below, we take a look at three of these explainers.

Note that SHAP is a part of the movement to promote explanable artificial intelligence (AI). There are other APIs available that do similar things to SHAP.

A great book on explanable AI or interpretable machine learning is available online.

6.1. Linear explainer

The LinearExplainer is used to understand the outputs of linear predictors (e.g. linear regression). We will generate some data and use the LinearRegression model to learn the parameters from the data.

[1]:
%matplotlib inline
import numpy as np
import pandas as pd
from patsy import dmatrices
from numpy.random import normal
import matplotlib.pyplot as plt

np.random.seed(37)

n = 100

x_0 = normal(10, 1, n)
x_1 = normal(5, 2.5, n)
x_2 = normal(20, 1, n)
y = 3.2 + (2.7 * x_0) - (4.8 * x_1) + (1.3 * x_2) + normal(0, 1, n)

df = pd.DataFrame(np.hstack([
    x_0.reshape(-1, 1),
    x_1.reshape(-1, 1),
    x_2.reshape(-1, 1),
    y.reshape(-1, 1)]), columns=['x0', 'x1', 'x2', 'y'])

y, X = dmatrices('y ~ x0 + x1 + x2 - 1', df, return_type='dataframe')
print(f'X shape = {X.shape}, y shape {y.shape}')
X shape = (100, 3), y shape (100, 1)
[2]:
from sklearn.linear_model import LinearRegression

model = LinearRegression()
model.fit(X, y)
[2]:
LinearRegression()

Before you can use SHAP, you must initialize the JavaScript.

[3]:
import shap
shap.initjs()

Here, we create the LinearExplainer. We have to pass in the dataset X.

[4]:
explainer = shap.LinearExplainer(model, X)
shap_values = explainer.shap_values(X)

A force plot can be used to explain each individual data point’s prediction. Below, we look at the force plots of the first, second and third observations (indexed 0, 1, 2).

  • First observation prediction explanation: the values of x1 and x2 are pushing the prediction value downard.

  • Second observation prediction explanation: the x0 value is pushing the prediction value higher, while x1 and x2 are pushing the value lower.

  • Third observation prediction explanation: the x0 and x1 values are pushing the prediction value lower and the x2 value is slightly nudging the value lower.

[5]:
shap.force_plot(explainer.expected_value, shap_values[0,:], X.iloc[0,:])
[5]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[6]:
shap.force_plot(explainer.expected_value, shap_values[1,:], X.iloc[1,:])
[6]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[7]:
shap.force_plot(explainer.expected_value, shap_values[2,:], X.iloc[2,:])
[7]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

The force plot can also be used to visualize explanation over all observations.

[8]:
shap.force_plot(explainer.expected_value, shap_values, X)
[8]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

The summary plot is a way to understand variable importance.

[9]:
shap.summary_plot(shap_values, X)
_images/shap_15_0.png

Just for comparison, the visualization of the variables’ importance coincide with the coefficients of the linear regression model.

[10]:
s = pd.Series(model.coef_[0], index=X.columns)
s
[10]:
x0    2.536395
x1   -4.742098
x2    1.415208
dtype: float64

6.2. Tree explainer

The TreeExplainer is appropriate for algorithms using trees. Here, we generate data for a classification problem and use RandomForestClassifier as the model that we want to explain.

[11]:
from scipy.stats import binom

def make_classification(n=100):
    X = np.hstack([
        np.array([1 for _ in range(n)]).reshape(n, 1),
        normal(0.0, 1.0, n).reshape(n, 1),
        normal(0.0, 1.0, n).reshape(n, 1)
    ])
    z = np.dot(X, np.array([1.0, 2.0, 3.0])) + normal(0.0, 1.0, n)
    p = 1.0 / (1.0 + np.exp(-z))
    y = binom.rvs(1, p)

    df = pd.DataFrame(np.hstack([X, y.reshape(-1, 1)]), columns=['intercept', 'x0', 'x1', 'y'])
    return df

df = make_classification()

y, X = dmatrices('y ~ x0 + x1 - 1', df, return_type='dataframe')
print(f'X shape = {X.shape}, y shape {y.shape}')
X shape = (100, 2), y shape (100, 1)
[12]:
from sklearn.ensemble import RandomForestClassifier

model = RandomForestClassifier(n_estimators=100, random_state=37)
model.fit(X, y.values.reshape(1, -1)[0])
[12]:
RandomForestClassifier(random_state=37)
[13]:
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
shap_interaction_values = explainer.shap_interaction_values(X)

Here are the forced plots for three observations.

[14]:
shap.force_plot(explainer.expected_value[1], shap_values[1][0,:], X.iloc[0,:])
[14]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[15]:
shap.force_plot(explainer.expected_value[1], shap_values[1][1,:], X.iloc[1,:])
[15]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[16]:
shap.force_plot(explainer.expected_value[1], shap_values[1][95,:], X.iloc[95,:])
[16]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Here is the force plot for all observations.

[17]:
shap.force_plot(explainer.expected_value[1], shap_values[1], X)
[17]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Below is the summary plot.

[18]:
shap.summary_plot(shap_values[1], X)
_images/shap_29_0.png

Below are dependence plots.

[19]:
shap.dependence_plot('x0', shap_values[1], X)
_images/shap_31_0.png
[20]:
shap.dependence_plot('x1', shap_values[1], X)
_images/shap_32_0.png
[21]:
shap.dependence_plot(('x0', 'x0'), shap_interaction_values[1], X)
_images/shap_33_0.png
[22]:
shap.dependence_plot(('x0', 'x1'), shap_interaction_values[1], X)
_images/shap_34_0.png
[23]:
shap.dependence_plot(('x1', 'x1'), shap_interaction_values[1], X)
_images/shap_35_0.png

Lastly, the summary plot.

[24]:
shap.summary_plot(shap_interaction_values[1], X)
_images/shap_37_0.png

6.3. Kernel explainer

The KernelExplainer is the general purpose explainer. Here, we use it to explain the LogisticRegression model. Notice the link parameter, which can be identity or logit. This argument specifies the model link to connect the feature importance values to the model output.

[25]:
from sklearn.linear_model import LogisticRegression

df = make_classification(n=10000)
X = df[['x0', 'x1']]
y = df.y

model = LogisticRegression(fit_intercept=True, solver='saga', random_state=37)
model.fit(X, y.values.reshape(1, -1)[0])
[25]:
LogisticRegression(random_state=37, solver='saga')
[26]:
df = make_classification()
X = df[['x0', 'x1']]
y = df.y

Observe that we pass in the proababilistic prediction function to the KernelExplainer.

[27]:
explainer = shap.KernelExplainer(model.predict_proba, link='logit', data=X)
shap_values = explainer.shap_values(X)

Again, example force plots on a few observations.

[28]:
shap.force_plot(explainer.expected_value[1], shap_values[1][0,:], X.iloc[0,:], link='logit')
[28]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[29]:
shap.force_plot(explainer.expected_value[1], shap_values[1][1,:], X.iloc[1,:], link='logit')
[29]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[30]:
shap.force_plot(explainer.expected_value[1], shap_values[1][99,:], X.iloc[99,:], link='logit')
[30]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

The force plot over all observations.

[31]:
shap.force_plot(explainer.expected_value[1], shap_values[1], X, link='logit')
[31]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Lastly, the summary plot.

[32]:
shap.summary_plot(shap_values[1], X)
_images/shap_50_0.png