fodnet


Namefodnet JSON
Version 1.1.0 PyPI version JSON
download
home_page
SummaryFOD-Net Reimplementation.
upload_time2023-10-03 15:59:15
maintainer
docs_urlNone
authorMatthew Lyon
requires_python>=3.8
licenseMIT License
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # FODNet

FOD-Net reimplementation with training and inference pipeline. This module uses the FODNet model originally implemented [here](https://github.com/ruizengalways/FOD-Net).

If you use this code for your research, please cite:

FOD-Net: A Deep Learning Method for Fiber Orientation Distribution Angular Super Resolution.<br>
[Rui Zeng](https://sites.google.com/site/ruizenghomepage/), Jinglei Lv, He Wang, Luping Zhou, Michael Barnett, Fernando Calamante\*, Chenyu Wang\*. In [Medical Image Analysis](https://www.sciencedirect.com/science/article/abs/pii/S1361841522000822). (* equal contributions) [[Bibtex]](bib.txt).

## Requirements

This module requires the following python packages:
- `torch >= 2.0.0`
- `lightning >= 2.0.0`
- `numpy`
- `einops`
- `npy-patcher`
- `scikit-image`
- `nibabel`

These will be installed upon installation of this package, however it is recommended to follow the instructions for installing PyTorch independently before installing this package, to ensure correct hardware optimizations are enabled.

## Installation

```
pip install fodnet
```

## Training

Follow the instructions below on how to train the FODNet model.


### Data Preprocessing

This training pipeline requires data to be saved in `.npy` format. Additionally the spherical harmonic dimension must be the first dimension within each 4D array. This is because this module uses [npy-patcher](https://github.com/m-lyon/npy-cpp-patches) to extract training patches at runtime. Below is an example on how to convert `NIfTI` files into `.npy` using [nibabel](https://nipy.org/nibabel/).

```python
import numpy as np
import nibabel as nib

img = nib.load('/path/to/fod.nii.gz')
data = np.asarray(img.dataobj, dtype=np.float32)  # Load FOD data into memory
data = data.transpose(3, 0, 1, 2)  # Move the SH dimension to 0
np.save('/path/to/fod.npy', data, allow_pickle=False)  # Save in npy format. Ensure this is on an SSD.
```

**N.B.** *Patches are read lazily from disk, therefore it is **highly** recommended to store the training data on an SSD type device, as an HDD will bottleneck the training process when data loading.*

### Training

```python
import lightning.pytorch as pl

from fodnet.core.model import FODNetLightningModel
from fodnet.core.dataset import Subject, FODNetDataModule

# Collect dataset filepaths
subj1 = Subject('/path/to/lowres_fod1.npy', '/path/to/highresres_fod1.npy', '/path/to/mask1.npy')
subj2 = Subject('/path/to/lowres_fod2.npy', '/path/to/highresres_fod2.npy', '/path/to/mask2.npy')
subj3 = Subject('/path/to/lowres_fod3.npy', '/path/to/highresres_fod3.npy', '/path/to/mask3.npy')

# Create DataModule instance. This is a thin wrapper around `pl.LightningDataModule`.
data_module = FODNetDataModule(
    train_subjects=(subj1, subj2),
    val_subjects=(subj3),
    batch_size=16, # Batch size of each device
    num_workers=8, # Number of CPU workers that load the data
)

# Load FODNet lightning model
model = FODNetLightningModel()

# Create `pl.Trainer` instance. `FODNetDataModule` is usable in DDP distributed training strategy.
trainer = pl.Trainer(devices=1, accelerator='gpu', epochs=100)

# Start training
trainer.fit(model, data_module)
```

#### Customization

This implemenation uses a different training optimizer, loss, and learning rate than that used in the [original implementation](https://github.com/ruizengalways/FOD-Net). In particular we use `AdamW`, `L1 Loss`, and `0.003` respectively.

Changing these hyperparameters is straightforward. Simply create a new class that inherits the `FODNetLightningModel`, and modify the properties/methods below. Use this class instead of `FODNetLightningModel` when training.

```python
class MyCustomModel(FODNetLightningModel):

    @property
    def loss_func(self):
        '''Different loss function'''
        return torch.nn.functional.mse_loss
    
    def configure_optimizers(self):
        '''Different Optimizer and learning rate'''
        optimizer = torch.optim.SGD(self.parameters(), lr=1e-5)
        return optimizer
```

## Prediction

```python
from fodnet.core.model import FODNetLightningModel
from fodnet.core.prediction import FODNetPredictionProcessor

model = FODNetLightningModel.load_from_checkpoint('/path/to/my/checkpoint.ckpt')
predict = FODNetPredictionProcessor(batch_size=32, num_workers=8, accelerator='gpu')
predict.run_subject(
    model,
    '/path/to/my/brainmask.nii.gz',
    '/path/to/lowres_fod.nii.gz',
    '/path/to/dest/highres_fod.nii.gz',
    tmp_dir=None,  # Optionally specify a temporary directory to save the FOD file during processing
)
```

**N.B.** *Patches are read lazily from disk, therefore it is recommended to ensure `tmp_dir` is on an SSD type device.*

            

Raw data

            {
    "_id": null,
    "home_page": "",
    "name": "fodnet",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.8",
    "maintainer_email": "",
    "keywords": "",
    "author": "Matthew Lyon",
    "author_email": "matthewlyon18@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/75/b9/2b660bbddbe3554cbcfa6aa86296ce609d47da3d6957af0107523d1db392/fodnet-1.1.0.tar.gz",
    "platform": null,
    "description": "# FODNet\n\nFOD-Net reimplementation with training and inference pipeline. This module uses the FODNet model originally implemented [here](https://github.com/ruizengalways/FOD-Net).\n\nIf you use this code for your research, please cite:\n\nFOD-Net: A Deep Learning Method for Fiber Orientation Distribution Angular Super Resolution.<br>\n[Rui Zeng](https://sites.google.com/site/ruizenghomepage/), Jinglei Lv, He Wang, Luping Zhou, Michael Barnett, Fernando Calamante\\*, Chenyu Wang\\*. In [Medical Image Analysis](https://www.sciencedirect.com/science/article/abs/pii/S1361841522000822). (* equal contributions) [[Bibtex]](bib.txt).\n\n## Requirements\n\nThis module requires the following python packages:\n- `torch >= 2.0.0`\n- `lightning >= 2.0.0`\n- `numpy`\n- `einops`\n- `npy-patcher`\n- `scikit-image`\n- `nibabel`\n\nThese will be installed upon installation of this package, however it is recommended to follow the instructions for installing PyTorch independently before installing this package, to ensure correct hardware optimizations are enabled.\n\n## Installation\n\n```\npip install fodnet\n```\n\n## Training\n\nFollow the instructions below on how to train the FODNet model.\n\n\n### Data Preprocessing\n\nThis training pipeline requires data to be saved in `.npy` format. Additionally the spherical harmonic dimension must be the first dimension within each 4D array. This is because this module uses [npy-patcher](https://github.com/m-lyon/npy-cpp-patches) to extract training patches at runtime. Below is an example on how to convert `NIfTI` files into `.npy` using [nibabel](https://nipy.org/nibabel/).\n\n```python\nimport numpy as np\nimport nibabel as nib\n\nimg = nib.load('/path/to/fod.nii.gz')\ndata = np.asarray(img.dataobj, dtype=np.float32)  # Load FOD data into memory\ndata = data.transpose(3, 0, 1, 2)  # Move the SH dimension to 0\nnp.save('/path/to/fod.npy', data, allow_pickle=False)  # Save in npy format. Ensure this is on an SSD.\n```\n\n**N.B.** *Patches are read lazily from disk, therefore it is **highly** recommended to store the training data on an SSD type device, as an HDD will bottleneck the training process when data loading.*\n\n### Training\n\n```python\nimport lightning.pytorch as pl\n\nfrom fodnet.core.model import FODNetLightningModel\nfrom fodnet.core.dataset import Subject, FODNetDataModule\n\n# Collect dataset filepaths\nsubj1 = Subject('/path/to/lowres_fod1.npy', '/path/to/highresres_fod1.npy', '/path/to/mask1.npy')\nsubj2 = Subject('/path/to/lowres_fod2.npy', '/path/to/highresres_fod2.npy', '/path/to/mask2.npy')\nsubj3 = Subject('/path/to/lowres_fod3.npy', '/path/to/highresres_fod3.npy', '/path/to/mask3.npy')\n\n# Create DataModule instance. This is a thin wrapper around `pl.LightningDataModule`.\ndata_module = FODNetDataModule(\n    train_subjects=(subj1, subj2),\n    val_subjects=(subj3),\n    batch_size=16, # Batch size of each device\n    num_workers=8, # Number of CPU workers that load the data\n)\n\n# Load FODNet lightning model\nmodel = FODNetLightningModel()\n\n# Create `pl.Trainer` instance. `FODNetDataModule` is usable in DDP distributed training strategy.\ntrainer = pl.Trainer(devices=1, accelerator='gpu', epochs=100)\n\n# Start training\ntrainer.fit(model, data_module)\n```\n\n#### Customization\n\nThis implemenation uses a different training optimizer, loss, and learning rate than that used in the [original implementation](https://github.com/ruizengalways/FOD-Net). In particular we use `AdamW`, `L1 Loss`, and `0.003` respectively.\n\nChanging these hyperparameters is straightforward. Simply create a new class that inherits the `FODNetLightningModel`, and modify the properties/methods below. Use this class instead of `FODNetLightningModel` when training.\n\n```python\nclass MyCustomModel(FODNetLightningModel):\n\n    @property\n    def loss_func(self):\n        '''Different loss function'''\n        return torch.nn.functional.mse_loss\n    \n    def configure_optimizers(self):\n        '''Different Optimizer and learning rate'''\n        optimizer = torch.optim.SGD(self.parameters(), lr=1e-5)\n        return optimizer\n```\n\n## Prediction\n\n```python\nfrom fodnet.core.model import FODNetLightningModel\nfrom fodnet.core.prediction import FODNetPredictionProcessor\n\nmodel = FODNetLightningModel.load_from_checkpoint('/path/to/my/checkpoint.ckpt')\npredict = FODNetPredictionProcessor(batch_size=32, num_workers=8, accelerator='gpu')\npredict.run_subject(\n    model,\n    '/path/to/my/brainmask.nii.gz',\n    '/path/to/lowres_fod.nii.gz',\n    '/path/to/dest/highres_fod.nii.gz',\n    tmp_dir=None,  # Optionally specify a temporary directory to save the FOD file during processing\n)\n```\n\n**N.B.** *Patches are read lazily from disk, therefore it is recommended to ensure `tmp_dir` is on an SSD type device.*\n",
    "bugtrack_url": null,
    "license": "MIT License",
    "summary": "FOD-Net Reimplementation.",
    "version": "1.1.0",
    "project_urls": null,
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "642cf95eb5f722ad434d49e9c08576754a6ce58e4246534cdc81014226f6384d",
                "md5": "794a890155bc09be649fa280040fc1bb",
                "sha256": "92acd2126c113c5835ade2cf5b4dc934e2e14841502a81232c8a356a78393dbb"
            },
            "downloads": -1,
            "filename": "fodnet-1.1.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "794a890155bc09be649fa280040fc1bb",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.8",
            "size": 13290,
            "upload_time": "2023-10-03T15:59:14",
            "upload_time_iso_8601": "2023-10-03T15:59:14.284592Z",
            "url": "https://files.pythonhosted.org/packages/64/2c/f95eb5f722ad434d49e9c08576754a6ce58e4246534cdc81014226f6384d/fodnet-1.1.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "75b92b660bbddbe3554cbcfa6aa86296ce609d47da3d6957af0107523d1db392",
                "md5": "dfc30e3f97639276aaa011caebea8d91",
                "sha256": "370ace5f4496ef5a7b0e9d08805776c16f677e1a4ef55e5a78263d001ca34dc8"
            },
            "downloads": -1,
            "filename": "fodnet-1.1.0.tar.gz",
            "has_sig": false,
            "md5_digest": "dfc30e3f97639276aaa011caebea8d91",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.8",
            "size": 13196,
            "upload_time": "2023-10-03T15:59:15",
            "upload_time_iso_8601": "2023-10-03T15:59:15.821381Z",
            "url": "https://files.pythonhosted.org/packages/75/b9/2b660bbddbe3554cbcfa6aa86296ce609d47da3d6957af0107523d1db392/fodnet-1.1.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-10-03 15:59:15",
    "github": false,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "lcname": "fodnet"
}
        
Elapsed time: 0.11575s