Name | hyperstate JSON |
Version |
0.4.4
JSON |
| download |
home_page | |
Summary | Library for managing hyperparameters and mutable state of machine learning training systems. |
upload_time | 2023-08-04 16:25:10 |
maintainer | |
docs_url | None |
author | Clemens Winter |
requires_python | >=3.7.1,<3.11.0 |
license | MIT |
keywords |
|
VCS |
|
bugtrack_url |
|
requirements |
No requirements were recorded.
|
Travis-CI |
No Travis.
|
coveralls test coverage |
No coveralls.
|
# HyperState
[](https://pypi.org/project/hyperstate/)
[](https://hyperstate.readthedocs.io/en/latest/?badge=latest)
Opinionated library for managing hyperparameter configs and mutable program state of machine learning training systems.
**Key Features**:
- (De)serialize nested Python dataclasses as [Rusty Object Notation](https://github.com/ron-rs/ron)
- Override any config value from the command line
- Automatic checkpointing and restoration of full program state
- Checkpoints are (partially) human readable and can be modified in a text editor
- Powerful tools for versioning and schema evolution that can detect breaking changes and make it easy to restructure your program while remaining backwards compatible with old checkpoints
- Large binary objects in checkpoints can be loaded lazily only when accessed
- Fermented-vegetable free
- DSL for hyperparameter schedules
- (planned) Edit hyperparameters of running experiments on the fly without restarts
## Quick start guide
Install with pip:
```
pip install hyperstate
```
All you need to use HyperState is a (nested) dataclass for your hyperparameters:
```python
from dataclasses import dataclass
@dataclass
class OptimizerConfig:
lr: float = 0.003
batch_size: int = 512
@dataclass
class NetConfig:
hidden_size: int = 128
num_layers: int = 2
@dataclass
class Config:
optimizer: OptimizerConfig
net: NetConfig
steps: int = 100
```
The `hyperstate.load` function can load values from a config file and/or apply specific overrides from the command line.
```python
import argparse
import hyperstate
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=None, help="Path to config file")
parser.add_argument("--hps", nargs="+", help="Override hyperparameter value")
args = parser.parse_args()
config = hyperstate.load(Config, file=args.config, overrides=args.hps)
print(config)
```
```shell
$ python main.py --hps net.num_layers=96 steps=50
Config(optimizer=OptimizerConfig(lr=0.003, batch_size=512), net=NetConfig(hidden_size=128, num_layers=96), steps=50)
```
```shell
$ cat config.ron
Config(
optimizer: (
lr: 0.05,
batch_size: 4096,
),
)
$ python main.py --config=config.ron
Config(optimizer=OptimizerConfig(lr=0.05, batch_size=4096), net=NetConfig(hidden_size=128, num_layers=2), steps=100)
```
The full code for this example can be found in [examples/basic-config](examples/basic-config).
Learn more about:
- [Configs](#configs)
- [Versioning and schema evolution](#versioning)
- [Serializing complex objects](#unstable-feature-serializable)
- [Checkpointing and schedules](#unstable-feature-hyperstate)
- [Example application](examples/mnist)
## Configs
HyperState supports a strictly typed subset of Python objects:
- dataclasses
- containers: `Dict`, `List`, `Tuple`, `Optional`
- primitives: `int`, `float`, `str`, `Enum`
- objects with custom serialization logic: [`hyperstate.Serializable`](#serializable)
Use `hyperstate.dump` to serialize configs.
The second argument to `dump` is a path to a file, and can be omitted to return the serialized config as a string instead of saving it to a file:
```python
>>> print(hyperstate.dump(Config(lr=0.1, batch_size=256))
Config(
lr: 0.1,
batch_size: 256,
)
```
Use `hyperstate.load` to deserialize configs.
The `load` method takes the type of the config as the first argugment, and allows you to optionally specify the path to a config file and/or a `List[str]` of overrides:
```python
@dataclass
class OptimizerConfig:
lr: float
batch_size: int
@dataclass
class Config:
optimzer: OptimizerConfig
steps: int
config = hyperstate.load(Config, file="config.ron", overrides=["optimizer.lr=0.1", "steps=100"])
```
## Versioning
Versioning allows you to modify your `Config` class while still remaining compatible with checkpoints recorded at previous version.
To benefit from versionining, your config must inherit `hyperstate.Versioned` and implement its `version` function:
```python
@dataclass
class Config(hyperstate.Versioned):
lr: float
batch_size: int
@classmethod
def version(clz) -> int:
return 0
```
When serializing the config, hyperstate will now record an additional `version` field with the value of the current version.
Any snapshots that contain configs without a version field are assumed to have a version of `0`.
### `RewriteRule`
Now suppose you modify your `Config` class, e.g. by renaming the `lr` field to `learning_rate`.
To still be able to load old configs that are using `lr` instead of `learning_rate`, you increase the `version` to `1` and add an entry to the dictionary returned by `upgrade_rules` that tells HyperState to change `lr` to `learning_rate` when upgrading configs from version `0`.
```python
from dataclasses import dataclass
from typing import Dict, List
from hyperstate import Versioned
from hyperstate.schema.rewrite_rule import RenameField, RewriteRule
@dataclass
class Config(Versioned):
learning_rate: float
batch_size: int
@classmethod
def version(clz) -> int:
return 1
@classmethod
def upgrade_rules(clz) -> Dict[int, List[RewriteRule]]:
"""
Returns a list of rewrite rules that can be applied to the given version
to make it compatible with the next version.
"""
return {
0: [RenameField(old_field=("lr",), new_field=("learning_rate",))],
}
```
In the majority of cases, you don't actually have to manually write out `RewriteRule`s.
Instead, they are generated for you automatically by the [Schema Evolution CLI](#schema-evolution-cli).
### Schema evolution CLI
HyperState comes with a command line tool for managing changes to your config schema.
To access the CLI, simply add the following code to the Python file defining your config:
```python
# config.py
from hyperstate import schema_evolution_cli
if __name__ == "__main__":
schema_evolution_cli(Config)
```
Run `python config.py` to see a list of available commands, described in more detail below.
#### `dump-schema`
The `dump-schema` command creates a file describing the schema of your config.
This file should commited to version control, and is used to detect changes to the config schema and perform automatic upgrades.
#### `check-schema`
The `check-schema` command compares your config class to a schema file and detects any backwards incompatible changes.
It also emits a suggested list of [`RewriteRule`](#rewrite-rule)s that can upgrade old configs to the new schema.
HyperState does not always guess the correct `RewriteRule`s so you still need to check that they are correct.
```
$ python config.py check-schema
WARN field renamed to learning_rate: lr
WARN schema changed but version identical
Schema incompatible
Proposed mitigations
- add upgrade rules:
0: [
RenameField(old_field=('lr',), new_field=('learning_rate',)),
],
- bump version to 1
```
#### `upgrade-schema`
The `upgrade-schema` command functions much the same as `check-schema`, but also updates your schema config files once all backwards-incompatability issues have been address.
#### `upgrade-config`
The `upgrade-config` command takes a list of paths to config files, and upgrades them to the latest version.
### Automated Tests
To prevent accidental backwards-incompatible modifications of your `Config` class, you can use the following code as an automated test that checks your config `Class` against a schema file created with [`dump-schema`](#dump-schema):
```python
from hyperstate.schema.schema_change import Severity
from hyperstate.schema.schema_checker import SchemaChecker
from hyperstate.schema.types import load_schema
from config import Config
def test_schema():
old = load_schema("config-schema.ron")
checker = SchemaChecker(old, Config)
if checker.severity() >= Severity.WARN:
checker.print_report()
assert checker.severity() == Severity.INFO
```
## _[unstable feature]_ `Serializable`
You can define custom serialization logic for a class by inheriting from `hyperstate.Serializable` and implementing the `serialize` and `deserialize` methods.
```python
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import hyperstate
@dataclass
class Config:
inputs: int
class LinearRegression(nn.Module, hyperstate.Serializable):
def __init__(self, inputs):
super(Net, self).__init__()
self.fc1 = nn.Linear(inputs, 1)
def forward(self, x):
return self.fc1(x)
# `serialize` should return a representation of the object consisting only of
# primitives, containers, numpy arrays, and torch tensors.
def serialize(self) -> Any:
return self.state_dict()
# `deserialize` should take a serialized representation of the object and
# return an instance of the class. The `ctx` argument allows you to pass
# additional information to the deserialization function.
@classmethod
def deserialize(clz, state_dict, ctx):
net = clz(ctx["config"].inputs)
return net.load_state_dict(state_dict)
@dataclass
class State:
net: LinearRegression
config = hyperstate.load("config.ron")
state = hyperstate.load("state.ron", ctx={"config": config})
```
Objects that implement `Serializable` are stored in separate files using a binary encoding.
In the above example, calling `hyperstate.dump(state, "checkpoint/state.ron")` will result in the following file structure:
```
checkpoint
├── state.net.blob
└── state.ron
```
### _[unstable feature]_ `Lazy`
If you inherit from `hyperstate.Lazy`, any fields with `Serializable` types will only be loaded/deserialized when accessed. If the `.blob` file for a field is missing, HyperState will not raise an error unless the corresponding field is accessed.
### _[unstable feature]_`blob`
To include objects in your state that do not directly implement `hyperstate.Serializable`, you can seperately implement `hyperstate.Serializable` and use the `blob` function to mix in the `Serializable` implementation:
```python
import torch.optim as optim
import torch.nn as nn
import hyperstate
class SerializableOptimizer(hyperstate.Serializable):
def serialize(self):
return self.state_dict()
@classmethod
def deserialize(clz, state_dict: Any, config: Config, state: "State") -> optim.Optimizer:
optimizer = blob(optim.SerializableAdam, mixin=SerializableOptimizer)(state.net.parameters())
optimizer.load_state_dict(state_dict)
return optimizer
@dataclass
class State(hyperstate.Lazy):
net: nn.Module
optimizer: blob(Adam, mixin=SerializableOptimizer)
```
## _[unstable feature]_ `HyperState`
To unlock the full power of HyperState, you must inherit from the `HyperState` class.
This class combines an immutable config and mutable state, and provides automatic checkpointing, hyperparameter schedules, and the on-the-fly changes to the config and state (not implemented yet).
```python
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import hyperstate
@dataclass
class Config:
inputs: int
steps: int
class LinearRegression(nn.Module, hyperstate.Serializable):
def __init__(self, inputs):
super(Net, self).__init__()
self.fc1 = nn.Linear(inputs, 1)
def forward(self, x):
return self.fc1(x)
def serialize(self) -> Any:
return self.state_dict()
@classmethod
def deserialize(clz, state_dict, ctx):
net = clz(ctx["config"].inputs)
return net.load_state_dict(state_dict)
@dataclass
class State:
net: LinearRegression
step: int
class Trainer(HyperState[Config, State]):
def __init__(
self,
# Path to the config file
initial_config: str,
# Optional path to the checkpoint directory, which enables automatic checkpointing.
# If any checkpoint files are present, they will be used to initialize the state.
checkpoint_dir: Optional[str] = None,
# List of manually specified config overrides.
config_overrides: Optional[List[str]] = None,
):
super().__init__(Config, State, initial_config, checkpoint_dir, overrides=config_overrides)
def initial_state(self) -> State:
"""
This function is called to initialize the state if no checkpoint files are found.
"""
return State(net=LinearRegression(self.config.inputs))
def train(self) -> None:
for step in range(self.state.step, self.config.steps):
# training code...
self.state.step = step
# At the end of each iteration, call `self.step()` to checkpoint the state and apply hyperparameter schedules.
self.step()
```
### _[unstable feature]_ Checkpointing
When using the `HyperState` object, the config and state are automatically checkpointed to the configured directory when calling the `step` method.
### _[unstable feature]_ Schedules
Any `int`/`float` fields in the config can also be set to a schedule that will be updated at each step.
For example, the following config defines a schedule that linearly decays the learning rate from 1.0 to 0.1 over 1000 steps:
```rust
Config(
lr: Schedule(
key: "state.step",
schedule: [
(0, 1.0),
"lin",
(1000, 0.1),
],
),
batch_size: 256,
)
```
When you call `step()`, all config values that are schedules will be updated.
## License
HyperState is dual-licensed under the MIT license and Apache License (Version 2.0).
See LICENSE-MIT and LICENSE-APACHE for more information.
Raw data
{
"_id": null,
"home_page": "",
"name": "hyperstate",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.7.1,<3.11.0",
"maintainer_email": "",
"keywords": "",
"author": "Clemens Winter",
"author_email": "clemenswinter1@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/63/5f/186c1a4d3be1095701bdc88cd81835f04d4d75950719ec6a70bc1c99dd1c/hyperstate-0.4.4.tar.gz",
"platform": null,
"description": "# HyperState\n\n[](https://pypi.org/project/hyperstate/)\n[](https://hyperstate.readthedocs.io/en/latest/?badge=latest)\n\nOpinionated library for managing hyperparameter configs and mutable program state of machine learning training systems.\n\n**Key Features**:\n- (De)serialize nested Python dataclasses as [Rusty Object Notation](https://github.com/ron-rs/ron)\n- Override any config value from the command line\n- Automatic checkpointing and restoration of full program state\n- Checkpoints are (partially) human readable and can be modified in a text editor\n- Powerful tools for versioning and schema evolution that can detect breaking changes and make it easy to restructure your program while remaining backwards compatible with old checkpoints\n- Large binary objects in checkpoints can be loaded lazily only when accessed\n- Fermented-vegetable free\n- DSL for hyperparameter schedules \n- (planned) Edit hyperparameters of running experiments on the fly without restarts\n\n## Quick start guide\n\nInstall with pip:\n\n```\npip install hyperstate\n```\n\nAll you need to use HyperState is a (nested) dataclass for your hyperparameters:\n\n```python\nfrom dataclasses import dataclass\n\n\n@dataclass\nclass OptimizerConfig:\n lr: float = 0.003\n batch_size: int = 512\n\n\n@dataclass\nclass NetConfig:\n hidden_size: int = 128\n num_layers: int = 2\n\n\n@dataclass\nclass Config:\n optimizer: OptimizerConfig\n net: NetConfig\n steps: int = 100\n```\n\nThe `hyperstate.load` function can load values from a config file and/or apply specific overrides from the command line.\n\n```python\nimport argparse\nimport hyperstate\n\nif __name__ == \"__main__\":\n parser = argparse.ArgumentParser()\n parser.add_argument(\"--config\", type=str, default=None, help=\"Path to config file\")\n parser.add_argument(\"--hps\", nargs=\"+\", help=\"Override hyperparameter value\")\n args = parser.parse_args()\n config = hyperstate.load(Config, file=args.config, overrides=args.hps)\n print(config)\n```\n\n```shell\n$ python main.py --hps net.num_layers=96 steps=50\nConfig(optimizer=OptimizerConfig(lr=0.003, batch_size=512), net=NetConfig(hidden_size=128, num_layers=96), steps=50)\n```\n\n```shell\n$ cat config.ron\nConfig(\n optimizer: (\n lr: 0.05,\n batch_size: 4096,\n ),\n)\n$ python main.py --config=config.ron\nConfig(optimizer=OptimizerConfig(lr=0.05, batch_size=4096), net=NetConfig(hidden_size=128, num_layers=2), steps=100)\n```\n\nThe full code for this example can be found in [examples/basic-config](examples/basic-config).\n\nLearn more about:\n- [Configs](#configs)\n- [Versioning and schema evolution](#versioning)\n- [Serializing complex objects](#unstable-feature-serializable)\n- [Checkpointing and schedules](#unstable-feature-hyperstate)\n- [Example application](examples/mnist)\n\n## Configs\n\nHyperState supports a strictly typed subset of Python objects:\n- dataclasses\n- containers: `Dict`, `List`, `Tuple`, `Optional`\n- primitives: `int`, `float`, `str`, `Enum`\n- objects with custom serialization logic: [`hyperstate.Serializable`](#serializable)\n\nUse `hyperstate.dump` to serialize configs.\nThe second argument to `dump` is a path to a file, and can be omitted to return the serialized config as a string instead of saving it to a file:\n\n```python\n>>> print(hyperstate.dump(Config(lr=0.1, batch_size=256))\nConfig(\n lr: 0.1,\n batch_size: 256,\n)\n```\n\nUse `hyperstate.load` to deserialize configs.\nThe `load` method takes the type of the config as the first argugment, and allows you to optionally specify the path to a config file and/or a `List[str]` of overrides:\n\n```python\n@dataclass\nclass OptimizerConfig:\n lr: float\n batch_size: int\n\n@dataclass\nclass Config:\n optimzer: OptimizerConfig\n steps: int\n\n\nconfig = hyperstate.load(Config, file=\"config.ron\", overrides=[\"optimizer.lr=0.1\", \"steps=100\"])\n```\n\n## Versioning\n\nVersioning allows you to modify your `Config` class while still remaining compatible with checkpoints recorded at previous version.\nTo benefit from versionining, your config must inherit `hyperstate.Versioned` and implement its `version` function:\n\n```python\n@dataclass\nclass Config(hyperstate.Versioned):\n lr: float\n batch_size: int\n \n @classmethod\n def version(clz) -> int:\n return 0\n```\n\nWhen serializing the config, hyperstate will now record an additional `version` field with the value of the current version.\nAny snapshots that contain configs without a version field are assumed to have a version of `0`.\n\n### `RewriteRule`\n\nNow suppose you modify your `Config` class, e.g. by renaming the `lr` field to `learning_rate`.\nTo still be able to load old configs that are using `lr` instead of `learning_rate`, you increase the `version` to `1` and add an entry to the dictionary returned by `upgrade_rules` that tells HyperState to change `lr` to `learning_rate` when upgrading configs from version `0`.\n\n```python\nfrom dataclasses import dataclass\nfrom typing import Dict, List\nfrom hyperstate import Versioned\nfrom hyperstate.schema.rewrite_rule import RenameField, RewriteRule\n\n@dataclass\nclass Config(Versioned):\n learning_rate: float\n batch_size: int\n \n @classmethod\n def version(clz) -> int:\n return 1\n\n @classmethod\n def upgrade_rules(clz) -> Dict[int, List[RewriteRule]]:\n \"\"\"\n Returns a list of rewrite rules that can be applied to the given version\n to make it compatible with the next version.\n \"\"\"\n return {\n 0: [RenameField(old_field=(\"lr\",), new_field=(\"learning_rate\",))],\n }\n```\n\nIn the majority of cases, you don't actually have to manually write out `RewriteRule`s.\nInstead, they are generated for you automatically by the [Schema Evolution CLI](#schema-evolution-cli).\n\n### Schema evolution CLI\n\nHyperState comes with a command line tool for managing changes to your config schema.\nTo access the CLI, simply add the following code to the Python file defining your config:\n\n```python\n# config.py\nfrom hyperstate import schema_evolution_cli\n\nif __name__ == \"__main__\":\n schema_evolution_cli(Config)\n```\n\nRun `python config.py` to see a list of available commands, described in more detail below.\n\n#### `dump-schema`\n\nThe `dump-schema` command creates a file describing the schema of your config.\nThis file should commited to version control, and is used to detect changes to the config schema and perform automatic upgrades.\n\n#### `check-schema`\n\nThe `check-schema` command compares your config class to a schema file and detects any backwards incompatible changes.\nIt also emits a suggested list of [`RewriteRule`](#rewrite-rule)s that can upgrade old configs to the new schema.\nHyperState does not always guess the correct `RewriteRule`s so you still need to check that they are correct.\n\n```\n$ python config.py check-schema\nWARN field renamed to learning_rate: lr\nWARN schema changed but version identical\nSchema incompatible\n\nProposed mitigations\n- add upgrade rules:\n 0: [\n RenameField(old_field=('lr',), new_field=('learning_rate',)),\n ],\n- bump version to 1\n```\n\n#### `upgrade-schema`\n\nThe `upgrade-schema` command functions much the same as `check-schema`, but also updates your schema config files once all backwards-incompatability issues have been address.\n\n#### `upgrade-config`\n\nThe `upgrade-config` command takes a list of paths to config files, and upgrades them to the latest version.\n\n### Automated Tests\n\nTo prevent accidental backwards-incompatible modifications of your `Config` class, you can use the following code as an automated test that checks your config `Class` against a schema file created with [`dump-schema`](#dump-schema): \n\n```python\nfrom hyperstate.schema.schema_change import Severity\nfrom hyperstate.schema.schema_checker import SchemaChecker\nfrom hyperstate.schema.types import load_schema\nfrom config import Config\n\ndef test_schema():\n old = load_schema(\"config-schema.ron\")\n checker = SchemaChecker(old, Config)\n if checker.severity() >= Severity.WARN:\n checker.print_report()\n assert checker.severity() == Severity.INFO\n```\n\n## _[unstable feature]_ `Serializable`\n\nYou can define custom serialization logic for a class by inheriting from `hyperstate.Serializable` and implementing the `serialize` and `deserialize` methods.\n\n```python\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport hyperstate\n\n@dataclass\nclass Config:\n inputs: int\n\nclass LinearRegression(nn.Module, hyperstate.Serializable):\n def __init__(self, inputs):\n super(Net, self).__init__()\n self.fc1 = nn.Linear(inputs, 1)\n \n def forward(self, x):\n return self.fc1(x)\n \n # `serialize` should return a representation of the object consisting only of\n # primitives, containers, numpy arrays, and torch tensors.\n def serialize(self) -> Any:\n return self.state_dict()\n\n # `deserialize` should take a serialized representation of the object and\n # return an instance of the class. The `ctx` argument allows you to pass\n # additional information to the deserialization function.\n @classmethod\n def deserialize(clz, state_dict, ctx):\n net = clz(ctx[\"config\"].inputs)\n return net.load_state_dict(state_dict)\n\n@dataclass\nclass State:\n net: LinearRegression\n\nconfig = hyperstate.load(\"config.ron\")\nstate = hyperstate.load(\"state.ron\", ctx={\"config\": config})\n```\n\nObjects that implement `Serializable` are stored in separate files using a binary encoding.\nIn the above example, calling `hyperstate.dump(state, \"checkpoint/state.ron\")` will result in the following file structure:\n\n```\ncheckpoint\n\u251c\u2500\u2500 state.net.blob\n\u2514\u2500\u2500 state.ron\n```\n\n### _[unstable feature]_ `Lazy`\n\nIf you inherit from `hyperstate.Lazy`, any fields with `Serializable` types will only be loaded/deserialized when accessed. If the `.blob` file for a field is missing, HyperState will not raise an error unless the corresponding field is accessed.\n\n### _[unstable feature]_`blob`\n\nTo include objects in your state that do not directly implement `hyperstate.Serializable`, you can seperately implement `hyperstate.Serializable` and use the `blob` function to mix in the `Serializable` implementation:\n\n```python\nimport torch.optim as optim\nimport torch.nn as nn\nimport hyperstate\n\nclass SerializableOptimizer(hyperstate.Serializable):\n def serialize(self):\n return self.state_dict()\n\n @classmethod\n def deserialize(clz, state_dict: Any, config: Config, state: \"State\") -> optim.Optimizer:\n optimizer = blob(optim.SerializableAdam, mixin=SerializableOptimizer)(state.net.parameters())\n optimizer.load_state_dict(state_dict)\n return optimizer\n\n@dataclass\nclass State(hyperstate.Lazy):\n net: nn.Module\n optimizer: blob(Adam, mixin=SerializableOptimizer)\n```\n\n\n\n## _[unstable feature]_ `HyperState`\n\nTo unlock the full power of HyperState, you must inherit from the `HyperState` class.\nThis class combines an immutable config and mutable state, and provides automatic checkpointing, hyperparameter schedules, and the on-the-fly changes to the config and state (not implemented yet).\n\n```python\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport hyperstate\n\n@dataclass\nclass Config:\n inputs: int\n steps: int\n\nclass LinearRegression(nn.Module, hyperstate.Serializable):\n def __init__(self, inputs):\n super(Net, self).__init__()\n self.fc1 = nn.Linear(inputs, 1)\n def forward(self, x):\n return self.fc1(x)\n def serialize(self) -> Any:\n return self.state_dict()\n @classmethod\n def deserialize(clz, state_dict, ctx):\n net = clz(ctx[\"config\"].inputs)\n return net.load_state_dict(state_dict)\n\n@dataclass\nclass State:\n net: LinearRegression\n step: int\n\n\nclass Trainer(HyperState[Config, State]):\n def __init__(\n self,\n # Path to the config file\n initial_config: str,\n # Optional path to the checkpoint directory, which enables automatic checkpointing.\n # If any checkpoint files are present, they will be used to initialize the state.\n checkpoint_dir: Optional[str] = None,\n # List of manually specified config overrides.\n config_overrides: Optional[List[str]] = None,\n ):\n super().__init__(Config, State, initial_config, checkpoint_dir, overrides=config_overrides)\n\n def initial_state(self) -> State:\n \"\"\"\n This function is called to initialize the state if no checkpoint files are found.\n \"\"\"\n return State(net=LinearRegression(self.config.inputs))\n\n def train(self) -> None:\n for step in range(self.state.step, self.config.steps):\n # training code...\n\n self.state.step = step\n # At the end of each iteration, call `self.step()` to checkpoint the state and apply hyperparameter schedules.\n self.step()\n```\n\n### _[unstable feature]_ Checkpointing\n\nWhen using the `HyperState` object, the config and state are automatically checkpointed to the configured directory when calling the `step` method.\n\n### _[unstable feature]_ Schedules\n\nAny `int`/`float` fields in the config can also be set to a schedule that will be updated at each step.\nFor example, the following config defines a schedule that linearly decays the learning rate from 1.0 to 0.1 over 1000 steps:\n\n```rust\nConfig(\n lr: Schedule(\n key: \"state.step\",\n schedule: [\n (0, 1.0),\n \"lin\",\n (1000, 0.1),\n ],\n ),\n batch_size: 256,\n)\n```\n\nWhen you call `step()`, all config values that are schedules will be updated.\n\n\n## License\n\nHyperState is dual-licensed under the MIT license and Apache License (Version 2.0).\n\nSee LICENSE-MIT and LICENSE-APACHE for more information.\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "Library for managing hyperparameters and mutable state of machine learning training systems.",
"version": "0.4.4",
"project_urls": null,
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "5e50b7740a2ab161143f4277d1a2c58a49e953468f175062f0b3e42493a70334",
"md5": "4ccc5461d35c5ed9e75de0982892011b",
"sha256": "c93b0c4630503fc3bebbcf837784f1449604e096050c5e6b8e269f44ea0b5bcb"
},
"downloads": -1,
"filename": "hyperstate-0.4.4-py3-none-any.whl",
"has_sig": false,
"md5_digest": "4ccc5461d35c5ed9e75de0982892011b",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.7.1,<3.11.0",
"size": 38640,
"upload_time": "2023-08-04T16:25:08",
"upload_time_iso_8601": "2023-08-04T16:25:08.545747Z",
"url": "https://files.pythonhosted.org/packages/5e/50/b7740a2ab161143f4277d1a2c58a49e953468f175062f0b3e42493a70334/hyperstate-0.4.4-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "635f186c1a4d3be1095701bdc88cd81835f04d4d75950719ec6a70bc1c99dd1c",
"md5": "98ec85475abfb955316f244ee4a11ac0",
"sha256": "eebdc988582717c9a75159c413a716d9265afe070b077fc2b713cf2770d696c1"
},
"downloads": -1,
"filename": "hyperstate-0.4.4.tar.gz",
"has_sig": false,
"md5_digest": "98ec85475abfb955316f244ee4a11ac0",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.7.1,<3.11.0",
"size": 39358,
"upload_time": "2023-08-04T16:25:10",
"upload_time_iso_8601": "2023-08-04T16:25:10.242415Z",
"url": "https://files.pythonhosted.org/packages/63/5f/186c1a4d3be1095701bdc88cd81835f04d4d75950719ec6a70bc1c99dd1c/hyperstate-0.4.4.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-08-04 16:25:10",
"github": false,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"lcname": "hyperstate"
}