Skip to content

enable batched intervention#605

Open
rfl-urbaniak wants to merge 11 commits into
masterfrom
ru-batched-intervention
Open

enable batched intervention#605
rfl-urbaniak wants to merge 11 commits into
masterfrom
ru-batched-intervention

Conversation

@rfl-urbaniak

@rfl-urbaniak rfl-urbaniak commented Jun 30, 2026

Copy link
Copy Markdown
Collaborator

1. Motivation

In search for explanations (and many other causal workflows), we need to
evaluate large batches of interventions. These interventions are
heterogeneous: different elements of the batch may intervene on different
variables, with different values.

ChiRho already provides two counterfactual handlers living at two edges of a spectrum:

  • TwinWorldCounterfactual (TWC): factual world plus one alternative
    world. All splits reuse a single shared dimension name, so the world axis
    has size 2. Cheap, but limited to one counterfactual scenario.

  • MultiWorldCounterfactual (MWC): factual world plus every combination
    of interventions. Each split allocates its own index-plate dimension, so
    the represented worlds are the full Cartesian product of all interventions.
    Memory grows exponentially in the number of intervened sites.

A batched intervention handler is the natural third point on this spectrum:
many alternative scenarios laid side-by-side on a single shared axis. This
gives access to many coordinated interventions without the compute and memory cost
of materializing every combination of them — combinations that, in practice,
are usually never inspected and are discarded anyway.

2. Intended semantics

For batch size N, world axis of size N+1:

  • Index 0 — factual world; no interventions applied.
  • Index i+1 — scenario i: for each site s,
    value[i+1] = act_{s,i} if scenario i intervenes on s, else the
    propagated obs[i+1] (the value implied by scenario i's upstream
    interventions; factual if it has none).

Equivalence property. With shared_noise=True, each batched world equals the
corresponding gathered MWC slice. For multi-site interventions with sampled
descendants, the batched single axis (size N+1) and MWC's Cartesian axes have
different shapes, so a downstream sample draws different noise per cell: the
worlds are then distributionally equivalent but not sample-path identical.

3. Requirements

  1. Linear memory in N: one shared axis of size N+1, vs MWC's exponential
    growth (one axis per intervened site) and the naive alternative of running
    TwinWorldCounterfactual once per scenario in a loop, which produces N separate
    size-2 tensors.

  2. One vectorized forward pass: vs N passes for looped do or TWC, and
    vs MWC, which vectorizes but at exponential memory cost.

  3. Correct downstream propagation: when a scenario does not intervene on a
    site, the site's value must reflect any upstream interventions in that scenario.

  4. An always-present factual world, addressable like TWC/MWC via
    gather(value, IndexSet(<name>={0})).

  5. Works with non-trivial event dimensions (to_event(k)).

  6. Citizen status alongside TWC/MWC: a Messenger class usable as a context
    manager or decorator, composing with FactualConditioningMessenger, whose outputs
    are read with the standard chirho.indexed operations (gather, IndexSet,
    indices_of).

4. Mechanism

The single shared world axis is a named index-plate dimension of size N+1
tracked by chirho.indexed. With shared_noise=True it is created lazily by
the first intervention and propagates downstream, leaving upstream latents
shared/scalar. With shared_noise=False it is created eagerly on every latent
site by BatchedLatents(N+1, name=...), which handles event dimensions correctly
(a to_event(1) latent is placed at the registered plate position).

For each intervened site, the handler computes

value = torch.where(mask, act, obs)

where act and mask are first moved onto the named batch dimension with

act  = unbind_leftmost_dim(act,  name, size=N+1, event_dim=event_dim)
mask = unbind_leftmost_dim(mask, name, size=N+1, event_dim=event_dim)

event_dim is the site's event dimensionality, which flows through msg["kwargs"]
from Interventions._pyro_post_sample. This ensures the leftmost axis coincides
with the named plate dimension.

We reserve index 0 as the factual world: an extra batch slice whose mask is
False for every site, so the factual value propagates through unchanged. Scenarios
occupy indices 1..N. This matches the TWC/MWC convention
(gather(value, IndexSet(name={0})) is the factual world).

5. API

BatchedWorldCounterfactual, the handler, subclasses
IndexPlatesMessenger. batched_do is a thin context-manager wrapper exposing the same
parameters.

Interventions are specified as a list of per-scenario {site: tensor} dicts;
see BatchedInterventions for the full accepted forms.

Results are read with the standard indexed operations:

gather(value, IndexSet(batched_interventions={k}), event_dim=...)

where k=0 is the factual world and k=1..N are the scenarios.

6. Benchmarks

Two benchmarks were run locally (single thread, skipped in CI). They live skipped in the
corresponding test module.

The first measures wall-clock time for N=300 single-site interventions on a
three-site model with event size 300, comparing four approaches:

Approach Time Ratio vs batched
batched_do 1.6 ms 1x
MWC 19 ms 12x
loop + do 30 ms 19x
loop + TWC 128 ms 80x

The second measures the size in bytes of the output tensor as the number of
intervened sites K grows. MWC allocates one axis per site, so the output covers
2^K worlds; the batched handler keeps a single axis of size K+1.

K MWC bytes batched bytes ratio
2 16 12 ~1x
4 64 20 ~3x
8 1024 36 ~28x
12 16384 52 ~315x

7. Inspiration

PR #594 by @jfeser
was a first push in this direction: batch interventions along a shared leading dimension
and use torch.where to select between the action and the observed value. This
work picks up from that starting point.

Getting from prototype to a first-class handler required a few extensions and modifications.
In the collection form, the loop that accumulated per-site action tensors into acts was
correct, but the stacking step at the end accidentally read from intv (the loop
variable, left pointing at the last scenario) rather than from acts. This meant
only the last scenario's values were used.

A subtler issue was that the batch dimension was left as a plain leading axis on
act, not registered as a named index-plate. This made it invisible to
chirho.indexed, so gather and indices_of could not address it, and sites
with non-trivial event dimensions did not work. Routing through
unbind_leftmost_dim places the axis on the named plate so the rest of the
indexed machinery sees it normally.

The original prototype always drew independent noise per world, which is one
useful mode but not the only one. A shared_noise flag (default True) lets
users hold noise fixed across worlds and recover exact MWC-equivalent
counterfactuals. Similarly, the original had no reserved factual world: adding
index 0 as the unintervened run and composing FactualConditioningMessenger
brings the handler in line with how TWC and MWC work. For the same reason,
BatchedWorldCounterfactual subclasses IndexPlatesMessenger directly rather
than requiring users to add the wrapper themselves.

@rfl-urbaniak rfl-urbaniak requested a review from jfeser June 30, 2026 15:44

@jfeser jfeser left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think batched interventions (intervention coproducts) are a useful addition and can turn into a powerful and general intervention interface down the line. I have some concerns about the API in this PR, however:

  1. BatchedWorldCounterfactual doesn't handle split, and it only handles intervene for its own batched type — so an ordinary intervene() in a model body wouldn't get the world-splitting semantics the other handlers give it, and I'd like to understand what those models do.
  2. BatchedWorldCounterfactual takes its interventions as a parameter instead of handling the effects raised by do.
  3. I'd lean toward fixing both factual and shared_noise to True and dropping the parameters. I feel less strongly about factual, but shared_noise=False diverges from the existing counterfactual semantics (noise is shared across worlds), so I worry that exposing it could be surprising. We can get the current shared_noise=False semantics by using BatchedLatents directly.

My main suggestion — and I think a lot of the above falls out naturally from it — is to handle split rather than intervene. A split handler that recognizes a BatchedActions type and creates a single shared dimension (instead of taking a product) should compose nicely with both TwinWorldCounterfactual and MultiWorldCounterfactual, and it'd reuse the existing machinery for shared noise for free. It also leaves room to extend to richer sum/product interventions later if we find we need them.

Concretely, the this PR has us write:

with BatchedWorldCounterfactual([{"z": za}, {"x": xb}]):
        z, x, y = scm(event_shape)()

and I think we should write:

with MultiWorldCounterfactual(), BatchedWorldCounterfactual(), do(actions=BatchedActions({"z": za}, {"x": xb})):
        z, x, y = scm(event_shape)()

which parallels the existing interface:

with MultiWorldCounterfactual(), do(actions={"z": (za,), "x": (xb,)}):
        z, x, y = scm(event_shape)()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants