# qpax
Differentiable QP solver in [JAX](https://github.com/google/jax).
[![Paper](http://img.shields.io/badge/arXiv-2207.00669-B31B1B.svg)](https://arxiv.org/abs/2406.11749)
This package can be used for solving convex quadratic programs of the following form:
$$
\begin{align*}
\underset{x}{\text{minimize}} & \quad \frac{1}{2}x^TQx + q^Tx \\
\text{subject to} & \quad Ax = b, \\
& \quad Gx \leq h
\end{align*}
$$
where $Q \succeq 0$. This solver can be combined with JAX's `jit` and `vmap` functionality, as well as differentiated with reverse-mode `grad`.
The QP is solved with a primal-dual interior point algorithm detailed in [cvxgen](https://stanford.edu/~boyd/papers/pdf/code_gen_impl.pdf), with the solution to the linear systems computed with reduction techniques from [cvxopt](http://www.seas.ucla.edu/~vandenbe/publications/coneprog.pdf). At an approximate primal-dual solution, the the primal variable $x$ is differentiated with respect to the problem parameters using the implicit function theorem as shown in [optnet](https://arxiv.org/abs/1703.00443), and their pytorch-based qp solver [qpth](https://github.com/locuslab/qpth).
## Installation
To install directly from github using `pip`:
```bash
$ pip install qpax
```
Alternatively, to install from source in editable mode:
```bash
$ pip install -e .
```
## Usage
### Solving a QP
We can solve QPs with qpax in a way that plays nice with JAX's `jit` and `vmap`:
```python
import qpax
# solve QP (this can be combined with jit or vmap)
x, s, z, y, converged, iters = qpax.solve_qp(Q, q, A, b, G, h)
```
### Solving a batch of QP's
Here let's solve a batch of nonnegative least squares problems as QPs. This outlines two bits of functionality from `qpax`, first is the ability to solve QPs without any equality constraints, and second is the ability to `vmap` over a batch of QPs.
```python
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
import qpax
import timeit
"""
solve batched non-negative least squares (nnls) problems
min_x |Fx - g|^2
st x >= 0
"""
n = 5 # size of x
m = 10 # rows in F
# create data for N_qps random nnls problems
N_qps = 10000
Fs = jnp.array(np.random.randn(N_qps, m, n))
gs = jnp.array(np.random.randn(N_qps, m))
@jit
def form_qp(F, g):
# convert the least squares to qp form
n = F.shape[1]
Q = F.T @ F
q = -F.T @ g
G = -jnp.eye(n)
h = jnp.zeros(n)
A = jnp.zeros((0, n))
b = jnp.zeros(0)
return Q, q, A, b, G, h
# create the QPs in a batched fashion
Qs, qs, As, bs, Gs, hs = vmap(form_qp, in_axes = (0, 0))(Fs, gs)
# create function for solving a batch of QPs
batch_qp = jit(vmap(qpax.solve_qp_primal, in_axes = (0, 0, 0, 0, 0, 0)))
xs = batch_qp(Qs, qs, As, bs, Gs, hs)
```
### Differentiating a QP
Alternatively, if we are only looking to use the primal variable `x`, we can use `solve_qp_primal` which enables automatic differenation:
```python
import jax
import jax.numpy as jnp
import qpax
def loss(Q, q, A, b, G, h):
x = qpax.solve_qp_primal(Q, q, A, b, G, h)
x_bar = jnp.ones(len(q))
return jnp.dot(x - x_bar, x - x_bar)
# gradient of loss function
loss_grad = jax.grad(loss, argnums = (0, 1, 2, 3, 4, 5))
# compatible with jit
loss_grad_jit = jax.jit(loss_grad)
# calculate derivatives
derivs = loss_grad_jit(Q, q, A, b, G, h)
dl_dQ, dl_dq, dl_dA, dl_db, dl_dG, dl_dh = derivs
```
## Citation
[![Paper](http://img.shields.io/badge/arXiv-2207.00669-B31B1B.svg)](https://arxiv.org/abs/2406.11749)
```
@misc{tracy2024differentiability,
title={On the Differentiability of the Primal-Dual Interior-Point Method},
author={Kevin Tracy and Zachary Manchester},
year={2024},
eprint={2406.11749},
archivePrefix={arXiv},
primaryClass={math.OC}
}
```
Raw data
{
"_id": null,
"home_page": null,
"name": "qpax",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.7",
"maintainer_email": null,
"keywords": "optimization, automatic differentiation, jax",
"author": null,
"author_email": "Kevin Tracy <ktracy@cmu.edu>",
"download_url": null,
"platform": null,
"description": "# qpax\nDifferentiable QP solver in [JAX](https://github.com/google/jax).\n\n[![Paper](http://img.shields.io/badge/arXiv-2207.00669-B31B1B.svg)](https://arxiv.org/abs/2406.11749)\n\nThis package can be used for solving convex quadratic programs of the following form:\n\n$$\n\\begin{align*}\n\\underset{x}{\\text{minimize}} & \\quad \\frac{1}{2}x^TQx + q^Tx \\\\\n\\text{subject to} & \\quad Ax = b, \\\\\n & \\quad Gx \\leq h\n\\end{align*}\n$$\n\nwhere $Q \\succeq 0$. This solver can be combined with JAX's `jit` and `vmap` functionality, as well as differentiated with reverse-mode `grad`. \n\nThe QP is solved with a primal-dual interior point algorithm detailed in [cvxgen](https://stanford.edu/~boyd/papers/pdf/code_gen_impl.pdf), with the solution to the linear systems computed with reduction techniques from [cvxopt](http://www.seas.ucla.edu/~vandenbe/publications/coneprog.pdf). At an approximate primal-dual solution, the the primal variable $x$ is differentiated with respect to the problem parameters using the implicit function theorem as shown in [optnet](https://arxiv.org/abs/1703.00443), and their pytorch-based qp solver [qpth](https://github.com/locuslab/qpth).\n\n## Installation\n\nTo install directly from github using `pip`:\n\n```bash\n$ pip install qpax\n```\n\nAlternatively, to install from source in editable mode:\n\n```bash\n$ pip install -e .\n```\n\n## Usage\n\n### Solving a QP \nWe can solve QPs with qpax in a way that plays nice with JAX's `jit` and `vmap`:\n```python \nimport qpax\n\n# solve QP (this can be combined with jit or vmap)\nx, s, z, y, converged, iters = qpax.solve_qp(Q, q, A, b, G, h)\n```\n### Solving a batch of QP's \n\nHere let's solve a batch of nonnegative least squares problems as QPs. This outlines two bits of functionality from `qpax`, first is the ability to solve QPs without any equality constraints, and second is the ability to `vmap` over a batch of QPs. \n\n```python \nimport numpy as np\nimport jax \nimport jax.numpy as jnp \nfrom jax import jit, grad, vmap \nimport qpax \nimport timeit\n\n\"\"\"\nsolve batched non-negative least squares (nnls) problems\n \nmin_x |Fx - g|^2 \nst x >= 0 \n\"\"\"\n\nn = 5 # size of x \nm = 10 # rows in F \n\n# create data for N_qps random nnls problems \nN_qps = 10000 \nFs = jnp.array(np.random.randn(N_qps, m, n))\ngs = jnp.array(np.random.randn(N_qps, m))\n\n@jit\ndef form_qp(F, g):\n # convert the least squares to qp form \n n = F.shape[1]\n Q = F.T @ F \n q = -F.T @ g \n G = -jnp.eye(n)\n h = jnp.zeros(n)\n A = jnp.zeros((0, n))\n b = jnp.zeros(0)\n return Q, q, A, b, G, h\n\n# create the QPs in a batched fashion \nQs, qs, As, bs, Gs, hs = vmap(form_qp, in_axes = (0, 0))(Fs, gs)\n\n# create function for solving a batch of QPs \nbatch_qp = jit(vmap(qpax.solve_qp_primal, in_axes = (0, 0, 0, 0, 0, 0)))\n\nxs = batch_qp(Qs, qs, As, bs, Gs, hs)\n```\n\n### Differentiating a QP \n\nAlternatively, if we are only looking to use the primal variable `x`, we can use `solve_qp_primal` which enables automatic differenation:\n\n```python\nimport jax \nimport jax.numpy as jnp \nimport qpax \n\ndef loss(Q, q, A, b, G, h):\n x = qpax.solve_qp_primal(Q, q, A, b, G, h) \n x_bar = jnp.ones(len(q))\n return jnp.dot(x - x_bar, x - x_bar)\n \n# gradient of loss function \nloss_grad = jax.grad(loss, argnums = (0, 1, 2, 3, 4, 5))\n\n# compatible with jit \nloss_grad_jit = jax.jit(loss_grad)\n\n# calculate derivatives \nderivs = loss_grad_jit(Q, q, A, b, G, h)\ndl_dQ, dl_dq, dl_dA, dl_db, dl_dG, dl_dh = derivs \n```\n\n## Citation \n[![Paper](http://img.shields.io/badge/arXiv-2207.00669-B31B1B.svg)](https://arxiv.org/abs/2406.11749)\n```\n@misc{tracy2024differentiability,\n title={On the Differentiability of the Primal-Dual Interior-Point Method},\n author={Kevin Tracy and Zachary Manchester},\n year={2024},\n eprint={2406.11749},\n archivePrefix={arXiv},\n primaryClass={math.OC}\n}\n```\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "Differentiable QP solver in JAX.",
"version": "0.0.9",
"project_urls": {
"Homepage": "https://github.com/kevin-tracy/qpax"
},
"split_keywords": [
"optimization",
" automatic differentiation",
" jax"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "4caa61e68bfbfdc7cfbe9a9760637d4a914188b85277cfa0208aa69b1c8a7b2b",
"md5": "c763c9ba19ed235b1b16638031994f75",
"sha256": "55d24b0ac30c95984a14f218ebdacddb434a94ea88eceb95a11f94e91fd0ce30"
},
"downloads": -1,
"filename": "qpax-0.0.9-py3-none-any.whl",
"has_sig": false,
"md5_digest": "c763c9ba19ed235b1b16638031994f75",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.7",
"size": 11675,
"upload_time": "2024-09-16T17:04:00",
"upload_time_iso_8601": "2024-09-16T17:04:00.067780Z",
"url": "https://files.pythonhosted.org/packages/4c/aa/61e68bfbfdc7cfbe9a9760637d4a914188b85277cfa0208aa69b1c8a7b2b/qpax-0.0.9-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-09-16 17:04:00",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "kevin-tracy",
"github_project": "qpax",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "qpax"
}