jaxDiversity


NamejaxDiversity JSON
Version 0.0.2 PyPI version JSON
download
home_pagehttps://github.com/NonlinearArtificialIntelligenceLab/jaxDiversity
Summaryjax implementation for metalearning neuronal diversity
upload_time2023-07-18 18:45:13
maintainer
docs_urlNone
authorAnil Radhakrishnan
requires_python>=3.9
licenseApache Software License 2.0
keywords nbdev jupyter notebook python
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # jaxDiversity

<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

This is an updated implementation for Neural networks embrace diversity
paper

## Authors

Anshul Choudhary, Anil Radhakrishnan, John F. Lindner, Sudeshna Sinha,
and William L. Ditto

## Link to paper

- [arXiv](https://arxiv.org/abs/2204.04348)

## Key Results

- We construct neural networks with learnable activation functions and
  sere that they quickly diversify from each other under training.
- These activations subsequently outperform their *pure* counterparts on
  classification tasks.
- The neuronal sub-networks instantiate the neurons and meta-learning
  adjusts their weights and biases to find efficient spanning sets of
  nonlinear activations.
- These improved neural networks provide quantitative examples of the
  emergence of diversity and insight into its advantages.

## Install

``` sh
pip install jaxDiversity
```

## How to use

The codebase has 4 main components: \* dataloading: Contains tools for
loading the datasets mentioned in the manuscript. We use pytorch
dataloaders with a custom numpy collate function to use this data in
jax.

- losses: We handle both traditional mlps and hamiltonian neural
  networkss with minimal changes with our loss implementations.

- mlp: Contains custom mlp that takes in multiple activations and uses
  them *intralayer* to create a diverse network. Also contains the
  activation neural networks.

- loops: Contains the inner and outer loops for metalearning to optimize
  the activation functions in tandem with the supervised learning task

### Minimum example

``` python
import jax
import optax

from jaxDiversity.utilclasses import InnerConfig, OuterConfig # simple utility classes for configuration consistency
from jaxDiversity.dataloading import NumpyLoader, DummyDataset
from jaxDiversity.mlp import mlp_afunc, MultiActMLP, init_linear_weight, xavier_normal_init, save
from jaxDiversity.baseline import compute_loss as compute_loss_baseline
from jaxDiversity.hnn import compute_loss as compute_loss_hnn
from jaxDiversity.loops import inner_opt, outer_opt
```

#### inner optimzation or standard training loop with the baseline activation

``` python
dev_inner_config = InnerConfig(test_train_split=0.8,
                            input_dim=2,
                            output_dim=2,
                            hidden_layer_sizes=[18],
                            batch_size=64,
                            epochs=2,
                            lr=1e-3,
                            mu=0.9,
                            n_fns=2,
                            l2_reg=1e-1,
                            seed=42)
key = jax.random.PRNGKey(dev_inner_config.seed)
model_key, init_key = jax.random.split(key)
afuncs = [lambda x: x**2, lambda x: x]
train_dataset = DummyDataset(1000, dev_inner_config.input_dim, dev_inner_config.output_dim)
test_dataset = DummyDataset(1000, dev_inner_config.input_dim, dev_inner_config.output_dim)
train_dataloader = NumpyLoader(train_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)
test_dataloader = NumpyLoader(test_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)

opt = optax.rmsprop(learning_rate=dev_inner_config.lr, momentum=dev_inner_config.mu, decay=dev_inner_config.l2_reg)
model = MultiActMLP(dev_inner_config.input_dim, dev_inner_config.output_dim, dev_inner_config.hidden_layer_sizes, model_key, bias=False)
baselineNN, opt_state ,inner_results = inner_opt(model =model, 
                                            train_data =train_dataloader,
                                            test_data = test_dataloader,
                                            afuncs = afuncs, 
                                            opt = opt, 
                                            loss_fn=compute_loss_baseline,
                                            config = dev_inner_config, training=True, verbose=True)
```

#### metalearning with Hamiltonian Neural Networks

``` python
inner_config = InnerConfig(test_train_split=0.8,
                            input_dim=2,
                            output_dim=1,
                            hidden_layer_sizes=[32],
                            batch_size=64,
                            epochs=5,
                            lr=1e-3,
                            mu=0.9,
                            n_fns=2,
                            l2_reg=1e-1,
                            seed=42)
outer_config = OuterConfig(input_dim=1,
                            output_dim=1,
                            hidden_layer_sizes=[18],
                            batch_size=1,
                            steps=2,
                            print_every=1,
                            lr=1e-3,
                            mu=0.9,
                            seed=24)
train_dataset = DummyDataset(1000, inner_config.input_dim, 2)
test_dataset = DummyDataset(1000, inner_config.input_dim, 2)
train_dataloader = NumpyLoader(train_dataset, batch_size=inner_config.batch_size, shuffle=True)
test_dataloader = NumpyLoader(test_dataset, batch_size=inner_config.batch_size, shuffle=True)

opt = optax.rmsprop(learning_rate=inner_config.lr, momentum=inner_config.mu, decay=inner_config.l2_reg)
meta_opt = optax.rmsprop(learning_rate=outer_config.lr, momentum=outer_config.mu)

HNN_acts, HNN_stats = outer_opt(train_dataloader, test_dataloader,compute_loss_hnn ,inner_config, outer_config, opt, meta_opt, save_path=None)
```

Link to older pytorch codebase with classification problem:
[DiversityNN](https://github.com/NonlinearArtificialIntelligenceLab/DiversityNN)

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/NonlinearArtificialIntelligenceLab/jaxDiversity",
    "name": "jaxDiversity",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": "",
    "keywords": "nbdev jupyter notebook python",
    "author": "Anil Radhakrishnan",
    "author_email": "aradhak5@ncsu.edu",
    "download_url": "https://files.pythonhosted.org/packages/36/ca/f42365a330da1c16118a5c659ac6cada04e84e902a4fb0640ef5c37d7d2e/jaxDiversity-0.0.2.tar.gz",
    "platform": null,
    "description": "# jaxDiversity\n\n<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->\n\nThis is an updated implementation for Neural networks embrace diversity\npaper\n\n## Authors\n\nAnshul Choudhary, Anil Radhakrishnan, John F. Lindner, Sudeshna Sinha,\nand William L. Ditto\n\n## Link to paper\n\n- [arXiv](https://arxiv.org/abs/2204.04348)\n\n## Key Results\n\n- We construct neural networks with learnable activation functions and\n  sere that they quickly diversify from each other under training.\n- These activations subsequently outperform their *pure* counterparts on\n  classification tasks.\n- The neuronal sub-networks instantiate the neurons and meta-learning\n  adjusts their weights and biases to find efficient spanning sets of\n  nonlinear activations.\n- These improved neural networks provide quantitative examples of the\n  emergence of diversity and insight into its advantages.\n\n## Install\n\n``` sh\npip install jaxDiversity\n```\n\n## How to use\n\nThe codebase has 4 main components: \\* dataloading: Contains tools for\nloading the datasets mentioned in the manuscript. We use pytorch\ndataloaders with a custom numpy collate function to use this data in\njax.\n\n- losses: We handle both traditional mlps and hamiltonian neural\n  networkss with minimal changes with our loss implementations.\n\n- mlp: Contains custom mlp that takes in multiple activations and uses\n  them *intralayer* to create a diverse network. Also contains the\n  activation neural networks.\n\n- loops: Contains the inner and outer loops for metalearning to optimize\n  the activation functions in tandem with the supervised learning task\n\n### Minimum example\n\n``` python\nimport jax\nimport optax\n\nfrom jaxDiversity.utilclasses import InnerConfig, OuterConfig # simple utility classes for configuration consistency\nfrom jaxDiversity.dataloading import NumpyLoader, DummyDataset\nfrom jaxDiversity.mlp import mlp_afunc, MultiActMLP, init_linear_weight, xavier_normal_init, save\nfrom jaxDiversity.baseline import compute_loss as compute_loss_baseline\nfrom jaxDiversity.hnn import compute_loss as compute_loss_hnn\nfrom jaxDiversity.loops import inner_opt, outer_opt\n```\n\n#### inner optimzation or standard training loop with the baseline activation\n\n``` python\ndev_inner_config = InnerConfig(test_train_split=0.8,\n                            input_dim=2,\n                            output_dim=2,\n                            hidden_layer_sizes=[18],\n                            batch_size=64,\n                            epochs=2,\n                            lr=1e-3,\n                            mu=0.9,\n                            n_fns=2,\n                            l2_reg=1e-1,\n                            seed=42)\nkey = jax.random.PRNGKey(dev_inner_config.seed)\nmodel_key, init_key = jax.random.split(key)\nafuncs = [lambda x: x**2, lambda x: x]\ntrain_dataset = DummyDataset(1000, dev_inner_config.input_dim, dev_inner_config.output_dim)\ntest_dataset = DummyDataset(1000, dev_inner_config.input_dim, dev_inner_config.output_dim)\ntrain_dataloader = NumpyLoader(train_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)\ntest_dataloader = NumpyLoader(test_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)\n\nopt = optax.rmsprop(learning_rate=dev_inner_config.lr, momentum=dev_inner_config.mu, decay=dev_inner_config.l2_reg)\nmodel = MultiActMLP(dev_inner_config.input_dim, dev_inner_config.output_dim, dev_inner_config.hidden_layer_sizes, model_key, bias=False)\nbaselineNN, opt_state ,inner_results = inner_opt(model =model, \n                                            train_data =train_dataloader,\n                                            test_data = test_dataloader,\n                                            afuncs = afuncs, \n                                            opt = opt, \n                                            loss_fn=compute_loss_baseline,\n                                            config = dev_inner_config, training=True, verbose=True)\n```\n\n#### metalearning with Hamiltonian Neural Networks\n\n``` python\ninner_config = InnerConfig(test_train_split=0.8,\n                            input_dim=2,\n                            output_dim=1,\n                            hidden_layer_sizes=[32],\n                            batch_size=64,\n                            epochs=5,\n                            lr=1e-3,\n                            mu=0.9,\n                            n_fns=2,\n                            l2_reg=1e-1,\n                            seed=42)\nouter_config = OuterConfig(input_dim=1,\n                            output_dim=1,\n                            hidden_layer_sizes=[18],\n                            batch_size=1,\n                            steps=2,\n                            print_every=1,\n                            lr=1e-3,\n                            mu=0.9,\n                            seed=24)\ntrain_dataset = DummyDataset(1000, inner_config.input_dim, 2)\ntest_dataset = DummyDataset(1000, inner_config.input_dim, 2)\ntrain_dataloader = NumpyLoader(train_dataset, batch_size=inner_config.batch_size, shuffle=True)\ntest_dataloader = NumpyLoader(test_dataset, batch_size=inner_config.batch_size, shuffle=True)\n\nopt = optax.rmsprop(learning_rate=inner_config.lr, momentum=inner_config.mu, decay=inner_config.l2_reg)\nmeta_opt = optax.rmsprop(learning_rate=outer_config.lr, momentum=outer_config.mu)\n\nHNN_acts, HNN_stats = outer_opt(train_dataloader, test_dataloader,compute_loss_hnn ,inner_config, outer_config, opt, meta_opt, save_path=None)\n```\n\nLink to older pytorch codebase with classification problem:\n[DiversityNN](https://github.com/NonlinearArtificialIntelligenceLab/DiversityNN)\n",
    "bugtrack_url": null,
    "license": "Apache Software License 2.0",
    "summary": "jax implementation for metalearning neuronal diversity",
    "version": "0.0.2",
    "project_urls": {
        "Homepage": "https://github.com/NonlinearArtificialIntelligenceLab/jaxDiversity"
    },
    "split_keywords": [
        "nbdev",
        "jupyter",
        "notebook",
        "python"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "4f990cf1fbeb14f042fc0ae9145ba1456da709c5338ffb0a5603f599f15f104a",
                "md5": "21e1ddfbb07dbfa69644468271e84fb3",
                "sha256": "79dcbea6fcf78b3876613aa47497677d43fb8db3447b5f8cb0dc630f8405ea44"
            },
            "downloads": -1,
            "filename": "jaxDiversity-0.0.2-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "21e1ddfbb07dbfa69644468271e84fb3",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 15838,
            "upload_time": "2023-07-18T18:45:11",
            "upload_time_iso_8601": "2023-07-18T18:45:11.635865Z",
            "url": "https://files.pythonhosted.org/packages/4f/99/0cf1fbeb14f042fc0ae9145ba1456da709c5338ffb0a5603f599f15f104a/jaxDiversity-0.0.2-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "36caf42365a330da1c16118a5c659ac6cada04e84e902a4fb0640ef5c37d7d2e",
                "md5": "0bd76e2a5a3e88dcb5455a82d84e5a75",
                "sha256": "40f72d1d3b7a464ea09cd8247e05deacf1d32a32ac9ac9a344487841d2a31121"
            },
            "downloads": -1,
            "filename": "jaxDiversity-0.0.2.tar.gz",
            "has_sig": false,
            "md5_digest": "0bd76e2a5a3e88dcb5455a82d84e5a75",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 16052,
            "upload_time": "2023-07-18T18:45:13",
            "upload_time_iso_8601": "2023-07-18T18:45:13.109431Z",
            "url": "https://files.pythonhosted.org/packages/36/ca/f42365a330da1c16118a5c659ac6cada04e84e902a4fb0640ef5c37d7d2e/jaxDiversity-0.0.2.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-07-18 18:45:13",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "NonlinearArtificialIntelligenceLab",
    "github_project": "jaxDiversity",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [],
    "lcname": "jaxdiversity"
}
        
Elapsed time: 0.21127s