| Name | pytree2safetensors JSON |
| Version |
0.1.4
JSON |
| download |
| home_page | None |
| Summary | A simple package to save and load JAX PyTrees to and from Safetensors |
| upload_time | 2024-08-26 10:40:57 |
| maintainer | None |
| docs_url | None |
| author | None |
| requires_python | >=3.10 |
| license | None |
| keywords |
|
| VCS |
 |
| bugtrack_url |
|
| requirements |
No requirements were recorded.
|
| Travis-CI |
No Travis.
|
| coveralls test coverage |
No coveralls.
|
# Pytree2Safetensors
Pytree2Safetensors is a simple package to save and load JAX PyTrees to and from
Safetensors, a popular file format for saving neural network weights.
To install, run
```
pip install --upgrade pytree2safetensors
```
Pytree2Safetensors depends on `jax`, `safetensors`, and `jaxtyping`. You also need
to have at least Python 3.10
## Specification
### Serialising/Deserialising
#### `keypath2string(path: KeyPath) -> str`
Serializes a JAX key path (i.e., a path to a leaf in a pytree) to a string by joining together a string representation of each key in the path. Prefixes of these representation tell what type of key it is. A GetAttryKey is prefixed with ".", a DictKey is prefixed with "@", and a SequenceKey is prefixed with "#". If the initial key is a
GetAttryKey, the initial "." is left off.
Examples:
```python
keypath2string((GetAttrKey("layers"), SequenceKey(10), DictKey("query"),))
# => "layers#10@query
keypath2string((SequenceKey(2), GetAttrKey("layers"), SequenceKey(10), DictKey("query"),))
# => "#2.layers#10@query
```
#### `string2keypath(string: str) -> KeyPath`
Inverse of `keypath2string`
#### `pytree2dict(tree: PyTree) -> dict`
Returns a dictionary of serialized key paths mapping to leaves of the tree.
#### `dict2pytree(dictionary: dict) -> tree`
Inverse of `pytree2dict`, except that it wraps attributes in `PyTreeContainer`s instead of using the original object. This is because there is no way for the deserialiser to know what the original object was. You can use `load_into_pytree` to load weights into an initialized
pytree.
#### `PyTreeContainer`
A class which implements the bare minimum to be a node in a pytree according to JAX.
### Saving
#### `save_pytree(tree: PyTree, path: str) -> None`
Saves the pytree as a safetensors at the given path. Equivalent to
```python
safetensors.flax.save_file(pytree2dict(tree), path)
```
### Loading
#### ```load_file```
Alias of `safetensors.flax.load_file`
#### ```load_pytree(path: str) -> PyTree```
Loads a file and uses `dict2pytree` to convert the safetensors dict to a pytree.
#### ```set_weights(module: PyTree, dictionary: dict) -> PyTree```
Given a pytree module and a safetensors dict, load the weights in the safetensors dict into the module using string2keypath to determine their paths. Returns a new pytree.
#### ```load_into_pytree(module: PyTree, path: str) -> PyTree```
Equivalent to `set_weights(module, load_file(path))`.
Raw data
{
"_id": null,
"home_page": null,
"name": "pytree2safetensors",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.10",
"maintainer_email": null,
"keywords": null,
"author": null,
"author_email": "Joseph Camacho <camacho.joseph@gmail.com>",
"download_url": "https://files.pythonhosted.org/packages/05/7d/f4953f902c9291bbdc9bbc2a3f41962338385524144d86322e0af0a205a5/pytree2safetensors-0.1.4.tar.gz",
"platform": null,
"description": "# Pytree2Safetensors\nPytree2Safetensors is a simple package to save and load JAX PyTrees to and from\nSafetensors, a popular file format for saving neural network weights.\n\nTo install, run\n\n```\npip install --upgrade pytree2safetensors\n```\n\nPytree2Safetensors depends on `jax`, `safetensors`, and `jaxtyping`. You also need\nto have at least Python 3.10\n\n## Specification\n### Serialising/Deserialising\n\n#### `keypath2string(path: KeyPath) -> str`\nSerializes a JAX key path (i.e., a path to a leaf in a pytree) to a string by joining together a string representation of each key in the path. Prefixes of these representation tell what type of key it is. A GetAttryKey is prefixed with \".\", a DictKey is prefixed with \"@\", and a SequenceKey is prefixed with \"#\". If the initial key is a\nGetAttryKey, the initial \".\" is left off.\n\nExamples:\n```python\nkeypath2string((GetAttrKey(\"layers\"), SequenceKey(10), DictKey(\"query\"),))\n# => \"layers#10@query\nkeypath2string((SequenceKey(2), GetAttrKey(\"layers\"), SequenceKey(10), DictKey(\"query\"),))\n# => \"#2.layers#10@query\n```\n\n#### `string2keypath(string: str) -> KeyPath`\nInverse of `keypath2string`\n\n#### `pytree2dict(tree: PyTree) -> dict`\nReturns a dictionary of serialized key paths mapping to leaves of the tree.\n\n#### `dict2pytree(dictionary: dict) -> tree`\nInverse of `pytree2dict`, except that it wraps attributes in `PyTreeContainer`s instead of using the original object. This is because there is no way for the deserialiser to know what the original object was. You can use `load_into_pytree` to load weights into an initialized\npytree.\n\n#### `PyTreeContainer`\nA class which implements the bare minimum to be a node in a pytree according to JAX.\n\n### Saving\n#### `save_pytree(tree: PyTree, path: str) -> None`\nSaves the pytree as a safetensors at the given path. Equivalent to\n```python\nsafetensors.flax.save_file(pytree2dict(tree), path)\n```\n\n### Loading\n#### ```load_file```\nAlias of `safetensors.flax.load_file`\n\n#### ```load_pytree(path: str) -> PyTree```\nLoads a file and uses `dict2pytree` to convert the safetensors dict to a pytree.\n\n#### ```set_weights(module: PyTree, dictionary: dict) -> PyTree```\nGiven a pytree module and a safetensors dict, load the weights in the safetensors dict into the module using string2keypath to determine their paths. Returns a new pytree.\n\n#### ```load_into_pytree(module: PyTree, path: str) -> PyTree```\nEquivalent to `set_weights(module, load_file(path))`.",
"bugtrack_url": null,
"license": null,
"summary": "A simple package to save and load JAX PyTrees to and from Safetensors",
"version": "0.1.4",
"project_urls": {
"Homepage": "https://github.com/cooljoseph1/pytree2safetensors",
"Issues": "https://github.com/cooljoseph1/pytree2safetensors/issues"
},
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "a9b1a6ca7d1339a88126199b07fb02863832d96b2a08af6870e58a798cf1695a",
"md5": "3323d611078558552e0063bc43d4f902",
"sha256": "5713fcb9b5f1fab0717c88bcfc3d7fe72b2dc39d865288f26e66fe5fad391657"
},
"downloads": -1,
"filename": "pytree2safetensors-0.1.4-py3-none-any.whl",
"has_sig": false,
"md5_digest": "3323d611078558552e0063bc43d4f902",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.10",
"size": 6171,
"upload_time": "2024-08-26T10:40:56",
"upload_time_iso_8601": "2024-08-26T10:40:56.280326Z",
"url": "https://files.pythonhosted.org/packages/a9/b1/a6ca7d1339a88126199b07fb02863832d96b2a08af6870e58a798cf1695a/pytree2safetensors-0.1.4-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "057df4953f902c9291bbdc9bbc2a3f41962338385524144d86322e0af0a205a5",
"md5": "470d04195338b4140cf13159556f59ec",
"sha256": "4d7c9bed9ed5b4bf6cc5dc6615b87c8bc96ed4ede82b7e9c1c22531e47babc65"
},
"downloads": -1,
"filename": "pytree2safetensors-0.1.4.tar.gz",
"has_sig": false,
"md5_digest": "470d04195338b4140cf13159556f59ec",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.10",
"size": 4404,
"upload_time": "2024-08-26T10:40:57",
"upload_time_iso_8601": "2024-08-26T10:40:57.164295Z",
"url": "https://files.pythonhosted.org/packages/05/7d/f4953f902c9291bbdc9bbc2a3f41962338385524144d86322e0af0a205a5/pytree2safetensors-0.1.4.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-08-26 10:40:57",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "cooljoseph1",
"github_project": "pytree2safetensors",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "pytree2safetensors"
}