graphpatch


Namegraphpatch JSON
Version 0.2.3 PyPI version JSON
download
home_pagehttps://www.graphpatch.dev
Summarygraphpatch is a library for activation patching on PyTorch neural network models.
upload_time2024-10-19 01:30:59
maintainerNone
docs_urlNone
authorEvan Lloyd
requires_python<3.13,>=3.8.1
licenseMIT
keywords mechanistic interpretability interpretability pytorch torch activation patch ablation transformer large language model llm
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # graphpatch 0.2.3

Documentation is hosted on [Read the Docs](https://graphpatch.readthedocs.io/en/stable).

## Overview

`graphpatch` is a library for [activation patching](https://graphpatch.readthedocs.io/en/stable/what_is_activation_patching.html#what-is-activation-patching) (often
also referred to as “ablation”) on [PyTorch](https://pytorch.org/docs/stable/index.html) neural network models. You use
it by first wrapping your model in a [`PatchableGraph`](https://graphpatch.readthedocs.io/en/stable/patchable_graph.html#graphpatch.PatchableGraph) and then running operations in a context
created by [`PatchableGraph.patch()`](https://graphpatch.readthedocs.io/en/stable/patchable_graph.html#graphpatch.PatchableGraph.patch):

```python
pg = PatchableGraph(model, **inputs, use_cache=False)
# Applies patches to the multiplication result within the activation function of the
# MLP in the 18th transformer layer. ProbePatch records the last observed value at the
# given node, while ZeroPatch zeroes out the value seen by downstream computations.
with pg.patch("transformer.h_17.mlp.act.mul_3": [probe := ProbePatch(), ZeroPatch()]):
   output = pg(**inputs)
# Patches are applied in order. probe.activation holds the value prior
# to ZeroPatch zeroing it out.
print(probe.activation)
```

In contrast to [other approaches](#related-work), `graphpatch` can patch (or record) any
intermediate tensor value without manual modification of the underlying model’s code. See [Working with graphpatch](https://graphpatch.readthedocs.io/en/stable/working_with_graphpatch.html#working-with-graphpatch) for
some tips on how to use the generated graphs.

Note that `graphpatch` activation patches are compatible with [AutoGrad](https://pytorch.org/docs/stable/autograd.html)!
This means that, for example, you can perform optimizations over the `value` parameter to
[`AddPatch`](https://graphpatch.readthedocs.io/en/stable/patch.html#graphpatch.patch.AddPatch):

```python
delta = torch.zeros(size, requires_grad=True, device="cuda")
optimizer = torch.optim.Adam([delta], lr=0.5)
for _ in range(num_steps):
   with graph.patch({node_name: AddPatch(value=delta)):
      logits = graph(**prompt_inputs)
   loss = my_loss_function(logits)
   loss.backward()
   optimizer.step()
```

For a practical usage example, see the [demo](https://github.com/evan-lloyd/graphpatch/tree/main/demos/ROME) of using `graphpatch` to replicate [ROME](https://rome.baulab.info/).

## Prerequisites

The only mandatory requirements are `torch>=2` and `numpy>=1.17`. Version 2+ of `torch` is required
because `graphpatch` leverages [`torch.compile()`](https://pytorch.org/docs/stable/generated/torch.compile.html#torch.compile), which was introduced in `2.0.0`, to extract computational graphs from models.
CUDA support is not required. `numpy` is required for full `compile()` support.

Python 3.8–3.12 are supported. Note that `torch` versions prior to `2.1.0` do not support compilation
on Python 3.11, and versions prior to `2.4.0` do not support compilation on Python 3.12;
you will get an exception when trying to use `graphpatch` with such a configuration. No version of
`torch` yet supports compilation on Python 3.13.

## Installation

`graphpatch` is available on PyPI, and can be installed via `pip`:

```default
pip install graphpatch
```

Note that you will likely want to do this in an environment that already has `torch`, since `pip` may not resolve
`torch` to a CUDA-enabled version by default. You don’t need to do anything special to make `graphpatch` compatible
with `transformers`, `accelerate`, and `bitsandbytes`; their presence is detected at run-time. However, for convenience,
you can install `graphpatch` with the “transformers” extra, which will install known compatible versions of these libraries along
with some of their optional dependencies that are otherwise mildly inconvenient to set up:

```default
pip install graphpatch[transformers]
```

## Model compatibility

For full functionality, `graphpatch` depends on being able to call [`torch.compile()`](https://pytorch.org/docs/stable/generated/torch.compile.html#torch.compile) on your
model. This currently supports a subset of possible Python operations–for example, it doesn’t support
context managers. `graphpatch` implements some workarounds for situations that a native
`compile()` can’t handle, but this coverage isn’t complete. To deal with this, `graphpatch`
has a graceful fallback that should be no worse of a user experience than using module hooks.
In that case, you will only be able to patch an uncompilable submodule’s inputs, outputs,
parameters, and buffers. See [Notes on compilation](https://graphpatch.readthedocs.io/en/stable/notes_on_compilation.html#notes-on-compilation) for more discussion.

## `transformers` integration

`graphpatch` is theoretically compatible with any model in Huggingface’s [transformers](https://huggingface.co/docs/transformers/main/en/index)
library, but note that there may be edge cases in specific model code that it can’t yet handle. For
example, it is not (yet!) compatible with the key-value caching implementation, so if you want full
compilation of such models you should pass `use_cache=False` as part of the example inputs.

`graphpatch` is compatible with models loaded via [accelerate](https://huggingface.co/docs/accelerate/main/en/index) and with 8-bit parameters
quantized by [bitsandbytes](https://pypi.org/project/bitsandbytes/). This means that you can run `graphpatch` on
multiple GPU’s and/or with quantized inference very easily on models provided by `transformers`:

```python
model = LlamaForCausalLM.from_pretrained(
   model_path,
   device_map="auto",
   quantization_config=BitsAndBytesConfig(load_in_8bit=True),
   torch_dtype=torch.float16,
)
pg = PatchableGraph(model, **example_inputs, use_cache=False)
```

For `transformers` models supporting the [`GenerationMixin`](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin) protocol, you will
also be able to use convenience functions like [`generate()`](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate) in
combination with activation patching:

```python
# Prevent Llama from outputting "Paris"
with pg.patch({"lm_head.output": ZeroPatch(slice=(slice(None), slice(None), 3681))}):
   output_tokens = pg.generate(**inputs, max_length=20, use_cache=False)
```

### Version compatibility

`graphpatch` should be compatible with all versions of optional libraries matching the minimum
version requirements, but this is a highly ambitious claim to make for a Python library. If you end
up with errors that seem related to `graphpatch`’s integration with these libraries, you might try
changing their versions to those listed below. This list was automatically generated as part of the
`graphpatch` release process. It reflects the versions used while testing `graphpatch 0.2.3`:

```default
accelerate==1.0.0
bitsandbytes==0.44.1
numpy==1.24.4 (Python 3.8)
numpy==2.0.2 (Python 3.9)
numpy==2.1.1 (later Python versions)
sentencepiece==0.2.0
transformer-lens==2.4.1
transformers==4.45.2
```

<a id="related-work"></a>

## Alternatives

[`Module hooks`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook) are built in to `torch` and can be used for activation
patching. You can even add them to existing models without modifying their code. However, this will only give you
access to module inputs and outputs; accessing or patching intermediate values still requires a manual rewrite.

[TransformerLens](https://transformerlensorg.github.io/TransformerLens/index.html) provides the
[`HookPoint`](https://transformerlensorg.github.io/TransformerLens/generated/code/transformer_lens.hook_points.html#transformer_lens.hook_points.HookPoint) class, which can record and patch intermediate
activations. However, this requires manually rewriting your model’s code to wrap the values you want to make
patchable.

[TorchLens](https://github.com/johnmarktaylor91/torchlens) records and outputs visualizations for every intermediate
activation. However, it is currently unable to perform any activation patching.

[nnsight](https://github.com/ndif-team/nnsight) offers a nice activation patching API, but is limited to
module inputs and outputs.

[pyvene](https://github.com/stanfordnlp/pyvene) offers fine-grained control over activation patches (for example, down to
a specific attention head), and a description language/serialization format to allow specification of reproducible
experiments.

## Documentation index

* [API](https://graphpatch.readthedocs.io/en/stable/api.html)
  * [ExtractionOptions](https://graphpatch.readthedocs.io/en/stable/extraction_options.html)
  * [Patch](https://graphpatch.readthedocs.io/en/stable/patch.html)
  * [PatchableGraph](https://graphpatch.readthedocs.io/en/stable/patchable_graph.html)
* [Data structures](https://graphpatch.readthedocs.io/en/stable/data_structures.html)
  * [CompiledGraphModule](https://graphpatch.readthedocs.io/en/stable/compiled_graph_module.html)
  * [MultiplyInvokedModule](https://graphpatch.readthedocs.io/en/stable/multiply_invoked_module.html)
  * [NodePath](https://graphpatch.readthedocs.io/en/stable/node_path.html)
  * [OpaqueGraphModule](https://graphpatch.readthedocs.io/en/stable/opaque_graph_module.html)
* [Notes on compilation](https://graphpatch.readthedocs.io/en/stable/notes_on_compilation.html)
* [What is activation patching?](https://graphpatch.readthedocs.io/en/stable/what_is_activation_patching.html)
* [Working with `graphpatch`](https://graphpatch.readthedocs.io/en/stable/working_with_graphpatch.html)

* [Full index](https://graphpatch.readthedocs.io/en/stable/genindex.html)

            

Raw data

            {
    "_id": null,
    "home_page": "https://www.graphpatch.dev",
    "name": "graphpatch",
    "maintainer": null,
    "docs_url": null,
    "requires_python": "<3.13,>=3.8.1",
    "maintainer_email": null,
    "keywords": "mechanistic interpretability, interpretability, pytorch, torch, activation patch, ablation, transformer, large language model, llm",
    "author": "Evan Lloyd",
    "author_email": "evan.t.lloyd@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/e5/69/a1e3e44ca9043441ed0c928b2a7e531daf245a1bf521dc855c7b36d5ecb8/graphpatch-0.2.3.tar.gz",
    "platform": null,
    "description": "# graphpatch 0.2.3\n\nDocumentation is hosted on [Read the Docs](https://graphpatch.readthedocs.io/en/stable).\n\n## Overview\n\n`graphpatch` is a library for [activation patching](https://graphpatch.readthedocs.io/en/stable/what_is_activation_patching.html#what-is-activation-patching) (often\nalso referred to as \u201cablation\u201d) on [PyTorch](https://pytorch.org/docs/stable/index.html) neural network models. You use\nit by first wrapping your model in a [`PatchableGraph`](https://graphpatch.readthedocs.io/en/stable/patchable_graph.html#graphpatch.PatchableGraph) and then running operations in a context\ncreated by [`PatchableGraph.patch()`](https://graphpatch.readthedocs.io/en/stable/patchable_graph.html#graphpatch.PatchableGraph.patch):\n\n```python\npg = PatchableGraph(model, **inputs, use_cache=False)\n# Applies patches to the multiplication result within the activation function of the\n# MLP in the 18th transformer layer. ProbePatch records the last observed value at the\n# given node, while ZeroPatch zeroes out the value seen by downstream computations.\nwith pg.patch(\"transformer.h_17.mlp.act.mul_3\": [probe := ProbePatch(), ZeroPatch()]):\n   output = pg(**inputs)\n# Patches are applied in order. probe.activation holds the value prior\n# to ZeroPatch zeroing it out.\nprint(probe.activation)\n```\n\nIn contrast to [other approaches](#related-work), `graphpatch` can patch (or record) any\nintermediate tensor value without manual modification of the underlying model\u2019s code. See [Working with graphpatch](https://graphpatch.readthedocs.io/en/stable/working_with_graphpatch.html#working-with-graphpatch) for\nsome tips on how to use the generated graphs.\n\nNote that `graphpatch` activation patches are compatible with [AutoGrad](https://pytorch.org/docs/stable/autograd.html)!\nThis means that, for example, you can perform optimizations over the `value` parameter to\n[`AddPatch`](https://graphpatch.readthedocs.io/en/stable/patch.html#graphpatch.patch.AddPatch):\n\n```python\ndelta = torch.zeros(size, requires_grad=True, device=\"cuda\")\noptimizer = torch.optim.Adam([delta], lr=0.5)\nfor _ in range(num_steps):\n   with graph.patch({node_name: AddPatch(value=delta)):\n      logits = graph(**prompt_inputs)\n   loss = my_loss_function(logits)\n   loss.backward()\n   optimizer.step()\n```\n\nFor a practical usage example, see the [demo](https://github.com/evan-lloyd/graphpatch/tree/main/demos/ROME) of using `graphpatch` to replicate [ROME](https://rome.baulab.info/).\n\n## Prerequisites\n\nThe only mandatory requirements are `torch>=2` and `numpy>=1.17`. Version 2+ of `torch` is required\nbecause `graphpatch` leverages [`torch.compile()`](https://pytorch.org/docs/stable/generated/torch.compile.html#torch.compile), which was introduced in `2.0.0`, to extract computational graphs from models.\nCUDA support is not required. `numpy` is required for full `compile()` support.\n\nPython 3.8\u20133.12 are supported. Note that `torch` versions prior to `2.1.0` do not support compilation\non Python 3.11, and versions prior to `2.4.0` do not support compilation on Python 3.12;\nyou will get an exception when trying to use `graphpatch` with such a configuration. No version of\n`torch` yet supports compilation on Python 3.13.\n\n## Installation\n\n`graphpatch` is available on PyPI, and can be installed via `pip`:\n\n```default\npip install graphpatch\n```\n\nNote that you will likely want to do this in an environment that already has `torch`, since `pip` may not resolve\n`torch` to a CUDA-enabled version by default. You don\u2019t need to do anything special to make `graphpatch` compatible\nwith `transformers`, `accelerate`, and `bitsandbytes`; their presence is detected at run-time. However, for convenience,\nyou can install `graphpatch` with the \u201ctransformers\u201d extra, which will install known compatible versions of these libraries along\nwith some of their optional dependencies that are otherwise mildly inconvenient to set up:\n\n```default\npip install graphpatch[transformers]\n```\n\n## Model compatibility\n\nFor full functionality, `graphpatch` depends on being able to call [`torch.compile()`](https://pytorch.org/docs/stable/generated/torch.compile.html#torch.compile) on your\nmodel. This currently supports a subset of possible Python operations\u2013for example, it doesn\u2019t support\ncontext managers. `graphpatch` implements some workarounds for situations that a native\n`compile()` can\u2019t handle, but this coverage isn\u2019t complete. To deal with this, `graphpatch`\nhas a graceful fallback that should be no worse of a user experience than using module hooks.\nIn that case, you will only be able to patch an uncompilable submodule\u2019s inputs, outputs,\nparameters, and buffers. See [Notes on compilation](https://graphpatch.readthedocs.io/en/stable/notes_on_compilation.html#notes-on-compilation) for more discussion.\n\n## `transformers` integration\n\n`graphpatch` is theoretically compatible with any model in Huggingface\u2019s [transformers](https://huggingface.co/docs/transformers/main/en/index)\nlibrary, but note that there may be edge cases in specific model code that it can\u2019t yet handle. For\nexample, it is not (yet!) compatible with the key-value caching implementation, so if you want full\ncompilation of such models you should pass `use_cache=False` as part of the example inputs.\n\n`graphpatch` is compatible with models loaded via [accelerate](https://huggingface.co/docs/accelerate/main/en/index) and with 8-bit parameters\nquantized by [bitsandbytes](https://pypi.org/project/bitsandbytes/). This means that you can run `graphpatch` on\nmultiple GPU\u2019s and/or with quantized inference very easily on models provided by `transformers`:\n\n```python\nmodel = LlamaForCausalLM.from_pretrained(\n   model_path,\n   device_map=\"auto\",\n   quantization_config=BitsAndBytesConfig(load_in_8bit=True),\n   torch_dtype=torch.float16,\n)\npg = PatchableGraph(model, **example_inputs, use_cache=False)\n```\n\nFor `transformers` models supporting the [`GenerationMixin`](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin) protocol, you will\nalso be able to use convenience functions like [`generate()`](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate) in\ncombination with activation patching:\n\n```python\n# Prevent Llama from outputting \"Paris\"\nwith pg.patch({\"lm_head.output\": ZeroPatch(slice=(slice(None), slice(None), 3681))}):\n   output_tokens = pg.generate(**inputs, max_length=20, use_cache=False)\n```\n\n### Version compatibility\n\n`graphpatch` should be compatible with all versions of optional libraries matching the minimum\nversion requirements, but this is a highly ambitious claim to make for a Python library. If you end\nup with errors that seem related to `graphpatch`\u2019s integration with these libraries, you might try\nchanging their versions to those listed below. This list was automatically generated as part of the\n`graphpatch` release process. It reflects the versions used while testing `graphpatch 0.2.3`:\n\n```default\naccelerate==1.0.0\nbitsandbytes==0.44.1\nnumpy==1.24.4 (Python 3.8)\nnumpy==2.0.2 (Python 3.9)\nnumpy==2.1.1 (later Python versions)\nsentencepiece==0.2.0\ntransformer-lens==2.4.1\ntransformers==4.45.2\n```\n\n<a id=\"related-work\"></a>\n\n## Alternatives\n\n[`Module hooks`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook) are built in to `torch` and can be used for activation\npatching. You can even add them to existing models without modifying their code. However, this will only give you\naccess to module inputs and outputs; accessing or patching intermediate values still requires a manual rewrite.\n\n[TransformerLens](https://transformerlensorg.github.io/TransformerLens/index.html) provides the\n[`HookPoint`](https://transformerlensorg.github.io/TransformerLens/generated/code/transformer_lens.hook_points.html#transformer_lens.hook_points.HookPoint) class, which can record and patch intermediate\nactivations. However, this requires manually rewriting your model\u2019s code to wrap the values you want to make\npatchable.\n\n[TorchLens](https://github.com/johnmarktaylor91/torchlens) records and outputs visualizations for every intermediate\nactivation. However, it is currently unable to perform any activation patching.\n\n[nnsight](https://github.com/ndif-team/nnsight) offers a nice activation patching API, but is limited to\nmodule inputs and outputs.\n\n[pyvene](https://github.com/stanfordnlp/pyvene) offers fine-grained control over activation patches (for example, down to\na specific attention head), and a description language/serialization format to allow specification of reproducible\nexperiments.\n\n## Documentation index\n\n* [API](https://graphpatch.readthedocs.io/en/stable/api.html)\n  * [ExtractionOptions](https://graphpatch.readthedocs.io/en/stable/extraction_options.html)\n  * [Patch](https://graphpatch.readthedocs.io/en/stable/patch.html)\n  * [PatchableGraph](https://graphpatch.readthedocs.io/en/stable/patchable_graph.html)\n* [Data structures](https://graphpatch.readthedocs.io/en/stable/data_structures.html)\n  * [CompiledGraphModule](https://graphpatch.readthedocs.io/en/stable/compiled_graph_module.html)\n  * [MultiplyInvokedModule](https://graphpatch.readthedocs.io/en/stable/multiply_invoked_module.html)\n  * [NodePath](https://graphpatch.readthedocs.io/en/stable/node_path.html)\n  * [OpaqueGraphModule](https://graphpatch.readthedocs.io/en/stable/opaque_graph_module.html)\n* [Notes on compilation](https://graphpatch.readthedocs.io/en/stable/notes_on_compilation.html)\n* [What is activation patching?](https://graphpatch.readthedocs.io/en/stable/what_is_activation_patching.html)\n* [Working with `graphpatch`](https://graphpatch.readthedocs.io/en/stable/working_with_graphpatch.html)\n\n* [Full index](https://graphpatch.readthedocs.io/en/stable/genindex.html)\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "graphpatch is a library for activation patching on PyTorch neural network models.",
    "version": "0.2.3",
    "project_urls": {
        "Documentation": "https://graphpatch.readthedocs.io/en/latest/index.html",
        "Homepage": "https://www.graphpatch.dev",
        "Repository": "https://github.com/evan-lloyd/graphpatch"
    },
    "split_keywords": [
        "mechanistic interpretability",
        " interpretability",
        " pytorch",
        " torch",
        " activation patch",
        " ablation",
        " transformer",
        " large language model",
        " llm"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "8956044cdd2b8d18872119e27fbc2d702926421818d1940a36c9bba4a711ead6",
                "md5": "e853fb475ef8b999c5cf2e69e3857957",
                "sha256": "a736f607121bc4b9348bc95126eda60d5583fcf097935bb9da95979d88c137b5"
            },
            "downloads": -1,
            "filename": "graphpatch-0.2.3-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "e853fb475ef8b999c5cf2e69e3857957",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": "<3.13,>=3.8.1",
            "size": 64792,
            "upload_time": "2024-10-19T01:30:57",
            "upload_time_iso_8601": "2024-10-19T01:30:57.875510Z",
            "url": "https://files.pythonhosted.org/packages/89/56/044cdd2b8d18872119e27fbc2d702926421818d1940a36c9bba4a711ead6/graphpatch-0.2.3-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "e569a1e3e44ca9043441ed0c928b2a7e531daf245a1bf521dc855c7b36d5ecb8",
                "md5": "1c994595717afd489c2f8c952e42d348",
                "sha256": "cc966043ed32ae0bd7d321438e8d19642667be14b4975f8d8842b3bbd7d4063b"
            },
            "downloads": -1,
            "filename": "graphpatch-0.2.3.tar.gz",
            "has_sig": false,
            "md5_digest": "1c994595717afd489c2f8c952e42d348",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": "<3.13,>=3.8.1",
            "size": 59146,
            "upload_time": "2024-10-19T01:30:59",
            "upload_time_iso_8601": "2024-10-19T01:30:59.232301Z",
            "url": "https://files.pythonhosted.org/packages/e5/69/a1e3e44ca9043441ed0c928b2a7e531daf245a1bf521dc855c7b36d5ecb8/graphpatch-0.2.3.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-10-19 01:30:59",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "evan-lloyd",
    "github_project": "graphpatch",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "tox": true,
    "lcname": "graphpatch"
}
        
Elapsed time: 1.08977s