fastshap


Namefastshap JSON
Version 0.3.1 PyPI version JSON
download
home_pagehttps://github.com/AnotherSamWilson/fastshap
SummaryFast SHAP kernel explainer
upload_time2023-03-12 14:09:15
maintainer
docs_urlNone
authorSamuel Wilson
requires_python>=3.7
licenseMIT
keywords shap model explainability
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI
coveralls test coverage No coveralls.
            
<!-- [![Build Status](https://app.travis-ci.com/AnotherSamWilson/fastshap.svg?branch=main)](https://app.travis-ci.com/github/AnotherSamWilson/fastshap) -->

[![CodeCov](https://codecov.io/gh/AnotherSamWilson/fastshap/branch/master/graphs/badge.svg?branch=master&service=github)](https://codecov.io/gh/AnotherSamWilson/fastshap)

## fastshap: A fast, approximate shap kernel

<!-- <a href='https://github.com/AnotherSamWilson/miceforest'><img src='https://i.imgur.com/nbrAQso.png' align="right" height="300" /></a> -->

Calculating shap values can take an extremely long time. `fastshap` was
designed to be as fast as possible by utilizing inner and outer batch
assignments to keep the calculations inside vectorized operations as
often as it can. This includes the model evaluation. If the model in
question is more efficient for 100 samples than 10, then this sort of
vectorization can have enormous benefits.

**This package specifically offers a kernel explainer**. Kernel
explainers can calculate approximate shap values of f(X) towards y for
any function f. Much faster shap solutions are available specifically
for gradient boosted trees and deep neural networks.

A kernel explainer is ideal in situations where:

1)  The model you are using does not have model-specific methods
    available (for example, support vector machine)
2)  You need to explain a modeling pipeline which includes variable
    transformations.
3)  The model has a link function or some other target transformation.
    For example, you wish to explain the raw probabilities in a
    classification model, instead of the log-odds.

### Features

Advantages of `fastshap`:

  - Fast. See benchmarks for comparisons.  
  - Native handling of both numpy arrays and pandas dataframes including
    principled treatment of categories.  
  - Easy built in stratification of background set.  
  - Capable of plotting categorical variables in dependence plots.  
  - Capable of determining categorical variable interactions in shap
    values.  
  - Capable of plotting missing values in interaction variable.

Disadvantages of `fastshap`:

  - Only dependency plotting is supported as of now.  
  - Does not support feature groups yet.  
  - Does not support weights yet.

### Installation

This package can be installed using pip:

``` bash
# Using pip
$ pip install fastshap --no-cache-dir
```

You can also download the latest development version from this
repository. If you want to install from github with conda, you must
first run `conda install pip git`.

``` bash
$ pip install git+https://github.com/AnotherSamWilson/fastshap.git
```

### Benchmarks

These benchmarks compare the `shap` package `KernelExplainer` to the one
in `fastshap`. All code is in `./benchmarks`. We left out model-specific
shap explainers, because they are usually orders of magnitued faster and
more efficient than kernel explainers.

##### Iris Dataset

The iris dataset is a table of 150 rows and 5 columns (4 features, one
target). This benchmark measured the time to calculate the shap values
for different row counts. The iris dataset was concatenated to itself to
get the desired dataset size:  
<img src="https://raw.githubusercontent.com/AnotherSamWilson/fastshap/master/benchmarks/iris_benchmark_time.png" width="600px" />

<table class=" lightable-minimal" style='font-family: "Trebuchet MS", verdana, sans-serif; margin-left: auto; margin-right: auto;'>

<thead>

<tr>

<th style="empty-cells: hide;" colspan="1">

</th>

<th style="padding-bottom:0; padding-left:3px;padding-right:3px;text-align: center; " colspan="2">

<div style="border-bottom: 2px solid #00000050; ">

Avg Times

</div>

</th>

<th style="padding-bottom:0; padding-left:3px;padding-right:3px;text-align: center; " colspan="2">

<div style="border-bottom: 2px solid #00000050; ">

StDev Times

</div>

</th>

<th style="empty-cells: hide;" colspan="1">

</th>

</tr>

<tr>

<th style="text-align:right;">

rows

</th>

<th style="text-align:right;">

fastshap

</th>

<th style="text-align:right;">

shap

</th>

<th style="text-align:right;">

fastshap

</th>

<th style="text-align:right;">

shap

</th>

<th style="text-align:right;">

Relative Difference

</th>

</tr>

</thead>

<tbody>

<tr>

<td style="text-align:right;">

150

</td>

<td style="text-align:right;">

0.27

</td>

<td style="text-align:right;">

5.57

</td>

<td style="text-align:right;">

0.02

</td>

<td style="text-align:right;">

0.02

</td>

<td style="text-align:right;">

20.41

</td>

</tr>

<tr>

<td style="text-align:right;">

300

</td>

<td style="text-align:right;">

0.54

</td>

<td style="text-align:right;">

11.30

</td>

<td style="text-align:right;">

0.06

</td>

<td style="text-align:right;">

0.27

</td>

<td style="text-align:right;">

21.11

</td>

</tr>

<tr>

<td style="text-align:right;">

450

</td>

<td style="text-align:right;">

0.81

</td>

<td style="text-align:right;">

17.57

</td>

<td style="text-align:right;">

0.07

</td>

<td style="text-align:right;">

0.59

</td>

<td style="text-align:right;">

21.57

</td>

</tr>

<tr>

<td style="text-align:right;">

600

</td>

<td style="text-align:right;">

1.05

</td>

<td style="text-align:right;">

23.30

</td>

<td style="text-align:right;">

0.03

</td>

<td style="text-align:right;">

0.45

</td>

<td style="text-align:right;">

22.18

</td>

</tr>

<tr>

<td style="text-align:right;">

750

</td>

<td style="text-align:right;">

1.49

</td>

<td style="text-align:right;">

30.06

</td>

<td style="text-align:right;">

0.13

</td>

<td style="text-align:right;">

0.67

</td>

<td style="text-align:right;">

20.17

</td>

</tr>

</tbody>

</table>

##### California Housing Dataset

The California Housing dataset is a table of 20640 rows and 9 columns (8
features, one target). This benchmark measured the time it took to
calculate shap values on the first 2000 rows for different sizes of the
background dataset.  
<img src="https://raw.githubusercontent.com/AnotherSamWilson/fastshap/master/benchmarks/cali_benchmark_time.png" width="600px" />

<table class=" lightable-minimal" style='font-family: "Trebuchet MS", verdana, sans-serif; margin-left: auto; margin-right: auto;'>

<thead>

<tr>

<th style="text-align:right;">

rows

</th>

<th style="text-align:right;">

fastshap

</th>

<th style="text-align:right;">

shap

</th>

<th style="text-align:right;">

Relative Difference

</th>

</tr>

</thead>

<tbody>

<tr>

<td style="text-align:right;">

42

</td>

<td style="text-align:right;">

14.61

</td>

<td style="text-align:right;">

128.48

</td>

<td style="text-align:right;">

8.79

</td>

</tr>

<tr>

<td style="text-align:right;">

52

</td>

<td style="text-align:right;">

19.33

</td>

<td style="text-align:right;">

156.86

</td>

<td style="text-align:right;">

8.12

</td>

</tr>

<tr>

<td style="text-align:right;">

69

</td>

<td style="text-align:right;">

24.79

</td>

<td style="text-align:right;">

203.43

</td>

<td style="text-align:right;">

8.21

</td>

</tr>

<tr>

<td style="text-align:right;">

104

</td>

<td style="text-align:right;">

38.32

</td>

<td style="text-align:right;">

290.76

</td>

<td style="text-align:right;">

7.59

</td>

</tr>

<tr>

<td style="text-align:right;">

207

</td>

<td style="text-align:right;">

72.80

</td>

<td style="text-align:right;">

515.27

</td>

<td style="text-align:right;">

7.08

</td>

</tr>

<tr>

<td style="text-align:right;">

413

</td>

<td style="text-align:right;">

146.65

</td>

<td style="text-align:right;">

979.44

</td>

<td style="text-align:right;">

6.68

</td>

</tr>

<tr>

<td style="text-align:right;">

826

</td>

<td style="text-align:right;">

313.18

</td>

<td style="text-align:right;">

1903.28

</td>

<td style="text-align:right;">

6.08

</td>

</tr>

</tbody>

</table>

##### Effect of Outer Batch Sizes

Increasing the outer batch size can have a significant effect on the run
time of the process:  
<img src="https://raw.githubusercontent.com/AnotherSamWilson/fastshap/master/benchmarks/batch_size_times.png" width="600px" />

### Basic Usage

We will use the iris dataset for this example. Here, we load the data
and train a simple lightgbm model on the dataset:

``` python
from sklearn.datasets import load_iris
import pandas as pd
import lightgbm as lgb
import numpy as np

# Define our dataset and target variable
data = pd.concat(load_iris(as_frame=True,return_X_y=True),axis=1)
data.rename({"target": "species"}, inplace=True, axis=1)
data["species"] = data["species"].astype("category")
target = data.pop("sepal length (cm)")

# Train our model
dtrain = lgb.Dataset(data=data, label=target)
lgbmodel = lgb.train(
    params={"seed": 1, "verbose": -1},
    train_set=dtrain,
    num_boost_round=10
)

# Define the function we wish to build shap values for.
model = lgbmodel.predict

preds = model(data)
```

We now have a `model` which takes a Pandas dataframe, and returns
predictions. We can create an explainer that will use `data` as a
background dataset to calculate the shap values of any dataset we wish:

``` python
from fastshap import KernelExplainer

ke = KernelExplainer(model, data)
sv = ke.calculate_shap_values(data, verbose=False)

print(all(preds == sv.sum(1)))
```

    ## True

### Plotting

Dependence plots can be created by passing the shap values and variable
/ interaction information to `plot_variable_effect_on_output`:

``` python
from fastshap.plotting import plot_variable_effect_on_output
plot_variable_effect_on_output(
    sv, data,
    variable="sepal width (cm)",
    interaction_variable="auto"
)
```

<img src="https://raw.githubusercontent.com/AnotherSamWilson/fastshap/master/graphics/depgraph.png" width="800px" />

The type of plot that is generated depends on the model output, the
variable type, and the interaction variable type. For example, plotting
the effect of a categorical variable shows the following:

``` python
from fastshap.plotting import plot_variable_effect_on_output
plot_variable_effect_on_output(
    sv, data,
    variable="species",
    interaction_variable="auto"
)
```

<img src="https://raw.githubusercontent.com/AnotherSamWilson/fastshap/master/graphics/depgraph_cat.png" width="800px" />

### Stratifying the Background Set

We can select a subset of our data to act as a background set. By
stratifying the background set on the results of the model output, we
will usually get very similar results, while decreasing the caculation
time drastically.

``` python
ke.stratify_background_set(5)
sv2 = ke.calculate_shap_values(
  data, 
  background_fold_to_use=0,
  verbose=False
)

print(np.abs(sv2 - sv).mean(0))
```

    ## [1.74764532e-03 1.61829094e-02 1.99534408e-03 4.02640884e-16
    ##  1.71084747e-02]

What we did is break up our background set into 10 different sets,
stratified by the model output. We then used the first of these sets as
our background set. We then compared the average difference between
these shap values, and the shap values we obtained from using the entire
dataset.

### Choosing Batch Sizes

If the entire process was vectorized, it would require an array of size
(`# Samples * # Coalitions * # Background samples`, `# Columns`). Where
`# Coalitions` is the sum of the total number of coalitions that are
going to be run. Even for small datasets, this becomes enormous.
`fastshap` breaks this array up into chunks by splitting the process
into a series of batches.

This is a list of the large arrays and their maximum size:

  - Global
      - Mask Matrix (`# Coalitions`, `# Columns`)
  - Outer Batch
      - Linear Targets (`Total Coalition Combinations`, `Outer Batch
        Size`, `Output Dimension`)\`
  - Inner Batch
      - Model Evaluation Features (`Inner Batch Size`, `# Background
        Samples`)\`

The final, returned shap values will also be returned as the datatype
returned by the model.

These theoretical sizes can be calculated directly so that the user can
determine appropriate batch sizes for their machine:

``` python
# Combines our background data back into 1 DataFrame
ke.stratify_background_set(1)
(
    mask_matrix_size, 
    linear_target_size, 
    inner_model_eval_set_size
) = ke.get_theoretical_array_expansion_sizes(
    data=data,
    outer_batch_size=150,
    inner_batch_size=150,
    n_coalition_sizes=3,
    background_fold_to_use=None
)

print(
  np.product(linear_target_size) + np.product(inner_model_eval_set_size)
)
```

    ## 452100

For the iris dataset, even if we sent the entire set (150 rows) through
as one batch, we only need 92100 elements stored in arrays. This is
manageable on most machines. However, this number ***grows extremely
quickly*** with the samples and number of columns. It is highly advised
to determine a good batch scheme before running this process.

Another way to determine optimal batch sizes is to use the function
`.get_theoretical_minimum_memory_requirements()`. This returns a list of
Gigabytes needed to build the arrays above:

``` python
# Combines our background data back into 1 DataFrame
(
    mask_matrix_GB, 
    linear_target_GB, 
    inner_model_eval_set_GB
) = ke.get_theoretical_minimum_memory_requirements(
    data=data,
    outer_batch_size=150,
    inner_batch_size=150,
    n_coalition_sizes=3,
    background_fold_to_use=None
)

total_GB_needed = mask_matrix_GB + linear_target_GB + inner_model_eval_set_GB
print(f"We need {total_GB_needed} GB to calculate shap values with these batch sizes.")
```

    ## We need 0.003368459641933441 GB to calculate shap values with these batch sizes.

### Specifying a Custom Linear Model

Any linear model available from sklearn.linear\_model can be used to
calculate the shap values. If you wish for some sparsity in the shap
values, you can use Lasso regression:

``` python
from sklearn.linear_model import Lasso

# Use our entire background set
ke.stratify_background_set(1)
sv_lasso = ke.calculate_shap_values(
  data, 
  background_fold_to_use=0,
  linear_model=Lasso(alpha=0.1),
  verbose=False
)

print(sv_lasso[0,:])
```

    ## [-0.         -0.33797832 -0.         -0.14634971  5.84333333]

The default model used is `sklearn.linear_model.LinearRegression`.

### Multiclass Outputs

If the model returns multiple outputs, the resulting shap values are
returned as an array of size (`rows`, `columns + 1`, `outputs`).
Therefore, to get the shap values for the effects on the second class,
you need to slice the resulting shap values using `shap_values[:,:,1]`.
Here is an example:

``` python
multi_features = pd.concat(load_iris(as_frame=True,return_X_y=True),axis=1)
multi_features.rename({"target": "species"}, inplace=True, axis=1)
target = multi_features.pop("species")

dtrain = lgb.Dataset(data=multi_features, label=target)
lgbmodel = lgb.train(
    params={"seed": 1, "objective": "multiclass", "num_class": 3, "verbose": -1},
    train_set=dtrain,
    num_boost_round=10
)
model = lgbmodel.predict

explainer_multi = KernelExplainer(model, multi_features)
shap_values_multi = explainer_multi.calculate_shap_values(multi_features, verbose=False)

# To get the shap values for the second class:
print(shap_values_multi.shape)
```

    ## (150, 5, 3)

Our shap values are a numpy array of shape `(150, 5, 3)` for each of our
150 rows, 4 columns (plus expected value), and our 3 output dimensions.
When plotting multiclass outputs, the classes are essentially treated as
a categorical variable. However, it is possible to plot variable
interactions with *one* of the output classes, see below.  
We can plot a variables shap values for each of the output classes:

``` python
plot_variable_effect_on_output(
    shap_values_multi,
    data,
    variable=2
)
```

<img src="https://raw.githubusercontent.com/AnotherSamWilson/fastshap/master/graphics/multiclass_depgraph.png" width="800px" />

We can also look at interactions if we are interested in a specific
class. For instance, if we wanted to know the effect that `sepal width
(cm)` had on our first class, we could do:

``` python
plot_variable_effect_on_output(sv, data, variable="sepal width (cm)", output_index=0)
```

<img src="https://raw.githubusercontent.com/AnotherSamWilson/fastshap/master/graphics/mc_so_depgraph.png" width="800px" />



            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/AnotherSamWilson/fastshap",
    "name": "fastshap",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.7",
    "maintainer_email": "",
    "keywords": "shap,model explainability",
    "author": "Samuel Wilson",
    "author_email": "samwilson303@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/10/0d/ed76b59b95a9ea1a75e1b6b8860e99d9e8e22ea821347825df2f13b224eb/fastshap-0.3.1.tar.gz",
    "platform": null,
    "description": "\n<!-- [![Build Status](https://app.travis-ci.com/AnotherSamWilson/fastshap.svg?branch=main)](https://app.travis-ci.com/github/AnotherSamWilson/fastshap) -->\n\n[![CodeCov](https://codecov.io/gh/AnotherSamWilson/fastshap/branch/master/graphs/badge.svg?branch=master&service=github)](https://codecov.io/gh/AnotherSamWilson/fastshap)\n\n## fastshap: A fast, approximate shap kernel\n\n<!-- <a href='https://github.com/AnotherSamWilson/miceforest'><img src='https://i.imgur.com/nbrAQso.png' align=\"right\" height=\"300\" /></a> -->\n\nCalculating shap values can take an extremely long time. `fastshap` was\ndesigned to be as fast as possible by utilizing inner and outer batch\nassignments to keep the calculations inside vectorized operations as\noften as it can. This includes the model evaluation. If the model in\nquestion is more efficient for 100 samples than 10, then this sort of\nvectorization can have enormous benefits.\n\n**This package specifically offers a kernel explainer**. Kernel\nexplainers can calculate approximate shap values of f(X) towards y for\nany function f.\u00c2\u00a0Much faster shap solutions are available specifically\nfor gradient boosted trees and deep neural networks.\n\nA kernel explainer is ideal in situations where:\n\n1)  The model you are using does not have model-specific methods\n    available (for example, support vector machine)\n2)  You need to explain a modeling pipeline which includes variable\n    transformations.\n3)  The model has a link function or some other target transformation.\n    For example, you wish to explain the raw probabilities in a\n    classification model, instead of the log-odds.\n\n### Features\n\nAdvantages of `fastshap`:\n\n  - Fast. See benchmarks for comparisons.  \n  - Native handling of both numpy arrays and pandas dataframes including\n    principled treatment of categories.  \n  - Easy built in stratification of background set.  \n  - Capable of plotting categorical variables in dependence plots.  \n  - Capable of determining categorical variable interactions in shap\n    values.  \n  - Capable of plotting missing values in interaction variable.\n\nDisadvantages of `fastshap`:\n\n  - Only dependency plotting is supported as of now.  \n  - Does not support feature groups yet.  \n  - Does not support weights yet.\n\n### Installation\n\nThis package can be installed using pip:\n\n``` bash\n# Using pip\n$ pip install fastshap --no-cache-dir\n```\n\nYou can also download the latest development version from this\nrepository. If you want to install from github with conda, you must\nfirst run `conda install pip git`.\n\n``` bash\n$ pip install git+https://github.com/AnotherSamWilson/fastshap.git\n```\n\n### Benchmarks\n\nThese benchmarks compare the `shap` package `KernelExplainer` to the one\nin `fastshap`. All code is in `./benchmarks`. We left out model-specific\nshap explainers, because they are usually orders of magnitued faster and\nmore efficient than kernel explainers.\n\n##### Iris Dataset\n\nThe iris dataset is a table of 150 rows and 5 columns (4 features, one\ntarget). This benchmark measured the time to calculate the shap values\nfor different row counts. The iris dataset was concatenated to itself to\nget the desired dataset size:  \n<img src=\"https://raw.githubusercontent.com/AnotherSamWilson/fastshap/master/benchmarks/iris_benchmark_time.png\" width=\"600px\" />\n\n<table class=\" lightable-minimal\" style='font-family: \"Trebuchet MS\", verdana, sans-serif; margin-left: auto; margin-right: auto;'>\n\n<thead>\n\n<tr>\n\n<th style=\"empty-cells: hide;\" colspan=\"1\">\n\n</th>\n\n<th style=\"padding-bottom:0; padding-left:3px;padding-right:3px;text-align: center; \" colspan=\"2\">\n\n<div style=\"border-bottom: 2px solid #00000050; \">\n\nAvg Times\n\n</div>\n\n</th>\n\n<th style=\"padding-bottom:0; padding-left:3px;padding-right:3px;text-align: center; \" colspan=\"2\">\n\n<div style=\"border-bottom: 2px solid #00000050; \">\n\nStDev Times\n\n</div>\n\n</th>\n\n<th style=\"empty-cells: hide;\" colspan=\"1\">\n\n</th>\n\n</tr>\n\n<tr>\n\n<th style=\"text-align:right;\">\n\nrows\n\n</th>\n\n<th style=\"text-align:right;\">\n\nfastshap\n\n</th>\n\n<th style=\"text-align:right;\">\n\nshap\n\n</th>\n\n<th style=\"text-align:right;\">\n\nfastshap\n\n</th>\n\n<th style=\"text-align:right;\">\n\nshap\n\n</th>\n\n<th style=\"text-align:right;\">\n\nRelative Difference\n\n</th>\n\n</tr>\n\n</thead>\n\n<tbody>\n\n<tr>\n\n<td style=\"text-align:right;\">\n\n150\n\n</td>\n\n<td style=\"text-align:right;\">\n\n0.27\n\n</td>\n\n<td style=\"text-align:right;\">\n\n5.57\n\n</td>\n\n<td style=\"text-align:right;\">\n\n0.02\n\n</td>\n\n<td style=\"text-align:right;\">\n\n0.02\n\n</td>\n\n<td style=\"text-align:right;\">\n\n20.41\n\n</td>\n\n</tr>\n\n<tr>\n\n<td style=\"text-align:right;\">\n\n300\n\n</td>\n\n<td style=\"text-align:right;\">\n\n0.54\n\n</td>\n\n<td style=\"text-align:right;\">\n\n11.30\n\n</td>\n\n<td style=\"text-align:right;\">\n\n0.06\n\n</td>\n\n<td style=\"text-align:right;\">\n\n0.27\n\n</td>\n\n<td style=\"text-align:right;\">\n\n21.11\n\n</td>\n\n</tr>\n\n<tr>\n\n<td style=\"text-align:right;\">\n\n450\n\n</td>\n\n<td style=\"text-align:right;\">\n\n0.81\n\n</td>\n\n<td style=\"text-align:right;\">\n\n17.57\n\n</td>\n\n<td style=\"text-align:right;\">\n\n0.07\n\n</td>\n\n<td style=\"text-align:right;\">\n\n0.59\n\n</td>\n\n<td style=\"text-align:right;\">\n\n21.57\n\n</td>\n\n</tr>\n\n<tr>\n\n<td style=\"text-align:right;\">\n\n600\n\n</td>\n\n<td style=\"text-align:right;\">\n\n1.05\n\n</td>\n\n<td style=\"text-align:right;\">\n\n23.30\n\n</td>\n\n<td style=\"text-align:right;\">\n\n0.03\n\n</td>\n\n<td style=\"text-align:right;\">\n\n0.45\n\n</td>\n\n<td style=\"text-align:right;\">\n\n22.18\n\n</td>\n\n</tr>\n\n<tr>\n\n<td style=\"text-align:right;\">\n\n750\n\n</td>\n\n<td style=\"text-align:right;\">\n\n1.49\n\n</td>\n\n<td style=\"text-align:right;\">\n\n30.06\n\n</td>\n\n<td style=\"text-align:right;\">\n\n0.13\n\n</td>\n\n<td style=\"text-align:right;\">\n\n0.67\n\n</td>\n\n<td style=\"text-align:right;\">\n\n20.17\n\n</td>\n\n</tr>\n\n</tbody>\n\n</table>\n\n##### California Housing Dataset\n\nThe California Housing dataset is a table of 20640 rows and 9 columns (8\nfeatures, one target). This benchmark measured the time it took to\ncalculate shap values on the first 2000 rows for different sizes of the\nbackground dataset.  \n<img src=\"https://raw.githubusercontent.com/AnotherSamWilson/fastshap/master/benchmarks/cali_benchmark_time.png\" width=\"600px\" />\n\n<table class=\" lightable-minimal\" style='font-family: \"Trebuchet MS\", verdana, sans-serif; margin-left: auto; margin-right: auto;'>\n\n<thead>\n\n<tr>\n\n<th style=\"text-align:right;\">\n\nrows\n\n</th>\n\n<th style=\"text-align:right;\">\n\nfastshap\n\n</th>\n\n<th style=\"text-align:right;\">\n\nshap\n\n</th>\n\n<th style=\"text-align:right;\">\n\nRelative Difference\n\n</th>\n\n</tr>\n\n</thead>\n\n<tbody>\n\n<tr>\n\n<td style=\"text-align:right;\">\n\n42\n\n</td>\n\n<td style=\"text-align:right;\">\n\n14.61\n\n</td>\n\n<td style=\"text-align:right;\">\n\n128.48\n\n</td>\n\n<td style=\"text-align:right;\">\n\n8.79\n\n</td>\n\n</tr>\n\n<tr>\n\n<td style=\"text-align:right;\">\n\n52\n\n</td>\n\n<td style=\"text-align:right;\">\n\n19.33\n\n</td>\n\n<td style=\"text-align:right;\">\n\n156.86\n\n</td>\n\n<td style=\"text-align:right;\">\n\n8.12\n\n</td>\n\n</tr>\n\n<tr>\n\n<td style=\"text-align:right;\">\n\n69\n\n</td>\n\n<td style=\"text-align:right;\">\n\n24.79\n\n</td>\n\n<td style=\"text-align:right;\">\n\n203.43\n\n</td>\n\n<td style=\"text-align:right;\">\n\n8.21\n\n</td>\n\n</tr>\n\n<tr>\n\n<td style=\"text-align:right;\">\n\n104\n\n</td>\n\n<td style=\"text-align:right;\">\n\n38.32\n\n</td>\n\n<td style=\"text-align:right;\">\n\n290.76\n\n</td>\n\n<td style=\"text-align:right;\">\n\n7.59\n\n</td>\n\n</tr>\n\n<tr>\n\n<td style=\"text-align:right;\">\n\n207\n\n</td>\n\n<td style=\"text-align:right;\">\n\n72.80\n\n</td>\n\n<td style=\"text-align:right;\">\n\n515.27\n\n</td>\n\n<td style=\"text-align:right;\">\n\n7.08\n\n</td>\n\n</tr>\n\n<tr>\n\n<td style=\"text-align:right;\">\n\n413\n\n</td>\n\n<td style=\"text-align:right;\">\n\n146.65\n\n</td>\n\n<td style=\"text-align:right;\">\n\n979.44\n\n</td>\n\n<td style=\"text-align:right;\">\n\n6.68\n\n</td>\n\n</tr>\n\n<tr>\n\n<td style=\"text-align:right;\">\n\n826\n\n</td>\n\n<td style=\"text-align:right;\">\n\n313.18\n\n</td>\n\n<td style=\"text-align:right;\">\n\n1903.28\n\n</td>\n\n<td style=\"text-align:right;\">\n\n6.08\n\n</td>\n\n</tr>\n\n</tbody>\n\n</table>\n\n##### Effect of Outer Batch Sizes\n\nIncreasing the outer batch size can have a significant effect on the run\ntime of the process:  \n<img src=\"https://raw.githubusercontent.com/AnotherSamWilson/fastshap/master/benchmarks/batch_size_times.png\" width=\"600px\" />\n\n### Basic Usage\n\nWe will use the iris dataset for this example. Here, we load the data\nand train a simple lightgbm model on the dataset:\n\n``` python\nfrom sklearn.datasets import load_iris\nimport pandas as pd\nimport lightgbm as lgb\nimport numpy as np\n\n# Define our dataset and target variable\ndata = pd.concat(load_iris(as_frame=True,return_X_y=True),axis=1)\ndata.rename({\"target\": \"species\"}, inplace=True, axis=1)\ndata[\"species\"] = data[\"species\"].astype(\"category\")\ntarget = data.pop(\"sepal length (cm)\")\n\n# Train our model\ndtrain = lgb.Dataset(data=data, label=target)\nlgbmodel = lgb.train(\n    params={\"seed\": 1, \"verbose\": -1},\n    train_set=dtrain,\n    num_boost_round=10\n)\n\n# Define the function we wish to build shap values for.\nmodel = lgbmodel.predict\n\npreds = model(data)\n```\n\nWe now have a `model` which takes a Pandas dataframe, and returns\npredictions. We can create an explainer that will use `data` as a\nbackground dataset to calculate the shap values of any dataset we wish:\n\n``` python\nfrom fastshap import KernelExplainer\n\nke = KernelExplainer(model, data)\nsv = ke.calculate_shap_values(data, verbose=False)\n\nprint(all(preds == sv.sum(1)))\n```\n\n    ## True\n\n### Plotting\n\nDependence plots can be created by passing the shap values and variable\n/ interaction information to `plot_variable_effect_on_output`:\n\n``` python\nfrom fastshap.plotting import plot_variable_effect_on_output\nplot_variable_effect_on_output(\n    sv, data,\n    variable=\"sepal width (cm)\",\n    interaction_variable=\"auto\"\n)\n```\n\n<img src=\"https://raw.githubusercontent.com/AnotherSamWilson/fastshap/master/graphics/depgraph.png\" width=\"800px\" />\n\nThe type of plot that is generated depends on the model output, the\nvariable type, and the interaction variable type. For example, plotting\nthe effect of a categorical variable shows the following:\n\n``` python\nfrom fastshap.plotting import plot_variable_effect_on_output\nplot_variable_effect_on_output(\n    sv, data,\n    variable=\"species\",\n    interaction_variable=\"auto\"\n)\n```\n\n<img src=\"https://raw.githubusercontent.com/AnotherSamWilson/fastshap/master/graphics/depgraph_cat.png\" width=\"800px\" />\n\n### Stratifying the Background Set\n\nWe can select a subset of our data to act as a background set. By\nstratifying the background set on the results of the model output, we\nwill usually get very similar results, while decreasing the caculation\ntime drastically.\n\n``` python\nke.stratify_background_set(5)\nsv2 = ke.calculate_shap_values(\n  data, \n  background_fold_to_use=0,\n  verbose=False\n)\n\nprint(np.abs(sv2 - sv).mean(0))\n```\n\n    ## [1.74764532e-03 1.61829094e-02 1.99534408e-03 4.02640884e-16\n    ##  1.71084747e-02]\n\nWhat we did is break up our background set into 10 different sets,\nstratified by the model output. We then used the first of these sets as\nour background set. We then compared the average difference between\nthese shap values, and the shap values we obtained from using the entire\ndataset.\n\n### Choosing Batch Sizes\n\nIf the entire process was vectorized, it would require an array of size\n(`# Samples * # Coalitions * # Background samples`, `# Columns`). Where\n`# Coalitions` is the sum of the total number of coalitions that are\ngoing to be run. Even for small datasets, this becomes enormous.\n`fastshap` breaks this array up into chunks by splitting the process\ninto a series of batches.\n\nThis is a list of the large arrays and their maximum size:\n\n  - Global\n      - Mask Matrix (`# Coalitions`, `# Columns`)\n  - Outer Batch\n      - Linear Targets (`Total Coalition Combinations`, `Outer Batch\n        Size`, `Output Dimension`)\\`\n  - Inner Batch\n      - Model Evaluation Features (`Inner Batch Size`, `# Background\n        Samples`)\\`\n\nThe final, returned shap values will also be returned as the datatype\nreturned by the model.\n\nThese theoretical sizes can be calculated directly so that the user can\ndetermine appropriate batch sizes for their machine:\n\n``` python\n# Combines our background data back into 1 DataFrame\nke.stratify_background_set(1)\n(\n    mask_matrix_size, \n    linear_target_size, \n    inner_model_eval_set_size\n) = ke.get_theoretical_array_expansion_sizes(\n    data=data,\n    outer_batch_size=150,\n    inner_batch_size=150,\n    n_coalition_sizes=3,\n    background_fold_to_use=None\n)\n\nprint(\n  np.product(linear_target_size) + np.product(inner_model_eval_set_size)\n)\n```\n\n    ## 452100\n\nFor the iris dataset, even if we sent the entire set (150 rows) through\nas one batch, we only need 92100 elements stored in arrays. This is\nmanageable on most machines. However, this number ***grows extremely\nquickly*** with the samples and number of columns. It is highly advised\nto determine a good batch scheme before running this process.\n\nAnother way to determine optimal batch sizes is to use the function\n`.get_theoretical_minimum_memory_requirements()`. This returns a list of\nGigabytes needed to build the arrays above:\n\n``` python\n# Combines our background data back into 1 DataFrame\n(\n    mask_matrix_GB, \n    linear_target_GB, \n    inner_model_eval_set_GB\n) = ke.get_theoretical_minimum_memory_requirements(\n    data=data,\n    outer_batch_size=150,\n    inner_batch_size=150,\n    n_coalition_sizes=3,\n    background_fold_to_use=None\n)\n\ntotal_GB_needed = mask_matrix_GB + linear_target_GB + inner_model_eval_set_GB\nprint(f\"We need {total_GB_needed} GB to calculate shap values with these batch sizes.\")\n```\n\n    ## We need 0.003368459641933441 GB to calculate shap values with these batch sizes.\n\n### Specifying a Custom Linear Model\n\nAny linear model available from sklearn.linear\\_model can be used to\ncalculate the shap values. If you wish for some sparsity in the shap\nvalues, you can use Lasso regression:\n\n``` python\nfrom sklearn.linear_model import Lasso\n\n# Use our entire background set\nke.stratify_background_set(1)\nsv_lasso = ke.calculate_shap_values(\n  data, \n  background_fold_to_use=0,\n  linear_model=Lasso(alpha=0.1),\n  verbose=False\n)\n\nprint(sv_lasso[0,:])\n```\n\n    ## [-0.         -0.33797832 -0.         -0.14634971  5.84333333]\n\nThe default model used is `sklearn.linear_model.LinearRegression`.\n\n### Multiclass Outputs\n\nIf the model returns multiple outputs, the resulting shap values are\nreturned as an array of size (`rows`, `columns + 1`, `outputs`).\nTherefore, to get the shap values for the effects on the second class,\nyou need to slice the resulting shap values using `shap_values[:,:,1]`.\nHere is an example:\n\n``` python\nmulti_features = pd.concat(load_iris(as_frame=True,return_X_y=True),axis=1)\nmulti_features.rename({\"target\": \"species\"}, inplace=True, axis=1)\ntarget = multi_features.pop(\"species\")\n\ndtrain = lgb.Dataset(data=multi_features, label=target)\nlgbmodel = lgb.train(\n    params={\"seed\": 1, \"objective\": \"multiclass\", \"num_class\": 3, \"verbose\": -1},\n    train_set=dtrain,\n    num_boost_round=10\n)\nmodel = lgbmodel.predict\n\nexplainer_multi = KernelExplainer(model, multi_features)\nshap_values_multi = explainer_multi.calculate_shap_values(multi_features, verbose=False)\n\n# To get the shap values for the second class:\nprint(shap_values_multi.shape)\n```\n\n    ## (150, 5, 3)\n\nOur shap values are a numpy array of shape `(150, 5, 3)` for each of our\n150 rows, 4 columns (plus expected value), and our 3 output dimensions.\nWhen plotting multiclass outputs, the classes are essentially treated as\na categorical variable. However, it is possible to plot variable\ninteractions with *one* of the output classes, see below.  \nWe can plot a variables shap values for each of the output classes:\n\n``` python\nplot_variable_effect_on_output(\n    shap_values_multi,\n    data,\n    variable=2\n)\n```\n\n<img src=\"https://raw.githubusercontent.com/AnotherSamWilson/fastshap/master/graphics/multiclass_depgraph.png\" width=\"800px\" />\n\nWe can also look at interactions if we are interested in a specific\nclass. For instance, if we wanted to know the effect that `sepal width\n(cm)` had on our first class, we could do:\n\n``` python\nplot_variable_effect_on_output(sv, data, variable=\"sepal width (cm)\", output_index=0)\n```\n\n<img src=\"https://raw.githubusercontent.com/AnotherSamWilson/fastshap/master/graphics/mc_so_depgraph.png\" width=\"800px\" />\n\n\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "Fast SHAP kernel explainer",
    "version": "0.3.1",
    "split_keywords": [
        "shap",
        "model explainability"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "5e4f2a9f5d237156d62ffdd455cc1ee0a5738d194cdb787ccbc95e77ebe641fb",
                "md5": "25f5ec8733d9f80b49af10820e793a74",
                "sha256": "735e0534e976f74b5b016907691050f8a6fb82bde22b8e325f4e8ff4ad2334c4"
            },
            "downloads": -1,
            "filename": "fastshap-0.3.1-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "25f5ec8733d9f80b49af10820e793a74",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.7",
            "size": 20716,
            "upload_time": "2023-03-12T14:09:14",
            "upload_time_iso_8601": "2023-03-12T14:09:14.252613Z",
            "url": "https://files.pythonhosted.org/packages/5e/4f/2a9f5d237156d62ffdd455cc1ee0a5738d194cdb787ccbc95e77ebe641fb/fastshap-0.3.1-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "100ded76b59b95a9ea1a75e1b6b8860e99d9e8e22ea821347825df2f13b224eb",
                "md5": "a2becdd6268fb85ad303221eda422fd6",
                "sha256": "6ae4cebdd2aceb2d6c1d2857e70aa102a9b56cd34f2e4a5aae831286af528eca"
            },
            "downloads": -1,
            "filename": "fastshap-0.3.1.tar.gz",
            "has_sig": false,
            "md5_digest": "a2becdd6268fb85ad303221eda422fd6",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.7",
            "size": 26821,
            "upload_time": "2023-03-12T14:09:15",
            "upload_time_iso_8601": "2023-03-12T14:09:15.626363Z",
            "url": "https://files.pythonhosted.org/packages/10/0d/ed76b59b95a9ea1a75e1b6b8860e99d9e8e22ea821347825df2f13b224eb/fastshap-0.3.1.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-03-12 14:09:15",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "github_user": "AnotherSamWilson",
    "github_project": "fastshap",
    "travis_ci": true,
    "coveralls": false,
    "github_actions": false,
    "lcname": "fastshap"
}
        
Elapsed time: 0.10991s