# CHEM-MRL
Chem-MRL is a SMILES embedding transformer model that leverages Matryoshka Representation Learning (MRL) to generate efficient, truncatable embeddings for downstream tasks such as classification, clustering, and database querying.
The model employs [SentenceTransformers' (SBERT)](https://sbert.net/) [2D Matryoshka Sentence Embeddings](https://sbert.net/examples/training/matryoshka/README.html) (`Matryoshka2dLoss`) to enable truncatable embeddings with minimal accuracy loss, improving query performance and flexibility in downstream applications.
Datasets should consists of SMILES pairs and their corresponding [Morgan fingerprint](https://www.rdkit.org/docs/GettingStartedInPython.html#morgan-fingerprints-circular-fingerprints) Tanimoto similarity scores.
Hyperparameter tuning indicates that a custom Tanimoto similarity loss function, [`TanimotoSentLoss`](https://github.com/emapco/chem-mrl/blob/main/chem_mrl/losses/TanimotoLoss.py), based on [CoSENTLoss](https://kexue.fm/archives/8847), outperforms [Tanimoto similarity](https://jcheminf.biomedcentral.com/articles/10.1186/s13321-015-0069-3/tables/2), CoSENTLoss, [AnglELoss](https://arxiv.org/pdf/2309.12871), and cosine similarity.
## Installation
**Install with pip**
```bash
pip install chem-mrl
```
**Install from source code**
```bash
pip install -e .
```
## Usage
### Hydra & Training Scripts
Hydra configuration files are in `chem_mrl/conf`. The base config (`base.yaml`) defines shared arguments and includes model-specific configurations from `chem_mrl/conf/model`. Supported models: `chem_mrl`, `chem_2d_mrl`, `classifier`, and `dice_loss_classifier`.
**Training Examples:**
```bash
# Default (chem_mrl model)
python scripts/train_chem_mrl.py
# Specify model type
python scripts/train_chem_mrl.py model=chem_2d_mrl
python scripts/train_chem_mrl.py model=classifier
# Override parameters
python scripts/train_chem_mrl.py model=chem_mrl training_args.num_train_epochs=5 datasets[0].train_dataset.name=/path/to/data.parquet
# Use different custom config also located in `chem_mrl/conf`
python scripts/train_chem_mrl.py --config-name=my_custom_config.yaml
```
**Configuration Options:**
- **Command line overrides:** Use `model=<type>` and parameter overrides as shown above
- **Modify base.yaml:** Edit the `- /model: chem_mrl` line in the defaults section to change the default model, or modify any other parameters directly
- **Override config file:** Use `--config-name=<config_name>` to specify a different base configuration file instead of the default `base.yaml`
### Basic Training Workflow
To train a model, initialize the configuration with dataset paths and model parameters, then pass it to `ChemMRLTrainer` for training.
```python
from sentence_transformers import SentenceTransformerTrainingArguments
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.schemas import BaseConfig, ChemMRLConfig, DatasetConfig, SplitConfig
from chem_mrl.schemas.Enums import FieldTypeOption
from chem_mrl.trainers import ChemMRLTrainer
dataset_config = DatasetConfig(
key="my_dataset",
train_dataset=SplitConfig(
name="train.parquet",
split_key="train",
label_cast_type=FieldTypeOption.float32,
sample_size=1000,
),
val_dataset=SplitConfig(
name="val.parquet",
split_key="train", # Use "train" for local files
label_cast_type=FieldTypeOption.float16,
sample_size=500,
),
test_dataset=SplitConfig(
name="test.parquet",
split_key="train",
label_cast_type=FieldTypeOption.float16,
sample_size=500,
),
smiles_a_column_name="smiles_a",
smiles_b_column_name="smiles_b",
label_column_name="similarity",
)
config = BaseConfig(
model=ChemMRLConfig(
model_name=BASE_MODEL_NAME, # Predefined model name - Can be any transformer model name or path that is compatible with sentence-transformers
n_dims_per_step=3, # Model-specific hyperparameter
use_2d_matryoshka=True, # Enable 2d MRL
# Additional parameters specific to 2D MRL models
n_layers_per_step=2,
kl_div_weight=0.7, # Weight for KL divergence regularization
kl_temperature=0.5, # Temperature parameter for KL loss
),
datasets=[dataset_config], # List of dataset configurations
training_args=SentenceTransformerTrainingArguments("training_output"),
)
# Initialize trainer and start training
trainer = ChemMRLTrainer(config)
test_eval_metric = (
trainer.train()
) # Returns the test evaluation metric if a test dataset is provided.
# Otherwise returns the final validation eval metric
```
### Experimental
#### Train a Query Model
To train a querying model, configure the model to utilize the specialized query tokenizer.
The query tokenizer supports the following query types:
- similar: Computes SMILES similarity between two molecular structures. For retrieving similar SMILES.
- substructure: Determines the presence of a substructure within the second SMILES string.
Supported query formats for `smiles_a` column:
- `similar {smiles}`
- `substructure {smiles}`
```python
from sentence_transformers import SentenceTransformerTrainingArguments
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.schemas import BaseConfig, ChemMRLConfig, DatasetConfig, SplitConfig
from chem_mrl.schemas.Enums import FieldTypeOption
from chem_mrl.trainers import ChemMRLTrainer
dataset_config = DatasetConfig(
key="query_dataset",
train_dataset=SplitConfig(
name="train.parquet",
split_key="train",
label_cast_type=FieldTypeOption.float32,
sample_size=1000,
),
val_dataset=SplitConfig(
name="val.parquet",
split_key="train",
label_cast_type=FieldTypeOption.float16,
sample_size=500,
),
smiles_a_column_name="query",
smiles_b_column_name="target_smiles",
label_column_name="similarity",
)
config = BaseConfig(
model=ChemMRLConfig(
model_name=BASE_MODEL_NAME,
use_query_tokenizer=True, # Train a query model
),
datasets=[dataset_config],
training_args=SentenceTransformerTrainingArguments("training_output"),
)
trainer = ChemMRLTrainer(config)
```
#### Latent Attention Layer
The Latent Attention Layer model is an experimental component designed to enhance the representation learning of transformer-based models by introducing a trainable latent dictionary. This mechanism applies cross-attention between token embeddings and a set of learnable latent vectors before pooling. The output of this layer contributes to both **1D Matryoshka loss** (as the final layer output) and **2D Matryoshka loss** (by integrating into all-layer outputs). Note: initial tests suggests that when using default configuration, the latent attention layer leads to overfitting.
```python
from sentence_transformers import SentenceTransformerTrainingArguments
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.schemas import (
BaseConfig,
ChemMRLConfig,
DatasetConfig,
LatentAttentionConfig,
SplitConfig,
)
from chem_mrl.schemas.Enums import FieldTypeOption
from chem_mrl.trainers import ChemMRLTrainer
dataset_config = DatasetConfig(
key="latent_dataset",
train_dataset=SplitConfig(
name="train.parquet",
split_key="train",
label_cast_type=FieldTypeOption.float32,
sample_size=1000,
),
val_dataset=SplitConfig(
name="val.parquet",
split_key="train",
label_cast_type=FieldTypeOption.float16,
sample_size=500,
),
smiles_a_column_name="smiles_a",
smiles_b_column_name="smiles_b",
label_column_name="similarity",
)
config = BaseConfig(
model=ChemMRLConfig(
model_name=BASE_MODEL_NAME,
latent_attention_config=LatentAttentionConfig(
hidden_dim=768, # Transformer hidden size
num_latents=512, # Number of learnable latents
num_cross_heads=8, # Number of attention heads
cross_head_dim=32, # Dimensionality of each head
output_normalize=True, # Apply L2 normalization to outputs
),
use_2d_matryoshka=True,
),
datasets=[dataset_config],
training_args=SentenceTransformerTrainingArguments("training_output"),
)
# Train a model with latent attention
trainer = ChemMRLTrainer(config)
```
### Custom Callbacks
You can provide a list of transformers.TrainerCallback classes to execute while training.
```python
from typing import Any
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainingArguments,
)
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.schemas import BaseConfig, ChemMRLConfig, DatasetConfig, SplitConfig
from chem_mrl.schemas.Enums import FieldTypeOption
from chem_mrl.trainers import ChemMRLTrainer
# Define a callback class for logging evaluation metrics
class EvalCallback(TrainerCallback):
def on_evaluate(
self,
args: SentenceTransformerTrainingArguments,
state: TrainerState,
control: TrainerControl,
metrics: dict[str, Any],
model: SentenceTransformer,
**kwargs,
) -> None:
"""
Event called after an evaluation phase.
"""
pass
dataset_config = DatasetConfig(
key="callback_dataset",
train_dataset=SplitConfig(
name="train.parquet",
split_key="train",
label_cast_type=FieldTypeOption.float32,
sample_size=1000,
),
val_dataset=SplitConfig(
name="val.parquet",
split_key="train",
label_cast_type=FieldTypeOption.float16,
sample_size=500,
),
smiles_a_column_name="smiles_a",
smiles_b_column_name="smiles_b",
label_column_name="similarity",
)
config = BaseConfig(
model=ChemMRLConfig(
model_name=BASE_MODEL_NAME,
),
datasets=[dataset_config],
training_args=SentenceTransformerTrainingArguments("training_output"),
)
# Train with callback
trainer = ChemMRLTrainer(config)
val_eval_metric = trainer.train(callbacks=[EvalCallback(...)])
```
## Classifier
This repository includes code for training a linear classifier with optional dropout regularization. The classifier categorizes substances based on SMILES and category features.
Hyperparameter tuning shows that cross-entropy loss (`softmax` option) outperforms self-adjusting dice loss in terms of accuracy, making it the preferred choice for molecular property classification.
### Usage
#### Basic Classification Training
To train a classifier, configure the model with dataset paths and column names, then initialize `ClassifierTrainer` to start training.
```python
from sentence_transformers import SentenceTransformerTrainingArguments
from chem_mrl.schemas import BaseConfig, ClassifierConfig, DatasetConfig, SplitConfig
from chem_mrl.schemas.Enums import FieldTypeOption
from chem_mrl.trainers import ClassifierTrainer
dataset_config = DatasetConfig(
key="classification_dataset",
train_dataset=SplitConfig(
name="train_classification.parquet",
split_key="train",
label_cast_type=FieldTypeOption.float32,
sample_size=1000,
),
val_dataset=SplitConfig(
name="val_classification.parquet",
split_key="train",
label_cast_type=FieldTypeOption.float16,
sample_size=500,
),
smiles_a_column_name="smiles",
smiles_b_column_name=None, # Not needed for classification
label_column_name="label",
)
# Define classification training configuration
config = BaseConfig(
model=ClassifierConfig(
model_name="path/to/trained_mrl_model", # Pretrained MRL model path
),
datasets=[dataset_config],
training_args=SentenceTransformerTrainingArguments("training_output"),
)
# Initialize and train the classifier
trainer = ClassifierTrainer(config)
trainer.train()
```
#### Training with Dice Loss
For imbalanced classification tasks, **Dice Loss** can improve performance by focusing on hard-to-classify samples. Below is a configuration using `DiceLossClassifierConfig`, which introduces additional hyperparameters.
```python
from sentence_transformers import SentenceTransformerTrainingArguments
from chem_mrl.schemas import BaseConfig, ClassifierConfig, DatasetConfig, SplitConfig
from chem_mrl.schemas.Enums import ClassifierLossFctOption, DiceReductionOption, FieldTypeOption
from chem_mrl.trainers import ClassifierTrainer
dataset_config = DatasetConfig(
key="dice_loss_dataset",
train_dataset=SplitConfig(
name="train_classification.parquet",
split_key="train",
label_cast_type=FieldTypeOption.float32,
sample_size=1000,
),
val_dataset=SplitConfig(
name="val_classification.parquet",
split_key="train",
label_cast_type=FieldTypeOption.float16,
sample_size=500,
),
smiles_a_column_name="smiles",
smiles_b_column_name=None, # Not needed for classification
label_column_name="label",
)
# Define classification training configuration with Dice Loss
config = BaseConfig(
model=ClassifierConfig(
model_name="path/to/trained_mrl_model",
loss_func=ClassifierLossFctOption.selfadjdice,
dice_reduction=DiceReductionOption.sum, # Reduction method for Dice Loss (e.g., 'mean' or 'sum')
dice_gamma=1.0, # Smoothing factor hyperparameter
),
datasets=[dataset_config],
training_args=SentenceTransformerTrainingArguments("training_output"),
)
# Initialize and train the classifier with Dice Loss
trainer = ClassifierTrainer(config)
trainer.train()
```
## References:
- Chithrananda, Seyone, et al. "ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction." _arXiv [Cs.LG]_, 2020. [Link](http://arxiv.org/abs/2010.09885).
- Ahmad, Walid, et al. "ChemBERTa-2: Towards Chemical Foundation Models." _arXiv [Cs.LG]_, 2022. [Link](http://arxiv.org/abs/2209.01712).
- Kusupati, Aditya, et al. "Matryoshka Representation Learning." _arXiv [Cs.LG]_, 2022. [Link](https://arxiv.org/abs/2205.13147).
- Li, Xianming, et al. "2D Matryoshka Sentence Embeddings." _arXiv [Cs.CL]_, 2024. [Link](http://arxiv.org/abs/2402.14776).
- Bajusz, Dávid, et al. "Why is the Tanimoto Index an Appropriate Choice for Fingerprint-Based Similarity Calculations?" _J Cheminform_, 7, 20 (2015). [Link](https://doi.org/10.1186/s13321-015-0069-3).
- Li, Xiaoya, et al. "Dice Loss for Data-imbalanced NLP Tasks." _arXiv [Cs.CL]_, 2020. [Link](https://arxiv.org/abs/1911.02855)
- Reimers, Nils, and Gurevych, Iryna. "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks." _Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing_, 2019. [Link](https://arxiv.org/abs/1908.10084).
- Lee, Chankyu, et al. "NV-Embed: Improved Techniques for Training LLMs as Generalist Embedding Models." _arXiv [Cs.CL]_, 2025. [Link](https://arxiv.org/abs/2405.17428).
Raw data
{
"_id": null,
"home_page": null,
"name": "chem-mrl",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.11",
"maintainer_email": "Emmanuel Cortes <manny@derifyai.com>",
"keywords": "cheminformatics, machine-learning, transformers, smiles, embeddings, matryoshka-representation-learning",
"author": null,
"author_email": "Emmanuel Cortes <manny@derifyai.com>",
"download_url": "https://files.pythonhosted.org/packages/b4/b7/1918cef84c7fecfbbdaf00ecb5461d3b73fde2f0936de45dc9ddc95e3dc3/chem_mrl-0.7.0.tar.gz",
"platform": null,
"description": "# CHEM-MRL\n\nChem-MRL is a SMILES embedding transformer model that leverages Matryoshka Representation Learning (MRL) to generate efficient, truncatable embeddings for downstream tasks such as classification, clustering, and database querying.\n\nThe model employs [SentenceTransformers' (SBERT)](https://sbert.net/) [2D Matryoshka Sentence Embeddings](https://sbert.net/examples/training/matryoshka/README.html) (`Matryoshka2dLoss`) to enable truncatable embeddings with minimal accuracy loss, improving query performance and flexibility in downstream applications.\n\nDatasets should consists of SMILES pairs and their corresponding [Morgan fingerprint](https://www.rdkit.org/docs/GettingStartedInPython.html#morgan-fingerprints-circular-fingerprints) Tanimoto similarity scores.\n\nHyperparameter tuning indicates that a custom Tanimoto similarity loss function, [`TanimotoSentLoss`](https://github.com/emapco/chem-mrl/blob/main/chem_mrl/losses/TanimotoLoss.py), based on [CoSENTLoss](https://kexue.fm/archives/8847), outperforms [Tanimoto similarity](https://jcheminf.biomedcentral.com/articles/10.1186/s13321-015-0069-3/tables/2), CoSENTLoss, [AnglELoss](https://arxiv.org/pdf/2309.12871), and cosine similarity.\n\n## Installation\n\n**Install with pip**\n\n```bash\npip install chem-mrl\n```\n\n**Install from source code**\n\n```bash\npip install -e .\n```\n\n## Usage\n\n### Hydra & Training Scripts\n\nHydra configuration files are in `chem_mrl/conf`. The base config (`base.yaml`) defines shared arguments and includes model-specific configurations from `chem_mrl/conf/model`. Supported models: `chem_mrl`, `chem_2d_mrl`, `classifier`, and `dice_loss_classifier`.\n\n**Training Examples:**\n\n```bash\n# Default (chem_mrl model)\npython scripts/train_chem_mrl.py\n\n# Specify model type\npython scripts/train_chem_mrl.py model=chem_2d_mrl\npython scripts/train_chem_mrl.py model=classifier\n\n# Override parameters\npython scripts/train_chem_mrl.py model=chem_mrl training_args.num_train_epochs=5 datasets[0].train_dataset.name=/path/to/data.parquet\n\n# Use different custom config also located in `chem_mrl/conf`\npython scripts/train_chem_mrl.py --config-name=my_custom_config.yaml\n```\n\n**Configuration Options:**\n- **Command line overrides:** Use `model=<type>` and parameter overrides as shown above\n- **Modify base.yaml:** Edit the `- /model: chem_mrl` line in the defaults section to change the default model, or modify any other parameters directly\n- **Override config file:** Use `--config-name=<config_name>` to specify a different base configuration file instead of the default `base.yaml`\n\n### Basic Training Workflow\n\nTo train a model, initialize the configuration with dataset paths and model parameters, then pass it to `ChemMRLTrainer` for training.\n\n```python\nfrom sentence_transformers import SentenceTransformerTrainingArguments\n\nfrom chem_mrl.constants import BASE_MODEL_NAME\nfrom chem_mrl.schemas import BaseConfig, ChemMRLConfig, DatasetConfig, SplitConfig\nfrom chem_mrl.schemas.Enums import FieldTypeOption\nfrom chem_mrl.trainers import ChemMRLTrainer\n\ndataset_config = DatasetConfig(\n key=\"my_dataset\",\n train_dataset=SplitConfig(\n name=\"train.parquet\",\n split_key=\"train\",\n label_cast_type=FieldTypeOption.float32,\n sample_size=1000,\n ),\n val_dataset=SplitConfig(\n name=\"val.parquet\",\n split_key=\"train\", # Use \"train\" for local files\n label_cast_type=FieldTypeOption.float16,\n sample_size=500,\n ),\n test_dataset=SplitConfig(\n name=\"test.parquet\",\n split_key=\"train\",\n label_cast_type=FieldTypeOption.float16,\n sample_size=500,\n ),\n smiles_a_column_name=\"smiles_a\",\n smiles_b_column_name=\"smiles_b\",\n label_column_name=\"similarity\",\n)\n\nconfig = BaseConfig(\n model=ChemMRLConfig(\n model_name=BASE_MODEL_NAME, # Predefined model name - Can be any transformer model name or path that is compatible with sentence-transformers\n n_dims_per_step=3, # Model-specific hyperparameter\n use_2d_matryoshka=True, # Enable 2d MRL\n # Additional parameters specific to 2D MRL models\n n_layers_per_step=2,\n kl_div_weight=0.7, # Weight for KL divergence regularization\n kl_temperature=0.5, # Temperature parameter for KL loss\n ),\n datasets=[dataset_config], # List of dataset configurations\n training_args=SentenceTransformerTrainingArguments(\"training_output\"),\n)\n\n# Initialize trainer and start training\ntrainer = ChemMRLTrainer(config)\ntest_eval_metric = (\n trainer.train()\n) # Returns the test evaluation metric if a test dataset is provided.\n# Otherwise returns the final validation eval metric\n```\n\n### Experimental\n\n#### Train a Query Model\n\nTo train a querying model, configure the model to utilize the specialized query tokenizer.\n\nThe query tokenizer supports the following query types:\n\n- similar: Computes SMILES similarity between two molecular structures. For retrieving similar SMILES.\n- substructure: Determines the presence of a substructure within the second SMILES string.\n\nSupported query formats for `smiles_a` column:\n\n- `similar {smiles}`\n- `substructure {smiles}`\n\n```python\nfrom sentence_transformers import SentenceTransformerTrainingArguments\n\nfrom chem_mrl.constants import BASE_MODEL_NAME\nfrom chem_mrl.schemas import BaseConfig, ChemMRLConfig, DatasetConfig, SplitConfig\nfrom chem_mrl.schemas.Enums import FieldTypeOption\nfrom chem_mrl.trainers import ChemMRLTrainer\n\ndataset_config = DatasetConfig(\n key=\"query_dataset\",\n train_dataset=SplitConfig(\n name=\"train.parquet\",\n split_key=\"train\",\n label_cast_type=FieldTypeOption.float32,\n sample_size=1000,\n ),\n val_dataset=SplitConfig(\n name=\"val.parquet\",\n split_key=\"train\",\n label_cast_type=FieldTypeOption.float16,\n sample_size=500,\n ),\n smiles_a_column_name=\"query\",\n smiles_b_column_name=\"target_smiles\",\n label_column_name=\"similarity\",\n)\n\nconfig = BaseConfig(\n model=ChemMRLConfig(\n model_name=BASE_MODEL_NAME,\n use_query_tokenizer=True, # Train a query model\n ),\n datasets=[dataset_config],\n training_args=SentenceTransformerTrainingArguments(\"training_output\"),\n)\ntrainer = ChemMRLTrainer(config)\n```\n\n#### Latent Attention Layer\n\nThe Latent Attention Layer model is an experimental component designed to enhance the representation learning of transformer-based models by introducing a trainable latent dictionary. This mechanism applies cross-attention between token embeddings and a set of learnable latent vectors before pooling. The output of this layer contributes to both **1D Matryoshka loss** (as the final layer output) and **2D Matryoshka loss** (by integrating into all-layer outputs). Note: initial tests suggests that when using default configuration, the latent attention layer leads to overfitting.\n\n```python\nfrom sentence_transformers import SentenceTransformerTrainingArguments\n\nfrom chem_mrl.constants import BASE_MODEL_NAME\nfrom chem_mrl.schemas import (\n BaseConfig,\n ChemMRLConfig,\n DatasetConfig,\n LatentAttentionConfig,\n SplitConfig,\n)\nfrom chem_mrl.schemas.Enums import FieldTypeOption\nfrom chem_mrl.trainers import ChemMRLTrainer\n\ndataset_config = DatasetConfig(\n key=\"latent_dataset\",\n train_dataset=SplitConfig(\n name=\"train.parquet\",\n split_key=\"train\",\n label_cast_type=FieldTypeOption.float32,\n sample_size=1000,\n ),\n val_dataset=SplitConfig(\n name=\"val.parquet\",\n split_key=\"train\",\n label_cast_type=FieldTypeOption.float16,\n sample_size=500,\n ),\n smiles_a_column_name=\"smiles_a\",\n smiles_b_column_name=\"smiles_b\",\n label_column_name=\"similarity\",\n)\n\nconfig = BaseConfig(\n model=ChemMRLConfig(\n model_name=BASE_MODEL_NAME,\n latent_attention_config=LatentAttentionConfig(\n hidden_dim=768, # Transformer hidden size\n num_latents=512, # Number of learnable latents\n num_cross_heads=8, # Number of attention heads\n cross_head_dim=32, # Dimensionality of each head\n output_normalize=True, # Apply L2 normalization to outputs\n ),\n use_2d_matryoshka=True,\n ),\n datasets=[dataset_config],\n training_args=SentenceTransformerTrainingArguments(\"training_output\"),\n)\n\n# Train a model with latent attention\ntrainer = ChemMRLTrainer(config)\n```\n\n### Custom Callbacks\n\nYou can provide a list of transformers.TrainerCallback classes to execute while training.\n\n```python\nfrom typing import Any\n\nfrom sentence_transformers import (\n SentenceTransformer,\n SentenceTransformerTrainingArguments,\n)\nfrom transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState\n\nfrom chem_mrl.constants import BASE_MODEL_NAME\nfrom chem_mrl.schemas import BaseConfig, ChemMRLConfig, DatasetConfig, SplitConfig\nfrom chem_mrl.schemas.Enums import FieldTypeOption\nfrom chem_mrl.trainers import ChemMRLTrainer\n\n\n# Define a callback class for logging evaluation metrics\nclass EvalCallback(TrainerCallback):\n def on_evaluate(\n self,\n args: SentenceTransformerTrainingArguments,\n state: TrainerState,\n control: TrainerControl,\n metrics: dict[str, Any],\n model: SentenceTransformer,\n **kwargs,\n ) -> None:\n \"\"\"\n Event called after an evaluation phase.\n \"\"\"\n pass\n\n\ndataset_config = DatasetConfig(\n key=\"callback_dataset\",\n train_dataset=SplitConfig(\n name=\"train.parquet\",\n split_key=\"train\",\n label_cast_type=FieldTypeOption.float32,\n sample_size=1000,\n ),\n val_dataset=SplitConfig(\n name=\"val.parquet\",\n split_key=\"train\",\n label_cast_type=FieldTypeOption.float16,\n sample_size=500,\n ),\n smiles_a_column_name=\"smiles_a\",\n smiles_b_column_name=\"smiles_b\",\n label_column_name=\"similarity\",\n)\n\nconfig = BaseConfig(\n model=ChemMRLConfig(\n model_name=BASE_MODEL_NAME,\n ),\n datasets=[dataset_config],\n training_args=SentenceTransformerTrainingArguments(\"training_output\"),\n)\n\n# Train with callback\ntrainer = ChemMRLTrainer(config)\nval_eval_metric = trainer.train(callbacks=[EvalCallback(...)])\n```\n\n## Classifier\n\nThis repository includes code for training a linear classifier with optional dropout regularization. The classifier categorizes substances based on SMILES and category features.\n\nHyperparameter tuning shows that cross-entropy loss (`softmax` option) outperforms self-adjusting dice loss in terms of accuracy, making it the preferred choice for molecular property classification.\n\n### Usage\n\n#### Basic Classification Training\n\nTo train a classifier, configure the model with dataset paths and column names, then initialize `ClassifierTrainer` to start training.\n\n```python\nfrom sentence_transformers import SentenceTransformerTrainingArguments\n\nfrom chem_mrl.schemas import BaseConfig, ClassifierConfig, DatasetConfig, SplitConfig\nfrom chem_mrl.schemas.Enums import FieldTypeOption\nfrom chem_mrl.trainers import ClassifierTrainer\n\ndataset_config = DatasetConfig(\n key=\"classification_dataset\",\n train_dataset=SplitConfig(\n name=\"train_classification.parquet\",\n split_key=\"train\",\n label_cast_type=FieldTypeOption.float32,\n sample_size=1000,\n ),\n val_dataset=SplitConfig(\n name=\"val_classification.parquet\",\n split_key=\"train\",\n label_cast_type=FieldTypeOption.float16,\n sample_size=500,\n ),\n smiles_a_column_name=\"smiles\",\n smiles_b_column_name=None, # Not needed for classification\n label_column_name=\"label\",\n)\n\n# Define classification training configuration\nconfig = BaseConfig(\n model=ClassifierConfig(\n model_name=\"path/to/trained_mrl_model\", # Pretrained MRL model path\n ),\n datasets=[dataset_config],\n training_args=SentenceTransformerTrainingArguments(\"training_output\"),\n)\n\n# Initialize and train the classifier\ntrainer = ClassifierTrainer(config)\ntrainer.train()\n```\n\n#### Training with Dice Loss\n\nFor imbalanced classification tasks, **Dice Loss** can improve performance by focusing on hard-to-classify samples. Below is a configuration using `DiceLossClassifierConfig`, which introduces additional hyperparameters.\n\n```python\nfrom sentence_transformers import SentenceTransformerTrainingArguments\n\nfrom chem_mrl.schemas import BaseConfig, ClassifierConfig, DatasetConfig, SplitConfig\nfrom chem_mrl.schemas.Enums import ClassifierLossFctOption, DiceReductionOption, FieldTypeOption\nfrom chem_mrl.trainers import ClassifierTrainer\n\ndataset_config = DatasetConfig(\n key=\"dice_loss_dataset\",\n train_dataset=SplitConfig(\n name=\"train_classification.parquet\",\n split_key=\"train\",\n label_cast_type=FieldTypeOption.float32,\n sample_size=1000,\n ),\n val_dataset=SplitConfig(\n name=\"val_classification.parquet\",\n split_key=\"train\",\n label_cast_type=FieldTypeOption.float16,\n sample_size=500,\n ),\n smiles_a_column_name=\"smiles\",\n smiles_b_column_name=None, # Not needed for classification\n label_column_name=\"label\",\n)\n\n# Define classification training configuration with Dice Loss\nconfig = BaseConfig(\n model=ClassifierConfig(\n model_name=\"path/to/trained_mrl_model\",\n loss_func=ClassifierLossFctOption.selfadjdice,\n dice_reduction=DiceReductionOption.sum, # Reduction method for Dice Loss (e.g., 'mean' or 'sum')\n dice_gamma=1.0, # Smoothing factor hyperparameter\n ),\n datasets=[dataset_config],\n training_args=SentenceTransformerTrainingArguments(\"training_output\"),\n)\n\n# Initialize and train the classifier with Dice Loss\ntrainer = ClassifierTrainer(config)\ntrainer.train()\n```\n\n## References:\n\n- Chithrananda, Seyone, et al. \"ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction.\" _arXiv [Cs.LG]_, 2020. [Link](http://arxiv.org/abs/2010.09885).\n- Ahmad, Walid, et al. \"ChemBERTa-2: Towards Chemical Foundation Models.\" _arXiv [Cs.LG]_, 2022. [Link](http://arxiv.org/abs/2209.01712).\n- Kusupati, Aditya, et al. \"Matryoshka Representation Learning.\" _arXiv [Cs.LG]_, 2022. [Link](https://arxiv.org/abs/2205.13147).\n- Li, Xianming, et al. \"2D Matryoshka Sentence Embeddings.\" _arXiv [Cs.CL]_, 2024. [Link](http://arxiv.org/abs/2402.14776).\n- Bajusz, D\u00e1vid, et al. \"Why is the Tanimoto Index an Appropriate Choice for Fingerprint-Based Similarity Calculations?\" _J Cheminform_, 7, 20 (2015). [Link](https://doi.org/10.1186/s13321-015-0069-3).\n- Li, Xiaoya, et al. \"Dice Loss for Data-imbalanced NLP Tasks.\" _arXiv [Cs.CL]_, 2020. [Link](https://arxiv.org/abs/1911.02855)\n- Reimers, Nils, and Gurevych, Iryna. \"Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks.\" _Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing_, 2019. [Link](https://arxiv.org/abs/1908.10084).\n- Lee, Chankyu, et al. \"NV-Embed: Improved Techniques for Training LLMs as Generalist Embedding Models.\" _arXiv [Cs.CL]_, 2025. [Link](https://arxiv.org/abs/2405.17428).\n",
"bugtrack_url": null,
"license": "Apache 2.0",
"summary": "SMILES-based Matryoshka Representation Learning Embedding Model",
"version": "0.7.0",
"project_urls": {
"Bug Tracker": "https://github.com/emapco/chem-mrl/issues",
"Changelog": "https://github.com/emapco/chem-mrl/releases",
"Documentation": "https://github.com/emapco/chem-mrl#readme",
"Homepage": "https://github.com/emapco/chem-mrl",
"Repository": "https://github.com/emapco/chem-mrl"
},
"split_keywords": [
"cheminformatics",
" machine-learning",
" transformers",
" smiles",
" embeddings",
" matryoshka-representation-learning"
],
"urls": [
{
"comment_text": null,
"digests": {
"blake2b_256": "06cc36692006a12c7ca52f859bf2183e2dbf476586c858b8b0cd7e2bc5b94392",
"md5": "a93b93507933ca3687f6262e574650b0",
"sha256": "f9438e6033d01a97782f2b48f0fee0d6b51c5ca4ed4a902f6c35cdf20e0cbe7d"
},
"downloads": -1,
"filename": "chem_mrl-0.7.0-py3-none-any.whl",
"has_sig": false,
"md5_digest": "a93b93507933ca3687f6262e574650b0",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.11",
"size": 62866,
"upload_time": "2025-07-22T09:18:28",
"upload_time_iso_8601": "2025-07-22T09:18:28.035653Z",
"url": "https://files.pythonhosted.org/packages/06/cc/36692006a12c7ca52f859bf2183e2dbf476586c858b8b0cd7e2bc5b94392/chem_mrl-0.7.0-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": null,
"digests": {
"blake2b_256": "b4b71918cef84c7fecfbbdaf00ecb5461d3b73fde2f0936de45dc9ddc95e3dc3",
"md5": "4b17e29c42375a914271c83ddd8e9cc2",
"sha256": "b504e54b6964d6af8c4ddca24a869fa223ed8eda3cb09db8ae9be08672809a4c"
},
"downloads": -1,
"filename": "chem_mrl-0.7.0.tar.gz",
"has_sig": false,
"md5_digest": "4b17e29c42375a914271c83ddd8e9cc2",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.11",
"size": 414602,
"upload_time": "2025-07-22T09:18:29",
"upload_time_iso_8601": "2025-07-22T09:18:29.636299Z",
"url": "https://files.pythonhosted.org/packages/b4/b7/1918cef84c7fecfbbdaf00ecb5461d3b73fde2f0936de45dc9ddc95e3dc3/chem_mrl-0.7.0.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-07-22 09:18:29",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "emapco",
"github_project": "chem-mrl",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "chem-mrl"
}