# jaxDiversity
<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->
This is an updated implementation for Neural networks embrace diversity
paper
## Authors
Anshul Choudhary, Anil Radhakrishnan, John F. Lindner, Sudeshna Sinha,
and William L. Ditto
## Link to paper
- [arXiv](https://arxiv.org/abs/2204.04348)
## Key Results
- We construct neural networks with learnable activation functions and
sere that they quickly diversify from each other under training.
- These activations subsequently outperform their *pure* counterparts on
classification tasks.
- The neuronal sub-networks instantiate the neurons and meta-learning
adjusts their weights and biases to find efficient spanning sets of
nonlinear activations.
- These improved neural networks provide quantitative examples of the
emergence of diversity and insight into its advantages.
## Install
``` sh
pip install jaxDiversity
```
## How to use
The codebase has 4 main components: \* dataloading: Contains tools for
loading the datasets mentioned in the manuscript. We use pytorch
dataloaders with a custom numpy collate function to use this data in
jax.
- losses: We handle both traditional mlps and hamiltonian neural
networkss with minimal changes with our loss implementations.
- mlp: Contains custom mlp that takes in multiple activations and uses
them *intralayer* to create a diverse network. Also contains the
activation neural networks.
- loops: Contains the inner and outer loops for metalearning to optimize
the activation functions in tandem with the supervised learning task
### Minimum example
``` python
import jax
import optax
from jaxDiversity.utilclasses import InnerConfig, OuterConfig # simple utility classes for configuration consistency
from jaxDiversity.dataloading import NumpyLoader, DummyDataset
from jaxDiversity.mlp import mlp_afunc, MultiActMLP, init_linear_weight, xavier_normal_init, save
from jaxDiversity.baseline import compute_loss as compute_loss_baseline
from jaxDiversity.hnn import compute_loss as compute_loss_hnn
from jaxDiversity.loops import inner_opt, outer_opt
```
#### inner optimzation or standard training loop with the baseline activation
``` python
dev_inner_config = InnerConfig(test_train_split=0.8,
input_dim=2,
output_dim=2,
hidden_layer_sizes=[18],
batch_size=64,
epochs=2,
lr=1e-3,
mu=0.9,
n_fns=2,
l2_reg=1e-1,
seed=42)
key = jax.random.PRNGKey(dev_inner_config.seed)
model_key, init_key = jax.random.split(key)
afuncs = [lambda x: x**2, lambda x: x]
train_dataset = DummyDataset(1000, dev_inner_config.input_dim, dev_inner_config.output_dim)
test_dataset = DummyDataset(1000, dev_inner_config.input_dim, dev_inner_config.output_dim)
train_dataloader = NumpyLoader(train_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)
test_dataloader = NumpyLoader(test_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)
opt = optax.rmsprop(learning_rate=dev_inner_config.lr, momentum=dev_inner_config.mu, decay=dev_inner_config.l2_reg)
model = MultiActMLP(dev_inner_config.input_dim, dev_inner_config.output_dim, dev_inner_config.hidden_layer_sizes, model_key, bias=False)
baselineNN, opt_state ,inner_results = inner_opt(model =model,
train_data =train_dataloader,
test_data = test_dataloader,
afuncs = afuncs,
opt = opt,
loss_fn=compute_loss_baseline,
config = dev_inner_config, training=True, verbose=True)
```
#### metalearning with Hamiltonian Neural Networks
``` python
inner_config = InnerConfig(test_train_split=0.8,
input_dim=2,
output_dim=1,
hidden_layer_sizes=[32],
batch_size=64,
epochs=5,
lr=1e-3,
mu=0.9,
n_fns=2,
l2_reg=1e-1,
seed=42)
outer_config = OuterConfig(input_dim=1,
output_dim=1,
hidden_layer_sizes=[18],
batch_size=1,
steps=2,
print_every=1,
lr=1e-3,
mu=0.9,
seed=24)
train_dataset = DummyDataset(1000, inner_config.input_dim, 2)
test_dataset = DummyDataset(1000, inner_config.input_dim, 2)
train_dataloader = NumpyLoader(train_dataset, batch_size=inner_config.batch_size, shuffle=True)
test_dataloader = NumpyLoader(test_dataset, batch_size=inner_config.batch_size, shuffle=True)
opt = optax.rmsprop(learning_rate=inner_config.lr, momentum=inner_config.mu, decay=inner_config.l2_reg)
meta_opt = optax.rmsprop(learning_rate=outer_config.lr, momentum=outer_config.mu)
HNN_acts, HNN_stats = outer_opt(train_dataloader, test_dataloader,compute_loss_hnn ,inner_config, outer_config, opt, meta_opt, save_path=None)
```
Link to older pytorch codebase with classification problem:
[DiversityNN](https://github.com/NonlinearArtificialIntelligenceLab/DiversityNN)
Raw data
{
"_id": null,
"home_page": "https://github.com/NonlinearArtificialIntelligenceLab/jaxDiversity",
"name": "jaxDiversity",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": "",
"keywords": "nbdev jupyter notebook python",
"author": "Anil Radhakrishnan",
"author_email": "aradhak5@ncsu.edu",
"download_url": "https://files.pythonhosted.org/packages/36/ca/f42365a330da1c16118a5c659ac6cada04e84e902a4fb0640ef5c37d7d2e/jaxDiversity-0.0.2.tar.gz",
"platform": null,
"description": "# jaxDiversity\n\n<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->\n\nThis is an updated implementation for Neural networks embrace diversity\npaper\n\n## Authors\n\nAnshul Choudhary, Anil Radhakrishnan, John F. Lindner, Sudeshna Sinha,\nand William L. Ditto\n\n## Link to paper\n\n- [arXiv](https://arxiv.org/abs/2204.04348)\n\n## Key Results\n\n- We construct neural networks with learnable activation functions and\n sere that they quickly diversify from each other under training.\n- These activations subsequently outperform their *pure* counterparts on\n classification tasks.\n- The neuronal sub-networks instantiate the neurons and meta-learning\n adjusts their weights and biases to find efficient spanning sets of\n nonlinear activations.\n- These improved neural networks provide quantitative examples of the\n emergence of diversity and insight into its advantages.\n\n## Install\n\n``` sh\npip install jaxDiversity\n```\n\n## How to use\n\nThe codebase has 4 main components: \\* dataloading: Contains tools for\nloading the datasets mentioned in the manuscript. We use pytorch\ndataloaders with a custom numpy collate function to use this data in\njax.\n\n- losses: We handle both traditional mlps and hamiltonian neural\n networkss with minimal changes with our loss implementations.\n\n- mlp: Contains custom mlp that takes in multiple activations and uses\n them *intralayer* to create a diverse network. Also contains the\n activation neural networks.\n\n- loops: Contains the inner and outer loops for metalearning to optimize\n the activation functions in tandem with the supervised learning task\n\n### Minimum example\n\n``` python\nimport jax\nimport optax\n\nfrom jaxDiversity.utilclasses import InnerConfig, OuterConfig # simple utility classes for configuration consistency\nfrom jaxDiversity.dataloading import NumpyLoader, DummyDataset\nfrom jaxDiversity.mlp import mlp_afunc, MultiActMLP, init_linear_weight, xavier_normal_init, save\nfrom jaxDiversity.baseline import compute_loss as compute_loss_baseline\nfrom jaxDiversity.hnn import compute_loss as compute_loss_hnn\nfrom jaxDiversity.loops import inner_opt, outer_opt\n```\n\n#### inner optimzation or standard training loop with the baseline activation\n\n``` python\ndev_inner_config = InnerConfig(test_train_split=0.8,\n input_dim=2,\n output_dim=2,\n hidden_layer_sizes=[18],\n batch_size=64,\n epochs=2,\n lr=1e-3,\n mu=0.9,\n n_fns=2,\n l2_reg=1e-1,\n seed=42)\nkey = jax.random.PRNGKey(dev_inner_config.seed)\nmodel_key, init_key = jax.random.split(key)\nafuncs = [lambda x: x**2, lambda x: x]\ntrain_dataset = DummyDataset(1000, dev_inner_config.input_dim, dev_inner_config.output_dim)\ntest_dataset = DummyDataset(1000, dev_inner_config.input_dim, dev_inner_config.output_dim)\ntrain_dataloader = NumpyLoader(train_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)\ntest_dataloader = NumpyLoader(test_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)\n\nopt = optax.rmsprop(learning_rate=dev_inner_config.lr, momentum=dev_inner_config.mu, decay=dev_inner_config.l2_reg)\nmodel = MultiActMLP(dev_inner_config.input_dim, dev_inner_config.output_dim, dev_inner_config.hidden_layer_sizes, model_key, bias=False)\nbaselineNN, opt_state ,inner_results = inner_opt(model =model, \n train_data =train_dataloader,\n test_data = test_dataloader,\n afuncs = afuncs, \n opt = opt, \n loss_fn=compute_loss_baseline,\n config = dev_inner_config, training=True, verbose=True)\n```\n\n#### metalearning with Hamiltonian Neural Networks\n\n``` python\ninner_config = InnerConfig(test_train_split=0.8,\n input_dim=2,\n output_dim=1,\n hidden_layer_sizes=[32],\n batch_size=64,\n epochs=5,\n lr=1e-3,\n mu=0.9,\n n_fns=2,\n l2_reg=1e-1,\n seed=42)\nouter_config = OuterConfig(input_dim=1,\n output_dim=1,\n hidden_layer_sizes=[18],\n batch_size=1,\n steps=2,\n print_every=1,\n lr=1e-3,\n mu=0.9,\n seed=24)\ntrain_dataset = DummyDataset(1000, inner_config.input_dim, 2)\ntest_dataset = DummyDataset(1000, inner_config.input_dim, 2)\ntrain_dataloader = NumpyLoader(train_dataset, batch_size=inner_config.batch_size, shuffle=True)\ntest_dataloader = NumpyLoader(test_dataset, batch_size=inner_config.batch_size, shuffle=True)\n\nopt = optax.rmsprop(learning_rate=inner_config.lr, momentum=inner_config.mu, decay=inner_config.l2_reg)\nmeta_opt = optax.rmsprop(learning_rate=outer_config.lr, momentum=outer_config.mu)\n\nHNN_acts, HNN_stats = outer_opt(train_dataloader, test_dataloader,compute_loss_hnn ,inner_config, outer_config, opt, meta_opt, save_path=None)\n```\n\nLink to older pytorch codebase with classification problem:\n[DiversityNN](https://github.com/NonlinearArtificialIntelligenceLab/DiversityNN)\n",
"bugtrack_url": null,
"license": "Apache Software License 2.0",
"summary": "jax implementation for metalearning neuronal diversity",
"version": "0.0.2",
"project_urls": {
"Homepage": "https://github.com/NonlinearArtificialIntelligenceLab/jaxDiversity"
},
"split_keywords": [
"nbdev",
"jupyter",
"notebook",
"python"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "4f990cf1fbeb14f042fc0ae9145ba1456da709c5338ffb0a5603f599f15f104a",
"md5": "21e1ddfbb07dbfa69644468271e84fb3",
"sha256": "79dcbea6fcf78b3876613aa47497677d43fb8db3447b5f8cb0dc630f8405ea44"
},
"downloads": -1,
"filename": "jaxDiversity-0.0.2-py3-none-any.whl",
"has_sig": false,
"md5_digest": "21e1ddfbb07dbfa69644468271e84fb3",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 15838,
"upload_time": "2023-07-18T18:45:11",
"upload_time_iso_8601": "2023-07-18T18:45:11.635865Z",
"url": "https://files.pythonhosted.org/packages/4f/99/0cf1fbeb14f042fc0ae9145ba1456da709c5338ffb0a5603f599f15f104a/jaxDiversity-0.0.2-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "36caf42365a330da1c16118a5c659ac6cada04e84e902a4fb0640ef5c37d7d2e",
"md5": "0bd76e2a5a3e88dcb5455a82d84e5a75",
"sha256": "40f72d1d3b7a464ea09cd8247e05deacf1d32a32ac9ac9a344487841d2a31121"
},
"downloads": -1,
"filename": "jaxDiversity-0.0.2.tar.gz",
"has_sig": false,
"md5_digest": "0bd76e2a5a3e88dcb5455a82d84e5a75",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 16052,
"upload_time": "2023-07-18T18:45:13",
"upload_time_iso_8601": "2023-07-18T18:45:13.109431Z",
"url": "https://files.pythonhosted.org/packages/36/ca/f42365a330da1c16118a5c659ac6cada04e84e902a4fb0640ef5c37d7d2e/jaxDiversity-0.0.2.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-07-18 18:45:13",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "NonlinearArtificialIntelligenceLab",
"github_project": "jaxDiversity",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"requirements": [],
"lcname": "jaxdiversity"
}