# ReLax
<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->
![Python](https://img.shields.io/pypi/pyversions/jax-relax.svg) ![CI
status](https://github.com/BirkhoffG/jax-relax/actions/workflows/test.yaml/badge.svg)
![Docs](https://github.com/BirkhoffG/jax-relax/actions/workflows/deploy.yaml/badge.svg)
![pypi](https://img.shields.io/pypi/v/jax-relax.svg) ![GitHub
License](https://img.shields.io/github/license/BirkhoffG/jax-relax.svg)
[**Overview**](#overview) \| [**Installation**](#installation) \|
[**Tutorials**](https://birkhoffg.github.io/jax-relax/tutorials/getting_started.html)
\| [**Documentation**](https://birkhoffg.github.io/jax-relax/) \|
[**Citing ReLax**](#citing-relax)
## Overview
`ReLax` (**Re**course Explanation **L**ibrary in J**ax**) is an
efficient and scalable benchmarking library for recourse and
counterfactual explanations, built on top of
[jax](https://jax.readthedocs.io/en/latest/). By leveraging language
primitives such as *vectorization*, *parallelization*, and
*just-in-time* compilation in
[jax](https://jax.readthedocs.io/en/latest/), `ReLax` offers massive
speed improvements in generating individual (or local) explanations for
predictions made by Machine Learning algorithms.
Some of the key features are as follows:
- π **Fast and scalable** recourse generation.
- π **Accelerated** over `cpu`, `gpu`, `tpu`.
- πͺ **Comprehensive** set of recourse methods implemented for
benchmarking.
- π **Customizable** API to enable the building of entire modeling and
interpretation pipelines for new recourse algorithms.
## Installation
``` bash
pip install jax-relax
# Or install the latest version of `jax-relax`
pip install git+https://github.com/BirkhoffG/jax-relax.git
```
To futher unleash the power of accelerators (i.e., GPU/TPU), we suggest
to first install this library via `pip install jax-relax`. Then, follow
steps in the [official install
guidelines](https://github.com/google/jax#installation) to install the
right version for GPU or TPU.
## Dive into `ReLax`
`ReLax` is a recourse explanation library for explaining (any) JAX-based
ML models. We believe that it is important to give users flexibility to
choose how to use `ReLax`. You can
- only use methods implemeted in `ReLax` (as a recourse methods
library);
- build a pipeline using `ReLax` to define data module, training ML
models, and generating CF explanation (for constructing recourse
benchmarking pipeline).
### `ReLax` as a Recourse Explanation Library
We introduce basic use cases of using methods in `ReLax` to generate
recourse explanations. For more advanced usages of methods in `ReLax`,
See this [tutorials](tutorials/methods.ipynb).
``` python
from relax.methods import VanillaCF
from relax import DataModule, MLModule, generate_cf_explanations, benchmark_cfs
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import functools as ft
import jax
```
Letβs first generate synthetic data:
``` python
xs, ys = make_classification(n_samples=1000, n_features=10, random_state=42)
train_xs, test_xs, train_ys, test_ys = train_test_split(xs, ys, random_state=42)
```
Next, we fit an MLP model for this data. Note that this model can be any
model implmented in JAX. We will use the
[`MLModule`](https://birkhoffg.github.io/jax-relax/ml_model.html#mlmodule)
in `ReLax` as an example.
``` python
model = MLModule()
model.train((train_xs, train_ys), epochs=10, batch_size=64)
```
Generating recourse explanations are straightforward. We can simply call
`generate_cf` of an implemented recourse method to generate *one*
recourse explanation:
``` python
vcf = VanillaCF(config={'n_steps': 1000, 'lr': 0.05})
cf = vcf.generate_cf(test_xs[0], model.pred_fn)
assert cf.shape == test_xs[0].shape
```
Or generate a bunch of recourse explanations with `jax.vmap`:
``` python
generate_fn = ft.partial(vcf.generate_cf, pred_fn=model.pred_fn)
cfs = jax.vmap(generate_fn)(test_xs)
assert cfs.shape == test_xs.shape
```
### `ReLax` for Building Recourse Explanation Pipelines
The above example illustrates the usage of the decoupled `relax.methods`
to generate recourse explanations. However, users are required to write
boilerplate code for tasks such as data preprocessing, model training,
and generating recourse explanations with feature constraints.
`ReLax` additionally offers a one-liner framework, streamlining the
process and helping users in building a standardized pipeline for
generating recourse explanations. You can write three lines of code to
benchmark recourse explanations:
``` python
data_module = DataModule.from_numpy(xs, ys)
exps = generate_cf_explanations(vcf, data_module, model.pred_fn)
benchmark_cfs([exps])
```
See [Getting Started with
ReLax](https://birkhoffg.github.io/jax-relax/tutorials/getting_started.html)
for an end-to-end example of using `ReLax`.
## Supported Recourse Methods
`ReLax` currently provides implementations of 9 recourse explanation
methods.
| Method | Type | Paper Title | Ref |
|--------------------------------------------------------------------------------------------|-----------------|------------------------------------------------------------------------------------------------|-------------------------------------------|
| [`VanillaCF`](https://birkhoffg.github.io/jax-relax/methods/vanilla.html#vanillacf) | Non-Parametric | Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR. | [\[1\]](https://arxiv.org/abs/1711.00399) |
| [`DiverseCF`](https://birkhoffg.github.io/jax-relax/methods/dice.html#diversecf) | Non-Parametric | Explaining Machine Learning Classifiers through Diverse Counterfactual Explanations. | [\[2\]](https://arxiv.org/abs/1905.07697) |
| [`ProtoCF`](https://birkhoffg.github.io/jax-relax/methods/proto.html#protocf) | Semi-Parametric | Interpretable Counterfactual Explanations Guided by Prototypes. | [\[3\]](https://arxiv.org/abs/1907.02584) |
| [`CounterNet`](https://birkhoffg.github.io/jax-relax/methods/counternet.html#counternet) | Parametric | CounterNet: End-to-End Training of Prediction Aware Counterfactual Explanations. | [\[4\]](https://arxiv.org/abs/2109.07557) |
| [`GrowingSphere`](https://birkhoffg.github.io/jax-relax/methods/sphere.html#growingsphere) | Non-Parametric | Inverse Classification for Comparison-based Interpretability in Machine Learning. | [\[5\]](https://arxiv.org/abs/1712.08443) |
| [`CCHVAE`](https://birkhoffg.github.io/jax-relax/methods/cchvae.html#cchvae) | Semi-Parametric | Learning Model-Agnostic Counterfactual Explanations for Tabular Data. | [\[6\]](https://arxiv.org/abs/1910.09398) |
| [`VAECF`](https://birkhoffg.github.io/jax-relax/methods/vaecf.html#vaecf) | Parametric | Preserving Causal Constraints in Counterfactual Explanations for Machine Learning Classifiers. | [\[7\]](https://arxiv.org/abs/1912.03277) |
| [`CLUE`](https://birkhoffg.github.io/jax-relax/methods/clue.html#clue) | Semi-Parametric | Getting a CLUE: A Method for Explaining Uncertainty Estimates. | [\[8\]](https://arxiv.org/abs/2006.06848) |
| [`L2C`](https://birkhoffg.github.io/jax-relax/methods/l2c.html#l2c) | Parametric | Feature-based Learning for Diverse and Privacy-Preserving Counterfactual Explanations | [\[9\]](https://arxiv.org/abs/2209.13446) |
## Citing `ReLax`
To cite this repository:
``` latex
@software{relax2023github,
author = {Hangzhi Guo and Xinchang Xiong and Amulya Yadav},
title = {{R}e{L}ax: Recourse Explanation Library in Jax},
url = {http://github.com/birkhoffg/jax-relax},
version = {0.2.0},
year = {2023},
}
```
Raw data
{
"_id": null,
"home_page": "https://github.com/birkhoffg/jax-relax",
"name": "jax-relax",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": null,
"keywords": "JAX, Recourse, Explanation, Interpretability, Machine Learning",
"author": "BirkhoffG",
"author_email": "26811230+BirkhoffG@users.noreply.github.com",
"download_url": "https://files.pythonhosted.org/packages/fb/d5/330154f628bbc19defd338b6453e4dd242ab01b311eb7d6acc83134e734a/jax-relax-0.2.8.tar.gz",
"platform": null,
"description": "# ReLax\n\n<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->\n\n![Python](https://img.shields.io/pypi/pyversions/jax-relax.svg) ![CI\nstatus](https://github.com/BirkhoffG/jax-relax/actions/workflows/test.yaml/badge.svg)\n![Docs](https://github.com/BirkhoffG/jax-relax/actions/workflows/deploy.yaml/badge.svg)\n![pypi](https://img.shields.io/pypi/v/jax-relax.svg) ![GitHub\nLicense](https://img.shields.io/github/license/BirkhoffG/jax-relax.svg)\n\n[**Overview**](#overview) \\| [**Installation**](#installation) \\|\n[**Tutorials**](https://birkhoffg.github.io/jax-relax/tutorials/getting_started.html)\n\\| [**Documentation**](https://birkhoffg.github.io/jax-relax/) \\|\n[**Citing ReLax**](#citing-relax)\n\n## Overview\n\n`ReLax` (**Re**course Explanation **L**ibrary in J**ax**) is an\nefficient and scalable benchmarking library for recourse and\ncounterfactual explanations, built on top of\n[jax](https://jax.readthedocs.io/en/latest/). By leveraging language\nprimitives such as *vectorization*, *parallelization*, and\n*just-in-time* compilation in\n[jax](https://jax.readthedocs.io/en/latest/), `ReLax` offers massive\nspeed improvements in generating individual (or local) explanations for\npredictions made by Machine Learning algorithms.\n\nSome of the key features are as follows:\n\n- \ud83c\udfc3 **Fast and scalable** recourse generation.\n\n- \ud83d\ude80 **Accelerated** over `cpu`, `gpu`, `tpu`.\n\n- \ud83e\ude93 **Comprehensive** set of recourse methods implemented for\n benchmarking.\n\n- \ud83d\udc50 **Customizable** API to enable the building of entire modeling and\n interpretation pipelines for new recourse algorithms.\n\n## Installation\n\n``` bash\npip install jax-relax\n# Or install the latest version of `jax-relax`\npip install git+https://github.com/BirkhoffG/jax-relax.git \n```\n\nTo futher unleash the power of accelerators (i.e., GPU/TPU), we suggest\nto first install this library via `pip install jax-relax`. Then, follow\nsteps in the [official install\nguidelines](https://github.com/google/jax#installation) to install the\nright version for GPU or TPU.\n\n## Dive into `ReLax`\n\n`ReLax` is a recourse explanation library for explaining (any) JAX-based\nML models. We believe that it is important to give users flexibility to\nchoose how to use `ReLax`. You can\n\n- only use methods implemeted in `ReLax` (as a recourse methods\n library);\n- build a pipeline using `ReLax` to define data module, training ML\n models, and generating CF explanation (for constructing recourse\n benchmarking pipeline).\n\n### `ReLax` as a Recourse Explanation Library\n\nWe introduce basic use cases of using methods in `ReLax` to generate\nrecourse explanations. For more advanced usages of methods in `ReLax`,\nSee this [tutorials](tutorials/methods.ipynb).\n\n``` python\nfrom relax.methods import VanillaCF\nfrom relax import DataModule, MLModule, generate_cf_explanations, benchmark_cfs\nfrom sklearn.datasets import make_classification\nfrom sklearn.model_selection import train_test_split\nimport functools as ft\nimport jax\n```\n\nLet\u2019s first generate synthetic data:\n\n``` python\nxs, ys = make_classification(n_samples=1000, n_features=10, random_state=42)\ntrain_xs, test_xs, train_ys, test_ys = train_test_split(xs, ys, random_state=42)\n```\n\nNext, we fit an MLP model for this data. Note that this model can be any\nmodel implmented in JAX. We will use the\n[`MLModule`](https://birkhoffg.github.io/jax-relax/ml_model.html#mlmodule)\nin `ReLax` as an example.\n\n``` python\nmodel = MLModule()\nmodel.train((train_xs, train_ys), epochs=10, batch_size=64)\n```\n\nGenerating recourse explanations are straightforward. We can simply call\n`generate_cf` of an implemented recourse method to generate *one*\nrecourse explanation:\n\n``` python\nvcf = VanillaCF(config={'n_steps': 1000, 'lr': 0.05})\ncf = vcf.generate_cf(test_xs[0], model.pred_fn)\nassert cf.shape == test_xs[0].shape\n```\n\nOr generate a bunch of recourse explanations with `jax.vmap`:\n\n``` python\ngenerate_fn = ft.partial(vcf.generate_cf, pred_fn=model.pred_fn)\ncfs = jax.vmap(generate_fn)(test_xs)\nassert cfs.shape == test_xs.shape\n```\n\n### `ReLax` for Building Recourse Explanation Pipelines\n\nThe above example illustrates the usage of the decoupled `relax.methods`\nto generate recourse explanations. However, users are required to write\nboilerplate code for tasks such as data preprocessing, model training,\nand generating recourse explanations with feature constraints.\n\n`ReLax` additionally offers a one-liner framework, streamlining the\nprocess and helping users in building a standardized pipeline for\ngenerating recourse explanations. You can write three lines of code to\nbenchmark recourse explanations:\n\n``` python\ndata_module = DataModule.from_numpy(xs, ys)\nexps = generate_cf_explanations(vcf, data_module, model.pred_fn)\nbenchmark_cfs([exps])\n```\n\nSee [Getting Started with\nReLax](https://birkhoffg.github.io/jax-relax/tutorials/getting_started.html)\nfor an end-to-end example of using `ReLax`.\n\n## Supported Recourse Methods\n\n`ReLax` currently provides implementations of 9 recourse explanation\nmethods.\n\n| Method | Type | Paper Title | Ref |\n|--------------------------------------------------------------------------------------------|-----------------|------------------------------------------------------------------------------------------------|-------------------------------------------|\n| [`VanillaCF`](https://birkhoffg.github.io/jax-relax/methods/vanilla.html#vanillacf) | Non-Parametric | Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR. | [\\[1\\]](https://arxiv.org/abs/1711.00399) |\n| [`DiverseCF`](https://birkhoffg.github.io/jax-relax/methods/dice.html#diversecf) | Non-Parametric | Explaining Machine Learning Classifiers through Diverse Counterfactual Explanations. | [\\[2\\]](https://arxiv.org/abs/1905.07697) |\n| [`ProtoCF`](https://birkhoffg.github.io/jax-relax/methods/proto.html#protocf) | Semi-Parametric | Interpretable Counterfactual Explanations Guided by Prototypes. | [\\[3\\]](https://arxiv.org/abs/1907.02584) |\n| [`CounterNet`](https://birkhoffg.github.io/jax-relax/methods/counternet.html#counternet) | Parametric | CounterNet: End-to-End Training of Prediction Aware Counterfactual Explanations. | [\\[4\\]](https://arxiv.org/abs/2109.07557) |\n| [`GrowingSphere`](https://birkhoffg.github.io/jax-relax/methods/sphere.html#growingsphere) | Non-Parametric | Inverse Classification for Comparison-based Interpretability in Machine Learning. | [\\[5\\]](https://arxiv.org/abs/1712.08443) |\n| [`CCHVAE`](https://birkhoffg.github.io/jax-relax/methods/cchvae.html#cchvae) | Semi-Parametric | Learning Model-Agnostic Counterfactual Explanations for Tabular Data. | [\\[6\\]](https://arxiv.org/abs/1910.09398) |\n| [`VAECF`](https://birkhoffg.github.io/jax-relax/methods/vaecf.html#vaecf) | Parametric | Preserving Causal Constraints in Counterfactual Explanations for Machine Learning Classifiers. | [\\[7\\]](https://arxiv.org/abs/1912.03277) |\n| [`CLUE`](https://birkhoffg.github.io/jax-relax/methods/clue.html#clue) | Semi-Parametric | Getting a CLUE: A Method for Explaining Uncertainty Estimates. | [\\[8\\]](https://arxiv.org/abs/2006.06848) |\n| [`L2C`](https://birkhoffg.github.io/jax-relax/methods/l2c.html#l2c) | Parametric | Feature-based Learning for Diverse and Privacy-Preserving Counterfactual Explanations | [\\[9\\]](https://arxiv.org/abs/2209.13446) |\n\n## Citing `ReLax`\n\nTo cite this repository:\n\n``` latex\n@software{relax2023github,\n author = {Hangzhi Guo and Xinchang Xiong and Amulya Yadav},\n title = {{R}e{L}ax: Recourse Explanation Library in Jax},\n url = {http://github.com/birkhoffg/jax-relax},\n version = {0.2.0},\n year = {2023},\n}\n```\n",
"bugtrack_url": null,
"license": "Apache Software License 2.0",
"summary": "JAX-based Recourse Explanation Library",
"version": "0.2.8",
"project_urls": {
"Homepage": "https://github.com/birkhoffg/jax-relax"
},
"split_keywords": [
"jax",
" recourse",
" explanation",
" interpretability",
" machine learning"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "884b24e8cda441e985d3f48b5325c63b0cf82518f1797a2f3bc555fe7c36680a",
"md5": "d0bdf4a46a7b2896ee413f595ed73746",
"sha256": "c91f4ce129eaa6db6951f69dd0e9dd035fc0a1473321d7d4ed028a3678b98425"
},
"downloads": -1,
"filename": "jax_relax-0.2.8-py3-none-any.whl",
"has_sig": false,
"md5_digest": "d0bdf4a46a7b2896ee413f595ed73746",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 82843,
"upload_time": "2024-08-29T02:26:41",
"upload_time_iso_8601": "2024-08-29T02:26:41.922771Z",
"url": "https://files.pythonhosted.org/packages/88/4b/24e8cda441e985d3f48b5325c63b0cf82518f1797a2f3bc555fe7c36680a/jax_relax-0.2.8-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "fbd5330154f628bbc19defd338b6453e4dd242ab01b311eb7d6acc83134e734a",
"md5": "31cd2aba9c6c162b49ac814fdea319a4",
"sha256": "4bc2cbfb196ce678f83aff50ef9c2b968c9ddab8a46c7160accfd79498fe92ab"
},
"downloads": -1,
"filename": "jax-relax-0.2.8.tar.gz",
"has_sig": false,
"md5_digest": "31cd2aba9c6c162b49ac814fdea319a4",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 67236,
"upload_time": "2024-08-29T02:26:43",
"upload_time_iso_8601": "2024-08-29T02:26:43.800511Z",
"url": "https://files.pythonhosted.org/packages/fb/d5/330154f628bbc19defd338b6453e4dd242ab01b311eb7d6acc83134e734a/jax-relax-0.2.8.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-08-29 02:26:43",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "birkhoffg",
"github_project": "jax-relax",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "jax-relax"
}