rax


Namerax JSON
Version 0.3.0 PyPI version JSON
download
home_pagehttps://github.com/google/rax
SummaryLearning-to-Rank using JAX.
upload_time2023-05-17 19:24:49
maintainer
docs_urlNone
authorGoogle
requires_python>=3.8
licenseApache 2.0
keywords learning-to-rank jax ranking
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # 🦖 **Rax**: Learning-to-Rank using JAX

[![Docs](https://readthedocs.org/projects/rax/badge/?version=latest)](https://rax.readthedocs.io/en/latest/?badge=latest)
[![PyPI](https://img.shields.io/pypi/v/rax?color=brightgreen)](https://pypi.org/project/rax/)
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://github.com/google/rax/blob/main/LICENSE)

**Rax** is a Learning-to-Rank library written in JAX. Rax provides off-the-shelf
implementations of ranking losses and metrics to be used with JAX. It provides
the following functionality:

- Ranking losses (`rax.*_loss`): `rax.softmax_loss`,
  `rax.pairwise_logistic_loss`, ...
- Ranking metrics (`rax.*_metric`): `rax.mrr_metric`, `rax.ndcg_metric`, ...
- Transformations (`rax.*_t12n`): `rax.approx_t12n`, `rax.gumbel_t12n`, ...

## Ranking

A ranking problem is different from traditional classification/regression
problems in that its objective is to optimize for the correctness of the
**relative order** of a **list of examples** (e.g., documents) for a given
context (e.g., a query). **Rax** provides support for ranking problems within
the JAX ecosystem. It can be used in, but is not limited to, the following
applications:

- **Search**: ranking a list of documents with respect to a query.
- **Recommendation**: ranking a list of items given a user as context.
- **Question Answering**: finding the best answer from a list of candidates.
- **Dialogue System**: finding the best response from a list of responses.

## Synopsis

In a nutshell, given the scores and labels for a list of items, Rax can compute
various ranking losses and metrics:

```python
import jax.numpy as jnp
import rax

scores = jnp.array([2.2, -1.3, 5.4])  # output of a model.
labels = jnp.array([1.0,  0.0, 0.0])  # indicates doc 1 is relevant.

rax.ndcg_metric(scores, labels)  # computes a ranking metric.
# 0.63092977

rax.pairwise_hinge_loss(scores, labels)  # computes a ranking loss.
# 2.1
```

All of the Rax losses and metrics are purely functional and compose well with
standard JAX transformations. Additionally, Rax provides ranking-specific
transformations so you can build new ranking losses. An example is
`rax.approx_t12n`, which can be used to transform any (non-differentiable)
ranking metric into a differentiable loss. For example:

```python
loss_fn = rax.approx_t12n(rax.ndcg_metric)
loss_fn(scores, labels)  # differentiable approx ndcg loss.
# -0.63282484

jax.grad(loss_fn)(scores, labels)  # computes gradients w.r.t. scores.
# [-0.01276882  0.00549765  0.00727116]
```

## Installation

See https://github.com/google/jax#installation for instructions on installing JAX.

We suggest installing the latest stable version of Rax by running:

```
$ pip install rax
```

## Examples

See the `examples/` directory for complete examples on how to use Rax.

## Citing Rax

If you use Rax, please consider citing our
[paper](https://research.google/pubs/pub51453/):

```
@inproceedings{jagerman2022rax,
  title = {Rax: Composable Learning-to-Rank using JAX},
  author  = {Rolf Jagerman and Xuanhui Wang and Honglei Zhuang and Zhen Qin and
  Michael Bendersky and Marc Najork},
  year  = {2022},
  booktitle = {Proceedings of the 28th ACM SIGKDD International Conference on Knowledge Discovery \& Data Mining}
}
```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/google/rax",
    "name": "rax",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.8",
    "maintainer_email": "",
    "keywords": "learning-to-rank jax ranking",
    "author": "Google",
    "author_email": "rax-dev@google.com",
    "download_url": "https://files.pythonhosted.org/packages/52/c6/e949ca17e4dfb551f0734c537e4898645ee2d796c1df4cc1a38bbb80cd93/rax-0.3.0.tar.gz",
    "platform": null,
    "description": "# \ud83e\udd96 **Rax**: Learning-to-Rank using JAX\n\n[![Docs](https://readthedocs.org/projects/rax/badge/?version=latest)](https://rax.readthedocs.io/en/latest/?badge=latest)\n[![PyPI](https://img.shields.io/pypi/v/rax?color=brightgreen)](https://pypi.org/project/rax/)\n[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://github.com/google/rax/blob/main/LICENSE)\n\n**Rax** is a Learning-to-Rank library written in JAX. Rax provides off-the-shelf\nimplementations of ranking losses and metrics to be used with JAX. It provides\nthe following functionality:\n\n- Ranking losses (`rax.*_loss`): `rax.softmax_loss`,\n  `rax.pairwise_logistic_loss`, ...\n- Ranking metrics (`rax.*_metric`): `rax.mrr_metric`, `rax.ndcg_metric`, ...\n- Transformations (`rax.*_t12n`): `rax.approx_t12n`, `rax.gumbel_t12n`, ...\n\n## Ranking\n\nA ranking problem is different from traditional classification/regression\nproblems in that its objective is to optimize for the correctness of the\n**relative order** of a **list of examples** (e.g., documents) for a given\ncontext (e.g., a query). **Rax** provides support for ranking problems within\nthe JAX ecosystem. It can be used in, but is not limited to, the following\napplications:\n\n- **Search**: ranking a list of documents with respect to a query.\n- **Recommendation**: ranking a list of items given a user as context.\n- **Question Answering**: finding the best answer from a list of candidates.\n- **Dialogue System**: finding the best response from a list of responses.\n\n## Synopsis\n\nIn a nutshell, given the scores and labels for a list of items, Rax can compute\nvarious ranking losses and metrics:\n\n```python\nimport jax.numpy as jnp\nimport rax\n\nscores = jnp.array([2.2, -1.3, 5.4])  # output of a model.\nlabels = jnp.array([1.0,  0.0, 0.0])  # indicates doc 1 is relevant.\n\nrax.ndcg_metric(scores, labels)  # computes a ranking metric.\n# 0.63092977\n\nrax.pairwise_hinge_loss(scores, labels)  # computes a ranking loss.\n# 2.1\n```\n\nAll of the Rax losses and metrics are purely functional and compose well with\nstandard JAX transformations. Additionally, Rax provides ranking-specific\ntransformations so you can build new ranking losses. An example is\n`rax.approx_t12n`, which can be used to transform any (non-differentiable)\nranking metric into a differentiable loss. For example:\n\n```python\nloss_fn = rax.approx_t12n(rax.ndcg_metric)\nloss_fn(scores, labels)  # differentiable approx ndcg loss.\n# -0.63282484\n\njax.grad(loss_fn)(scores, labels)  # computes gradients w.r.t. scores.\n# [-0.01276882  0.00549765  0.00727116]\n```\n\n## Installation\n\nSee https://github.com/google/jax#installation for instructions on installing JAX.\n\nWe suggest installing the latest stable version of Rax by running:\n\n```\n$ pip install rax\n```\n\n## Examples\n\nSee the `examples/` directory for complete examples on how to use Rax.\n\n## Citing Rax\n\nIf you use Rax, please consider citing our\n[paper](https://research.google/pubs/pub51453/):\n\n```\n@inproceedings{jagerman2022rax,\n  title = {Rax: Composable Learning-to-Rank using JAX},\n  author  = {Rolf Jagerman and Xuanhui Wang and Honglei Zhuang and Zhen Qin and\n  Michael Bendersky and Marc Najork},\n  year  = {2022},\n  booktitle = {Proceedings of the 28th ACM SIGKDD International Conference on Knowledge Discovery \\& Data Mining}\n}\n```\n",
    "bugtrack_url": null,
    "license": "Apache 2.0",
    "summary": "Learning-to-Rank using JAX.",
    "version": "0.3.0",
    "project_urls": {
        "Homepage": "https://github.com/google/rax"
    },
    "split_keywords": [
        "learning-to-rank",
        "jax",
        "ranking"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "9ccfaa093e63ded3d3ae419707e168cbce68d74c7b3815df58e17a6e84e7e879",
                "md5": "666a2239fd44410e59d2dbde69d158c8",
                "sha256": "527654d733530b66595e386d979a347e9affb1dda68cf837455145f66950e240"
            },
            "downloads": -1,
            "filename": "rax-0.3.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "666a2239fd44410e59d2dbde69d158c8",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.8",
            "size": 78906,
            "upload_time": "2023-05-17T19:24:47",
            "upload_time_iso_8601": "2023-05-17T19:24:47.083779Z",
            "url": "https://files.pythonhosted.org/packages/9c/cf/aa093e63ded3d3ae419707e168cbce68d74c7b3815df58e17a6e84e7e879/rax-0.3.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "52c6e949ca17e4dfb551f0734c537e4898645ee2d796c1df4cc1a38bbb80cd93",
                "md5": "51f6ce64c76acb65ef0a6e66fca70ea8",
                "sha256": "9d0b86a22cfad2d87129d30403057c076b2505e7b95f626ab11f02416af240da"
            },
            "downloads": -1,
            "filename": "rax-0.3.0.tar.gz",
            "has_sig": false,
            "md5_digest": "51f6ce64c76acb65ef0a6e66fca70ea8",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.8",
            "size": 59452,
            "upload_time": "2023-05-17T19:24:49",
            "upload_time_iso_8601": "2023-05-17T19:24:49.011541Z",
            "url": "https://files.pythonhosted.org/packages/52/c6/e949ca17e4dfb551f0734c537e4898645ee2d796c1df4cc1a38bbb80cd93/rax-0.3.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-05-17 19:24:49",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "google",
    "github_project": "rax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "rax"
}
        
Elapsed time: 8.77392s