# ESM-Efficient
[![pypi](https://img.shields.io/pypi/v/esm-efficient.svg)](https://pypi.python.org/pypi/esm-efficient)
[![DOI:10.1101/2024.10.22.619563](http://img.shields.io/badge/DOI-10.1101/2024.10.22.619563-B31B1B.svg)](https://doi.org/10.1101/2024.10.22.619563)
Efficient implementation of the ESM family of models: ESM1b, ESM1v, ESM2, ESMC.
<img src="docs/methods.png" width="300" /> <img src="docs/speedup.png" width="400" />
## Installation
Download the appropriate version of [pytorch](https://pytorch.org/get-started/locally/) and install it.
```
pip install flash-attn --no-build-isolation
pip install esm-efficient
```
## Basic Usage
```python
from esme import ESM
model = ESM.from_pretrained('esmc') # or 'esm1b', 'esm1v', 'esm2', 'esm2_8m', ...
```
This will download the model weights from the HuggingFace model hub and load the model. [See doc from getting started]().
## Tokenization and Predicting Log Probabilities
Predict the log probabilities of a sequence of tokens using the model.
```python
import torch
from esme import ESM2
from esme.alphabet import tokenize
# create load the model
model = ESM2.from_pretrained("{model}.safetensors", device=0)
tokens = tokenize(['MEEPQSDPSVEPPLSQESTFSLDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])
tokens = tokens.to(0)
# predict logits
logits = model(tokens)
# logits.shape = (2, seq_len, embed_size)
# predict log probabilities
log_probs = model.predict_log_prob(tokens)
# log_probs.shape = (2, seq_len, embed_size)
```
## Tokenization without Padding
```python
from esme.alphabet import tokenize_unpad
# tokenize without padding (more efficient avoids calculating with padding)
tokens, indices, cu_lens, max_len = tokenize_unpad(['MEEPQSDPSVEPPLSQETFSDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])
tokens = tokens.to(0)
cu_lens = cu_lens.to(0)
log_probs = model.predict_log_prob(tokens, (cu_lens, max_len))
# log_probs.shape = (seq_len_protein1 + seq_len_protein2, embed_size)
```
## Predict effect of variants
```python
from esme.variant import predict_mask_margin
seq = 'MEEPQSDPSVEPPLSQETFSDLWK'
df = predict_mask_margin(model, seq)
# ... pd.DataFrame({
# ... 'variant': ['M1A', 'M1C', ..., 'P16Y'],
# ... 'score': [-0.1, -0.2, ..., -0.3]
# ... }).set_index('variant')
```
## Fine-tune the model with lora adapters:
```python
# only add will be trained by default
model.add_lora(rank=16, layers=('query', 'key', 'value'), adapter_names=['adapter1', 'adapter2'])
# mark only lora as trainable called by default when adding lora
model.mark_only_lora_as_trainable()
# save the model with the lora weights
model.save_lora('<path>.safetensors', adapter_names=['adapter1'])
# load the model with the lora weights
model.load_lora('<path>.safetensors')
```
## Quantization of the model:
```python
model = ESM2.from_pretrained('8M.safetensors', quantization='4bit', device=0)
```
Activation checkpointing of each transformer layer:
```python
model = ESM2.from_pretrained('8M.safetensors', checkpointing=True)
```
## Training the model
We provide pytorch lightning trainer for training the model. The following code trains the model with the masked language model objective:
```python
from esme import ESM2
from esme.data import MaskedFastaTokenDataModule
from esme.trainer import MaskedPLM
trainer = MaskedPLM(model) # pytorch lightning trainer
datamodule = MaskedFastaTokenDataModule(
train_fasta='train.fasta',
val_fasta='val.fasta',
token_per_batch=50_000,
) # data module for training
trainer.fit(datamodule)
```
# Model Weights
The model weights can be downloaded from the HuggingFace: [https://huggingface.co/mhcelik/esm-efficient/tree/main](https://huggingface.co/mhcelik/esm-efficient/tree/main)
# Evaluation
To perform the evaluation reported in the paper, run the following command:
```bash
snakemake -n --use-conda
```
This will download the data, train the models, and evaluate them. The results will be saved in the `results` directory.
See the `workflow/Snakefile` for more details.
To generate a specific figures in the paper, run the following command:
```bash
snakemake reports/paper_figures/figure-2.pdf -n --use-conda
```
# Citation
Manuscript for the efficient implementation: [https://www.biorxiv.org/content/10.1101/2024.10.22.619563v1](https://www.biorxiv.org/content/10.1101/2024.10.22.619563v1)
```bib
@article {Celik2024.10.22.619563,
author = {Celik, Muhammed Hasan and Xie, Xiaohui},
title = {Efficient Inference, Training, and Fine-tuning of Protein Language Models},
elocation-id = {2024.10.22.619563},
year = {2024},
doi = {10.1101/2024.10.22.619563},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2024/10/25/2024.10.22.619563},
eprint = {https://www.biorxiv.org/content/early/2024/10/25/2024.10.22.619563.full.pdf},
journal = {bioRxiv}
}
```
Also, cite original ESM papers for the related model: [https://github.com/facebookresearch/esm](https://github.com/facebookresearch/esm)
## LICENSE
This code implements ESM models from scratch and is licensed under the MIT License. Refer to the [esm](https://github.com/evolutionaryscale/esm) and [fair-esm](https://github.com/facebookresearch/esm) repositories for the licenses for the model weights.
Raw data
{
"_id": null,
"home_page": "https://github.com/uci-cbcl/esm-efficient",
"name": "esm-efficient",
"maintainer": null,
"docs_url": null,
"requires_python": null,
"maintainer_email": null,
"keywords": "LLM, PLM, protein language model",
"author": null,
"author_email": null,
"download_url": "https://files.pythonhosted.org/packages/2f/5f/9317454cc749b6b21c637115dac97c45ef19f8fa7231ba7b949347601934/esm_efficient-0.0.7.tar.gz",
"platform": null,
"description": "# ESM-Efficient\n\n[![pypi](https://img.shields.io/pypi/v/esm-efficient.svg)](https://pypi.python.org/pypi/esm-efficient)\n[![DOI:10.1101/2024.10.22.619563](http://img.shields.io/badge/DOI-10.1101/2024.10.22.619563-B31B1B.svg)](https://doi.org/10.1101/2024.10.22.619563)\n\nEfficient implementation of the ESM family of models: ESM1b, ESM1v, ESM2, ESMC.\n\n<img src=\"docs/methods.png\" width=\"300\" /> <img src=\"docs/speedup.png\" width=\"400\" />\n\n## Installation\n\nDownload the appropriate version of [pytorch](https://pytorch.org/get-started/locally/) and install it.\n```\npip install flash-attn --no-build-isolation\npip install esm-efficient\n```\n\n## Basic Usage\n\n```python\nfrom esme import ESM\n\nmodel = ESM.from_pretrained('esmc') # or 'esm1b', 'esm1v', 'esm2', 'esm2_8m', ...\n```\nThis will download the model weights from the HuggingFace model hub and load the model. [See doc from getting started]().\n\n## Tokenization and Predicting Log Probabilities\nPredict the log probabilities of a sequence of tokens using the model. \n\n```python\nimport torch\nfrom esme import ESM2\nfrom esme.alphabet import tokenize\n\n# create load the model\nmodel = ESM2.from_pretrained(\"{model}.safetensors\", device=0)\n\ntokens = tokenize(['MEEPQSDPSVEPPLSQESTFSLDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])\ntokens = tokens.to(0)\n\n# predict logits\nlogits = model(tokens)\n# logits.shape = (2, seq_len, embed_size)\n\n# predict log probabilities\nlog_probs = model.predict_log_prob(tokens)\n# log_probs.shape = (2, seq_len, embed_size)\n```\n\n## Tokenization without Padding\n```python\nfrom esme.alphabet import tokenize_unpad\n# tokenize without padding (more efficient avoids calculating with padding)\ntokens, indices, cu_lens, max_len = tokenize_unpad(['MEEPQSDPSVEPPLSQETFSDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])\ntokens = tokens.to(0)\ncu_lens = cu_lens.to(0)\nlog_probs = model.predict_log_prob(tokens, (cu_lens, max_len))\n# log_probs.shape = (seq_len_protein1 + seq_len_protein2, embed_size)\n```\n\n## Predict effect of variants\n```python\n\nfrom esme.variant import predict_mask_margin\n\nseq = 'MEEPQSDPSVEPPLSQETFSDLWK'\ndf = predict_mask_margin(model, seq)\n# ... pd.DataFrame({\n# ... 'variant': ['M1A', 'M1C', ..., 'P16Y'],\n# ... 'score': [-0.1, -0.2, ..., -0.3]\n# ... }).set_index('variant')\n```\n\n## Fine-tune the model with lora adapters:\n```python\n\n# only add will be trained by default\nmodel.add_lora(rank=16, layers=('query', 'key', 'value'), adapter_names=['adapter1', 'adapter2'])\n\n# mark only lora as trainable called by default when adding lora\nmodel.mark_only_lora_as_trainable()\n\n# save the model with the lora weights\nmodel.save_lora('<path>.safetensors', adapter_names=['adapter1'])\n\n# load the model with the lora weights\nmodel.load_lora('<path>.safetensors')\n```\n\n## Quantization of the model:\n```python\nmodel = ESM2.from_pretrained('8M.safetensors', quantization='4bit', device=0)\n```\n\nActivation checkpointing of each transformer layer:\n```python\nmodel = ESM2.from_pretrained('8M.safetensors', checkpointing=True)\n```\n\n## Training the model\n\nWe provide pytorch lightning trainer for training the model. The following code trains the model with the masked language model objective:\n\n```python\nfrom esme import ESM2\nfrom esme.data import MaskedFastaTokenDataModule\nfrom esme.trainer import MaskedPLM\n\ntrainer = MaskedPLM(model) # pytorch lightning trainer\ndatamodule = MaskedFastaTokenDataModule(\n train_fasta='train.fasta',\n val_fasta='val.fasta',\n token_per_batch=50_000,\n) # data module for training\ntrainer.fit(datamodule) \n```\n\n# Model Weights\n\nThe model weights can be downloaded from the HuggingFace: [https://huggingface.co/mhcelik/esm-efficient/tree/main](https://huggingface.co/mhcelik/esm-efficient/tree/main)\n\n# Evaluation \n\nTo perform the evaluation reported in the paper, run the following command:\n\n```bash\nsnakemake -n --use-conda\n```\n\nThis will download the data, train the models, and evaluate them. The results will be saved in the `results` directory.\nSee the `workflow/Snakefile` for more details.\n\nTo generate a specific figures in the paper, run the following command:\n```bash\nsnakemake reports/paper_figures/figure-2.pdf -n --use-conda \n```\n\n# Citation\nManuscript for the efficient implementation: [https://www.biorxiv.org/content/10.1101/2024.10.22.619563v1](https://www.biorxiv.org/content/10.1101/2024.10.22.619563v1)\n```bib\n@article {Celik2024.10.22.619563,\n author = {Celik, Muhammed Hasan and Xie, Xiaohui},\n title = {Efficient Inference, Training, and Fine-tuning of Protein Language Models},\n elocation-id = {2024.10.22.619563},\n year = {2024},\n doi = {10.1101/2024.10.22.619563},\n publisher = {Cold Spring Harbor Laboratory},\n URL = {https://www.biorxiv.org/content/early/2024/10/25/2024.10.22.619563},\n eprint = {https://www.biorxiv.org/content/early/2024/10/25/2024.10.22.619563.full.pdf},\n journal = {bioRxiv}\n}\n```\nAlso, cite original ESM papers for the related model: [https://github.com/facebookresearch/esm](https://github.com/facebookresearch/esm)\n\n## LICENSE\nThis code implements ESM models from scratch and is licensed under the MIT License. Refer to the [esm](https://github.com/evolutionaryscale/esm) and [fair-esm](https://github.com/facebookresearch/esm) repositories for the licenses for the model weights.\n\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "Efficient Evolutionary Scale Modeling: Efficient and simplified implementation of protein language model for inference and training.",
"version": "0.0.7",
"project_urls": {
"Homepage": "https://github.com/uci-cbcl/esm-efficient"
},
"split_keywords": [
"llm",
" plm",
" protein language model"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "993e1b30ead0736521ce434e4b85d7893fde53c592616b5a3fb5317b9a5c51d3",
"md5": "1dcb2453e801dd4c517661dc82e4bdc8",
"sha256": "e51774714c43e267680345989ba17a762ffbf2ef2f3d4d2bc4ccaf0144cfa179"
},
"downloads": -1,
"filename": "esm_efficient-0.0.7-py3-none-any.whl",
"has_sig": false,
"md5_digest": "1dcb2453e801dd4c517661dc82e4bdc8",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": null,
"size": 30625,
"upload_time": "2024-12-20T07:47:42",
"upload_time_iso_8601": "2024-12-20T07:47:42.532214Z",
"url": "https://files.pythonhosted.org/packages/99/3e/1b30ead0736521ce434e4b85d7893fde53c592616b5a3fb5317b9a5c51d3/esm_efficient-0.0.7-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "2f5f9317454cc749b6b21c637115dac97c45ef19f8fa7231ba7b949347601934",
"md5": "7ab3839fc7b8576a274bd2868a5be893",
"sha256": "243c8b0829a53d263fa7b3ad48a951e3cf32b31a185f6b50a038154625bb9728"
},
"downloads": -1,
"filename": "esm_efficient-0.0.7.tar.gz",
"has_sig": false,
"md5_digest": "7ab3839fc7b8576a274bd2868a5be893",
"packagetype": "sdist",
"python_version": "source",
"requires_python": null,
"size": 34675,
"upload_time": "2024-12-20T07:47:44",
"upload_time_iso_8601": "2024-12-20T07:47:44.646276Z",
"url": "https://files.pythonhosted.org/packages/2f/5f/9317454cc749b6b21c637115dac97c45ef19f8fa7231ba7b949347601934/esm_efficient-0.0.7.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-12-20 07:47:44",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "uci-cbcl",
"github_project": "esm-efficient",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "esm-efficient"
}