kerastorch 是一个通用的pytorch模型训练模版工具,你可以像使用keras一样,使用pytorch,它具有如下优点:
* **好看** (代码优雅,日志美丽,自带可视化)
* **好用** (使用方便,支持 进度条、评估指标、early-stopping等常用功能,支持tensorboard,wandb回调函数等扩展功能)
* **好改** (修改简单,核心代码模块化,仅约200行,并提供丰富的修改使用案例)
## 使用方法 🍊🍊
安装 kerastorch
```
pip install kerastorch
```
通过使用kerastorch,你不需要写自己的pytorch模型训练循环。你只要做这样两步就可以了。
(1) 创建你的模型结构net,然后把它和损失函数传入kerastorch.KerasModel构建一个model。
(2) 使用model的fit方法在你的训练数据和验证数据上进行训练,训练数据和验证数据需要封装成两个DataLoader.
核心使用代码就像下面这样:
```python
import torch
import kerastorch
import torchmetrics
model = kerastorch.KerasModel(net,
loss_fn = nn.BCEWithLogitsLoss(),
optimizer= torch.optim.Adam(net.parameters(),lr = 1e-4),
metrics_dict = {"acc":torchmetrics.Accuracy(task='binary')}
)
dfhistory=model.fit(train_data=dl_train,
val_data=dl_val,
epochs=20,
patience=3,
ckpt_path='checkpoint.pt',
monitor="val_acc",
mode="max",
plot=True,
)
```
在jupyter notebook中执行训练代码,你将看到类似下面的训练可视化图像和训练日志进度条。
![](./data/kerastorch_plot.gif)
## 主要特性 🍉🍉
kerastorch 支持以下这些功能特性,稳定支持这些功能的起始版本以及这些功能借鉴或者依赖的库的来源见下表。
|功能| 稳定支持起始版本 | 依赖或借鉴库 |
|:----|:-------------------:|:--------------|
|✅ 训练进度条 | 3.0.0 | 依赖tqdm,借鉴keras|
|✅ 训练评估指标 | 3.0.0 | 借鉴pytorch_lightning |
|✅ notebook中训练自带可视化 | 3.8.0 |借鉴fastai |
|✅ early stopping | 3.0.0 | 借鉴keras |
|✅ gpu training | 3.0.0 |依赖accelerate|
|✅ multi-gpus training(ddp) | 3.6.0 | 依赖accelerate|
|✅ fp16/bf16 training| 3.6.0 | 依赖accelerate|
|✅ tensorboard callback | 3.7.0 |依赖tensorboard |
|✅ wandb callback | 3.7.0 |依赖wandb |
Raw data
{
"_id": null,
"home_page": "",
"name": "kerastorch",
"maintainer": "",
"docs_url": null,
"requires_python": "",
"maintainer_email": "",
"keywords": "machine-learning,deep-learning,ML,DL,pytorch,torch,llm",
"author": "kerastorch team",
"author_email": "1027763372@qq.com",
"download_url": "https://files.pythonhosted.org/packages/e3/3f/0cc24395ba6334111ed373267209eabe27bcc139d639b139c48cb5cc99d8/kerastorch-1.0.3.tar.gz",
"platform": "all",
"description": "kerastorch \u662f\u4e00\u4e2a\u901a\u7528\u7684pytorch\u6a21\u578b\u8bad\u7ec3\u6a21\u7248\u5de5\u5177\uff0c\u4f60\u53ef\u4ee5\u50cf\u4f7f\u7528keras\u4e00\u6837\uff0c\u4f7f\u7528pytorch,\u5b83\u5177\u6709\u5982\u4e0b\u4f18\u70b9\uff1a\n\n* **\u597d\u770b** (\u4ee3\u7801\u4f18\u96c5\uff0c\u65e5\u5fd7\u7f8e\u4e3d\uff0c\u81ea\u5e26\u53ef\u89c6\u5316)\n\n* **\u597d\u7528** (\u4f7f\u7528\u65b9\u4fbf\uff0c\u652f\u6301 \u8fdb\u5ea6\u6761\u3001\u8bc4\u4f30\u6307\u6807\u3001early-stopping\u7b49\u5e38\u7528\u529f\u80fd\uff0c\u652f\u6301tensorboard\uff0cwandb\u56de\u8c03\u51fd\u6570\u7b49\u6269\u5c55\u529f\u80fd)\n\n* **\u597d\u6539** (\u4fee\u6539\u7b80\u5355\uff0c\u6838\u5fc3\u4ee3\u7801\u6a21\u5757\u5316\uff0c\u4ec5\u7ea6200\u884c\uff0c\u5e76\u63d0\u4f9b\u4e30\u5bcc\u7684\u4fee\u6539\u4f7f\u7528\u6848\u4f8b)\n\n\n\n## \u4f7f\u7528\u65b9\u6cd5 \ud83c\udf4a\ud83c\udf4a\n\n\n\u5b89\u88c5 kerastorch\n```\npip install kerastorch\n```\n\n\u901a\u8fc7\u4f7f\u7528kerastorch\uff0c\u4f60\u4e0d\u9700\u8981\u5199\u81ea\u5df1\u7684pytorch\u6a21\u578b\u8bad\u7ec3\u5faa\u73af\u3002\u4f60\u53ea\u8981\u505a\u8fd9\u6837\u4e24\u6b65\u5c31\u53ef\u4ee5\u4e86\u3002\n\n(1) \u521b\u5efa\u4f60\u7684\u6a21\u578b\u7ed3\u6784net,\u7136\u540e\u628a\u5b83\u548c\u635f\u5931\u51fd\u6570\u4f20\u5165kerastorch.KerasModel\u6784\u5efa\u4e00\u4e2amodel\u3002\n\n(2) \u4f7f\u7528model\u7684fit\u65b9\u6cd5\u5728\u4f60\u7684\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\u9700\u8981\u5c01\u88c5\u6210\u4e24\u4e2aDataLoader.\n\n\n\n\u6838\u5fc3\u4f7f\u7528\u4ee3\u7801\u5c31\u50cf\u4e0b\u9762\u8fd9\u6837\uff1a\n\n```python\nimport torch \nimport kerastorch\nimport torchmetrics\nmodel = kerastorch.KerasModel(net,\n loss_fn = nn.BCEWithLogitsLoss(),\n optimizer= torch.optim.Adam(net.parameters(),lr = 1e-4),\n metrics_dict = {\"acc\":torchmetrics.Accuracy(task='binary')}\n )\ndfhistory=model.fit(train_data=dl_train, \n val_data=dl_val, \n epochs=20, \n patience=3, \n ckpt_path='checkpoint.pt',\n monitor=\"val_acc\",\n mode=\"max\",\n plot=True,\n \n )\n\n```\n\n\u5728jupyter notebook\u4e2d\u6267\u884c\u8bad\u7ec3\u4ee3\u7801\uff0c\u4f60\u5c06\u770b\u5230\u7c7b\u4f3c\u4e0b\u9762\u7684\u8bad\u7ec3\u53ef\u89c6\u5316\u56fe\u50cf\u548c\u8bad\u7ec3\u65e5\u5fd7\u8fdb\u5ea6\u6761\u3002\n\n![](./data/kerastorch_plot.gif)\n\n\n\n\n## \u4e3b\u8981\u7279\u6027 \ud83c\udf49\ud83c\udf49\n\n\nkerastorch \u652f\u6301\u4ee5\u4e0b\u8fd9\u4e9b\u529f\u80fd\u7279\u6027\uff0c\u7a33\u5b9a\u652f\u6301\u8fd9\u4e9b\u529f\u80fd\u7684\u8d77\u59cb\u7248\u672c\u4ee5\u53ca\u8fd9\u4e9b\u529f\u80fd\u501f\u9274\u6216\u8005\u4f9d\u8d56\u7684\u5e93\u7684\u6765\u6e90\u89c1\u4e0b\u8868\u3002\n\n\n|\u529f\u80fd| \u7a33\u5b9a\u652f\u6301\u8d77\u59cb\u7248\u672c | \u4f9d\u8d56\u6216\u501f\u9274\u5e93 |\n|:----|:-------------------:|:--------------|\n|\u2705 \u8bad\u7ec3\u8fdb\u5ea6\u6761 | 3.0.0 | \u4f9d\u8d56tqdm,\u501f\u9274keras|\n|\u2705 \u8bad\u7ec3\u8bc4\u4f30\u6307\u6807 | 3.0.0 | \u501f\u9274pytorch_lightning |\n|\u2705 notebook\u4e2d\u8bad\u7ec3\u81ea\u5e26\u53ef\u89c6\u5316 | 3.8.0 |\u501f\u9274fastai |\n|\u2705 early stopping | 3.0.0 | \u501f\u9274keras |\n|\u2705 gpu training | 3.0.0 |\u4f9d\u8d56accelerate|\n|\u2705 multi-gpus training(ddp) | 3.6.0 | \u4f9d\u8d56accelerate|\n|\u2705 fp16/bf16 training| 3.6.0 | \u4f9d\u8d56accelerate|\n|\u2705 tensorboard callback | 3.7.0 |\u4f9d\u8d56tensorboard |\n|\u2705 wandb callback | 3.7.0 |\u4f9d\u8d56wandb |\n\n",
"bugtrack_url": null,
"license": "BSD License",
"summary": "llm model process for pytorch",
"version": "1.0.3",
"project_urls": null,
"split_keywords": [
"machine-learning",
"deep-learning",
"ml",
"dl",
"pytorch",
"torch",
"llm"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "0f05b099d62a05132ef06cf30e73beec291d9425b34cbf0db9c5ca24d07766ba",
"md5": "fef9f1b9f8fd40a7505a64f931cf089e",
"sha256": "0c31766f88f77cc4084f372bc1e001d40cf66cf91a240a692ed6353eedb1ed8d"
},
"downloads": -1,
"filename": "kerastorch-1.0.3-py3-none-any.whl",
"has_sig": false,
"md5_digest": "fef9f1b9f8fd40a7505a64f931cf089e",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": null,
"size": 91104,
"upload_time": "2023-10-15T01:38:35",
"upload_time_iso_8601": "2023-10-15T01:38:35.184157Z",
"url": "https://files.pythonhosted.org/packages/0f/05/b099d62a05132ef06cf30e73beec291d9425b34cbf0db9c5ca24d07766ba/kerastorch-1.0.3-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "e33f0cc24395ba6334111ed373267209eabe27bcc139d639b139c48cb5cc99d8",
"md5": "2f6c1db4c676334954cf9e82eb08eeea",
"sha256": "005779bb620a595c0f3607738e77ca11485b128d1fe6dde08e0a1ceb88933cef"
},
"downloads": -1,
"filename": "kerastorch-1.0.3.tar.gz",
"has_sig": false,
"md5_digest": "2f6c1db4c676334954cf9e82eb08eeea",
"packagetype": "sdist",
"python_version": "source",
"requires_python": null,
"size": 43148,
"upload_time": "2023-10-15T01:38:36",
"upload_time_iso_8601": "2023-10-15T01:38:36.650130Z",
"url": "https://files.pythonhosted.org/packages/e3/3f/0cc24395ba6334111ed373267209eabe27bcc139d639b139c48cb5cc99d8/kerastorch-1.0.3.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-10-15 01:38:36",
"github": false,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"lcname": "kerastorch"
}