HAMUX
================
<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->
<img src="https://raw.githubusercontent.com/bhoov/hamux/main/assets/header.png" alt="HAMUX Logo" width="400"/>
Part proof-of-concept, part functional prototype, HAMUX is designed to
bridge modern AI architectures and biologically plausible Hopfield
Networks.
**HAMUX**: A **H**ierarchical **A**ssociative **M**emory **U**ser
e**X**perience
<div class="alert alert-info">
🚧 <strong>API is in rapid development</strong>. Remember to specify the version when building off of HAMUX.
</div>
## A Universal Abstraction for Hopfield Networks
HAMUX fully captures the the energy fundamentals of Hopfield Networks
and enables anyone to:
- 🧠 Build **DEEP** Hopfield nets
- 🧱 With modular **ENERGY** components
- 🏆 That resemble modern DL operations
**Every** architecture built using HAMUX is a *dynamical system*
guaranteed to have a *tractable energy* function that *converges* to a
fixed point. Our deep [Hierarchical Associative
Memories](https://arxiv.org/abs/2107.06446) (HAMs) have several
additional advantages over traditional [Hopfield
Networks](https://en.wikipedia.org/wiki/Hopfield_network) (HNs):
| Hopfield Networks (HNs) | Hierarchical Associative Memories (HAMs) |
|--------------------------------------------------------|------------------------------------------------------------------------------------------------|
| HNs are only **two layers** systems | HAMs connect **any number** of layers |
| HNs model only **linear relationships** between layers | HAMs model **any differentiable operation** (e.g., convolutions, pooling, attention, $\ldots$) |
| HNs use only **pairwise synapses** | HAMs use **many-body synapses** |
## How does HAMUX work?
> **HAMUX** is a
> <a href="https://en.wikipedia.org/wiki/Hypergraph" >hypergraph</a> of
> 🌀neurons connected via 🤝synapses, an abstraction sufficiently
> general to model the complexity of connections contained in the 🧠.
HAMUX defines two fundamental building blocks of energy: the **🌀neuron
layer** and the **🤝synapse**, connected via a **hypergraph**
<img src="https://raw.githubusercontent.com/bhoov/hamux/main/assets/fig1.png" alt="HAMUX Overview" width="700"/>
### 🌀Neuron Layers
Neuron layers are the recurrent unit of a HAM; that is, 🌀neurons keep a
state that changes over time according to the dynamics of the system.
These states always change to minimize the global energy function of the
system.
For those of us familiar with traditional Deep Learning architectures,
we are familiar with nonlinear activation functions like the `ReLU` and
`SoftMax`. A neuron layer in HAMUX is exactly that: a nonlinear
activation function defined on some neuron. However, we need to express
the activation function as a convex **Lagrangian function**
$\mathcal{L}$ that is the integral of the desired non-linearity such
that the **derivative of the Lagrangian function** $\nabla \mathcal{L}$
is our desired non-linearity. E.g., consider the ReLU:
$$
\begin{align*}
\mathcal{L}(x) &:= \frac{1}{2} (\max(x, 0))^2\\
\nabla \mathcal{L} &= \max(x, 0) = \mathrm{relu}(x)\\
\end{align*}
$$
We need to define our activation layer in terms of the *Lagrangian* of
the ReLU instead of the ReLU itself. Extending this constraint to other
nonlinearities makes it possible to define the scalar energy for any
neuron in a HAM. It turns out that many activation functions used in
today’s Deep Learning landscape are expressible as a Lagrangian. HAMUX
is “batteries-included” for many common activation functions including
`relu`s, `softmax`es, `sigmoid`s, `LayerNorm`s, etc. See our
[documentation on
Lagrangians](https://bhoov.github.io/hamux/lagrangians.html) for
examples on how to implement efficient activation functions from
Lagrangians in JAX. We show how to turn Lagrangians into usable energy
building blocks in our [documentation on neuron
layers](https://bhoov.github.io/hamux/layers.html).
### 🤝Synapses
A 🤝synapse ONLY sees activations of connected 🌀neuron layers. Its one
job: report HIGH ⚡️energy if the connected activations are dissimilar
and LOW ⚡️energy when they are aligned. Synapses can resemble
convolutions, dense multiplications, even attention… Take a look at our
[documentation on
synapses](https://bhoov.github.io/hamux/synapses.html).
<div class="alert alert-info">
🚨 <strong>Point of confusion</strong>: modern AI frameworks have <code>AttentionLayer</code>s and <code>ConvolutionalLayer</code>s. In HAMUX, these would be more appropriately called <code>AttentionSynapse</code>s and <code>ConvolutionalSynapse</code>s.
</div>
## Install
**From pip**:
pip install hamux
If you are using accelerators beyond the CPU you will need to
additionally install the corresponding `jax` and `jaxlib` versions
following [their
documentation](https://github.com/google/jax#installation). E.g.,
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
**From source**:
After cloning:
cd hamux
conda env create -f environment.yml
conda activate hamux
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # If using GPU accelerator
pip install -e .
pip install -r requirements-dev.txt # To run the examples
## How to Use
``` python
import hamux as hmx
import jax.numpy as jnp
import jax
import jax.tree_util as jtu
```
We can build a simple 4 layer HAM architecture using the following code
``` python
layers = [
hmx.TanhLayer((32,32,3)), # e.g., CIFAR Images
hmx.SigmoidLayer((11,11,1000)), # CIFAR patches
hmx.SoftmaxLayer((10,)), # CIFAR Labels
hmx.SoftmaxLayer((1000,)), # Hidden Memory Layer
]
synapses = [
hmx.ConvSynapse((3,3), strides=3),
hmx.DenseSynapse(),
hmx.DenseSynapse(),
]
connections = [
([0,1], 0),
([1,3], 1),
([2,3], 2),
]
rng = jax.random.PRNGKey(0)
param_key, state_key, rng = jax.random.split(rng, 3)
states, ham = hmx.HAM(layers, synapses, connections).init_states_and_params(param_key, state_key=state_key);
```
Notice that we did not specify any output channel shapes in the
synapses. The desired output shape is computed from the layers connected
to each synapse during `hmx.HAM.init_states_and_params`.
We have two fundamental objects: `states` and `ham`. The `ham` object
contains the connectivity structure of the HAM (e.g.,
layer+synapse+hypergraph information) alongside the **parameters** of
the network. The `states` object is a list of length `nlayers` where
each item is a tensor representing the neuron states of the
corresponding layer.
``` python
assert len(states) == ham.n_layers
assert all([state.shape == layer.shape for state, layer in zip(states, ham.layers)])
```
We make it easy to run the dynamics of any HAM. Every `forward` function
is defined external to the memory and can be modified to extract
different memories from different layers, as desired. The general steps
for any forward function are:
1. Initialize the dynamic states
2. Inject an initial state into the system
3. Run dynamics, calculating energy gradient at every point in time.
4. Return the layer state/activation of interest
``` python
def fwd(model, x, depth=15, dt=0.1):
"""Assuming a trained HAM, run association with the HAM on batched inputs `x`"""
# 1. Initialize model states at t=0. Account for batch size
xs = model.init_states(x.shape[0])
# Inject initial state
xs[0] = x
energies = []
for i in range(depth):
energies.append(model.venergy(xs)) # If desired, observe the energy
dEdg = model.vdEdg(xs) # Calculate the gradients
xs = jtu.tree_map(lambda x, stepsize, grad: x - stepsize * grad, xs, model.alphas(dt), dEdg)
# Return probabilities of our label layer
probs = model.layers[-2].activation(xs[-2])
return jnp.stack(energies), probs
```
``` python
batch_size=3
x = jax.random.normal(jax.random.PRNGKey(2), (batch_size, 32,32,3))
energies, probs = fwd(ham, x, depth=20, dt=0.3)
print(probs.shape) # batchsize, nclasses
assert jnp.allclose(probs.sum(-1), 1)
```
(3, 10)
![](index_files/figure-gfm/cell-11-output-1.png)
## The Energy Function vs the Loss Function
We use JAX’s autograd to descend the energy function of our system AND
the loss function of our task. The derivative of the energy is always
taken wrt to our *states*; the derivative of the loss function is always
taken wrt our *parameters*. During training, we change our parameters to
optimize the *Loss Function*. During inference, we assume that
parameters are constant.
**Autograd for Descending Energy**
Every [`HAM`](https://bhoov.github.io/hamux/ham.html#ham) defines the
energy function for our system, which is everything we need to compute
memories of the system. Naively, we can calculate $\nabla_x E$: the
derivative of the energy function wrt the *states* of each layer:
``` python
stepsize = 0.01
fscore_naive = jax.grad(ham.energy)
next_states = jax.tree_util.tree_map(lambda state, score: state - stepsize, states, fscore_naive(states))
```
But it turns out we improve the efficiency of our network if we instead
take $\nabla_g E$: the derivative of the energy wrt. the *activations*
instead of the *states*. They have the same local minima, even though
the trajectory to get there is different. Some nice terms cancel, and we
get:
$$\nabla_g E_\text{HAM} = x + \nabla_g E_\text{synapse}$$
``` python
stepsize = 0.01
def fscore_smart(xs):
gs = ham.activations(xs)
return jax.tree_util.tree_map(lambda x, nabla_g_Esyn: x + nabla_g_Esyn, xs, jax.grad(ham.synapse_energy)(gs))
next_states = jax.tree_util.tree_map(lambda state, score: state - stepsize, states, fscore_smart(states))
```
## Citation
Work is a collaboration between the [MIT-IBM Watson AI
Lab](https://mitibmwatsonailab.mit.edu/) and the
[PoloClub](https://poloclub.github.io/) @ GA Tech
Raw data
{
"_id": null,
"home_page": "https://github.com/bhoov/hamux",
"name": "hamux",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.7",
"maintainer_email": "",
"keywords": "nbdev jupyter notebook python hopfield ai hypergraph neurons synapses associative memory hierarchical framework jax treex energy autograd lagrangian",
"author": "Benjamin Hoover",
"author_email": "benjamin.hoover@ibm.com",
"download_url": "https://files.pythonhosted.org/packages/a6/5b/ee1fa1533e4689167d758df1f38dc25ed2c4d6f6e997160182415cf97583/hamux-0.1.1.tar.gz",
"platform": null,
"description": "HAMUX\n================\n\n<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->\n\n<img src=\"https://raw.githubusercontent.com/bhoov/hamux/main/assets/header.png\" alt=\"HAMUX Logo\" width=\"400\"/>\n\nPart proof-of-concept, part functional prototype, HAMUX is designed to\nbridge modern AI architectures and biologically plausible Hopfield\nNetworks.\n\n**HAMUX**: A **H**ierarchical **A**ssociative **M**emory **U**ser\ne**X**perience\n\n<div class=\"alert alert-info\">\n \ud83d\udea7 <strong>API is in rapid development</strong>. Remember to specify the version when building off of HAMUX.\n</div>\n\n## A Universal Abstraction for Hopfield Networks\n\nHAMUX fully captures the the energy fundamentals of Hopfield Networks\nand enables anyone to:\n\n- \ud83e\udde0 Build **DEEP** Hopfield nets\n\n- \ud83e\uddf1 With modular **ENERGY** components\n\n- \ud83c\udfc6 That resemble modern DL operations\n\n**Every** architecture built using HAMUX is a *dynamical system*\nguaranteed to have a *tractable energy* function that *converges* to a\nfixed point. Our deep [Hierarchical Associative\nMemories](https://arxiv.org/abs/2107.06446) (HAMs) have several\nadditional advantages over traditional [Hopfield\nNetworks](https://en.wikipedia.org/wiki/Hopfield_network) (HNs):\n\n| Hopfield Networks (HNs) | Hierarchical Associative Memories (HAMs) |\n|--------------------------------------------------------|------------------------------------------------------------------------------------------------|\n| HNs are only **two layers** systems | HAMs connect **any number** of layers |\n| HNs model only **linear relationships** between layers | HAMs model **any differentiable operation** (e.g., convolutions, pooling, attention, $\\ldots$) |\n| HNs use only **pairwise synapses** | HAMs use **many-body synapses** |\n\n## How does HAMUX work?\n\n> **HAMUX** is a\n> <a href=\"https://en.wikipedia.org/wiki/Hypergraph\" >hypergraph</a> of\n> \ud83c\udf00neurons connected via \ud83e\udd1dsynapses, an abstraction sufficiently\n> general to model the complexity of connections contained in the \ud83e\udde0.\n\nHAMUX defines two fundamental building blocks of energy: the **\ud83c\udf00neuron\nlayer** and the **\ud83e\udd1dsynapse**, connected via a **hypergraph**\n\n<img src=\"https://raw.githubusercontent.com/bhoov/hamux/main/assets/fig1.png\" alt=\"HAMUX Overview\" width=\"700\"/>\n\n### \ud83c\udf00Neuron Layers\n\nNeuron layers are the recurrent unit of a HAM; that is, \ud83c\udf00neurons keep a\nstate that changes over time according to the dynamics of the system.\nThese states always change to minimize the global energy function of the\nsystem.\n\nFor those of us familiar with traditional Deep Learning architectures,\nwe are familiar with nonlinear activation functions like the `ReLU` and\n`SoftMax`. A neuron layer in HAMUX is exactly that: a nonlinear\nactivation function defined on some neuron. However, we need to express\nthe activation function as a convex **Lagrangian function**\n$\\mathcal{L}$ that is the integral of the desired non-linearity such\nthat the **derivative of the Lagrangian function** $\\nabla \\mathcal{L}$\nis our desired non-linearity. E.g., consider the ReLU:\n\n$$\n\\begin{align*}\n\\mathcal{L}(x) &:= \\frac{1}{2} (\\max(x, 0))^2\\\\\n\\nabla \\mathcal{L} &= \\max(x, 0) = \\mathrm{relu}(x)\\\\\n\\end{align*}\n$$\n\nWe need to define our activation layer in terms of the *Lagrangian* of\nthe ReLU instead of the ReLU itself. Extending this constraint to other\nnonlinearities makes it possible to define the scalar energy for any\nneuron in a HAM. It turns out that many activation functions used in\ntoday\u2019s Deep Learning landscape are expressible as a Lagrangian. HAMUX\nis \u201cbatteries-included\u201d for many common activation functions including\n`relu`s, `softmax`es, `sigmoid`s, `LayerNorm`s, etc. See our\n[documentation on\nLagrangians](https://bhoov.github.io/hamux/lagrangians.html) for\nexamples on how to implement efficient activation functions from\nLagrangians in JAX. We show how to turn Lagrangians into usable energy\nbuilding blocks in our [documentation on neuron\nlayers](https://bhoov.github.io/hamux/layers.html).\n\n### \ud83e\udd1dSynapses\n\nA \ud83e\udd1dsynapse ONLY sees activations of connected \ud83c\udf00neuron layers. Its one\njob: report HIGH \u26a1\ufe0fenergy if the connected activations are dissimilar\nand LOW \u26a1\ufe0fenergy when they are aligned. Synapses can resemble\nconvolutions, dense multiplications, even attention\u2026 Take a look at our\n[documentation on\nsynapses](https://bhoov.github.io/hamux/synapses.html).\n\n<div class=\"alert alert-info\">\n \ud83d\udea8 <strong>Point of confusion</strong>: modern AI frameworks have <code>AttentionLayer</code>s and <code>ConvolutionalLayer</code>s. In HAMUX, these would be more appropriately called <code>AttentionSynapse</code>s and <code>ConvolutionalSynapse</code>s.\n</div>\n\n## Install\n\n**From pip**:\n\n pip install hamux\n\nIf you are using accelerators beyond the CPU you will need to\nadditionally install the corresponding `jax` and `jaxlib` versions\nfollowing [their\ndocumentation](https://github.com/google/jax#installation). E.g.,\n\n pip install --upgrade \"jax[cuda]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n\n**From source**:\n\nAfter cloning:\n\n cd hamux\n conda env create -f environment.yml\n conda activate hamux\n pip install --upgrade \"jax[cuda]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # If using GPU accelerator\n pip install -e .\n pip install -r requirements-dev.txt # To run the examples\n\n## How to Use\n\n``` python\nimport hamux as hmx\nimport jax.numpy as jnp\nimport jax\nimport jax.tree_util as jtu\n```\n\nWe can build a simple 4 layer HAM architecture using the following code\n\n``` python\nlayers = [\n hmx.TanhLayer((32,32,3)), # e.g., CIFAR Images\n hmx.SigmoidLayer((11,11,1000)), # CIFAR patches\n hmx.SoftmaxLayer((10,)), # CIFAR Labels\n hmx.SoftmaxLayer((1000,)), # Hidden Memory Layer\n]\n\nsynapses = [\n hmx.ConvSynapse((3,3), strides=3),\n hmx.DenseSynapse(),\n hmx.DenseSynapse(),\n]\n\nconnections = [\n ([0,1], 0),\n ([1,3], 1),\n ([2,3], 2),\n]\n\nrng = jax.random.PRNGKey(0)\nparam_key, state_key, rng = jax.random.split(rng, 3)\nstates, ham = hmx.HAM(layers, synapses, connections).init_states_and_params(param_key, state_key=state_key);\n```\n\nNotice that we did not specify any output channel shapes in the\nsynapses. The desired output shape is computed from the layers connected\nto each synapse during `hmx.HAM.init_states_and_params`.\n\nWe have two fundamental objects: `states` and `ham`. The `ham` object\ncontains the connectivity structure of the HAM (e.g.,\nlayer+synapse+hypergraph information) alongside the **parameters** of\nthe network. The `states` object is a list of length `nlayers` where\neach item is a tensor representing the neuron states of the\ncorresponding layer.\n\n``` python\nassert len(states) == ham.n_layers\nassert all([state.shape == layer.shape for state, layer in zip(states, ham.layers)])\n```\n\nWe make it easy to run the dynamics of any HAM. Every `forward` function\nis defined external to the memory and can be modified to extract\ndifferent memories from different layers, as desired. The general steps\nfor any forward function are:\n\n1. Initialize the dynamic states\n2. Inject an initial state into the system\n3. Run dynamics, calculating energy gradient at every point in time.\n4. Return the layer state/activation of interest\n\n``` python\ndef fwd(model, x, depth=15, dt=0.1):\n \"\"\"Assuming a trained HAM, run association with the HAM on batched inputs `x`\"\"\"\n # 1. Initialize model states at t=0. Account for batch size\n xs = model.init_states(x.shape[0])\n \n # Inject initial state\n xs[0] = x \n\n energies = []\n for i in range(depth):\n energies.append(model.venergy(xs)) # If desired, observe the energy\n dEdg = model.vdEdg(xs) # Calculate the gradients\n xs = jtu.tree_map(lambda x, stepsize, grad: x - stepsize * grad, xs, model.alphas(dt), dEdg)\n\n \n # Return probabilities of our label layer\n probs = model.layers[-2].activation(xs[-2])\n return jnp.stack(energies), probs\n```\n\n``` python\nbatch_size=3\nx = jax.random.normal(jax.random.PRNGKey(2), (batch_size, 32,32,3))\nenergies, probs = fwd(ham, x, depth=20, dt=0.3)\nprint(probs.shape) # batchsize, nclasses\nassert jnp.allclose(probs.sum(-1), 1)\n```\n\n (3, 10)\n\n![](index_files/figure-gfm/cell-11-output-1.png)\n\n## The Energy Function vs the Loss Function\n\nWe use JAX\u2019s autograd to descend the energy function of our system AND\nthe loss function of our task. The derivative of the energy is always\ntaken wrt to our *states*; the derivative of the loss function is always\ntaken wrt our *parameters*. During training, we change our parameters to\noptimize the *Loss Function*. During inference, we assume that\nparameters are constant.\n\n**Autograd for Descending Energy**\n\nEvery [`HAM`](https://bhoov.github.io/hamux/ham.html#ham) defines the\nenergy function for our system, which is everything we need to compute\nmemories of the system. Naively, we can calculate $\\nabla_x E$: the\nderivative of the energy function wrt the *states* of each layer:\n\n``` python\nstepsize = 0.01\nfscore_naive = jax.grad(ham.energy)\nnext_states = jax.tree_util.tree_map(lambda state, score: state - stepsize, states, fscore_naive(states))\n```\n\nBut it turns out we improve the efficiency of our network if we instead\ntake $\\nabla_g E$: the derivative of the energy wrt. the *activations*\ninstead of the *states*. They have the same local minima, even though\nthe trajectory to get there is different. Some nice terms cancel, and we\nget:\n\n$$\\nabla_g E_\\text{HAM} = x + \\nabla_g E_\\text{synapse}$$\n\n``` python\nstepsize = 0.01\ndef fscore_smart(xs):\n gs = ham.activations(xs)\n return jax.tree_util.tree_map(lambda x, nabla_g_Esyn: x + nabla_g_Esyn, xs, jax.grad(ham.synapse_energy)(gs))\n\nnext_states = jax.tree_util.tree_map(lambda state, score: state - stepsize, states, fscore_smart(states))\n```\n\n## Citation\n\nWork is a collaboration between the [MIT-IBM Watson AI\nLab](https://mitibmwatsonailab.mit.edu/) and the\n[PoloClub](https://poloclub.github.io/) @ GA Tech\n",
"bugtrack_url": null,
"license": "Apache Software License 2.0",
"summary": "A Deep Learning framework built around ENERGY",
"version": "0.1.1",
"split_keywords": [
"nbdev",
"jupyter",
"notebook",
"python",
"hopfield",
"ai",
"hypergraph",
"neurons",
"synapses",
"associative",
"memory",
"hierarchical",
"framework",
"jax",
"treex",
"energy",
"autograd",
"lagrangian"
],
"urls": [
{
"comment_text": "",
"digests": {
"md5": "6944c19dc6bd7702e41081309a3fdf8e",
"sha256": "263b317af1f5e78e4d126033a1c2e45882fc539bbbc380d9bbf5d8a005b6c0bb"
},
"downloads": -1,
"filename": "hamux-0.1.1-py3-none-any.whl",
"has_sig": false,
"md5_digest": "6944c19dc6bd7702e41081309a3fdf8e",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.7",
"size": 28475,
"upload_time": "2022-12-02T06:29:17",
"upload_time_iso_8601": "2022-12-02T06:29:17.881155Z",
"url": "https://files.pythonhosted.org/packages/00/d7/b8e2328d98f81e8b26d1a9926d181d85ff6506c7df91ae19c862c8c9aae3/hamux-0.1.1-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"md5": "7330bf20cdb6316d77d07ae790eae9b1",
"sha256": "73ad6a0f209de2e65ba0ada9b3b469314e4bf5f18d45138b64b03275974dc530"
},
"downloads": -1,
"filename": "hamux-0.1.1.tar.gz",
"has_sig": false,
"md5_digest": "7330bf20cdb6316d77d07ae790eae9b1",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.7",
"size": 30900,
"upload_time": "2022-12-02T06:29:19",
"upload_time_iso_8601": "2022-12-02T06:29:19.922024Z",
"url": "https://files.pythonhosted.org/packages/a6/5b/ee1fa1533e4689167d758df1f38dc25ed2c4d6f6e997160182415cf97583/hamux-0.1.1.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2022-12-02 06:29:19",
"github": true,
"gitlab": false,
"bitbucket": false,
"github_user": "bhoov",
"github_project": "hamux",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "hamux"
}