folx


Namefolx JSON
Version 0.2.6 PyPI version JSON
download
home_pagehttps://github.com/microsoft/folx
SummaryForward Laplacian for JAX
upload_time2024-04-25 12:37:51
maintainerNicholas Gao
docs_urlNone
authorNicholas Gao
requires_python<4.0,>=3.10
licenseMIT
keywords jax laplacian numeric
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # `folx` - Forward Laplacian for JAX

This submodule implements the forward laplacian from https://arxiv.org/abs/2307.08214. It is implemented as a [custom interpreter for Jaxprs](https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html).

## Install

Either clone repo and install locally via
```bash
poetry install
```
or
```bash
pip install .
```
or install via `pip` package manager via
```bash
pip install folx
```

## Example
### Dense example
For simple usage, one can decorate any function with `forward_laplacian`.
```python
import numpy as np
from folx import forward_laplacian

def f(x):
    return (x**2).sum()

fwd_f = forward_laplacian(f)
result = fwd_f(np.arange(3, dtype=float))
result.x # f(x) 3
result.jacobian.dense_array # J_f(x) [0, 2, 4]
result.laplacian # tr(H_f(x)) 6
```
### Sparsity example
A big feature of `folx` is to automatically work with sparse jacobians to accelerate computations. Note that the results are still **exact**. To enable this feature simply supply a maximum sparsity threshold. Compile times may increase significantly as tracing the sparsity patterns of the jacobians is a lengthy process. Here is an example with an MLP operating indepdently on individual node features.
```python
import folx
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        for _ in range(10):
            x = nn.Dense(100)(x)
            x = nn.silu(x)
        return nn.Dense(1)(x).sum()

mlp = MLP()
x = jnp.ones((20, 100, 4))
params = mlp.init(jax.random.PRNGKey(0), x)
def fwd(x):
    return mlp.apply(params, x)

# Traditional loop implementation
lapl = jax.jit(jax.vmap(folx.LoopLaplacianOperator()(fwd)))
%time jax.block_until_ready(lapl(x)) # Wall time: 1.42 s
%timeit jax.block_until_ready(lapl(x)) # 224 ms ± 54 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

# Forward laplacian without sparsity
lapl = jax.jit(jax.vmap(folx.ForwardLaplacianOperator(0)(fwd)))
%time jax.block_until_ready(lapl(x)) # Wall time: 2.66 s
%timeit jax.block_until_ready(lapl(x)) # 48.7 ms ± 42.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

# Forward laplacian with sparsity
lapl = jax.jit(jax.vmap(folx.ForwardLaplacianOperator(4)(fwd)))
%time jax.block_until_ready(lapl(x)) # Wall time: 5.05 s
%timeit jax.block_until_ready(lapl(x)) # 2.59 ms ± 15.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
For electronic wave function like FermiNet or PsiFormer, `sparsity_threshold=6` is a recommended value. But, tuning this hyperparameter may accelerate computations.

## Introduction
To avoid custom wrappers for all of JAX's commands, the forward laplacian is implemented as custom interpreter for Jaxpr.
This means if you have a function
```python
class Fn(Protocol):
    def __call__(self, *args: PyTree[Array]) -> PyTree[Array]:
        ...
```
the resulting function will have the signature:
```python
class LaplacianFn(Protocol):
    def __call__(self, *args: PyTree[Array]) -> PyTree[FwdLaplArray]:
        ...
```
where `FwdLaplArray` is a triplet of
```python
FwdLaplArray.x # jax.Array f(x) f(x).shape
FwdLaplArray.jacobian # FwdJacobian J_f(x)
FwdLaplArray.laplacian # jax.Array tr(H_f(x)) f(x).shape
```
The jacobian is implemented by a custom class as the forward laplacian supports automatic sparsity. To get the full jacobian:
```python
FwdLaplArray.jacobian.dense_array # jax.Array (*f(x).shape, x.size)
```

## Implementation idea
The idea is to rely on the original function and autodifferentiation to propagate `FwdLaplArray` forward instead of the regular `jax.Array`. The rules for updating `FwdLaplArray` are described by the pseudocode:
```python
x # FwdLaplArray
y = FwdLaplArray(
    x=f(x.x),
    jacobian=jvp(f, (x.x,), (x.jacobian)),
    laplacian=tr_vhv(f, x.jacobian) + jvp(f, (x.x,), (x.laplacian,))
)
# tr_vhv is tr(J_f H_f J_f^T)
```

## Implementation

When you call the function returned by `forward_laplacian(fn)`, we first use `jax.make_jaxpr` to obtain the jaxpr for `fn`.
But instead of using the [standard evaluation pipeline](https://github.com/google/jax/blob/776baba0a3fca15a909cb7d108eea830cbe3fc1d/jax/_src/core.py#L436), we use a custom interpreter that replaces all operations to propate `FwdLaplArray` forward instead of regular `jax.Array`.

### Package structure
The general structure of the package is
* `interpreter.py` contains the evaluation of jaxpr and exported function decorator.
* `wrapper.py` contains subfunction decorator that maps a function that takes `jax.Array`s to a function that accepts `FwdLaplArray`s instead.
* `wrapped_functions.py` contains a registry of predefined functions as well as utility functions to add new functions to the registry.
* `jvp.py` contains logic for jacobian vector products.
* `hessian.py` contains logic for tr(JHJ^T).
* `custom_hessian.py` contains special treatment logic for tr(JHJ^T).
* `api.py` contains general interfaces shared in the package.
* `operators.py` contains a forward laplacian operator as well as alternatives.
* `utils.py` contains several small utility functions.
* `tree_utils.py` contains several utility functions for PyTrees.
* `vmap.py` contains a batched vmap implementation to reduce memory usage by going through a batch sequentially in chunks.


### Function Annotations
There is a default interpreter that will simply apply the rules outlined above but if additional information about a function is available, e.g., that it applies elementwise like `jnp.tanh`, we can do better.
These additional annotations are available in `wrapped_functions.py`'s `_LAPLACE_FN_REGISTRY`.
Specifically, to augment a function `fn` to accept `FwdLaplArray` instead of regular `jax.Array`, we wrap it with `wrap_forward_laplacian` from `fwd_laplacian.py`:
```python
wrap_forward_laplacian(jnp.tanh, in_axes=())
```
In this case, we annotate the function to be applied elementwise, i.e., `()` indicates that none of the axes are relevant for the function.

If we know nothing about which axes might be essential, one must pass `None` (the default value) to mark all axes as imporatnt, e.g.,
```python
wrap_forward_laplacian(jnp.sum, in_axes=None, flags=FunctionFlags.LINEAR)
```
However, in this case we know that a summation is a linear operation. This information is useful for fast hessian computations.

If you want rules to a function and add it to the registry you can do the following
```python
import jax
from folx import register_function, wrap_forward_laplacian

register_function(jax.lax.cos_p, wrap_forward_laplacian(f, in_axes=()))
# Now the tracer is aware that the cosine function is applied elementwise.
```
We can do even more by defining custom rules:
```python
import jax
from folx import register_function, wrap_forward_laplacian

# the jit is important
@jax.jit
def f(x):
    return x

# define a custom jacobian hessian jacobian product rule
def custom_jac_hessian_jac(args, extra_args, merge, materialize_idx):
    return jtu.tree_map(lambda x: jnp.full_like(x, 10), args.x)

# make sure to use the same name here as above
register_function("f", wrap_forward_laplacian(f, custom_jac_hessian_jac=custom_jac_hessian_jac))

@forward_laplacian
def g(x):
    return f(x)

g(jnp.ones(())).laplacian # 10
```


### Sparsity
Sparsity is detected at compile time, this has the advantage of avoiding expensive index computations at runtime and enables efficient reductions. However, it completely prohibits dynamic indexing, i.e., if indices are data-dependent we will simply default to full jacobians.

As we know a lot about the sparsity structure apriori, e.g., that we are only sparse in one dimension, we use a custom sparsity operations that are more efficient than relying on JAX's default `BCOO` (further, at the time of writing, the support for `jax.experimental.sparse` is quite bad).
So, the sparsity data format is implemented in `FwdJacobian` in `api.py`. Instead of storing a dense array `(m, n)` for a function `f:R^n -> R^m`, we store only the non-zero data in a `(m,d)` array where `d<n` is the maximum number of non-zero inputs any output depends on.
To be able to recreate the larger `(m,n)` array from the `(m,d)` array, we additional keep track of the indices in the last dimension in a mask `(m,d)` dimensional array of integers `0<mask_ij<n`.

Masks are treated as compile time static and will be traced automatically. If the tracing is not possible, e.g., due to data dependent indexing, we will fall back to a dense implementation. These propagation rules are implemented in `jvp.py`.


### Memory efficiency
The forward laplacian uses more GPU memory due to the full materialization of the Jacobian matrix.
To compensate for this, it is recommended to loop over the batch size (while other implementations typically loop over the Hessian).
We provide an easy to use utility for this via `folx.batched_vmap` which functions like `jax.vmap` but chunks the input into batches and operates on these sequentially.
```python
from folx import batched_vmap

def f(x):
    return x**2

batched_f = batched_vmap(f, max_batch_size=64)
```

## Citation
If you find work helpful, please consider citing it as
```
@software{gao2023folx,
  author = {Nicholas Gao and Jonas Köhler and Adam Foster},
  title = {folx - Forward Laplacian for JAX},
  url = {http://github.com/microsoft/folx},
  version = {0.2.5},
  year = {2023},
}
```
as well as the original forward laplacian:
```
@article{li2023forward,
  title={Forward Laplacian: A New Computational Framework for Neural Network-based Variational Monte Carlo},
  author={Li, Ruichen and Ye, Haotian and Jiang, Du and Wen, Xuelan and Wang, Chuwei and Li, Zhe and Li, Xiang and He, Di and Chen, Ji and Ren, Weiluo and Wang, Liwei},
  journal={arXiv preprint arXiv:2307.08214},
  year={2023}
}
```

## Contributing

This project welcomes contributions and suggestions.  Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.

## Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
trademarks or logos is subject to and must follow
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos are subject to those third-party's policies.

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/microsoft/folx",
    "name": "folx",
    "maintainer": "Nicholas Gao",
    "docs_url": null,
    "requires_python": "<4.0,>=3.10",
    "maintainer_email": "n.gao@tum.de",
    "keywords": "jax, laplacian, numeric",
    "author": "Nicholas Gao",
    "author_email": "n.gao@tum.de",
    "download_url": "https://files.pythonhosted.org/packages/f7/d1/19d5ad7c1a2e1f94ec69b13d34b8fed10dfab24cb0a952e2cfa86b97adcf/folx-0.2.6.tar.gz",
    "platform": null,
    "description": "# `folx` - Forward Laplacian for JAX\n\nThis submodule implements the forward laplacian from https://arxiv.org/abs/2307.08214. It is implemented as a [custom interpreter for Jaxprs](https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html).\n\n## Install\n\nEither clone repo and install locally via\n```bash\npoetry install\n```\nor\n```bash\npip install .\n```\nor install via `pip` package manager via\n```bash\npip install folx\n```\n\n## Example\n### Dense example\nFor simple usage, one can decorate any function with `forward_laplacian`.\n```python\nimport numpy as np\nfrom folx import forward_laplacian\n\ndef f(x):\n    return (x**2).sum()\n\nfwd_f = forward_laplacian(f)\nresult = fwd_f(np.arange(3, dtype=float))\nresult.x # f(x) 3\nresult.jacobian.dense_array # J_f(x) [0, 2, 4]\nresult.laplacian # tr(H_f(x)) 6\n```\n### Sparsity example\nA big feature of `folx` is to automatically work with sparse jacobians to accelerate computations. Note that the results are still **exact**. To enable this feature simply supply a maximum sparsity threshold. Compile times may increase significantly as tracing the sparsity patterns of the jacobians is a lengthy process. Here is an example with an MLP operating indepdently on individual node features.\n```python\nimport folx\nimport jax\nimport jax.numpy as jnp\nimport flax.linen as nn\n\nclass MLP(nn.Module):\n    @nn.compact\n    def __call__(self, x):\n        for _ in range(10):\n            x = nn.Dense(100)(x)\n            x = nn.silu(x)\n        return nn.Dense(1)(x).sum()\n\nmlp = MLP()\nx = jnp.ones((20, 100, 4))\nparams = mlp.init(jax.random.PRNGKey(0), x)\ndef fwd(x):\n    return mlp.apply(params, x)\n\n# Traditional loop implementation\nlapl = jax.jit(jax.vmap(folx.LoopLaplacianOperator()(fwd)))\n%time jax.block_until_ready(lapl(x)) # Wall time: 1.42 s\n%timeit jax.block_until_ready(lapl(x)) # 224 ms \u00b1 54 \u00b5s per loop (mean \u00b1 std. dev. of 7 runs, 1 loop each)\n\n# Forward laplacian without sparsity\nlapl = jax.jit(jax.vmap(folx.ForwardLaplacianOperator(0)(fwd)))\n%time jax.block_until_ready(lapl(x)) # Wall time: 2.66 s\n%timeit jax.block_until_ready(lapl(x)) # 48.7 ms \u00b1 42.1 \u00b5s per loop (mean \u00b1 std. dev. of 7 runs, 10 loops each)\n\n# Forward laplacian with sparsity\nlapl = jax.jit(jax.vmap(folx.ForwardLaplacianOperator(4)(fwd)))\n%time jax.block_until_ready(lapl(x)) # Wall time: 5.05 s\n%timeit jax.block_until_ready(lapl(x)) # 2.59 ms \u00b1 15.3 \u00b5s per loop (mean \u00b1 std. dev. of 7 runs, 100 loops each)\n```\nFor electronic wave function like FermiNet or PsiFormer, `sparsity_threshold=6` is a recommended value. But, tuning this hyperparameter may accelerate computations.\n\n## Introduction\nTo avoid custom wrappers for all of JAX's commands, the forward laplacian is implemented as custom interpreter for Jaxpr.\nThis means if you have a function\n```python\nclass Fn(Protocol):\n    def __call__(self, *args: PyTree[Array]) -> PyTree[Array]:\n        ...\n```\nthe resulting function will have the signature:\n```python\nclass LaplacianFn(Protocol):\n    def __call__(self, *args: PyTree[Array]) -> PyTree[FwdLaplArray]:\n        ...\n```\nwhere `FwdLaplArray` is a triplet of\n```python\nFwdLaplArray.x # jax.Array f(x) f(x).shape\nFwdLaplArray.jacobian # FwdJacobian J_f(x)\nFwdLaplArray.laplacian # jax.Array tr(H_f(x)) f(x).shape\n```\nThe jacobian is implemented by a custom class as the forward laplacian supports automatic sparsity. To get the full jacobian:\n```python\nFwdLaplArray.jacobian.dense_array # jax.Array (*f(x).shape, x.size)\n```\n\n## Implementation idea\nThe idea is to rely on the original function and autodifferentiation to propagate `FwdLaplArray` forward instead of the regular `jax.Array`. The rules for updating `FwdLaplArray` are described by the pseudocode:\n```python\nx # FwdLaplArray\ny = FwdLaplArray(\n    x=f(x.x),\n    jacobian=jvp(f, (x.x,), (x.jacobian)),\n    laplacian=tr_vhv(f, x.jacobian) + jvp(f, (x.x,), (x.laplacian,))\n)\n# tr_vhv is tr(J_f H_f J_f^T)\n```\n\n## Implementation\n\nWhen you call the function returned by `forward_laplacian(fn)`, we first use `jax.make_jaxpr` to obtain the jaxpr for `fn`.\nBut instead of using the [standard evaluation pipeline](https://github.com/google/jax/blob/776baba0a3fca15a909cb7d108eea830cbe3fc1d/jax/_src/core.py#L436), we use a custom interpreter that replaces all operations to propate `FwdLaplArray` forward instead of regular `jax.Array`.\n\n### Package structure\nThe general structure of the package is\n* `interpreter.py` contains the evaluation of jaxpr and exported function decorator.\n* `wrapper.py` contains subfunction decorator that maps a function that takes `jax.Array`s to a function that accepts `FwdLaplArray`s instead.\n* `wrapped_functions.py` contains a registry of predefined functions as well as utility functions to add new functions to the registry.\n* `jvp.py` contains logic for jacobian vector products.\n* `hessian.py` contains logic for tr(JHJ^T).\n* `custom_hessian.py` contains special treatment logic for tr(JHJ^T).\n* `api.py` contains general interfaces shared in the package.\n* `operators.py` contains a forward laplacian operator as well as alternatives.\n* `utils.py` contains several small utility functions.\n* `tree_utils.py` contains several utility functions for PyTrees.\n* `vmap.py` contains a batched vmap implementation to reduce memory usage by going through a batch sequentially in chunks.\n\n\n### Function Annotations\nThere is a default interpreter that will simply apply the rules outlined above but if additional information about a function is available, e.g., that it applies elementwise like `jnp.tanh`, we can do better.\nThese additional annotations are available in `wrapped_functions.py`'s `_LAPLACE_FN_REGISTRY`.\nSpecifically, to augment a function `fn` to accept `FwdLaplArray` instead of regular `jax.Array`, we wrap it with `wrap_forward_laplacian` from `fwd_laplacian.py`:\n```python\nwrap_forward_laplacian(jnp.tanh, in_axes=())\n```\nIn this case, we annotate the function to be applied elementwise, i.e., `()` indicates that none of the axes are relevant for the function.\n\nIf we know nothing about which axes might be essential, one must pass `None` (the default value) to mark all axes as imporatnt, e.g.,\n```python\nwrap_forward_laplacian(jnp.sum, in_axes=None, flags=FunctionFlags.LINEAR)\n```\nHowever, in this case we know that a summation is a linear operation. This information is useful for fast hessian computations.\n\nIf you want rules to a function and add it to the registry you can do the following\n```python\nimport jax\nfrom folx import register_function, wrap_forward_laplacian\n\nregister_function(jax.lax.cos_p, wrap_forward_laplacian(f, in_axes=()))\n# Now the tracer is aware that the cosine function is applied elementwise.\n```\nWe can do even more by defining custom rules:\n```python\nimport jax\nfrom folx import register_function, wrap_forward_laplacian\n\n# the jit is important\n@jax.jit\ndef f(x):\n    return x\n\n# define a custom jacobian hessian jacobian product rule\ndef custom_jac_hessian_jac(args, extra_args, merge, materialize_idx):\n    return jtu.tree_map(lambda x: jnp.full_like(x, 10), args.x)\n\n# make sure to use the same name here as above\nregister_function(\"f\", wrap_forward_laplacian(f, custom_jac_hessian_jac=custom_jac_hessian_jac))\n\n@forward_laplacian\ndef g(x):\n    return f(x)\n\ng(jnp.ones(())).laplacian # 10\n```\n\n\n### Sparsity\nSparsity is detected at compile time, this has the advantage of avoiding expensive index computations at runtime and enables efficient reductions. However, it completely prohibits dynamic indexing, i.e., if indices are data-dependent we will simply default to full jacobians.\n\nAs we know a lot about the sparsity structure apriori, e.g., that we are only sparse in one dimension, we use a custom sparsity operations that are more efficient than relying on JAX's default `BCOO` (further, at the time of writing, the support for `jax.experimental.sparse` is quite bad).\nSo, the sparsity data format is implemented in `FwdJacobian` in `api.py`. Instead of storing a dense array `(m, n)` for a function `f:R^n -> R^m`, we store only the non-zero data in a `(m,d)` array where `d<n` is the maximum number of non-zero inputs any output depends on.\nTo be able to recreate the larger `(m,n)` array from the `(m,d)` array, we additional keep track of the indices in the last dimension in a mask `(m,d)` dimensional array of integers `0<mask_ij<n`.\n\nMasks are treated as compile time static and will be traced automatically. If the tracing is not possible, e.g., due to data dependent indexing, we will fall back to a dense implementation. These propagation rules are implemented in `jvp.py`.\n\n\n### Memory efficiency\nThe forward laplacian uses more GPU memory due to the full materialization of the Jacobian matrix.\nTo compensate for this, it is recommended to loop over the batch size (while other implementations typically loop over the Hessian).\nWe provide an easy to use utility for this via `folx.batched_vmap` which functions like `jax.vmap` but chunks the input into batches and operates on these sequentially.\n```python\nfrom folx import batched_vmap\n\ndef f(x):\n    return x**2\n\nbatched_f = batched_vmap(f, max_batch_size=64)\n```\n\n## Citation\nIf you find work helpful, please consider citing it as\n```\n@software{gao2023folx,\n  author = {Nicholas Gao and Jonas K\u00f6hler and Adam Foster},\n  title = {folx - Forward Laplacian for JAX},\n  url = {http://github.com/microsoft/folx},\n  version = {0.2.5},\n  year = {2023},\n}\n```\nas well as the original forward laplacian:\n```\n@article{li2023forward,\n  title={Forward Laplacian: A New Computational Framework for Neural Network-based Variational Monte Carlo},\n  author={Li, Ruichen and Ye, Haotian and Jiang, Du and Wen, Xuelan and Wang, Chuwei and Li, Zhe and Li, Xiang and He, Di and Chen, Ji and Ren, Weiluo and Wang, Liwei},\n  journal={arXiv preprint arXiv:2307.08214},\n  year={2023}\n}\n```\n\n## Contributing\n\nThis project welcomes contributions and suggestions.  Most contributions require you to agree to a\nContributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us\nthe rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.\n\nWhen you submit a pull request, a CLA bot will automatically determine whether you need to provide\na CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions\nprovided by the bot. You will only need to do this once across all repos using our CLA.\n\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).\nFor more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or\ncontact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.\n\n## Trademarks\n\nThis project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft\ntrademarks or logos is subject to and must follow\n[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).\nUse of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.\nAny use of third-party trademarks or logos are subject to those third-party's policies.\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "Forward Laplacian for JAX",
    "version": "0.2.6",
    "project_urls": {
        "Homepage": "https://github.com/microsoft/folx",
        "Repository": "https://github.com/microsoft/folx"
    },
    "split_keywords": [
        "jax",
        " laplacian",
        " numeric"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "e478169674a2a7758f2aae5c51f5c34f26f2539b0ffbc0b4d8d64145b120b1d4",
                "md5": "ef61c218845113039178ad308850d6f8",
                "sha256": "67cbc14c22985415f715413af8ac8c867ca2974bf30cdb3c49788ea6f8e2dceb"
            },
            "downloads": -1,
            "filename": "folx-0.2.6-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "ef61c218845113039178ad308850d6f8",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": "<4.0,>=3.10",
            "size": 39250,
            "upload_time": "2024-04-25T12:37:49",
            "upload_time_iso_8601": "2024-04-25T12:37:49.740884Z",
            "url": "https://files.pythonhosted.org/packages/e4/78/169674a2a7758f2aae5c51f5c34f26f2539b0ffbc0b4d8d64145b120b1d4/folx-0.2.6-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "f7d119d5ad7c1a2e1f94ec69b13d34b8fed10dfab24cb0a952e2cfa86b97adcf",
                "md5": "eea7710e937ebbf8095d68f1499577d5",
                "sha256": "e3bc14c260dff6d44afab62f2844da7f25a6cf1f4304d3b0ecafffc38a8ca858"
            },
            "downloads": -1,
            "filename": "folx-0.2.6.tar.gz",
            "has_sig": false,
            "md5_digest": "eea7710e937ebbf8095d68f1499577d5",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": "<4.0,>=3.10",
            "size": 37986,
            "upload_time": "2024-04-25T12:37:51",
            "upload_time_iso_8601": "2024-04-25T12:37:51.345382Z",
            "url": "https://files.pythonhosted.org/packages/f7/d1/19d5ad7c1a2e1f94ec69b13d34b8fed10dfab24cb0a952e2cfa86b97adcf/folx-0.2.6.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-04-25 12:37:51",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "microsoft",
    "github_project": "folx",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "folx"
}
        
Elapsed time: 0.29109s