chem-mrl


Namechem-mrl JSON
Version 0.7.0 PyPI version JSON
download
home_pageNone
SummarySMILES-based Matryoshka Representation Learning Embedding Model
upload_time2025-07-22 09:18:29
maintainerNone
docs_urlNone
authorNone
requires_python>=3.11
licenseApache 2.0
keywords cheminformatics machine-learning transformers smiles embeddings matryoshka-representation-learning
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # CHEM-MRL

Chem-MRL is a SMILES embedding transformer model that leverages Matryoshka Representation Learning (MRL) to generate efficient, truncatable embeddings for downstream tasks such as classification, clustering, and database querying.

The model employs [SentenceTransformers' (SBERT)](https://sbert.net/) [2D Matryoshka Sentence Embeddings](https://sbert.net/examples/training/matryoshka/README.html) (`Matryoshka2dLoss`) to enable truncatable embeddings with minimal accuracy loss, improving query performance and flexibility in downstream applications.

Datasets should consists of SMILES pairs and their corresponding [Morgan fingerprint](https://www.rdkit.org/docs/GettingStartedInPython.html#morgan-fingerprints-circular-fingerprints) Tanimoto similarity scores.

Hyperparameter tuning indicates that a custom Tanimoto similarity loss function, [`TanimotoSentLoss`](https://github.com/emapco/chem-mrl/blob/main/chem_mrl/losses/TanimotoLoss.py), based on [CoSENTLoss](https://kexue.fm/archives/8847), outperforms [Tanimoto similarity](https://jcheminf.biomedcentral.com/articles/10.1186/s13321-015-0069-3/tables/2), CoSENTLoss, [AnglELoss](https://arxiv.org/pdf/2309.12871), and cosine similarity.

## Installation

**Install with pip**

```bash
pip install chem-mrl
```

**Install from source code**

```bash
pip install -e .
```

## Usage

### Hydra & Training Scripts

Hydra configuration files are in `chem_mrl/conf`. The base config (`base.yaml`) defines shared arguments and includes model-specific configurations from `chem_mrl/conf/model`. Supported models: `chem_mrl`, `chem_2d_mrl`, `classifier`, and `dice_loss_classifier`.

**Training Examples:**

```bash
# Default (chem_mrl model)
python scripts/train_chem_mrl.py

# Specify model type
python scripts/train_chem_mrl.py model=chem_2d_mrl
python scripts/train_chem_mrl.py model=classifier

# Override parameters
python scripts/train_chem_mrl.py model=chem_mrl training_args.num_train_epochs=5 datasets[0].train_dataset.name=/path/to/data.parquet

# Use different custom config also located in `chem_mrl/conf`
python scripts/train_chem_mrl.py --config-name=my_custom_config.yaml
```

**Configuration Options:**
- **Command line overrides:** Use `model=<type>` and parameter overrides as shown above
- **Modify base.yaml:** Edit the `- /model: chem_mrl` line in the defaults section to change the default model, or modify any other parameters directly
- **Override config file:** Use `--config-name=<config_name>` to specify a different base configuration file instead of the default `base.yaml`

### Basic Training Workflow

To train a model, initialize the configuration with dataset paths and model parameters, then pass it to `ChemMRLTrainer` for training.

```python
from sentence_transformers import SentenceTransformerTrainingArguments

from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.schemas import BaseConfig, ChemMRLConfig, DatasetConfig, SplitConfig
from chem_mrl.schemas.Enums import FieldTypeOption
from chem_mrl.trainers import ChemMRLTrainer

dataset_config = DatasetConfig(
    key="my_dataset",
    train_dataset=SplitConfig(
        name="train.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float32,
        sample_size=1000,
    ),
    val_dataset=SplitConfig(
        name="val.parquet",
        split_key="train",  # Use "train" for local files
        label_cast_type=FieldTypeOption.float16,
        sample_size=500,
    ),
    test_dataset=SplitConfig(
        name="test.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float16,
        sample_size=500,
    ),
    smiles_a_column_name="smiles_a",
    smiles_b_column_name="smiles_b",
    label_column_name="similarity",
)

config = BaseConfig(
    model=ChemMRLConfig(
        model_name=BASE_MODEL_NAME,  # Predefined model name - Can be any transformer model name or path that is compatible with sentence-transformers
        n_dims_per_step=3,  # Model-specific hyperparameter
        use_2d_matryoshka=True,  # Enable 2d MRL
        # Additional parameters specific to 2D MRL models
        n_layers_per_step=2,
        kl_div_weight=0.7,  # Weight for KL divergence regularization
        kl_temperature=0.5,  # Temperature parameter for KL loss
    ),
    datasets=[dataset_config],  # List of dataset configurations
    training_args=SentenceTransformerTrainingArguments("training_output"),
)

# Initialize trainer and start training
trainer = ChemMRLTrainer(config)
test_eval_metric = (
    trainer.train()
)  # Returns the test evaluation metric if a test dataset is provided.
# Otherwise returns the final validation eval metric
```

### Experimental

#### Train a Query Model

To train a querying model, configure the model to utilize the specialized query tokenizer.

The query tokenizer supports the following query types:

- similar: Computes SMILES similarity between two molecular structures. For retrieving similar SMILES.
- substructure: Determines the presence of a substructure within the second SMILES string.

Supported query formats for `smiles_a` column:

- `similar {smiles}`
- `substructure {smiles}`

```python
from sentence_transformers import SentenceTransformerTrainingArguments

from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.schemas import BaseConfig, ChemMRLConfig, DatasetConfig, SplitConfig
from chem_mrl.schemas.Enums import FieldTypeOption
from chem_mrl.trainers import ChemMRLTrainer

dataset_config = DatasetConfig(
    key="query_dataset",
    train_dataset=SplitConfig(
        name="train.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float32,
        sample_size=1000,
    ),
    val_dataset=SplitConfig(
        name="val.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float16,
        sample_size=500,
    ),
    smiles_a_column_name="query",
    smiles_b_column_name="target_smiles",
    label_column_name="similarity",
)

config = BaseConfig(
    model=ChemMRLConfig(
        model_name=BASE_MODEL_NAME,
        use_query_tokenizer=True,  # Train a query model
    ),
    datasets=[dataset_config],
    training_args=SentenceTransformerTrainingArguments("training_output"),
)
trainer = ChemMRLTrainer(config)
```

#### Latent Attention Layer

The Latent Attention Layer model is an experimental component designed to enhance the representation learning of transformer-based models by introducing a trainable latent dictionary. This mechanism applies cross-attention between token embeddings and a set of learnable latent vectors before pooling. The output of this layer contributes to both **1D Matryoshka loss** (as the final layer output) and **2D Matryoshka loss** (by integrating into all-layer outputs). Note: initial tests suggests that when using default configuration, the latent attention layer leads to overfitting.

```python
from sentence_transformers import SentenceTransformerTrainingArguments

from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.schemas import (
    BaseConfig,
    ChemMRLConfig,
    DatasetConfig,
    LatentAttentionConfig,
    SplitConfig,
)
from chem_mrl.schemas.Enums import FieldTypeOption
from chem_mrl.trainers import ChemMRLTrainer

dataset_config = DatasetConfig(
    key="latent_dataset",
    train_dataset=SplitConfig(
        name="train.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float32,
        sample_size=1000,
    ),
    val_dataset=SplitConfig(
        name="val.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float16,
        sample_size=500,
    ),
    smiles_a_column_name="smiles_a",
    smiles_b_column_name="smiles_b",
    label_column_name="similarity",
)

config = BaseConfig(
    model=ChemMRLConfig(
        model_name=BASE_MODEL_NAME,
        latent_attention_config=LatentAttentionConfig(
            hidden_dim=768,  # Transformer hidden size
            num_latents=512,  # Number of learnable latents
            num_cross_heads=8,  # Number of attention heads
            cross_head_dim=32,  # Dimensionality of each head
            output_normalize=True,  # Apply L2 normalization to outputs
        ),
        use_2d_matryoshka=True,
    ),
    datasets=[dataset_config],
    training_args=SentenceTransformerTrainingArguments("training_output"),
)

# Train a model with latent attention
trainer = ChemMRLTrainer(config)
```

### Custom Callbacks

You can provide a list of transformers.TrainerCallback classes to execute while training.

```python
from typing import Any

from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainingArguments,
)
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState

from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.schemas import BaseConfig, ChemMRLConfig, DatasetConfig, SplitConfig
from chem_mrl.schemas.Enums import FieldTypeOption
from chem_mrl.trainers import ChemMRLTrainer


# Define a callback class for logging evaluation metrics
class EvalCallback(TrainerCallback):
    def on_evaluate(
        self,
        args: SentenceTransformerTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        metrics: dict[str, Any],
        model: SentenceTransformer,
        **kwargs,
    ) -> None:
        """
        Event called after an evaluation phase.
        """
        pass


dataset_config = DatasetConfig(
    key="callback_dataset",
    train_dataset=SplitConfig(
        name="train.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float32,
        sample_size=1000,
    ),
    val_dataset=SplitConfig(
        name="val.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float16,
        sample_size=500,
    ),
    smiles_a_column_name="smiles_a",
    smiles_b_column_name="smiles_b",
    label_column_name="similarity",
)

config = BaseConfig(
    model=ChemMRLConfig(
        model_name=BASE_MODEL_NAME,
    ),
    datasets=[dataset_config],
    training_args=SentenceTransformerTrainingArguments("training_output"),
)

# Train with callback
trainer = ChemMRLTrainer(config)
val_eval_metric = trainer.train(callbacks=[EvalCallback(...)])
```

## Classifier

This repository includes code for training a linear classifier with optional dropout regularization. The classifier categorizes substances based on SMILES and category features.

Hyperparameter tuning shows that cross-entropy loss (`softmax` option) outperforms self-adjusting dice loss in terms of accuracy, making it the preferred choice for molecular property classification.

### Usage

#### Basic Classification Training

To train a classifier, configure the model with dataset paths and column names, then initialize `ClassifierTrainer` to start training.

```python
from sentence_transformers import SentenceTransformerTrainingArguments

from chem_mrl.schemas import BaseConfig, ClassifierConfig, DatasetConfig, SplitConfig
from chem_mrl.schemas.Enums import FieldTypeOption
from chem_mrl.trainers import ClassifierTrainer

dataset_config = DatasetConfig(
    key="classification_dataset",
    train_dataset=SplitConfig(
        name="train_classification.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float32,
        sample_size=1000,
    ),
    val_dataset=SplitConfig(
        name="val_classification.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float16,
        sample_size=500,
    ),
    smiles_a_column_name="smiles",
    smiles_b_column_name=None,  # Not needed for classification
    label_column_name="label",
)

# Define classification training configuration
config = BaseConfig(
    model=ClassifierConfig(
        model_name="path/to/trained_mrl_model",  # Pretrained MRL model path
    ),
    datasets=[dataset_config],
    training_args=SentenceTransformerTrainingArguments("training_output"),
)

# Initialize and train the classifier
trainer = ClassifierTrainer(config)
trainer.train()
```

#### Training with Dice Loss

For imbalanced classification tasks, **Dice Loss** can improve performance by focusing on hard-to-classify samples. Below is a configuration using `DiceLossClassifierConfig`, which introduces additional hyperparameters.

```python
from sentence_transformers import SentenceTransformerTrainingArguments

from chem_mrl.schemas import BaseConfig, ClassifierConfig, DatasetConfig, SplitConfig
from chem_mrl.schemas.Enums import ClassifierLossFctOption, DiceReductionOption, FieldTypeOption
from chem_mrl.trainers import ClassifierTrainer

dataset_config = DatasetConfig(
    key="dice_loss_dataset",
    train_dataset=SplitConfig(
        name="train_classification.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float32,
        sample_size=1000,
    ),
    val_dataset=SplitConfig(
        name="val_classification.parquet",
        split_key="train",
        label_cast_type=FieldTypeOption.float16,
        sample_size=500,
    ),
    smiles_a_column_name="smiles",
    smiles_b_column_name=None,  # Not needed for classification
    label_column_name="label",
)

# Define classification training configuration with Dice Loss
config = BaseConfig(
    model=ClassifierConfig(
        model_name="path/to/trained_mrl_model",
        loss_func=ClassifierLossFctOption.selfadjdice,
        dice_reduction=DiceReductionOption.sum,  # Reduction method for Dice Loss (e.g., 'mean' or 'sum')
        dice_gamma=1.0,  # Smoothing factor hyperparameter
    ),
    datasets=[dataset_config],
    training_args=SentenceTransformerTrainingArguments("training_output"),
)

# Initialize and train the classifier with Dice Loss
trainer = ClassifierTrainer(config)
trainer.train()
```

## References:

- Chithrananda, Seyone, et al. "ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction." _arXiv [Cs.LG]_, 2020. [Link](http://arxiv.org/abs/2010.09885).
- Ahmad, Walid, et al. "ChemBERTa-2: Towards Chemical Foundation Models." _arXiv [Cs.LG]_, 2022. [Link](http://arxiv.org/abs/2209.01712).
- Kusupati, Aditya, et al. "Matryoshka Representation Learning." _arXiv [Cs.LG]_, 2022. [Link](https://arxiv.org/abs/2205.13147).
- Li, Xianming, et al. "2D Matryoshka Sentence Embeddings." _arXiv [Cs.CL]_, 2024. [Link](http://arxiv.org/abs/2402.14776).
- Bajusz, Dávid, et al. "Why is the Tanimoto Index an Appropriate Choice for Fingerprint-Based Similarity Calculations?" _J Cheminform_, 7, 20 (2015). [Link](https://doi.org/10.1186/s13321-015-0069-3).
- Li, Xiaoya, et al. "Dice Loss for Data-imbalanced NLP Tasks." _arXiv [Cs.CL]_, 2020. [Link](https://arxiv.org/abs/1911.02855)
- Reimers, Nils, and Gurevych, Iryna. "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks." _Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing_, 2019. [Link](https://arxiv.org/abs/1908.10084).
- Lee, Chankyu, et al. "NV-Embed: Improved Techniques for Training LLMs as Generalist Embedding Models." _arXiv [Cs.CL]_, 2025. [Link](https://arxiv.org/abs/2405.17428).

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "chem-mrl",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.11",
    "maintainer_email": "Emmanuel Cortes <manny@derifyai.com>",
    "keywords": "cheminformatics, machine-learning, transformers, smiles, embeddings, matryoshka-representation-learning",
    "author": null,
    "author_email": "Emmanuel Cortes <manny@derifyai.com>",
    "download_url": "https://files.pythonhosted.org/packages/b4/b7/1918cef84c7fecfbbdaf00ecb5461d3b73fde2f0936de45dc9ddc95e3dc3/chem_mrl-0.7.0.tar.gz",
    "platform": null,
    "description": "# CHEM-MRL\n\nChem-MRL is a SMILES embedding transformer model that leverages Matryoshka Representation Learning (MRL) to generate efficient, truncatable embeddings for downstream tasks such as classification, clustering, and database querying.\n\nThe model employs [SentenceTransformers' (SBERT)](https://sbert.net/) [2D Matryoshka Sentence Embeddings](https://sbert.net/examples/training/matryoshka/README.html) (`Matryoshka2dLoss`) to enable truncatable embeddings with minimal accuracy loss, improving query performance and flexibility in downstream applications.\n\nDatasets should consists of SMILES pairs and their corresponding [Morgan fingerprint](https://www.rdkit.org/docs/GettingStartedInPython.html#morgan-fingerprints-circular-fingerprints) Tanimoto similarity scores.\n\nHyperparameter tuning indicates that a custom Tanimoto similarity loss function, [`TanimotoSentLoss`](https://github.com/emapco/chem-mrl/blob/main/chem_mrl/losses/TanimotoLoss.py), based on [CoSENTLoss](https://kexue.fm/archives/8847), outperforms [Tanimoto similarity](https://jcheminf.biomedcentral.com/articles/10.1186/s13321-015-0069-3/tables/2), CoSENTLoss, [AnglELoss](https://arxiv.org/pdf/2309.12871), and cosine similarity.\n\n## Installation\n\n**Install with pip**\n\n```bash\npip install chem-mrl\n```\n\n**Install from source code**\n\n```bash\npip install -e .\n```\n\n## Usage\n\n### Hydra & Training Scripts\n\nHydra configuration files are in `chem_mrl/conf`. The base config (`base.yaml`) defines shared arguments and includes model-specific configurations from `chem_mrl/conf/model`. Supported models: `chem_mrl`, `chem_2d_mrl`, `classifier`, and `dice_loss_classifier`.\n\n**Training Examples:**\n\n```bash\n# Default (chem_mrl model)\npython scripts/train_chem_mrl.py\n\n# Specify model type\npython scripts/train_chem_mrl.py model=chem_2d_mrl\npython scripts/train_chem_mrl.py model=classifier\n\n# Override parameters\npython scripts/train_chem_mrl.py model=chem_mrl training_args.num_train_epochs=5 datasets[0].train_dataset.name=/path/to/data.parquet\n\n# Use different custom config also located in `chem_mrl/conf`\npython scripts/train_chem_mrl.py --config-name=my_custom_config.yaml\n```\n\n**Configuration Options:**\n- **Command line overrides:** Use `model=<type>` and parameter overrides as shown above\n- **Modify base.yaml:** Edit the `- /model: chem_mrl` line in the defaults section to change the default model, or modify any other parameters directly\n- **Override config file:** Use `--config-name=<config_name>` to specify a different base configuration file instead of the default `base.yaml`\n\n### Basic Training Workflow\n\nTo train a model, initialize the configuration with dataset paths and model parameters, then pass it to `ChemMRLTrainer` for training.\n\n```python\nfrom sentence_transformers import SentenceTransformerTrainingArguments\n\nfrom chem_mrl.constants import BASE_MODEL_NAME\nfrom chem_mrl.schemas import BaseConfig, ChemMRLConfig, DatasetConfig, SplitConfig\nfrom chem_mrl.schemas.Enums import FieldTypeOption\nfrom chem_mrl.trainers import ChemMRLTrainer\n\ndataset_config = DatasetConfig(\n    key=\"my_dataset\",\n    train_dataset=SplitConfig(\n        name=\"train.parquet\",\n        split_key=\"train\",\n        label_cast_type=FieldTypeOption.float32,\n        sample_size=1000,\n    ),\n    val_dataset=SplitConfig(\n        name=\"val.parquet\",\n        split_key=\"train\",  # Use \"train\" for local files\n        label_cast_type=FieldTypeOption.float16,\n        sample_size=500,\n    ),\n    test_dataset=SplitConfig(\n        name=\"test.parquet\",\n        split_key=\"train\",\n        label_cast_type=FieldTypeOption.float16,\n        sample_size=500,\n    ),\n    smiles_a_column_name=\"smiles_a\",\n    smiles_b_column_name=\"smiles_b\",\n    label_column_name=\"similarity\",\n)\n\nconfig = BaseConfig(\n    model=ChemMRLConfig(\n        model_name=BASE_MODEL_NAME,  # Predefined model name - Can be any transformer model name or path that is compatible with sentence-transformers\n        n_dims_per_step=3,  # Model-specific hyperparameter\n        use_2d_matryoshka=True,  # Enable 2d MRL\n        # Additional parameters specific to 2D MRL models\n        n_layers_per_step=2,\n        kl_div_weight=0.7,  # Weight for KL divergence regularization\n        kl_temperature=0.5,  # Temperature parameter for KL loss\n    ),\n    datasets=[dataset_config],  # List of dataset configurations\n    training_args=SentenceTransformerTrainingArguments(\"training_output\"),\n)\n\n# Initialize trainer and start training\ntrainer = ChemMRLTrainer(config)\ntest_eval_metric = (\n    trainer.train()\n)  # Returns the test evaluation metric if a test dataset is provided.\n# Otherwise returns the final validation eval metric\n```\n\n### Experimental\n\n#### Train a Query Model\n\nTo train a querying model, configure the model to utilize the specialized query tokenizer.\n\nThe query tokenizer supports the following query types:\n\n- similar: Computes SMILES similarity between two molecular structures. For retrieving similar SMILES.\n- substructure: Determines the presence of a substructure within the second SMILES string.\n\nSupported query formats for `smiles_a` column:\n\n- `similar {smiles}`\n- `substructure {smiles}`\n\n```python\nfrom sentence_transformers import SentenceTransformerTrainingArguments\n\nfrom chem_mrl.constants import BASE_MODEL_NAME\nfrom chem_mrl.schemas import BaseConfig, ChemMRLConfig, DatasetConfig, SplitConfig\nfrom chem_mrl.schemas.Enums import FieldTypeOption\nfrom chem_mrl.trainers import ChemMRLTrainer\n\ndataset_config = DatasetConfig(\n    key=\"query_dataset\",\n    train_dataset=SplitConfig(\n        name=\"train.parquet\",\n        split_key=\"train\",\n        label_cast_type=FieldTypeOption.float32,\n        sample_size=1000,\n    ),\n    val_dataset=SplitConfig(\n        name=\"val.parquet\",\n        split_key=\"train\",\n        label_cast_type=FieldTypeOption.float16,\n        sample_size=500,\n    ),\n    smiles_a_column_name=\"query\",\n    smiles_b_column_name=\"target_smiles\",\n    label_column_name=\"similarity\",\n)\n\nconfig = BaseConfig(\n    model=ChemMRLConfig(\n        model_name=BASE_MODEL_NAME,\n        use_query_tokenizer=True,  # Train a query model\n    ),\n    datasets=[dataset_config],\n    training_args=SentenceTransformerTrainingArguments(\"training_output\"),\n)\ntrainer = ChemMRLTrainer(config)\n```\n\n#### Latent Attention Layer\n\nThe Latent Attention Layer model is an experimental component designed to enhance the representation learning of transformer-based models by introducing a trainable latent dictionary. This mechanism applies cross-attention between token embeddings and a set of learnable latent vectors before pooling. The output of this layer contributes to both **1D Matryoshka loss** (as the final layer output) and **2D Matryoshka loss** (by integrating into all-layer outputs). Note: initial tests suggests that when using default configuration, the latent attention layer leads to overfitting.\n\n```python\nfrom sentence_transformers import SentenceTransformerTrainingArguments\n\nfrom chem_mrl.constants import BASE_MODEL_NAME\nfrom chem_mrl.schemas import (\n    BaseConfig,\n    ChemMRLConfig,\n    DatasetConfig,\n    LatentAttentionConfig,\n    SplitConfig,\n)\nfrom chem_mrl.schemas.Enums import FieldTypeOption\nfrom chem_mrl.trainers import ChemMRLTrainer\n\ndataset_config = DatasetConfig(\n    key=\"latent_dataset\",\n    train_dataset=SplitConfig(\n        name=\"train.parquet\",\n        split_key=\"train\",\n        label_cast_type=FieldTypeOption.float32,\n        sample_size=1000,\n    ),\n    val_dataset=SplitConfig(\n        name=\"val.parquet\",\n        split_key=\"train\",\n        label_cast_type=FieldTypeOption.float16,\n        sample_size=500,\n    ),\n    smiles_a_column_name=\"smiles_a\",\n    smiles_b_column_name=\"smiles_b\",\n    label_column_name=\"similarity\",\n)\n\nconfig = BaseConfig(\n    model=ChemMRLConfig(\n        model_name=BASE_MODEL_NAME,\n        latent_attention_config=LatentAttentionConfig(\n            hidden_dim=768,  # Transformer hidden size\n            num_latents=512,  # Number of learnable latents\n            num_cross_heads=8,  # Number of attention heads\n            cross_head_dim=32,  # Dimensionality of each head\n            output_normalize=True,  # Apply L2 normalization to outputs\n        ),\n        use_2d_matryoshka=True,\n    ),\n    datasets=[dataset_config],\n    training_args=SentenceTransformerTrainingArguments(\"training_output\"),\n)\n\n# Train a model with latent attention\ntrainer = ChemMRLTrainer(config)\n```\n\n### Custom Callbacks\n\nYou can provide a list of transformers.TrainerCallback classes to execute while training.\n\n```python\nfrom typing import Any\n\nfrom sentence_transformers import (\n    SentenceTransformer,\n    SentenceTransformerTrainingArguments,\n)\nfrom transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState\n\nfrom chem_mrl.constants import BASE_MODEL_NAME\nfrom chem_mrl.schemas import BaseConfig, ChemMRLConfig, DatasetConfig, SplitConfig\nfrom chem_mrl.schemas.Enums import FieldTypeOption\nfrom chem_mrl.trainers import ChemMRLTrainer\n\n\n# Define a callback class for logging evaluation metrics\nclass EvalCallback(TrainerCallback):\n    def on_evaluate(\n        self,\n        args: SentenceTransformerTrainingArguments,\n        state: TrainerState,\n        control: TrainerControl,\n        metrics: dict[str, Any],\n        model: SentenceTransformer,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Event called after an evaluation phase.\n        \"\"\"\n        pass\n\n\ndataset_config = DatasetConfig(\n    key=\"callback_dataset\",\n    train_dataset=SplitConfig(\n        name=\"train.parquet\",\n        split_key=\"train\",\n        label_cast_type=FieldTypeOption.float32,\n        sample_size=1000,\n    ),\n    val_dataset=SplitConfig(\n        name=\"val.parquet\",\n        split_key=\"train\",\n        label_cast_type=FieldTypeOption.float16,\n        sample_size=500,\n    ),\n    smiles_a_column_name=\"smiles_a\",\n    smiles_b_column_name=\"smiles_b\",\n    label_column_name=\"similarity\",\n)\n\nconfig = BaseConfig(\n    model=ChemMRLConfig(\n        model_name=BASE_MODEL_NAME,\n    ),\n    datasets=[dataset_config],\n    training_args=SentenceTransformerTrainingArguments(\"training_output\"),\n)\n\n# Train with callback\ntrainer = ChemMRLTrainer(config)\nval_eval_metric = trainer.train(callbacks=[EvalCallback(...)])\n```\n\n## Classifier\n\nThis repository includes code for training a linear classifier with optional dropout regularization. The classifier categorizes substances based on SMILES and category features.\n\nHyperparameter tuning shows that cross-entropy loss (`softmax` option) outperforms self-adjusting dice loss in terms of accuracy, making it the preferred choice for molecular property classification.\n\n### Usage\n\n#### Basic Classification Training\n\nTo train a classifier, configure the model with dataset paths and column names, then initialize `ClassifierTrainer` to start training.\n\n```python\nfrom sentence_transformers import SentenceTransformerTrainingArguments\n\nfrom chem_mrl.schemas import BaseConfig, ClassifierConfig, DatasetConfig, SplitConfig\nfrom chem_mrl.schemas.Enums import FieldTypeOption\nfrom chem_mrl.trainers import ClassifierTrainer\n\ndataset_config = DatasetConfig(\n    key=\"classification_dataset\",\n    train_dataset=SplitConfig(\n        name=\"train_classification.parquet\",\n        split_key=\"train\",\n        label_cast_type=FieldTypeOption.float32,\n        sample_size=1000,\n    ),\n    val_dataset=SplitConfig(\n        name=\"val_classification.parquet\",\n        split_key=\"train\",\n        label_cast_type=FieldTypeOption.float16,\n        sample_size=500,\n    ),\n    smiles_a_column_name=\"smiles\",\n    smiles_b_column_name=None,  # Not needed for classification\n    label_column_name=\"label\",\n)\n\n# Define classification training configuration\nconfig = BaseConfig(\n    model=ClassifierConfig(\n        model_name=\"path/to/trained_mrl_model\",  # Pretrained MRL model path\n    ),\n    datasets=[dataset_config],\n    training_args=SentenceTransformerTrainingArguments(\"training_output\"),\n)\n\n# Initialize and train the classifier\ntrainer = ClassifierTrainer(config)\ntrainer.train()\n```\n\n#### Training with Dice Loss\n\nFor imbalanced classification tasks, **Dice Loss** can improve performance by focusing on hard-to-classify samples. Below is a configuration using `DiceLossClassifierConfig`, which introduces additional hyperparameters.\n\n```python\nfrom sentence_transformers import SentenceTransformerTrainingArguments\n\nfrom chem_mrl.schemas import BaseConfig, ClassifierConfig, DatasetConfig, SplitConfig\nfrom chem_mrl.schemas.Enums import ClassifierLossFctOption, DiceReductionOption, FieldTypeOption\nfrom chem_mrl.trainers import ClassifierTrainer\n\ndataset_config = DatasetConfig(\n    key=\"dice_loss_dataset\",\n    train_dataset=SplitConfig(\n        name=\"train_classification.parquet\",\n        split_key=\"train\",\n        label_cast_type=FieldTypeOption.float32,\n        sample_size=1000,\n    ),\n    val_dataset=SplitConfig(\n        name=\"val_classification.parquet\",\n        split_key=\"train\",\n        label_cast_type=FieldTypeOption.float16,\n        sample_size=500,\n    ),\n    smiles_a_column_name=\"smiles\",\n    smiles_b_column_name=None,  # Not needed for classification\n    label_column_name=\"label\",\n)\n\n# Define classification training configuration with Dice Loss\nconfig = BaseConfig(\n    model=ClassifierConfig(\n        model_name=\"path/to/trained_mrl_model\",\n        loss_func=ClassifierLossFctOption.selfadjdice,\n        dice_reduction=DiceReductionOption.sum,  # Reduction method for Dice Loss (e.g., 'mean' or 'sum')\n        dice_gamma=1.0,  # Smoothing factor hyperparameter\n    ),\n    datasets=[dataset_config],\n    training_args=SentenceTransformerTrainingArguments(\"training_output\"),\n)\n\n# Initialize and train the classifier with Dice Loss\ntrainer = ClassifierTrainer(config)\ntrainer.train()\n```\n\n## References:\n\n- Chithrananda, Seyone, et al. \"ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction.\" _arXiv [Cs.LG]_, 2020. [Link](http://arxiv.org/abs/2010.09885).\n- Ahmad, Walid, et al. \"ChemBERTa-2: Towards Chemical Foundation Models.\" _arXiv [Cs.LG]_, 2022. [Link](http://arxiv.org/abs/2209.01712).\n- Kusupati, Aditya, et al. \"Matryoshka Representation Learning.\" _arXiv [Cs.LG]_, 2022. [Link](https://arxiv.org/abs/2205.13147).\n- Li, Xianming, et al. \"2D Matryoshka Sentence Embeddings.\" _arXiv [Cs.CL]_, 2024. [Link](http://arxiv.org/abs/2402.14776).\n- Bajusz, D\u00e1vid, et al. \"Why is the Tanimoto Index an Appropriate Choice for Fingerprint-Based Similarity Calculations?\" _J Cheminform_, 7, 20 (2015). [Link](https://doi.org/10.1186/s13321-015-0069-3).\n- Li, Xiaoya, et al. \"Dice Loss for Data-imbalanced NLP Tasks.\" _arXiv [Cs.CL]_, 2020. [Link](https://arxiv.org/abs/1911.02855)\n- Reimers, Nils, and Gurevych, Iryna. \"Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks.\" _Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing_, 2019. [Link](https://arxiv.org/abs/1908.10084).\n- Lee, Chankyu, et al. \"NV-Embed: Improved Techniques for Training LLMs as Generalist Embedding Models.\" _arXiv [Cs.CL]_, 2025. [Link](https://arxiv.org/abs/2405.17428).\n",
    "bugtrack_url": null,
    "license": "Apache 2.0",
    "summary": "SMILES-based Matryoshka Representation Learning Embedding Model",
    "version": "0.7.0",
    "project_urls": {
        "Bug Tracker": "https://github.com/emapco/chem-mrl/issues",
        "Changelog": "https://github.com/emapco/chem-mrl/releases",
        "Documentation": "https://github.com/emapco/chem-mrl#readme",
        "Homepage": "https://github.com/emapco/chem-mrl",
        "Repository": "https://github.com/emapco/chem-mrl"
    },
    "split_keywords": [
        "cheminformatics",
        " machine-learning",
        " transformers",
        " smiles",
        " embeddings",
        " matryoshka-representation-learning"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "06cc36692006a12c7ca52f859bf2183e2dbf476586c858b8b0cd7e2bc5b94392",
                "md5": "a93b93507933ca3687f6262e574650b0",
                "sha256": "f9438e6033d01a97782f2b48f0fee0d6b51c5ca4ed4a902f6c35cdf20e0cbe7d"
            },
            "downloads": -1,
            "filename": "chem_mrl-0.7.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "a93b93507933ca3687f6262e574650b0",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.11",
            "size": 62866,
            "upload_time": "2025-07-22T09:18:28",
            "upload_time_iso_8601": "2025-07-22T09:18:28.035653Z",
            "url": "https://files.pythonhosted.org/packages/06/cc/36692006a12c7ca52f859bf2183e2dbf476586c858b8b0cd7e2bc5b94392/chem_mrl-0.7.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "b4b71918cef84c7fecfbbdaf00ecb5461d3b73fde2f0936de45dc9ddc95e3dc3",
                "md5": "4b17e29c42375a914271c83ddd8e9cc2",
                "sha256": "b504e54b6964d6af8c4ddca24a869fa223ed8eda3cb09db8ae9be08672809a4c"
            },
            "downloads": -1,
            "filename": "chem_mrl-0.7.0.tar.gz",
            "has_sig": false,
            "md5_digest": "4b17e29c42375a914271c83ddd8e9cc2",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.11",
            "size": 414602,
            "upload_time": "2025-07-22T09:18:29",
            "upload_time_iso_8601": "2025-07-22T09:18:29.636299Z",
            "url": "https://files.pythonhosted.org/packages/b4/b7/1918cef84c7fecfbbdaf00ecb5461d3b73fde2f0936de45dc9ddc95e3dc3/chem_mrl-0.7.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-07-22 09:18:29",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "emapco",
    "github_project": "chem-mrl",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "chem-mrl"
}
        
Elapsed time: 0.98935s