| Name | pytree2safetensor JSON |
| Version |
0.1.2
JSON |
| download |
| home_page | None |
| Summary | A simple package to save and load JAX PyTrees to and from Safetensors |
| upload_time | 2024-08-23 07:40:38 |
| maintainer | None |
| docs_url | None |
| author | None |
| requires_python | >=3.10 |
| license | None |
| keywords |
|
| VCS |
 |
| bugtrack_url |
|
| requirements |
No requirements were recorded.
|
| Travis-CI |
No Travis.
|
| coveralls test coverage |
No coveralls.
|
# Pytree2Safetensor
Pytree2Safetensor is a simple package to save and load JAX PyTrees to and from
Safetensors, a popular file format for saving neural network weights.
To install, run
```
pip install --upgrade pytree2safetensor
```
Pytree2Safetensor's depends on `jax`, `safetensors`, and `jaxtyping`. You also need
to have at least Python 3.10
## Specification
### Serialising/Deserialising
#### `keypath2string(path: KeyPath) -> str`
Serializes a JAX key path (i.e., a path to a leaf in a pytree) to a string by joining together a string representation of each key in the path. Prefixes of these representation tell what type of key it is. A GetAttryKey is prefixed with ".", a DictKey is prefixed with "@", and a SequenceKey is prefixed with "#". If the initial key is a
GetAttryKey, the initial "." is left off.
Examples:
```python
keypath2string((GetAttrKey("layers"), SequenceKey(10), DictKey("query"),))
# => "layers#10@query
keypath2string((SequenceKey(2), GetAttrKey("layers"), SequenceKey(10), DictKey("query"),))
# => "#2.layers#10@query
```
#### `string2keypath(string: str) -> KeyPath`
Inverse of `keypath2string`
#### `pytree2dict(tree: PyTree) -> dict`
Returns a dictionary of serialized key paths mapping to leaves of the tree.
#### `dict2pytree(dictionary: dict) -> tree`
Inverse of `pytree2dict`, except that it wraps attributes in `PyTreeContainer`s instead of using the original object. This is because there is no way for the deserialiser to know what the original object was. You can use `load_into_pytree` to load weights into an initialized
pytree.
#### `PyTreeContainer`
A class which implements the bare minimum to be a node in a pytree according to JAX.
### Saving
#### `save_pytree(tree: PyTree, path: str) -> None`
Saves the pytree as a safetensor at the given path. Equivalent to
```python
safetensors.flax.save_file(pytree2dict(tree), path)
```
### Loading
#### ```load_file```
Alias of `safetensors.flax.load_file`
#### ```load_pytree(path: str) -> PyTree```
Loads a file and uses `dict2pytree` to convert the safetensor dict to a pytree.
#### ```set_weights(module: PyTree, dictionary: dict) -> PyTree```
Given a pytree module and a safetensor dict, load the weights in the safetensor dict into the module using string2keypath to determine their paths. Returns a new pytree.
#### ```load_into_pytree(module: PyTree, path: str) -> PyTree```
Equivalent to `set_weights(module, load_file(path))`.
Raw data
{
"_id": null,
"home_page": null,
"name": "pytree2safetensor",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.10",
"maintainer_email": null,
"keywords": null,
"author": null,
"author_email": "Joseph Camacho <camacho.joseph@gmail.com>",
"download_url": "https://files.pythonhosted.org/packages/44/d2/595554315b42a2f902e27f6a8402416d4b37f9883f983e5b1b9d2501bb6d/pytree2safetensor-0.1.2.tar.gz",
"platform": null,
"description": "# Pytree2Safetensor\nPytree2Safetensor is a simple package to save and load JAX PyTrees to and from\nSafetensors, a popular file format for saving neural network weights.\n\nTo install, run\n\n```\npip install --upgrade pytree2safetensor\n```\n\nPytree2Safetensor's depends on `jax`, `safetensors`, and `jaxtyping`. You also need\nto have at least Python 3.10\n\n## Specification\n### Serialising/Deserialising\n\n#### `keypath2string(path: KeyPath) -> str`\nSerializes a JAX key path (i.e., a path to a leaf in a pytree) to a string by joining together a string representation of each key in the path. Prefixes of these representation tell what type of key it is. A GetAttryKey is prefixed with \".\", a DictKey is prefixed with \"@\", and a SequenceKey is prefixed with \"#\". If the initial key is a\nGetAttryKey, the initial \".\" is left off.\n\nExamples:\n```python\nkeypath2string((GetAttrKey(\"layers\"), SequenceKey(10), DictKey(\"query\"),))\n# => \"layers#10@query\nkeypath2string((SequenceKey(2), GetAttrKey(\"layers\"), SequenceKey(10), DictKey(\"query\"),))\n# => \"#2.layers#10@query\n```\n\n#### `string2keypath(string: str) -> KeyPath`\nInverse of `keypath2string`\n\n#### `pytree2dict(tree: PyTree) -> dict`\nReturns a dictionary of serialized key paths mapping to leaves of the tree.\n\n#### `dict2pytree(dictionary: dict) -> tree`\nInverse of `pytree2dict`, except that it wraps attributes in `PyTreeContainer`s instead of using the original object. This is because there is no way for the deserialiser to know what the original object was. You can use `load_into_pytree` to load weights into an initialized\npytree.\n\n#### `PyTreeContainer`\nA class which implements the bare minimum to be a node in a pytree according to JAX.\n\n### Saving\n#### `save_pytree(tree: PyTree, path: str) -> None`\nSaves the pytree as a safetensor at the given path. Equivalent to\n```python\nsafetensors.flax.save_file(pytree2dict(tree), path)\n```\n\n### Loading\n#### ```load_file```\nAlias of `safetensors.flax.load_file`\n\n#### ```load_pytree(path: str) -> PyTree```\nLoads a file and uses `dict2pytree` to convert the safetensor dict to a pytree.\n\n#### ```set_weights(module: PyTree, dictionary: dict) -> PyTree```\nGiven a pytree module and a safetensor dict, load the weights in the safetensor dict into the module using string2keypath to determine their paths. Returns a new pytree.\n\n#### ```load_into_pytree(module: PyTree, path: str) -> PyTree```\nEquivalent to `set_weights(module, load_file(path))`.",
"bugtrack_url": null,
"license": null,
"summary": "A simple package to save and load JAX PyTrees to and from Safetensors",
"version": "0.1.2",
"project_urls": {
"Homepage": "https://github.com/cooljoseph1/pytree2safetensor",
"Issues": "https://github.com/cooljoseph1/pytree2safetensor/issues"
},
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "c020c67677dd5395770e12f4996b6908272dca62076c163c70a4a25a441977ba",
"md5": "c66cc0670a03bc9e37b65d8e0a098a9e",
"sha256": "43a9f302bc6e3251614189cb5f71334b1eaff69f142072bbe20af2f1ff31820e"
},
"downloads": -1,
"filename": "pytree2safetensor-0.1.2-py3-none-any.whl",
"has_sig": false,
"md5_digest": "c66cc0670a03bc9e37b65d8e0a098a9e",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.10",
"size": 5675,
"upload_time": "2024-08-23T07:40:36",
"upload_time_iso_8601": "2024-08-23T07:40:36.629610Z",
"url": "https://files.pythonhosted.org/packages/c0/20/c67677dd5395770e12f4996b6908272dca62076c163c70a4a25a441977ba/pytree2safetensor-0.1.2-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "44d2595554315b42a2f902e27f6a8402416d4b37f9883f983e5b1b9d2501bb6d",
"md5": "36346d15bebd76664312563083808e56",
"sha256": "70efb62b368c5c8255e8932a90a22a61666f54bd8a3b4c44faeb623cd3cd6fb7"
},
"downloads": -1,
"filename": "pytree2safetensor-0.1.2.tar.gz",
"has_sig": false,
"md5_digest": "36346d15bebd76664312563083808e56",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.10",
"size": 3982,
"upload_time": "2024-08-23T07:40:38",
"upload_time_iso_8601": "2024-08-23T07:40:38.112174Z",
"url": "https://files.pythonhosted.org/packages/44/d2/595554315b42a2f902e27f6a8402416d4b37f9883f983e5b1b9d2501bb6d/pytree2safetensor-0.1.2.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-08-23 07:40:38",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "cooljoseph1",
"github_project": "pytree2safetensor",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "pytree2safetensor"
}