jax-relax


Namejax-relax JSON
Version 0.2.8 PyPI version JSON
download
home_pagehttps://github.com/birkhoffg/jax-relax
SummaryJAX-based Recourse Explanation Library
upload_time2024-08-29 02:26:43
maintainerNone
docs_urlNone
authorBirkhoffG
requires_python>=3.9
licenseApache Software License 2.0
keywords jax recourse explanation interpretability machine learning
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # ReLax

<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

![Python](https://img.shields.io/pypi/pyversions/jax-relax.svg) ![CI
status](https://github.com/BirkhoffG/jax-relax/actions/workflows/test.yaml/badge.svg)
![Docs](https://github.com/BirkhoffG/jax-relax/actions/workflows/deploy.yaml/badge.svg)
![pypi](https://img.shields.io/pypi/v/jax-relax.svg) ![GitHub
License](https://img.shields.io/github/license/BirkhoffG/jax-relax.svg)

[**Overview**](#overview) \| [**Installation**](#installation) \|
[**Tutorials**](https://birkhoffg.github.io/jax-relax/tutorials/getting_started.html)
\| [**Documentation**](https://birkhoffg.github.io/jax-relax/) \|
[**Citing ReLax**](#citing-relax)

## Overview

`ReLax` (**Re**course Explanation **L**ibrary in J**ax**) is an
efficient and scalable benchmarking library for recourse and
counterfactual explanations, built on top of
[jax](https://jax.readthedocs.io/en/latest/). By leveraging language
primitives such as *vectorization*, *parallelization*, and
*just-in-time* compilation in
[jax](https://jax.readthedocs.io/en/latest/), `ReLax` offers massive
speed improvements in generating individual (or local) explanations for
predictions made by Machine Learning algorithms.

Some of the key features are as follows:

- πŸƒ **Fast and scalable** recourse generation.

- πŸš€ **Accelerated** over `cpu`, `gpu`, `tpu`.

- πŸͺ“ **Comprehensive** set of recourse methods implemented for
  benchmarking.

- πŸ‘ **Customizable** API to enable the building of entire modeling and
  interpretation pipelines for new recourse algorithms.

## Installation

``` bash
pip install jax-relax
# Or install the latest version of `jax-relax`
pip install git+https://github.com/BirkhoffG/jax-relax.git 
```

To futher unleash the power of accelerators (i.e., GPU/TPU), we suggest
to first install this library via `pip install jax-relax`. Then, follow
steps in the [official install
guidelines](https://github.com/google/jax#installation) to install the
right version for GPU or TPU.

## Dive into `ReLax`

`ReLax` is a recourse explanation library for explaining (any) JAX-based
ML models. We believe that it is important to give users flexibility to
choose how to use `ReLax`. You can

- only use methods implemeted in `ReLax` (as a recourse methods
  library);
- build a pipeline using `ReLax` to define data module, training ML
  models, and generating CF explanation (for constructing recourse
  benchmarking pipeline).

### `ReLax` as a Recourse Explanation Library

We introduce basic use cases of using methods in `ReLax` to generate
recourse explanations. For more advanced usages of methods in `ReLax`,
See this [tutorials](tutorials/methods.ipynb).

``` python
from relax.methods import VanillaCF
from relax import DataModule, MLModule, generate_cf_explanations, benchmark_cfs
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import functools as ft
import jax
```

Let’s first generate synthetic data:

``` python
xs, ys = make_classification(n_samples=1000, n_features=10, random_state=42)
train_xs, test_xs, train_ys, test_ys = train_test_split(xs, ys, random_state=42)
```

Next, we fit an MLP model for this data. Note that this model can be any
model implmented in JAX. We will use the
[`MLModule`](https://birkhoffg.github.io/jax-relax/ml_model.html#mlmodule)
in `ReLax` as an example.

``` python
model = MLModule()
model.train((train_xs, train_ys), epochs=10, batch_size=64)
```

Generating recourse explanations are straightforward. We can simply call
`generate_cf` of an implemented recourse method to generate *one*
recourse explanation:

``` python
vcf = VanillaCF(config={'n_steps': 1000, 'lr': 0.05})
cf = vcf.generate_cf(test_xs[0], model.pred_fn)
assert cf.shape == test_xs[0].shape
```

Or generate a bunch of recourse explanations with `jax.vmap`:

``` python
generate_fn = ft.partial(vcf.generate_cf, pred_fn=model.pred_fn)
cfs = jax.vmap(generate_fn)(test_xs)
assert cfs.shape == test_xs.shape
```

### `ReLax` for Building Recourse Explanation Pipelines

The above example illustrates the usage of the decoupled `relax.methods`
to generate recourse explanations. However, users are required to write
boilerplate code for tasks such as data preprocessing, model training,
and generating recourse explanations with feature constraints.

`ReLax` additionally offers a one-liner framework, streamlining the
process and helping users in building a standardized pipeline for
generating recourse explanations. You can write three lines of code to
benchmark recourse explanations:

``` python
data_module = DataModule.from_numpy(xs, ys)
exps = generate_cf_explanations(vcf, data_module, model.pred_fn)
benchmark_cfs([exps])
```

See [Getting Started with
ReLax](https://birkhoffg.github.io/jax-relax/tutorials/getting_started.html)
for an end-to-end example of using `ReLax`.

## Supported Recourse Methods

`ReLax` currently provides implementations of 9 recourse explanation
methods.

| Method                                                                                     | Type            | Paper Title                                                                                    | Ref                                       |
|--------------------------------------------------------------------------------------------|-----------------|------------------------------------------------------------------------------------------------|-------------------------------------------|
| [`VanillaCF`](https://birkhoffg.github.io/jax-relax/methods/vanilla.html#vanillacf)        | Non-Parametric  | Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR.   | [\[1\]](https://arxiv.org/abs/1711.00399) |
| [`DiverseCF`](https://birkhoffg.github.io/jax-relax/methods/dice.html#diversecf)           | Non-Parametric  | Explaining Machine Learning Classifiers through Diverse Counterfactual Explanations.           | [\[2\]](https://arxiv.org/abs/1905.07697) |
| [`ProtoCF`](https://birkhoffg.github.io/jax-relax/methods/proto.html#protocf)              | Semi-Parametric | Interpretable Counterfactual Explanations Guided by Prototypes.                                | [\[3\]](https://arxiv.org/abs/1907.02584) |
| [`CounterNet`](https://birkhoffg.github.io/jax-relax/methods/counternet.html#counternet)   | Parametric      | CounterNet: End-to-End Training of Prediction Aware Counterfactual Explanations.               | [\[4\]](https://arxiv.org/abs/2109.07557) |
| [`GrowingSphere`](https://birkhoffg.github.io/jax-relax/methods/sphere.html#growingsphere) | Non-Parametric  | Inverse Classification for Comparison-based Interpretability in Machine Learning.              | [\[5\]](https://arxiv.org/abs/1712.08443) |
| [`CCHVAE`](https://birkhoffg.github.io/jax-relax/methods/cchvae.html#cchvae)               | Semi-Parametric | Learning Model-Agnostic Counterfactual Explanations for Tabular Data.                          | [\[6\]](https://arxiv.org/abs/1910.09398) |
| [`VAECF`](https://birkhoffg.github.io/jax-relax/methods/vaecf.html#vaecf)                  | Parametric      | Preserving Causal Constraints in Counterfactual Explanations for Machine Learning Classifiers. | [\[7\]](https://arxiv.org/abs/1912.03277) |
| [`CLUE`](https://birkhoffg.github.io/jax-relax/methods/clue.html#clue)                     | Semi-Parametric | Getting a CLUE: A Method for Explaining Uncertainty Estimates.                                 | [\[8\]](https://arxiv.org/abs/2006.06848) |
| [`L2C`](https://birkhoffg.github.io/jax-relax/methods/l2c.html#l2c)                        | Parametric      | Feature-based Learning for Diverse and Privacy-Preserving Counterfactual Explanations          | [\[9\]](https://arxiv.org/abs/2209.13446) |

## Citing `ReLax`

To cite this repository:

``` latex
@software{relax2023github,
  author = {Hangzhi Guo and Xinchang Xiong and Amulya Yadav},
  title = {{R}e{L}ax: Recourse Explanation Library in Jax},
  url = {http://github.com/birkhoffg/jax-relax},
  version = {0.2.0},
  year = {2023},
}
```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/birkhoffg/jax-relax",
    "name": "jax-relax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": "JAX, Recourse, Explanation, Interpretability, Machine Learning",
    "author": "BirkhoffG",
    "author_email": "26811230+BirkhoffG@users.noreply.github.com",
    "download_url": "https://files.pythonhosted.org/packages/fb/d5/330154f628bbc19defd338b6453e4dd242ab01b311eb7d6acc83134e734a/jax-relax-0.2.8.tar.gz",
    "platform": null,
    "description": "# ReLax\n\n<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->\n\n![Python](https://img.shields.io/pypi/pyversions/jax-relax.svg) ![CI\nstatus](https://github.com/BirkhoffG/jax-relax/actions/workflows/test.yaml/badge.svg)\n![Docs](https://github.com/BirkhoffG/jax-relax/actions/workflows/deploy.yaml/badge.svg)\n![pypi](https://img.shields.io/pypi/v/jax-relax.svg) ![GitHub\nLicense](https://img.shields.io/github/license/BirkhoffG/jax-relax.svg)\n\n[**Overview**](#overview) \\| [**Installation**](#installation) \\|\n[**Tutorials**](https://birkhoffg.github.io/jax-relax/tutorials/getting_started.html)\n\\| [**Documentation**](https://birkhoffg.github.io/jax-relax/) \\|\n[**Citing ReLax**](#citing-relax)\n\n## Overview\n\n`ReLax` (**Re**course Explanation **L**ibrary in J**ax**) is an\nefficient and scalable benchmarking library for recourse and\ncounterfactual explanations, built on top of\n[jax](https://jax.readthedocs.io/en/latest/). By leveraging language\nprimitives such as *vectorization*, *parallelization*, and\n*just-in-time* compilation in\n[jax](https://jax.readthedocs.io/en/latest/), `ReLax` offers massive\nspeed improvements in generating individual (or local) explanations for\npredictions made by Machine Learning algorithms.\n\nSome of the key features are as follows:\n\n- \ud83c\udfc3 **Fast and scalable** recourse generation.\n\n- \ud83d\ude80 **Accelerated** over `cpu`, `gpu`, `tpu`.\n\n- \ud83e\ude93 **Comprehensive** set of recourse methods implemented for\n  benchmarking.\n\n- \ud83d\udc50 **Customizable** API to enable the building of entire modeling and\n  interpretation pipelines for new recourse algorithms.\n\n## Installation\n\n``` bash\npip install jax-relax\n# Or install the latest version of `jax-relax`\npip install git+https://github.com/BirkhoffG/jax-relax.git \n```\n\nTo futher unleash the power of accelerators (i.e., GPU/TPU), we suggest\nto first install this library via `pip install jax-relax`. Then, follow\nsteps in the [official install\nguidelines](https://github.com/google/jax#installation) to install the\nright version for GPU or TPU.\n\n## Dive into `ReLax`\n\n`ReLax` is a recourse explanation library for explaining (any) JAX-based\nML models. We believe that it is important to give users flexibility to\nchoose how to use `ReLax`. You can\n\n- only use methods implemeted in `ReLax` (as a recourse methods\n  library);\n- build a pipeline using `ReLax` to define data module, training ML\n  models, and generating CF explanation (for constructing recourse\n  benchmarking pipeline).\n\n### `ReLax` as a Recourse Explanation Library\n\nWe introduce basic use cases of using methods in `ReLax` to generate\nrecourse explanations. For more advanced usages of methods in `ReLax`,\nSee this [tutorials](tutorials/methods.ipynb).\n\n``` python\nfrom relax.methods import VanillaCF\nfrom relax import DataModule, MLModule, generate_cf_explanations, benchmark_cfs\nfrom sklearn.datasets import make_classification\nfrom sklearn.model_selection import train_test_split\nimport functools as ft\nimport jax\n```\n\nLet\u2019s first generate synthetic data:\n\n``` python\nxs, ys = make_classification(n_samples=1000, n_features=10, random_state=42)\ntrain_xs, test_xs, train_ys, test_ys = train_test_split(xs, ys, random_state=42)\n```\n\nNext, we fit an MLP model for this data. Note that this model can be any\nmodel implmented in JAX. We will use the\n[`MLModule`](https://birkhoffg.github.io/jax-relax/ml_model.html#mlmodule)\nin `ReLax` as an example.\n\n``` python\nmodel = MLModule()\nmodel.train((train_xs, train_ys), epochs=10, batch_size=64)\n```\n\nGenerating recourse explanations are straightforward. We can simply call\n`generate_cf` of an implemented recourse method to generate *one*\nrecourse explanation:\n\n``` python\nvcf = VanillaCF(config={'n_steps': 1000, 'lr': 0.05})\ncf = vcf.generate_cf(test_xs[0], model.pred_fn)\nassert cf.shape == test_xs[0].shape\n```\n\nOr generate a bunch of recourse explanations with `jax.vmap`:\n\n``` python\ngenerate_fn = ft.partial(vcf.generate_cf, pred_fn=model.pred_fn)\ncfs = jax.vmap(generate_fn)(test_xs)\nassert cfs.shape == test_xs.shape\n```\n\n### `ReLax` for Building Recourse Explanation Pipelines\n\nThe above example illustrates the usage of the decoupled `relax.methods`\nto generate recourse explanations. However, users are required to write\nboilerplate code for tasks such as data preprocessing, model training,\nand generating recourse explanations with feature constraints.\n\n`ReLax` additionally offers a one-liner framework, streamlining the\nprocess and helping users in building a standardized pipeline for\ngenerating recourse explanations. You can write three lines of code to\nbenchmark recourse explanations:\n\n``` python\ndata_module = DataModule.from_numpy(xs, ys)\nexps = generate_cf_explanations(vcf, data_module, model.pred_fn)\nbenchmark_cfs([exps])\n```\n\nSee [Getting Started with\nReLax](https://birkhoffg.github.io/jax-relax/tutorials/getting_started.html)\nfor an end-to-end example of using `ReLax`.\n\n## Supported Recourse Methods\n\n`ReLax` currently provides implementations of 9 recourse explanation\nmethods.\n\n| Method                                                                                     | Type            | Paper Title                                                                                    | Ref                                       |\n|--------------------------------------------------------------------------------------------|-----------------|------------------------------------------------------------------------------------------------|-------------------------------------------|\n| [`VanillaCF`](https://birkhoffg.github.io/jax-relax/methods/vanilla.html#vanillacf)        | Non-Parametric  | Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR.   | [\\[1\\]](https://arxiv.org/abs/1711.00399) |\n| [`DiverseCF`](https://birkhoffg.github.io/jax-relax/methods/dice.html#diversecf)           | Non-Parametric  | Explaining Machine Learning Classifiers through Diverse Counterfactual Explanations.           | [\\[2\\]](https://arxiv.org/abs/1905.07697) |\n| [`ProtoCF`](https://birkhoffg.github.io/jax-relax/methods/proto.html#protocf)              | Semi-Parametric | Interpretable Counterfactual Explanations Guided by Prototypes.                                | [\\[3\\]](https://arxiv.org/abs/1907.02584) |\n| [`CounterNet`](https://birkhoffg.github.io/jax-relax/methods/counternet.html#counternet)   | Parametric      | CounterNet: End-to-End Training of Prediction Aware Counterfactual Explanations.               | [\\[4\\]](https://arxiv.org/abs/2109.07557) |\n| [`GrowingSphere`](https://birkhoffg.github.io/jax-relax/methods/sphere.html#growingsphere) | Non-Parametric  | Inverse Classification for Comparison-based Interpretability in Machine Learning.              | [\\[5\\]](https://arxiv.org/abs/1712.08443) |\n| [`CCHVAE`](https://birkhoffg.github.io/jax-relax/methods/cchvae.html#cchvae)               | Semi-Parametric | Learning Model-Agnostic Counterfactual Explanations for Tabular Data.                          | [\\[6\\]](https://arxiv.org/abs/1910.09398) |\n| [`VAECF`](https://birkhoffg.github.io/jax-relax/methods/vaecf.html#vaecf)                  | Parametric      | Preserving Causal Constraints in Counterfactual Explanations for Machine Learning Classifiers. | [\\[7\\]](https://arxiv.org/abs/1912.03277) |\n| [`CLUE`](https://birkhoffg.github.io/jax-relax/methods/clue.html#clue)                     | Semi-Parametric | Getting a CLUE: A Method for Explaining Uncertainty Estimates.                                 | [\\[8\\]](https://arxiv.org/abs/2006.06848) |\n| [`L2C`](https://birkhoffg.github.io/jax-relax/methods/l2c.html#l2c)                        | Parametric      | Feature-based Learning for Diverse and Privacy-Preserving Counterfactual Explanations          | [\\[9\\]](https://arxiv.org/abs/2209.13446) |\n\n## Citing `ReLax`\n\nTo cite this repository:\n\n``` latex\n@software{relax2023github,\n  author = {Hangzhi Guo and Xinchang Xiong and Amulya Yadav},\n  title = {{R}e{L}ax: Recourse Explanation Library in Jax},\n  url = {http://github.com/birkhoffg/jax-relax},\n  version = {0.2.0},\n  year = {2023},\n}\n```\n",
    "bugtrack_url": null,
    "license": "Apache Software License 2.0",
    "summary": "JAX-based Recourse Explanation Library",
    "version": "0.2.8",
    "project_urls": {
        "Homepage": "https://github.com/birkhoffg/jax-relax"
    },
    "split_keywords": [
        "jax",
        " recourse",
        " explanation",
        " interpretability",
        " machine learning"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "884b24e8cda441e985d3f48b5325c63b0cf82518f1797a2f3bc555fe7c36680a",
                "md5": "d0bdf4a46a7b2896ee413f595ed73746",
                "sha256": "c91f4ce129eaa6db6951f69dd0e9dd035fc0a1473321d7d4ed028a3678b98425"
            },
            "downloads": -1,
            "filename": "jax_relax-0.2.8-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "d0bdf4a46a7b2896ee413f595ed73746",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 82843,
            "upload_time": "2024-08-29T02:26:41",
            "upload_time_iso_8601": "2024-08-29T02:26:41.922771Z",
            "url": "https://files.pythonhosted.org/packages/88/4b/24e8cda441e985d3f48b5325c63b0cf82518f1797a2f3bc555fe7c36680a/jax_relax-0.2.8-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "fbd5330154f628bbc19defd338b6453e4dd242ab01b311eb7d6acc83134e734a",
                "md5": "31cd2aba9c6c162b49ac814fdea319a4",
                "sha256": "4bc2cbfb196ce678f83aff50ef9c2b968c9ddab8a46c7160accfd79498fe92ab"
            },
            "downloads": -1,
            "filename": "jax-relax-0.2.8.tar.gz",
            "has_sig": false,
            "md5_digest": "31cd2aba9c6c162b49ac814fdea319a4",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 67236,
            "upload_time": "2024-08-29T02:26:43",
            "upload_time_iso_8601": "2024-08-29T02:26:43.800511Z",
            "url": "https://files.pythonhosted.org/packages/fb/d5/330154f628bbc19defd338b6453e4dd242ab01b311eb7d6acc83134e734a/jax-relax-0.2.8.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-08-29 02:26:43",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "birkhoffg",
    "github_project": "jax-relax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "jax-relax"
}
        
Elapsed time: 2.36312s