qpax


Nameqpax JSON
Version 0.0.9 PyPI version JSON
download
home_pageNone
SummaryDifferentiable QP solver in JAX.
upload_time2024-09-16 17:04:00
maintainerNone
docs_urlNone
authorNone
requires_python>=3.7
licenseMIT
keywords optimization automatic differentiation jax
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # qpax
Differentiable QP solver in [JAX](https://github.com/google/jax).

[![Paper](http://img.shields.io/badge/arXiv-2207.00669-B31B1B.svg)](https://arxiv.org/abs/2406.11749)

This package can be used for solving convex quadratic programs of the following form:

$$
\begin{align*}
\underset{x}{\text{minimize}} & \quad \frac{1}{2}x^TQx + q^Tx \\
\text{subject to} & \quad  Ax = b, \\
                  & \quad  Gx \leq h
\end{align*}
$$

where $Q \succeq 0$. This solver can be combined with JAX's `jit` and `vmap` functionality, as well as differentiated with reverse-mode `grad`. 

The QP is solved with a primal-dual interior point algorithm detailed in [cvxgen](https://stanford.edu/~boyd/papers/pdf/code_gen_impl.pdf), with the solution to the linear systems computed with reduction techniques from [cvxopt](http://www.seas.ucla.edu/~vandenbe/publications/coneprog.pdf). At an approximate primal-dual solution, the the primal variable $x$ is differentiated with respect to the problem parameters using the implicit function theorem as shown in [optnet](https://arxiv.org/abs/1703.00443), and their pytorch-based qp solver [qpth](https://github.com/locuslab/qpth).

## Installation

To install directly from github using `pip`:

```bash
$ pip install qpax
```

Alternatively, to install from source in editable mode:

```bash
$ pip install -e .
```

## Usage

### Solving a QP 
We can solve QPs with qpax in a way that plays nice with JAX's `jit` and `vmap`:
```python 
import qpax

# solve QP (this can be combined with jit or vmap)
x, s, z, y, converged, iters = qpax.solve_qp(Q, q, A, b, G, h)
```
### Solving a batch of QP's 

Here let's solve a batch of nonnegative least squares problems as QPs. This outlines two bits of functionality from `qpax`, first is the ability to solve QPs without any equality constraints, and second is the ability to `vmap` over a batch of QPs. 

```python 
import numpy as np
import jax 
import jax.numpy as jnp 
from jax import jit, grad, vmap  
import qpax 
import timeit

"""
solve batched non-negative least squares (nnls) problems
 
min_x    |Fx - g|^2 
st        x >= 0 
"""

n = 5   # size of x 
m = 10  # rows in F 

# create data for N_qps random nnls problems  
N_qps = 10000 
Fs = jnp.array(np.random.randn(N_qps, m, n))
gs = jnp.array(np.random.randn(N_qps, m))

@jit
def form_qp(F, g):
  # convert the least squares to qp form 
  n = F.shape[1]
  Q = F.T @ F 
  q = -F.T @ g 
  G = -jnp.eye(n)
  h = jnp.zeros(n)
  A = jnp.zeros((0, n))
  b = jnp.zeros(0)
  return Q, q, A, b, G, h

# create the QPs in a batched fashion 
Qs, qs, As, bs, Gs, hs = vmap(form_qp, in_axes = (0, 0))(Fs, gs)

# create function for solving a batch of QPs 
batch_qp = jit(vmap(qpax.solve_qp_primal, in_axes = (0, 0, 0, 0, 0, 0)))

xs = batch_qp(Qs, qs, As, bs, Gs, hs)
```

### Differentiating a QP 

Alternatively, if we are only looking to use the primal variable `x`, we can use `solve_qp_primal` which enables automatic differenation:

```python
import jax 
import jax.numpy as jnp 
import qpax 

def loss(Q, q, A, b, G, h):
    x = qpax.solve_qp_primal(Q, q, A, b, G, h) 
    x_bar = jnp.ones(len(q))
    return jnp.dot(x - x_bar, x - x_bar)
  
# gradient of loss function   
loss_grad = jax.grad(loss, argnums = (0, 1, 2, 3, 4, 5))

# compatible with jit 
loss_grad_jit = jax.jit(loss_grad)

# calculate derivatives 
derivs = loss_grad_jit(Q, q, A, b, G, h)
dl_dQ, dl_dq, dl_dA, dl_db, dl_dG, dl_dh = derivs 
```

## Citation 
[![Paper](http://img.shields.io/badge/arXiv-2207.00669-B31B1B.svg)](https://arxiv.org/abs/2406.11749)
```
@misc{tracy2024differentiability,
    title={On the Differentiability of the Primal-Dual Interior-Point Method},
    author={Kevin Tracy and Zachary Manchester},
    year={2024},
    eprint={2406.11749},
    archivePrefix={arXiv},
    primaryClass={math.OC}
}
```

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "qpax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.7",
    "maintainer_email": null,
    "keywords": "optimization, automatic differentiation, jax",
    "author": null,
    "author_email": "Kevin Tracy <ktracy@cmu.edu>",
    "download_url": null,
    "platform": null,
    "description": "# qpax\nDifferentiable QP solver in [JAX](https://github.com/google/jax).\n\n[![Paper](http://img.shields.io/badge/arXiv-2207.00669-B31B1B.svg)](https://arxiv.org/abs/2406.11749)\n\nThis package can be used for solving convex quadratic programs of the following form:\n\n$$\n\\begin{align*}\n\\underset{x}{\\text{minimize}} & \\quad \\frac{1}{2}x^TQx + q^Tx \\\\\n\\text{subject to} & \\quad  Ax = b, \\\\\n                  & \\quad  Gx \\leq h\n\\end{align*}\n$$\n\nwhere $Q \\succeq 0$. This solver can be combined with JAX's `jit` and `vmap` functionality, as well as differentiated with reverse-mode `grad`. \n\nThe QP is solved with a primal-dual interior point algorithm detailed in [cvxgen](https://stanford.edu/~boyd/papers/pdf/code_gen_impl.pdf), with the solution to the linear systems computed with reduction techniques from [cvxopt](http://www.seas.ucla.edu/~vandenbe/publications/coneprog.pdf). At an approximate primal-dual solution, the the primal variable $x$ is differentiated with respect to the problem parameters using the implicit function theorem as shown in [optnet](https://arxiv.org/abs/1703.00443), and their pytorch-based qp solver [qpth](https://github.com/locuslab/qpth).\n\n## Installation\n\nTo install directly from github using `pip`:\n\n```bash\n$ pip install qpax\n```\n\nAlternatively, to install from source in editable mode:\n\n```bash\n$ pip install -e .\n```\n\n## Usage\n\n### Solving a QP \nWe can solve QPs with qpax in a way that plays nice with JAX's `jit` and `vmap`:\n```python \nimport qpax\n\n# solve QP (this can be combined with jit or vmap)\nx, s, z, y, converged, iters = qpax.solve_qp(Q, q, A, b, G, h)\n```\n### Solving a batch of QP's \n\nHere let's solve a batch of nonnegative least squares problems as QPs. This outlines two bits of functionality from `qpax`, first is the ability to solve QPs without any equality constraints, and second is the ability to `vmap` over a batch of QPs. \n\n```python \nimport numpy as np\nimport jax \nimport jax.numpy as jnp \nfrom jax import jit, grad, vmap  \nimport qpax \nimport timeit\n\n\"\"\"\nsolve batched non-negative least squares (nnls) problems\n \nmin_x    |Fx - g|^2 \nst        x >= 0 \n\"\"\"\n\nn = 5   # size of x \nm = 10  # rows in F \n\n# create data for N_qps random nnls problems  \nN_qps = 10000 \nFs = jnp.array(np.random.randn(N_qps, m, n))\ngs = jnp.array(np.random.randn(N_qps, m))\n\n@jit\ndef form_qp(F, g):\n  # convert the least squares to qp form \n  n = F.shape[1]\n  Q = F.T @ F \n  q = -F.T @ g \n  G = -jnp.eye(n)\n  h = jnp.zeros(n)\n  A = jnp.zeros((0, n))\n  b = jnp.zeros(0)\n  return Q, q, A, b, G, h\n\n# create the QPs in a batched fashion \nQs, qs, As, bs, Gs, hs = vmap(form_qp, in_axes = (0, 0))(Fs, gs)\n\n# create function for solving a batch of QPs \nbatch_qp = jit(vmap(qpax.solve_qp_primal, in_axes = (0, 0, 0, 0, 0, 0)))\n\nxs = batch_qp(Qs, qs, As, bs, Gs, hs)\n```\n\n### Differentiating a QP \n\nAlternatively, if we are only looking to use the primal variable `x`, we can use `solve_qp_primal` which enables automatic differenation:\n\n```python\nimport jax \nimport jax.numpy as jnp \nimport qpax \n\ndef loss(Q, q, A, b, G, h):\n    x = qpax.solve_qp_primal(Q, q, A, b, G, h) \n    x_bar = jnp.ones(len(q))\n    return jnp.dot(x - x_bar, x - x_bar)\n  \n# gradient of loss function   \nloss_grad = jax.grad(loss, argnums = (0, 1, 2, 3, 4, 5))\n\n# compatible with jit \nloss_grad_jit = jax.jit(loss_grad)\n\n# calculate derivatives \nderivs = loss_grad_jit(Q, q, A, b, G, h)\ndl_dQ, dl_dq, dl_dA, dl_db, dl_dG, dl_dh = derivs \n```\n\n## Citation \n[![Paper](http://img.shields.io/badge/arXiv-2207.00669-B31B1B.svg)](https://arxiv.org/abs/2406.11749)\n```\n@misc{tracy2024differentiability,\n    title={On the Differentiability of the Primal-Dual Interior-Point Method},\n    author={Kevin Tracy and Zachary Manchester},\n    year={2024},\n    eprint={2406.11749},\n    archivePrefix={arXiv},\n    primaryClass={math.OC}\n}\n```\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "Differentiable QP solver in JAX.",
    "version": "0.0.9",
    "project_urls": {
        "Homepage": "https://github.com/kevin-tracy/qpax"
    },
    "split_keywords": [
        "optimization",
        " automatic differentiation",
        " jax"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "4caa61e68bfbfdc7cfbe9a9760637d4a914188b85277cfa0208aa69b1c8a7b2b",
                "md5": "c763c9ba19ed235b1b16638031994f75",
                "sha256": "55d24b0ac30c95984a14f218ebdacddb434a94ea88eceb95a11f94e91fd0ce30"
            },
            "downloads": -1,
            "filename": "qpax-0.0.9-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "c763c9ba19ed235b1b16638031994f75",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.7",
            "size": 11675,
            "upload_time": "2024-09-16T17:04:00",
            "upload_time_iso_8601": "2024-09-16T17:04:00.067780Z",
            "url": "https://files.pythonhosted.org/packages/4c/aa/61e68bfbfdc7cfbe9a9760637d4a914188b85277cfa0208aa69b1c8a7b2b/qpax-0.0.9-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-09-16 17:04:00",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "kevin-tracy",
    "github_project": "qpax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "qpax"
}
        
Elapsed time: 0.38406s