# Auto_HookPoint
Auto_HookPoint is a Python library that seamlessly integrates arbitrary PyTorch models with transformer_lens. It provides an `auto_hook` function that automatically wraps your PyTorch model, applying HookPoints to every `nn.Module` and select `nn.Parameter` instances within the model structure.
## Features
- Works with both `nn.Module` and `nn.Parameter` operations
- Can be used both as a class decorator or on an already instantiated model
- Makes code cleaner
## Installation
```bash
pip install Auto_HookPoint
```
## Usage
### Usage as decorator
```python
from Auto_HookPoint import auto_hook
import torch.nn as nn
@auto_hook
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 10)
#self.fc1_hook_point = HookPoint() NOW NOT NEEDED
def forward(self, x):
# self.fc1_hook_point(self.fc1(x)) NOW NOT NEEDED
return self.fc1(x)
model = MyModel()
print(model.hook_dict.items()) # dict_items([('hook_point', HookPoint()), ('fc1.hook_point', HookPoint())])
orig_model = model.unwrap() #get back the original model
```
### Wrap an instance
Auto_HookPoint can also work with models that use `nn.Parameter`, such as this AutoEncoder example:
```python
from Auto_HookPoint import auto_hook
import torch
from torch import nn
# taken from neel nandas excellent autoencoder tutorial: https://colab.research.google.com/drive/1u8larhpxy8w4mMsJiSBddNOzFGj7_RTn#scrollTo=MYrIYDEfBtbL
class AutoEncoder(nn.Module):
def __init__(self, cfg):
super().__init__()
d_hidden = cfg["d_mlp"] * cfg["dict_mult"]
d_mlp = cfg["d_mlp"]
dtype = torch.float32
torch.manual_seed(cfg["seed"])
self.W_enc = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(d_mlp, d_hidden, dtype=dtype)))
self.W_dec = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(d_hidden, d_mlp, dtype=dtype)))
self.b_enc = nn.Parameter(
torch.zeros(d_hidden, dtype=dtype)
)
self.b_dec = nn.Parameter(
torch.zeros(d_mlp, dtype=dtype)
)
def forward(self, x):
x_cent = x - self.b_dec
acts = torch.relu(x_cent @ self.W_enc + self.b_enc)
x_reconstruct = acts @ self.W_dec + self.b_dec
return x_reconstruct
autoencoder = auto_hook(AutoEncoder({"d_mlp": 10, "dict_mult": 10, "l1_coeff": 10, "seed": 1}))
print(autoencoder.hook_dict.items())
# dict_items([('hook_point', HookPoint()), ('W_enc.hook_point', HookPoint()), ('W_dec.hook_point', HookPoint()), ('b_enc.hook_point', HookPoint()), ('b_dec.hook_point', HookPoint())])
input_kwargs = {'x': torch.randn(10, 10)}
def hook_fn(x, hook=None, hook_name=None):
print('hello from hook:', hook_name)
return x
autoencoder.run_with_hooks(
**input_kwargs,
fwd_hooks=[
(hook_name, partial(hook_fn, hook_name=hook_name))
for hook_name in autoencoder.hook_dict.keys()
]
)
#if you want full typing support after hooking your model
# a hacky solution would be:
class Model(HookedRootModule, AutoEncoder):
pass
autoencoder = cast(Model, autoencoder)
# autoencoder.forward() is now type hinted in vscode
```
## auto_hook + huggingface transformers
auto_hook can also work with hf-models
```python
from Auto_HookPoint import auto_hook, check_auto_hook
from transformers.models.mamba.modeling_mamba import MambaForCausalLM, MambaConfig
import torch
model = MambaForCausalLM(mamba_cfg)
model = auto_hook(model)
print('model.mod_dict', model.hook_d_dict.keys())
```
## auto_hook + manual hookpointing
As auto_hook will not hook arbitrary tensor manipulation functions, sometimes manual hooking will be necessary. for instance if using torch.relu() instead of nn.Relu(). Luckily auto_hook does not modify the existing hooks, so you can still use them.
```python
@auto_hook
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
self.relu_hook_point = HookPoint()
def forward(self, x):
x = self.linear(x)
x = self.relu_hook_point(torch.relu(x))
return x
```
## Train SAE“s with sae_lens
with auto_hook you can train a SparseAutoEncoder on any huggingface transformers model via sae_lens
```python
#most of the credit for this example goes to https://gist.github.com/joelburget
#check https://github.com/HP2706/Auto_HookPoint/blob/main/examples/sae_lens.py for a complete example
from Auto_HookPoint import HookedTransformerAdapter
#install via: pip install sae_lens
from sae_lens import SAETrainingRunner, LanguageModelSAERunnerConfig
cfg = LanguageModelSAERunnerConfig(
model_name=model_name,
hook_name="model.norm.hook_point",
...
)
hooked_model = HookedTransformerAdapter(Cfg(device="cuda", n_ctx=512), hf_model_name= model_name)
sparse_autoencoder = SAETrainingRunner(cfg, override_model=hooked_model).run()
```
### Note on SAE-Lens Integration:
1. Not all hook_points are compatible as run_with_cache only works for hook_points that return pure tensors which most
hf transformers block does not do. This is a limitation that can only be removed by changin transformer_lens, sae_lens or both.
2. SAE-Lens expects activations with shape [batch, sequence_length, hidden_size].
Some hookpoints (e.g., MixtralSparseMoeBlock gate) may not work due to different shapes.
3. If your model has more than one nn.Embedding attribute specify which one is the input embedding via the `input_embedding_name` parameter in HookedTransformerAdapter.
Note that after the model is hooked the naming of the self.model.embed_tokens(nn.Embedding) attribute becomes self.model._module.model._module.embed_tokens._module.weight
4. auto_hook does not yet support premature stopping via stop_at_layer in the forward pass, which would make building the activation_store in sae_lens impractible for very large models.
## Note
To ensure comprehensive coverage and identify potential edge cases, the 'check_auto_hook' function is provided. This utility runs the model class through a suite of internal tests, helping to validate the auto-hooking process and catch any unexpected behaviors or unsupported scenarios.
Note however that these might not always be informative specifically the bwd_hook test function should generally be ignored.
```python
from Auto_HookPoint import check_auto_hook
hooked_model = auto_hook(model)
input_kwargs = {'x': torch.randn(10, 10)}
init_kwargs = {'cfg': {'d_mlp': 10, 'dict_mult': 10, 'l1_coeff': 10, 'seed': 1}}
check_auto_hook(AutoEncoder, input_kwargs, init_kwargs)
```
If strict is set to True, a runtime error will be raised if the tests fail; otherwise,
a warning will be issued.
## Note on Backward Hooks (bwd_hooks)
Some issues might occur when using backward hooks. As auto_hook hooks anything that is an instance of nn.Module, modules that return non-tensor objects will also be hooked. It is advised to only use backward hooks on hookpoints that returns tensors as output.
Raw data
{
"_id": null,
"home_page": "https://github.com/HP2706/Auto_HookPoint",
"name": "Auto-HookPoint",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.8",
"maintainer_email": null,
"keywords": null,
"author": "HP",
"author_email": "hprjdk@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/2d/fa/a52033305d1f7fec1a975cb65274a4588f763c30686fb44b5fce574556d6/auto_hookpoint-0.6.0.tar.gz",
"platform": null,
"description": "# Auto_HookPoint\n\nAuto_HookPoint is a Python library that seamlessly integrates arbitrary PyTorch models with transformer_lens. It provides an `auto_hook` function that automatically wraps your PyTorch model, applying HookPoints to every `nn.Module` and select `nn.Parameter` instances within the model structure. \n\n## Features\n\n- Works with both `nn.Module` and `nn.Parameter` operations\n- Can be used both as a class decorator or on an already instantiated model \n- Makes code cleaner\n\n## Installation\n\n```bash\npip install Auto_HookPoint\n```\n\n## Usage\n\n### Usage as decorator\n\n```python\nfrom Auto_HookPoint import auto_hook\nimport torch.nn as nn\n\n@auto_hook\nclass MyModel(nn.Module):\n def __init__(self):\n super().__init__()\n self.fc1 = nn.Linear(10, 10)\n #self.fc1_hook_point = HookPoint() NOW NOT NEEDED\n\n def forward(self, x):\n # self.fc1_hook_point(self.fc1(x)) NOW NOT NEEDED\n return self.fc1(x)\n\nmodel = MyModel()\nprint(model.hook_dict.items()) # dict_items([('hook_point', HookPoint()), ('fc1.hook_point', HookPoint())])\n\norig_model = model.unwrap() #get back the original model\n\n```\n\n### Wrap an instance\n\nAuto_HookPoint can also work with models that use `nn.Parameter`, such as this AutoEncoder example:\n\n```python\nfrom Auto_HookPoint import auto_hook\nimport torch\nfrom torch import nn\n\n# taken from neel nandas excellent autoencoder tutorial: https://colab.research.google.com/drive/1u8larhpxy8w4mMsJiSBddNOzFGj7_RTn#scrollTo=MYrIYDEfBtbL\nclass AutoEncoder(nn.Module):\n def __init__(self, cfg):\n super().__init__()\n d_hidden = cfg[\"d_mlp\"] * cfg[\"dict_mult\"]\n d_mlp = cfg[\"d_mlp\"]\n dtype = torch.float32\n torch.manual_seed(cfg[\"seed\"])\n self.W_enc = nn.Parameter(\n torch.nn.init.kaiming_uniform_(\n torch.empty(d_mlp, d_hidden, dtype=dtype)))\n self.W_dec = nn.Parameter(\n torch.nn.init.kaiming_uniform_(\n torch.empty(d_hidden, d_mlp, dtype=dtype)))\n self.b_enc = nn.Parameter(\n torch.zeros(d_hidden, dtype=dtype)\n )\n self.b_dec = nn.Parameter(\n torch.zeros(d_mlp, dtype=dtype)\n )\n\n def forward(self, x):\n x_cent = x - self.b_dec\n acts = torch.relu(x_cent @ self.W_enc + self.b_enc)\n x_reconstruct = acts @ self.W_dec + self.b_dec\n return x_reconstruct\n\nautoencoder = auto_hook(AutoEncoder({\"d_mlp\": 10, \"dict_mult\": 10, \"l1_coeff\": 10, \"seed\": 1}))\nprint(autoencoder.hook_dict.items())\n# dict_items([('hook_point', HookPoint()), ('W_enc.hook_point', HookPoint()), ('W_dec.hook_point', HookPoint()), ('b_enc.hook_point', HookPoint()), ('b_dec.hook_point', HookPoint())])\n\n\ninput_kwargs = {'x': torch.randn(10, 10)}\n\ndef hook_fn(x, hook=None, hook_name=None):\n print('hello from hook:', hook_name)\n return x\n\nautoencoder.run_with_hooks(\n **input_kwargs, \n fwd_hooks=[\n (hook_name, partial(hook_fn, hook_name=hook_name))\n for hook_name in autoencoder.hook_dict.keys()\n ]\n)\n\n#if you want full typing support after hooking your model\n# a hacky solution would be:\nclass Model(HookedRootModule, AutoEncoder):\n pass\n\nautoencoder = cast(Model, autoencoder)\n# autoencoder.forward() is now type hinted in vscode\n```\n\n## auto_hook + huggingface transformers\n\nauto_hook can also work with hf-models\n\n```python\nfrom Auto_HookPoint import auto_hook, check_auto_hook\nfrom transformers.models.mamba.modeling_mamba import MambaForCausalLM, MambaConfig\nimport torch\n\nmodel = MambaForCausalLM(mamba_cfg)\nmodel = auto_hook(model)\nprint('model.mod_dict', model.hook_d_dict.keys()) \n```\n\n## auto_hook + manual hookpointing\n\nAs auto_hook will not hook arbitrary tensor manipulation functions, sometimes manual hooking will be necessary. for instance if using torch.relu() instead of nn.Relu(). Luckily auto_hook does not modify the existing hooks, so you can still use them. \n\n```python\n@auto_hook\nclass TestModel(nn.Module):\n def __init__(self):\n super().__init__()\n self.linear = nn.Linear(10, 10)\n self.relu_hook_point = HookPoint()\n def forward(self, x):\n x = self.linear(x)\n x = self.relu_hook_point(torch.relu(x))\n return x\n```\n\n## Train SAE\u00b4s with sae_lens\n\nwith auto_hook you can train a SparseAutoEncoder on any huggingface transformers model via sae_lens\n\n\n```python\n#most of the credit for this example goes to https://gist.github.com/joelburget\n#check https://github.com/HP2706/Auto_HookPoint/blob/main/examples/sae_lens.py for a complete example\nfrom Auto_HookPoint import HookedTransformerAdapter \n#install via: pip install sae_lens\nfrom sae_lens import SAETrainingRunner, LanguageModelSAERunnerConfig\n\ncfg = LanguageModelSAERunnerConfig(\n model_name=model_name,\n hook_name=\"model.norm.hook_point\",\n ...\n)\n\nhooked_model = HookedTransformerAdapter(Cfg(device=\"cuda\", n_ctx=512), hf_model_name= model_name)\nsparse_autoencoder = SAETrainingRunner(cfg, override_model=hooked_model).run()\n```\n### Note on SAE-Lens Integration:\n1. Not all hook_points are compatible as run_with_cache only works for hook_points that return pure tensors which most \nhf transformers block does not do. This is a limitation that can only be removed by changin transformer_lens, sae_lens or both.\n2. SAE-Lens expects activations with shape [batch, sequence_length, hidden_size].\n Some hookpoints (e.g., MixtralSparseMoeBlock gate) may not work due to different shapes.\n3. If your model has more than one nn.Embedding attribute specify which one is the input embedding via the `input_embedding_name` parameter in HookedTransformerAdapter. \nNote that after the model is hooked the naming of the self.model.embed_tokens(nn.Embedding) attribute becomes self.model._module.model._module.embed_tokens._module.weight\n4. auto_hook does not yet support premature stopping via stop_at_layer in the forward pass, which would make building the activation_store in sae_lens impractible for very large models.\n\n## Note \n\nTo ensure comprehensive coverage and identify potential edge cases, the 'check_auto_hook' function is provided. This utility runs the model class through a suite of internal tests, helping to validate the auto-hooking process and catch any unexpected behaviors or unsupported scenarios.\n\nNote however that these might not always be informative specifically the bwd_hook test function should generally be ignored.\n\n```python\nfrom Auto_HookPoint import check_auto_hook\nhooked_model = auto_hook(model)\ninput_kwargs = {'x': torch.randn(10, 10)}\ninit_kwargs = {'cfg': {'d_mlp': 10, 'dict_mult': 10, 'l1_coeff': 10, 'seed': 1}}\ncheck_auto_hook(AutoEncoder, input_kwargs, init_kwargs)\n```\n\nIf strict is set to True, a runtime error will be raised if the tests fail; otherwise, \na warning will be issued. \n\n## Note on Backward Hooks (bwd_hooks)\nSome issues might occur when using backward hooks. As auto_hook hooks anything that is an instance of nn.Module, modules that return non-tensor objects will also be hooked. It is advised to only use backward hooks on hookpoints that returns tensors as output.\n\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "Make any model compatible with transformer_lens",
"version": "0.6.0",
"project_urls": {
"Homepage": "https://github.com/HP2706/Auto_HookPoint",
"Repository": "https://github.com/HP2706/Auto_HookPoint"
},
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "556d27483fffe649609e4799bf57ce868a2f138639bba2350f80830370913541",
"md5": "9e0d22c080f430146ba3e90c196ec084",
"sha256": "d9e384eddd46c517c81fbb9aa043366b54a27217f15ad541d516eedfca31684c"
},
"downloads": -1,
"filename": "auto_hookpoint-0.6.0-py3-none-any.whl",
"has_sig": false,
"md5_digest": "9e0d22c080f430146ba3e90c196ec084",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.8",
"size": 24612,
"upload_time": "2024-08-26T10:42:59",
"upload_time_iso_8601": "2024-08-26T10:42:59.852131Z",
"url": "https://files.pythonhosted.org/packages/55/6d/27483fffe649609e4799bf57ce868a2f138639bba2350f80830370913541/auto_hookpoint-0.6.0-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "2dfaa52033305d1f7fec1a975cb65274a4588f763c30686fb44b5fce574556d6",
"md5": "9ca33acbd1995c28c7e8731c9155877b",
"sha256": "c7bba24a296f7f818544acb1bf8d386a209a12173508afc5997e5c92aa06b5f3"
},
"downloads": -1,
"filename": "auto_hookpoint-0.6.0.tar.gz",
"has_sig": false,
"md5_digest": "9ca33acbd1995c28c7e8731c9155877b",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.8",
"size": 21938,
"upload_time": "2024-08-26T10:43:01",
"upload_time_iso_8601": "2024-08-26T10:43:01.910692Z",
"url": "https://files.pythonhosted.org/packages/2d/fa/a52033305d1f7fec1a975cb65274a4588f763c30686fb44b5fce574556d6/auto_hookpoint-0.6.0.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-08-26 10:43:01",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "HP2706",
"github_project": "Auto_HookPoint",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "auto-hookpoint"
}