pytree2safetensors


Namepytree2safetensors JSON
Version 0.1.4 PyPI version JSON
download
home_pageNone
SummaryA simple package to save and load JAX PyTrees to and from Safetensors
upload_time2024-08-26 10:40:57
maintainerNone
docs_urlNone
authorNone
requires_python>=3.10
licenseNone
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Pytree2Safetensors
Pytree2Safetensors is a simple package to save and load JAX PyTrees to and from
Safetensors, a popular file format for saving neural network weights.

To install, run

```
pip install --upgrade pytree2safetensors
```

Pytree2Safetensors depends on `jax`, `safetensors`, and `jaxtyping`. You also need
to have at least Python 3.10

## Specification
### Serialising/Deserialising

#### `keypath2string(path: KeyPath) -> str`
Serializes a JAX key path (i.e., a path to a leaf in a pytree) to a string by joining together a string representation of each key in the path. Prefixes of these representation tell what type of key it is. A GetAttryKey is prefixed with ".", a DictKey is prefixed with "@", and a SequenceKey is prefixed with "#". If the initial key is a
GetAttryKey, the initial "." is left off.

Examples:
```python
keypath2string((GetAttrKey("layers"), SequenceKey(10), DictKey("query"),))
# => "layers#10@query
keypath2string((SequenceKey(2), GetAttrKey("layers"), SequenceKey(10), DictKey("query"),))
# => "#2.layers#10@query
```

#### `string2keypath(string: str) -> KeyPath`
Inverse of `keypath2string`

#### `pytree2dict(tree: PyTree) -> dict`
Returns a dictionary of serialized key paths mapping to leaves of the tree.

#### `dict2pytree(dictionary: dict) -> tree`
Inverse of `pytree2dict`, except that it wraps attributes in `PyTreeContainer`s instead of using the original object. This is because there is no way for the deserialiser to know what the original object was. You can use `load_into_pytree` to load weights into an initialized
pytree.

#### `PyTreeContainer`
A class which implements the bare minimum to be a node in a pytree according to JAX.

### Saving
#### `save_pytree(tree: PyTree, path: str) -> None`
Saves the pytree as a safetensors at the given path. Equivalent to
```python
safetensors.flax.save_file(pytree2dict(tree), path)
```

### Loading
#### ```load_file```
Alias of `safetensors.flax.load_file`

#### ```load_pytree(path: str) -> PyTree```
Loads a file and uses `dict2pytree` to convert the safetensors dict to a pytree.

#### ```set_weights(module: PyTree, dictionary: dict) -> PyTree```
Given a pytree module and a safetensors dict, load the weights in the safetensors dict into the module using string2keypath to determine their paths. Returns a new pytree.

#### ```load_into_pytree(module: PyTree, path: str) -> PyTree```
Equivalent to `set_weights(module, load_file(path))`.
            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "pytree2safetensors",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.10",
    "maintainer_email": null,
    "keywords": null,
    "author": null,
    "author_email": "Joseph Camacho <camacho.joseph@gmail.com>",
    "download_url": "https://files.pythonhosted.org/packages/05/7d/f4953f902c9291bbdc9bbc2a3f41962338385524144d86322e0af0a205a5/pytree2safetensors-0.1.4.tar.gz",
    "platform": null,
    "description": "# Pytree2Safetensors\nPytree2Safetensors is a simple package to save and load JAX PyTrees to and from\nSafetensors, a popular file format for saving neural network weights.\n\nTo install, run\n\n```\npip install --upgrade pytree2safetensors\n```\n\nPytree2Safetensors depends on `jax`, `safetensors`, and `jaxtyping`. You also need\nto have at least Python 3.10\n\n## Specification\n### Serialising/Deserialising\n\n#### `keypath2string(path: KeyPath) -> str`\nSerializes a JAX key path (i.e., a path to a leaf in a pytree) to a string by joining together a string representation of each key in the path. Prefixes of these representation tell what type of key it is. A GetAttryKey is prefixed with \".\", a DictKey is prefixed with \"@\", and a SequenceKey is prefixed with \"#\". If the initial key is a\nGetAttryKey, the initial \".\" is left off.\n\nExamples:\n```python\nkeypath2string((GetAttrKey(\"layers\"), SequenceKey(10), DictKey(\"query\"),))\n# => \"layers#10@query\nkeypath2string((SequenceKey(2), GetAttrKey(\"layers\"), SequenceKey(10), DictKey(\"query\"),))\n# => \"#2.layers#10@query\n```\n\n#### `string2keypath(string: str) -> KeyPath`\nInverse of `keypath2string`\n\n#### `pytree2dict(tree: PyTree) -> dict`\nReturns a dictionary of serialized key paths mapping to leaves of the tree.\n\n#### `dict2pytree(dictionary: dict) -> tree`\nInverse of `pytree2dict`, except that it wraps attributes in `PyTreeContainer`s instead of using the original object. This is because there is no way for the deserialiser to know what the original object was. You can use `load_into_pytree` to load weights into an initialized\npytree.\n\n#### `PyTreeContainer`\nA class which implements the bare minimum to be a node in a pytree according to JAX.\n\n### Saving\n#### `save_pytree(tree: PyTree, path: str) -> None`\nSaves the pytree as a safetensors at the given path. Equivalent to\n```python\nsafetensors.flax.save_file(pytree2dict(tree), path)\n```\n\n### Loading\n#### ```load_file```\nAlias of `safetensors.flax.load_file`\n\n#### ```load_pytree(path: str) -> PyTree```\nLoads a file and uses `dict2pytree` to convert the safetensors dict to a pytree.\n\n#### ```set_weights(module: PyTree, dictionary: dict) -> PyTree```\nGiven a pytree module and a safetensors dict, load the weights in the safetensors dict into the module using string2keypath to determine their paths. Returns a new pytree.\n\n#### ```load_into_pytree(module: PyTree, path: str) -> PyTree```\nEquivalent to `set_weights(module, load_file(path))`.",
    "bugtrack_url": null,
    "license": null,
    "summary": "A simple package to save and load JAX PyTrees to and from Safetensors",
    "version": "0.1.4",
    "project_urls": {
        "Homepage": "https://github.com/cooljoseph1/pytree2safetensors",
        "Issues": "https://github.com/cooljoseph1/pytree2safetensors/issues"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "a9b1a6ca7d1339a88126199b07fb02863832d96b2a08af6870e58a798cf1695a",
                "md5": "3323d611078558552e0063bc43d4f902",
                "sha256": "5713fcb9b5f1fab0717c88bcfc3d7fe72b2dc39d865288f26e66fe5fad391657"
            },
            "downloads": -1,
            "filename": "pytree2safetensors-0.1.4-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "3323d611078558552e0063bc43d4f902",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.10",
            "size": 6171,
            "upload_time": "2024-08-26T10:40:56",
            "upload_time_iso_8601": "2024-08-26T10:40:56.280326Z",
            "url": "https://files.pythonhosted.org/packages/a9/b1/a6ca7d1339a88126199b07fb02863832d96b2a08af6870e58a798cf1695a/pytree2safetensors-0.1.4-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "057df4953f902c9291bbdc9bbc2a3f41962338385524144d86322e0af0a205a5",
                "md5": "470d04195338b4140cf13159556f59ec",
                "sha256": "4d7c9bed9ed5b4bf6cc5dc6615b87c8bc96ed4ede82b7e9c1c22531e47babc65"
            },
            "downloads": -1,
            "filename": "pytree2safetensors-0.1.4.tar.gz",
            "has_sig": false,
            "md5_digest": "470d04195338b4140cf13159556f59ec",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.10",
            "size": 4404,
            "upload_time": "2024-08-26T10:40:57",
            "upload_time_iso_8601": "2024-08-26T10:40:57.164295Z",
            "url": "https://files.pythonhosted.org/packages/05/7d/f4953f902c9291bbdc9bbc2a3f41962338385524144d86322e0af0a205a5/pytree2safetensors-0.1.4.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-08-26 10:40:57",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "cooljoseph1",
    "github_project": "pytree2safetensors",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "lcname": "pytree2safetensors"
}
        
Elapsed time: 0.32725s