# Spatio-Angular Convolutions for Super-resolution in Diffusion MRI
![Model Architecture](resources/figure1.png)
[![PyPI version](https://badge.fury.io/py/dmri-pcconv.svg)](https://badge.fury.io/py/dmri-pcconv)
This project performs angular super-resolution of dMRI data through a parametric continuous convolutional neural network (PCCNN). This codebase is associated with the following paper. Please cite the paper if you use this model:
[Spatio-Angular Convolutions for Super-resolution in Diffusion MRI](https://arxiv.org/abs/2306.00854) [NeurIPS 2023]
## Table of contents
* [Installation](#installation)
* [Training](#training)
* [Prediction](#prediction)
## Installation
`dmri-pcconv` can be installed via pip:
```bash
pip install dmri-pcconv
```
### Requirements
`dmri-pcconv` uses [PyTorch](https://pytorch.org/) as the deep learning framework.
Listed below are the requirements for this package, these will automatically be installed when installing via pip.
* `torch`
* `lightning`
* `npy-patcher`
* `einops`
* `nibabel`
## Training
Follow the instructions below on how to train a PCCNN model for dMRI angular super-resolution.
### Data Preprocessing
This training pipeline requires `dMRI` data to be saved in `.npy` format. Additionally, the angular dimension must be the first dimension within each 4D array. This is because this module uses [npy-patcher](https://github.com/m-lyon/npy-cpp-patches) to extract training patches at runtime. Below is an example of how to convert `NIfTI` files into `.npy` using [nibabel](https://nipy.org/nibabel/).
```python
import numpy as np
from dmri_pcconv.core.io import load_nifti
data, _ = load_nifti('/path/to/data.nii.gz') # Load dMRI data into memory
data = data.transpose(3, 0, 1, 2) # Move the angular dimension from last to first
np.save('/path/to/data.npy', data, allow_pickle=False) # Save in npy format. Ensure this is on an SSD.
```
**N.B.** *Patches are read lazily from disk, therefore it is **highly** recommended to store the training data on an SSD type device, as an HDD will significantly bottleneck the training process when data loading.*
Additionally, `xmax` values are required prior to training, due to the lazy runtime of data extraction mentioned above. Below is an example of how to extract and save `xmax` values for a given subject.
```python
from dmri_pcconv.core.io import load_bval, load_nifti
from dmri_pcconv.core.normalisation import TrainingNormaliser
bvals = load_bval('path/to/bvals')
dmri, _ = load_nifti('/path/to/dmri.nii.gz')
mask, _ = load_nifti('/path/to/brain_mask.nii.gz')
xmax_dict = TrainingNormaliser.calculate_xmax(dmri, bvals, mask)
TrainingNormaliser.save_xmax('/path/to/xmax.json', xmax_dict)
```
### Training
Below is an example of how to train the `PCCNN` model, it uses the `lightning` module `PCCNNLightningModel` and data module `PCCNNDataModule`. The `PCCNN-Bv`, `PCCNN-Sp`, and `PCCNN-Bv-Sp` variants all have their own corresponding model and data module classes.
```python
import lightning.pytorch as pl
from dmri_pcconv.core.qspace import QSpaceInfo
from dmri_pcconv.core.model import PCCNNLightningModel
from dmri_pcconv.core.training import Subject, PCCNNDataModule
# Collect dataset filepaths
subj1 = Subject(
'/path/to/first/dmri.npy',
'/path/to/first/bvecs',
'/path/to/first/bvals',
'/path/to/first/brain_mask.nii.gz',
'/path/to/first/xmax.json'
)
subj2 = Subject(
'/path/to/second/dmri.npy',
'/path/to/second/bvecs',
'/path/to/second/bvals',
'/path/to/second/brain_mask.nii.gz',
'/path/to/second/xmax.json'
)
subj3 = Subject(
'/path/to/third/dmri.npy',
'/path/to/third/bvecs',
'/path/to/third/bvals',
'/path/to/third/brain_mask.nii.gz',
'/path/to/third/xmax.json'
)
# Assign Q-space training parameters
qinfo = QSpaceInfo(
q_in_min=6, # Minimum number of q-space samples each training example will hold
q_in_max=20 # Maximum number. Training will sample in between this range.
q_out=10 # Number of output samples per training example.
shells=(1000, 2000, 3000) # Shells used in training and prediction.
seed=12345 # Optionally provide a random seed for sampling
)
# Create DataModule instance. This is a thin wrapper around `pl.LightningDataModule`.
data_module = PCCNNDataModule(
train_subjects=(subj1, subj2),
val_subjects=(subj3),
qinfo=qinfo,
batch_size=16, # Batch size of each device
num_workers=8, # Number of CPU workers that load the data
seed=12345, # Optionally provide a random seed for sampling
)
# Load PCCNN lightning model
model = PCCNNLightningModel()
# Create `pl.Trainer` instance. `PCCNNDataModule` is usable in DDP distributed training strategy.
trainer = pl.Trainer(devices=1, accelerator='gpu', epochs=100)
# Start training
trainer.fit(model, data_module)
```
## Prediction
Here we outline how to perform prediction after training.
```python
import torch
from dmri_pcconv.core.weights import get_weights
from dmri_pcconv.core.model import PCCNNBvLightningModel
from dmri_pcconv.core.prediction import PCCNNBvPredictionProcessor
# Load your pretrained weights
## From the original paper
weights = torch.load(get_weights('pccnn-bv'))
model = PCCNNBvLightningModel()
model.load_state_dict(weights)
## Or from a pytorch_lightning checkpoint
model = PCCNNBvLightningModel.load_from_checkpoint('/path/to/my/checkpoint.ckpt')
# Run prediction
predict = PCCNNBvPredictionProcessor(batch_size=4, num_workers=8, accelerator='gpu')
predict.run_subject(
model=model,
dmri_in='/path/to/context_dmri.nii.gz',
bvec_in='/path/to/context_bvecs',
bval_in='/path/to/context_bvals',
bvec_out='/path/to/target_bvecs',
bval_out='/path/to/target_bvals',
mask='/path/to/brain_mask.nii.gz',
out_fpath='/path/to/predicted_dmri.nii.gz',
)
```
**N.B.** *Weights provided by the `get_weights` function are saved within `~/.dmri_pcconv` by default. Set `DMRI_PCCONV_DIR` environment variable to override the save directory.*
Raw data
{
"_id": null,
"home_page": "",
"name": "dmri-pcconv",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.8",
"maintainer_email": "",
"keywords": "ai,cv,computer-vision,mri,dmri,super-resolution,cnn,pcconv",
"author": "Matthew Lyon",
"author_email": "matthewlyon18@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/f6/c1/fb12bec1765305ddaa167661ea4eabef6c7628cefb7322001128d4766161/dmri-pcconv-1.0.0.tar.gz",
"platform": null,
"description": "# Spatio-Angular Convolutions for Super-resolution in Diffusion MRI\n\n![Model Architecture](resources/figure1.png)\n\n[![PyPI version](https://badge.fury.io/py/dmri-pcconv.svg)](https://badge.fury.io/py/dmri-pcconv)\n\nThis project performs angular super-resolution of dMRI data through a parametric continuous convolutional neural network (PCCNN). This codebase is associated with the following paper. Please cite the paper if you use this model:\n\n[Spatio-Angular Convolutions for Super-resolution in Diffusion MRI](https://arxiv.org/abs/2306.00854) [NeurIPS 2023]\n\n## Table of contents\n\n* [Installation](#installation)\n* [Training](#training)\n* [Prediction](#prediction)\n\n## Installation\n\n`dmri-pcconv` can be installed via pip:\n\n```bash\npip install dmri-pcconv\n```\n\n### Requirements\n\n`dmri-pcconv` uses [PyTorch](https://pytorch.org/) as the deep learning framework.\n\nListed below are the requirements for this package, these will automatically be installed when installing via pip.\n\n* `torch`\n* `lightning`\n* `npy-patcher`\n* `einops`\n* `nibabel`\n\n\n## Training\n\nFollow the instructions below on how to train a PCCNN model for dMRI angular super-resolution.\n\n### Data Preprocessing\n\nThis training pipeline requires `dMRI` data to be saved in `.npy` format. Additionally, the angular dimension must be the first dimension within each 4D array. This is because this module uses [npy-patcher](https://github.com/m-lyon/npy-cpp-patches) to extract training patches at runtime. Below is an example of how to convert `NIfTI` files into `.npy` using [nibabel](https://nipy.org/nibabel/).\n\n```python\nimport numpy as np\n\nfrom dmri_pcconv.core.io import load_nifti\n\ndata, _ = load_nifti('/path/to/data.nii.gz') # Load dMRI data into memory\ndata = data.transpose(3, 0, 1, 2) # Move the angular dimension from last to first\nnp.save('/path/to/data.npy', data, allow_pickle=False) # Save in npy format. Ensure this is on an SSD.\n```\n\n**N.B.** *Patches are read lazily from disk, therefore it is **highly** recommended to store the training data on an SSD type device, as an HDD will significantly bottleneck the training process when data loading.*\n\nAdditionally, `xmax` values are required prior to training, due to the lazy runtime of data extraction mentioned above. Below is an example of how to extract and save `xmax` values for a given subject.\n\n```python\nfrom dmri_pcconv.core.io import load_bval, load_nifti\nfrom dmri_pcconv.core.normalisation import TrainingNormaliser\n\nbvals = load_bval('path/to/bvals')\ndmri, _ = load_nifti('/path/to/dmri.nii.gz')\nmask, _ = load_nifti('/path/to/brain_mask.nii.gz')\n\nxmax_dict = TrainingNormaliser.calculate_xmax(dmri, bvals, mask)\nTrainingNormaliser.save_xmax('/path/to/xmax.json', xmax_dict)\n```\n\n### Training\n\nBelow is an example of how to train the `PCCNN` model, it uses the `lightning` module `PCCNNLightningModel` and data module `PCCNNDataModule`. The `PCCNN-Bv`, `PCCNN-Sp`, and `PCCNN-Bv-Sp` variants all have their own corresponding model and data module classes.\n\n```python\nimport lightning.pytorch as pl\n\nfrom dmri_pcconv.core.qspace import QSpaceInfo\nfrom dmri_pcconv.core.model import PCCNNLightningModel\nfrom dmri_pcconv.core.training import Subject, PCCNNDataModule\n\n# Collect dataset filepaths\nsubj1 = Subject(\n '/path/to/first/dmri.npy',\n '/path/to/first/bvecs',\n '/path/to/first/bvals',\n '/path/to/first/brain_mask.nii.gz',\n '/path/to/first/xmax.json'\n)\nsubj2 = Subject(\n '/path/to/second/dmri.npy',\n '/path/to/second/bvecs',\n '/path/to/second/bvals',\n '/path/to/second/brain_mask.nii.gz',\n '/path/to/second/xmax.json'\n)\nsubj3 = Subject(\n '/path/to/third/dmri.npy',\n '/path/to/third/bvecs',\n '/path/to/third/bvals',\n '/path/to/third/brain_mask.nii.gz',\n '/path/to/third/xmax.json'\n)\n\n# Assign Q-space training parameters\nqinfo = QSpaceInfo(\n q_in_min=6, # Minimum number of q-space samples each training example will hold\n q_in_max=20 # Maximum number. Training will sample in between this range.\n q_out=10 # Number of output samples per training example.\n shells=(1000, 2000, 3000) # Shells used in training and prediction.\n seed=12345 # Optionally provide a random seed for sampling\n)\n\n# Create DataModule instance. This is a thin wrapper around `pl.LightningDataModule`.\ndata_module = PCCNNDataModule(\n train_subjects=(subj1, subj2),\n val_subjects=(subj3),\n qinfo=qinfo,\n batch_size=16, # Batch size of each device\n num_workers=8, # Number of CPU workers that load the data\n seed=12345, # Optionally provide a random seed for sampling\n)\n\n# Load PCCNN lightning model\nmodel = PCCNNLightningModel()\n\n# Create `pl.Trainer` instance. `PCCNNDataModule` is usable in DDP distributed training strategy.\ntrainer = pl.Trainer(devices=1, accelerator='gpu', epochs=100)\n\n# Start training\ntrainer.fit(model, data_module)\n```\n\n## Prediction\n\nHere we outline how to perform prediction after training.\n\n```python\nimport torch\n\nfrom dmri_pcconv.core.weights import get_weights\nfrom dmri_pcconv.core.model import PCCNNBvLightningModel\nfrom dmri_pcconv.core.prediction import PCCNNBvPredictionProcessor\n\n# Load your pretrained weights\n\n## From the original paper\nweights = torch.load(get_weights('pccnn-bv'))\nmodel = PCCNNBvLightningModel()\nmodel.load_state_dict(weights)\n\n## Or from a pytorch_lightning checkpoint\nmodel = PCCNNBvLightningModel.load_from_checkpoint('/path/to/my/checkpoint.ckpt')\n\n# Run prediction\npredict = PCCNNBvPredictionProcessor(batch_size=4, num_workers=8, accelerator='gpu')\npredict.run_subject(\n model=model,\n dmri_in='/path/to/context_dmri.nii.gz',\n bvec_in='/path/to/context_bvecs',\n bval_in='/path/to/context_bvals',\n bvec_out='/path/to/target_bvecs',\n bval_out='/path/to/target_bvals',\n mask='/path/to/brain_mask.nii.gz',\n out_fpath='/path/to/predicted_dmri.nii.gz',\n)\n```\n\n**N.B.** *Weights provided by the `get_weights` function are saved within `~/.dmri_pcconv` by default. Set `DMRI_PCCONV_DIR` environment variable to override the save directory.*\n",
"bugtrack_url": null,
"license": "MIT License",
"summary": "Parametric Continuous Convolution framework used for Diffusion MRI.",
"version": "1.0.0",
"project_urls": null,
"split_keywords": [
"ai",
"cv",
"computer-vision",
"mri",
"dmri",
"super-resolution",
"cnn",
"pcconv"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "d2eddd897da4aa0a0e82b41f8c30a370ec187bcdf6113f634cd0a71dd0f9e24a",
"md5": "76f55736f80965b2246eaeb25a82d37f",
"sha256": "7b3e64e29247f93683e24f36896ad31d0acb318c27f6d361879c40224ed81603"
},
"downloads": -1,
"filename": "dmri_pcconv-1.0.0-py3-none-any.whl",
"has_sig": false,
"md5_digest": "76f55736f80965b2246eaeb25a82d37f",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.8",
"size": 43228,
"upload_time": "2023-11-29T13:49:26",
"upload_time_iso_8601": "2023-11-29T13:49:26.056010Z",
"url": "https://files.pythonhosted.org/packages/d2/ed/dd897da4aa0a0e82b41f8c30a370ec187bcdf6113f634cd0a71dd0f9e24a/dmri_pcconv-1.0.0-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "f6c1fb12bec1765305ddaa167661ea4eabef6c7628cefb7322001128d4766161",
"md5": "f0fb068dff4257c6a31dac943d691a1e",
"sha256": "4ae7628110f9a30fba2e3ce045ecb9aee64e214a5739e7d1fb98355eb6f71230"
},
"downloads": -1,
"filename": "dmri-pcconv-1.0.0.tar.gz",
"has_sig": false,
"md5_digest": "f0fb068dff4257c6a31dac943d691a1e",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.8",
"size": 33950,
"upload_time": "2023-11-29T13:49:27",
"upload_time_iso_8601": "2023-11-29T13:49:27.758155Z",
"url": "https://files.pythonhosted.org/packages/f6/c1/fb12bec1765305ddaa167661ea4eabef6c7628cefb7322001128d4766161/dmri-pcconv-1.0.0.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-11-29 13:49:27",
"github": false,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"lcname": "dmri-pcconv"
}