causal-tracer


Namecausal-tracer JSON
Version 1.1.0 PyPI version JSON
download
home_pagehttps://github.com/chanind/causal-tracer
SummaryNone
upload_time2024-04-02 13:37:31
maintainerNone
docs_urlNone
authorDavid Chanin
requires_python<4.0,>=3.10
licenseNone
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage
            # Causal Tracer

[![ci](https://img.shields.io/github/actions/workflow/status/chanind/causal-tracer/ci.yaml?branch=main)](https://github.com/chanind/causal-tracer)
[![Codecov](https://img.shields.io/codecov/c/github/chanind/causal-tracer/main)](https://codecov.io/gh/chanind/causal-tracer)
[![PyPI](https://img.shields.io/pypi/v/causal-tracer?color=blue)](https://pypi.org/project/causal-tracer/)

Causal trace plots for transformer language models.

Demo:
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1rOA_r7Gv6bGjXNfUvrqk9Gt3dLwZNvGJ?usp=sharing)

![rome_knows_fact](https://github.com/chanind/causal-tracer/assets/200725/e621e179-ee87-48a7-9493-1a1ed422f036)

## About

This library generates causal trace plots for transformer language models like Llama and GPT2, and should work with any decoder-only model on Huggingface. This library is based on causal tracing code from [ROME](https://rome.baulab.info/), and broadly packages and improves on their excellent work. Thank you to these authors! There are some notable differences between the original ROME causal tracing code and this library, such as support for batch processing, automatic noise calculation, more processing options, and a slightly different API.

Causal tracing is a technique to find which activations at which layers are causally important for the model to generate any given output. The way this works is by scrambing subject tokens, then slowly replacing activations in the scrambled computation graph and observe if replacing an activation gets the model closer to its original answer.

For instance, if we prompt a languge model with "Rome is located in the country of", it will output "Italy". If we want to understand how the model generated that answer, we can scramble the tokens for "Rome" by adding gaussian noise so the model now sees gibberish instead, like "@#(\* is located in the country of". Of course, after this scrambling, there's no way for the model to output "Italy" since the subject is just noise. However, we can take this corrupted computation graph and start replacing activations in it with the original uncorrupted activations, and see if the model starts outputting "Italy" again. If it does, we know that activation is important to the computation!

For more info on causal tracing, check out the original ROME paper, [Locating and Editing Factual Associations in GPT](https://arxiv.org/pdf/2202.05262.pdf).

## Installation

```
pip install causal-tracer
```

## Basic usage

If you're generating causal traces for a Llama-based model or GPT2, you don't need any further configuration.

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from causal_tracer import CausalTracer, plot_hidden_flow_heatmap

model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
tracer = CausalTracer(model, tokenizer)

# perform causal tracing across hidden layers (residual stream) of the model
hidden_layer_flow = tracer.calculate_hidden_flow(
  "The Space Needle is located in the city of",
  subject="The Space Needle",
)
# plot the result
plot_hidden_flow_heatmap(hidden_layer_flow)
```

You can also generate causal traces of MLP layers or attention layers in the transformer.

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from causal_tracer import CausalTracer, plot_hidden_flow_heatmap

model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
tracer = CausalTracer(model, tokenizer)

# perform causal tracing across MLP layers of the model
mlp_layer_flow = tracer.calculate_hidden_flow(
  "The Space Needle is located in the city of",
  subject="The Space Needle",
  kind="mlp",
  window=10,
)
plot_hidden_flow_heatmap(mlp_layer_flow)

# perform causal tracing across MLP layers of the model
attn_layer_flow = tracer.calculate_hidden_flow(
  "The Space Needle is located in the city of",
  subject="The Space Needle",
  kind="attention",
  window=10,
)
plot_hidden_flow_heatmap(attn_layer_flow)
```

When generating MLP or attention causal traces, it's you should typically set a window size. In the ROME paper, this is set to 10, which means the MLP or attention traces are replaced as a group and their impact is averaged to make it easier to see the impact of smaller changes.

## Batching and sampling

By default, causal traces will be calculated by scrambling the subject tokens with 10 different noise samples, and will run in batches of size 32. You can improve the quality of the causal trace by increasing the number of samples higher. Also, if you run out of RAM during processing, you can try decreasing the batch size.

```python
hidden_layer_flow = tracer.calculate_hidden_flow(
  "The Space Needle is located in the city of",
  subject="The Space Needle",
  samples=50,
  batch_size=8,
)
```

## Limiting patching for performance

Running causal tracing can be slow as it requires a lot of passes through the model to generate a trace. You can get a speed-up by only calculating causal traces of certain layers, or only performing patching on subject tokens themselves. The results won't be complete if you do this, but depending on the use-case, that might be fine.

```python
hidden_layer_flow = tracer.calculate_hidden_flow(
  "The Space Needle is located in the city of",
  subject="The Space Needle",
  start_layer=10,
  end_layer=15,
  patch_subject_tokens_only=True,
)
```

## Custom layer configs

If you're using a model that isn't automatically detected by the library, you'll need to add a `LayerConfig` to tell CausalTracer where to findthe embeddings, attention, MLP, and hidden layers within the model. You can do this by creating a `LayerConfig` object and passing it in when creating a `CausalTracer` object.

```python
from causal_tracer import CausalTracer, LayerConfig

custom_layer_config = LayerConfig(
  hidden_layers_matcher="h.{num}",
  attention_layers_matcher="h.{num}.attn",
  mlp_layers_matcher="h.{num}.mlp",
  embedding_layer="wte",
)
tracer = CausalTracer(model, tokenizer, layer_config=custom_layer_config)
```

Note that `hidden_layers_matcher`, `attention_layers_matcher`, and `mlp_layers_matcher` are template strings, containg `{num}` in the middle. During processing, `{num}` will get replaced with the layer number. These strings correspond to the named modules of the transformer. You find all the named modules of a Pytorch model by running `model.named_modules()`.

## Using hidden flow results directly

If you want to use the results of the `tracer.calculate_hidden_flow()` method in downstream tasks instead of just making a plot, the returned `HiddenFlow` object contains a number of fields which can be further analyzed. The full `HiddenFlow` dataclass types are below:

```python
class HiddenFlow:
    scores: torch.Tensor
    low_score: float
    high_score: float
    input_ids: torch.Tensor
    input_tokens: list[str]
    subject_range: tuple[int, int]
    answer: str
    kind: LayerKind # one of "hidden", "attention", or "mlp"
    layer_outputs: OrderedDict[str, torch.Tensor]
```

Of particular interest, the `score` attribute contains the full matrix of causal tracing scores. The `layer_outputs` attribute contains the uncorrupted layer activations for each layer of the type being analyzed.

## Contributing

Contributions are welcome! If you submit code, please make sure to add or update tests coverage along with your change. This repo uses Black for code formatting, MyPy for type checking, and Flake8 for linting.


            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/chanind/causal-tracer",
    "name": "causal-tracer",
    "maintainer": null,
    "docs_url": null,
    "requires_python": "<4.0,>=3.10",
    "maintainer_email": null,
    "keywords": null,
    "author": "David Chanin",
    "author_email": "chanindav@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/82/7e/2e41ce511628eb5a6b6ef7365422ca37f345da186c176ecf6e2d4a39dbda/causal_tracer-1.1.0.tar.gz",
    "platform": null,
    "description": "# Causal Tracer\n\n[![ci](https://img.shields.io/github/actions/workflow/status/chanind/causal-tracer/ci.yaml?branch=main)](https://github.com/chanind/causal-tracer)\n[![Codecov](https://img.shields.io/codecov/c/github/chanind/causal-tracer/main)](https://codecov.io/gh/chanind/causal-tracer)\n[![PyPI](https://img.shields.io/pypi/v/causal-tracer?color=blue)](https://pypi.org/project/causal-tracer/)\n\nCausal trace plots for transformer language models.\n\nDemo:\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1rOA_r7Gv6bGjXNfUvrqk9Gt3dLwZNvGJ?usp=sharing)\n\n![rome_knows_fact](https://github.com/chanind/causal-tracer/assets/200725/e621e179-ee87-48a7-9493-1a1ed422f036)\n\n## About\n\nThis library generates causal trace plots for transformer language models like Llama and GPT2, and should work with any decoder-only model on Huggingface. This library is based on causal tracing code from [ROME](https://rome.baulab.info/), and broadly packages and improves on their excellent work. Thank you to these authors! There are some notable differences between the original ROME causal tracing code and this library, such as support for batch processing, automatic noise calculation, more processing options, and a slightly different API.\n\nCausal tracing is a technique to find which activations at which layers are causally important for the model to generate any given output. The way this works is by scrambing subject tokens, then slowly replacing activations in the scrambled computation graph and observe if replacing an activation gets the model closer to its original answer.\n\nFor instance, if we prompt a languge model with \"Rome is located in the country of\", it will output \"Italy\". If we want to understand how the model generated that answer, we can scramble the tokens for \"Rome\" by adding gaussian noise so the model now sees gibberish instead, like \"@#(\\* is located in the country of\". Of course, after this scrambling, there's no way for the model to output \"Italy\" since the subject is just noise. However, we can take this corrupted computation graph and start replacing activations in it with the original uncorrupted activations, and see if the model starts outputting \"Italy\" again. If it does, we know that activation is important to the computation!\n\nFor more info on causal tracing, check out the original ROME paper, [Locating and Editing Factual Associations in GPT](https://arxiv.org/pdf/2202.05262.pdf).\n\n## Installation\n\n```\npip install causal-tracer\n```\n\n## Basic usage\n\nIf you're generating causal traces for a Llama-based model or GPT2, you don't need any further configuration.\n\n```python\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom causal_tracer import CausalTracer, plot_hidden_flow_heatmap\n\nmodel = AutoModelForCausalLM.from_pretrained(\"gpt2-medium\")\ntokenizer = AutoTokenizer.from_pretrained(\"gpt2-medium\")\ntracer = CausalTracer(model, tokenizer)\n\n# perform causal tracing across hidden layers (residual stream) of the model\nhidden_layer_flow = tracer.calculate_hidden_flow(\n  \"The Space Needle is located in the city of\",\n  subject=\"The Space Needle\",\n)\n# plot the result\nplot_hidden_flow_heatmap(hidden_layer_flow)\n```\n\nYou can also generate causal traces of MLP layers or attention layers in the transformer.\n\n```python\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom causal_tracer import CausalTracer, plot_hidden_flow_heatmap\n\nmodel = AutoModelForCausalLM.from_pretrained(\"gpt2-medium\")\ntokenizer = AutoTokenizer.from_pretrained(\"gpt2-medium\")\ntracer = CausalTracer(model, tokenizer)\n\n# perform causal tracing across MLP layers of the model\nmlp_layer_flow = tracer.calculate_hidden_flow(\n  \"The Space Needle is located in the city of\",\n  subject=\"The Space Needle\",\n  kind=\"mlp\",\n  window=10,\n)\nplot_hidden_flow_heatmap(mlp_layer_flow)\n\n# perform causal tracing across MLP layers of the model\nattn_layer_flow = tracer.calculate_hidden_flow(\n  \"The Space Needle is located in the city of\",\n  subject=\"The Space Needle\",\n  kind=\"attention\",\n  window=10,\n)\nplot_hidden_flow_heatmap(attn_layer_flow)\n```\n\nWhen generating MLP or attention causal traces, it's you should typically set a window size. In the ROME paper, this is set to 10, which means the MLP or attention traces are replaced as a group and their impact is averaged to make it easier to see the impact of smaller changes.\n\n## Batching and sampling\n\nBy default, causal traces will be calculated by scrambling the subject tokens with 10 different noise samples, and will run in batches of size 32. You can improve the quality of the causal trace by increasing the number of samples higher. Also, if you run out of RAM during processing, you can try decreasing the batch size.\n\n```python\nhidden_layer_flow = tracer.calculate_hidden_flow(\n  \"The Space Needle is located in the city of\",\n  subject=\"The Space Needle\",\n  samples=50,\n  batch_size=8,\n)\n```\n\n## Limiting patching for performance\n\nRunning causal tracing can be slow as it requires a lot of passes through the model to generate a trace. You can get a speed-up by only calculating causal traces of certain layers, or only performing patching on subject tokens themselves. The results won't be complete if you do this, but depending on the use-case, that might be fine.\n\n```python\nhidden_layer_flow = tracer.calculate_hidden_flow(\n  \"The Space Needle is located in the city of\",\n  subject=\"The Space Needle\",\n  start_layer=10,\n  end_layer=15,\n  patch_subject_tokens_only=True,\n)\n```\n\n## Custom layer configs\n\nIf you're using a model that isn't automatically detected by the library, you'll need to add a `LayerConfig` to tell CausalTracer where to findthe embeddings, attention, MLP, and hidden layers within the model. You can do this by creating a `LayerConfig` object and passing it in when creating a `CausalTracer` object.\n\n```python\nfrom causal_tracer import CausalTracer, LayerConfig\n\ncustom_layer_config = LayerConfig(\n  hidden_layers_matcher=\"h.{num}\",\n  attention_layers_matcher=\"h.{num}.attn\",\n  mlp_layers_matcher=\"h.{num}.mlp\",\n  embedding_layer=\"wte\",\n)\ntracer = CausalTracer(model, tokenizer, layer_config=custom_layer_config)\n```\n\nNote that `hidden_layers_matcher`, `attention_layers_matcher`, and `mlp_layers_matcher` are template strings, containg `{num}` in the middle. During processing, `{num}` will get replaced with the layer number. These strings correspond to the named modules of the transformer. You find all the named modules of a Pytorch model by running `model.named_modules()`.\n\n## Using hidden flow results directly\n\nIf you want to use the results of the `tracer.calculate_hidden_flow()` method in downstream tasks instead of just making a plot, the returned `HiddenFlow` object contains a number of fields which can be further analyzed. The full `HiddenFlow` dataclass types are below:\n\n```python\nclass HiddenFlow:\n    scores: torch.Tensor\n    low_score: float\n    high_score: float\n    input_ids: torch.Tensor\n    input_tokens: list[str]\n    subject_range: tuple[int, int]\n    answer: str\n    kind: LayerKind # one of \"hidden\", \"attention\", or \"mlp\"\n    layer_outputs: OrderedDict[str, torch.Tensor]\n```\n\nOf particular interest, the `score` attribute contains the full matrix of causal tracing scores. The `layer_outputs` attribute contains the uncorrupted layer activations for each layer of the type being analyzed.\n\n## Contributing\n\nContributions are welcome! If you submit code, please make sure to add or update tests coverage along with your change. This repo uses Black for code formatting, MyPy for type checking, and Flake8 for linting.\n\n",
    "bugtrack_url": null,
    "license": null,
    "summary": null,
    "version": "1.1.0",
    "project_urls": {
        "Homepage": "https://github.com/chanind/causal-tracer",
        "Repository": "https://github.com/chanind/causal-tracer"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "c017aa5d58f74185beaa3764b14bba75b61e9755f67e52486bd45988205e282b",
                "md5": "35e3a68a582346c65c4964029c894e5d",
                "sha256": "d1e1d12744a0ad9c1edf28636541ba2140f56eaaf7e00632f2e73d4ec3d519d8"
            },
            "downloads": -1,
            "filename": "causal_tracer-1.1.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "35e3a68a582346c65c4964029c894e5d",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": "<4.0,>=3.10",
            "size": 24374,
            "upload_time": "2024-04-02T13:37:29",
            "upload_time_iso_8601": "2024-04-02T13:37:29.294988Z",
            "url": "https://files.pythonhosted.org/packages/c0/17/aa5d58f74185beaa3764b14bba75b61e9755f67e52486bd45988205e282b/causal_tracer-1.1.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "827e2e41ce511628eb5a6b6ef7365422ca37f345da186c176ecf6e2d4a39dbda",
                "md5": "15bfe66c4318ca116587b1c0e1e68542",
                "sha256": "6ba553500c99481a86ce252e2fc4f7ce39b5e11d77a452785c200f15ffe6ae57"
            },
            "downloads": -1,
            "filename": "causal_tracer-1.1.0.tar.gz",
            "has_sig": false,
            "md5_digest": "15bfe66c4318ca116587b1c0e1e68542",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": "<4.0,>=3.10",
            "size": 20878,
            "upload_time": "2024-04-02T13:37:31",
            "upload_time_iso_8601": "2024-04-02T13:37:31.320240Z",
            "url": "https://files.pythonhosted.org/packages/82/7e/2e41ce511628eb5a6b6ef7365422ca37f345da186c176ecf6e2d4a39dbda/causal_tracer-1.1.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-04-02 13:37:31",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "chanind",
    "github_project": "causal-tracer",
    "travis_ci": false,
    "coveralls": true,
    "github_actions": true,
    "lcname": "causal-tracer"
}
        
Elapsed time: 0.42485s