enable batched intervention#605
Conversation
…tion # Conflicts: # chirho/interventional/handlers.py
jfeser
left a comment
There was a problem hiding this comment.
I think batched interventions (intervention coproducts) are a useful addition and can turn into a powerful and general intervention interface down the line. I have some concerns about the API in this PR, however:
BatchedWorldCounterfactualdoesn't handlesplit, and it only handlesintervenefor its own batched type — so an ordinaryintervene()in a model body wouldn't get the world-splitting semantics the other handlers give it, and I'd like to understand what those models do.BatchedWorldCounterfactualtakes its interventions as a parameter instead of handling the effects raised bydo.- I'd lean toward fixing both
factualandshared_noisetoTrueand dropping the parameters. I feel less strongly aboutfactual, butshared_noise=Falsediverges from the existing counterfactual semantics (noise is shared across worlds), so I worry that exposing it could be surprising. We can get the currentshared_noise=Falsesemantics by usingBatchedLatentsdirectly.
My main suggestion — and I think a lot of the above falls out naturally from it — is to handle split rather than intervene. A split handler that recognizes a BatchedActions type and creates a single shared dimension (instead of taking a product) should compose nicely with both TwinWorldCounterfactual and MultiWorldCounterfactual, and it'd reuse the existing machinery for shared noise for free. It also leaves room to extend to richer sum/product interventions later if we find we need them.
Concretely, the this PR has us write:
with BatchedWorldCounterfactual([{"z": za}, {"x": xb}]):
z, x, y = scm(event_shape)()and I think we should write:
with MultiWorldCounterfactual(), BatchedWorldCounterfactual(), do(actions=BatchedActions({"z": za}, {"x": xb})):
z, x, y = scm(event_shape)()which parallels the existing interface:
with MultiWorldCounterfactual(), do(actions={"z": (za,), "x": (xb,)}):
z, x, y = scm(event_shape)()
1. Motivation
In search for explanations (and many other causal workflows), we need to
evaluate large batches of interventions. These interventions are
heterogeneous: different elements of the batch may intervene on different
variables, with different values.
ChiRho already provides two counterfactual handlers living at two edges of a spectrum:
TwinWorldCounterfactual(TWC): factual world plus one alternativeworld. All
splits reuse a single shared dimension name, so the world axishas size 2. Cheap, but limited to one counterfactual scenario.
MultiWorldCounterfactual(MWC): factual world plus every combinationof interventions. Each
splitallocates its own index-plate dimension, sothe represented worlds are the full Cartesian product of all interventions.
Memory grows exponentially in the number of intervened sites.
A batched intervention handler is the natural third point on this spectrum:
many alternative scenarios laid side-by-side on a single shared axis. This
gives access to many coordinated interventions without the compute and memory cost
of materializing every combination of them — combinations that, in practice,
are usually never inspected and are discarded anyway.
2. Intended semantics
For batch size
N, world axis of sizeN+1:i+1— scenarioi: for each sites,value[i+1] = act_{s,i}if scenarioiintervenes ons, else thepropagated
obs[i+1](the value implied by scenarioi's upstreaminterventions; factual if it has none).
Equivalence property. With
shared_noise=True, each batched world equals thecorresponding gathered MWC slice. For multi-site interventions with sampled
descendants, the batched single axis (size
N+1) and MWC's Cartesian axes havedifferent shapes, so a downstream
sampledraws different noise per cell: theworlds are then distributionally equivalent but not sample-path identical.
3. Requirements
Linear memory in N: one shared axis of size
N+1, vs MWC's exponentialgrowth (one axis per intervened site) and the naive alternative of running
TwinWorldCounterfactualonce per scenario in a loop, which produces N separatesize-2 tensors.
One vectorized forward pass: vs
Npasses for loopeddoor TWC, andvs MWC, which vectorizes but at exponential memory cost.
Correct downstream propagation: when a scenario does not intervene on a
site, the site's value must reflect any upstream interventions in that scenario.
An always-present factual world, addressable like TWC/MWC via
gather(value, IndexSet(<name>={0})).Works with non-trivial event dimensions (
to_event(k)).Citizen status alongside TWC/MWC: a
Messengerclass usable as a contextmanager or decorator, composing with
FactualConditioningMessenger, whose outputsare read with the standard
chirho.indexedoperations (gather,IndexSet,indices_of).4. Mechanism
The single shared world axis is a named index-plate dimension of size
N+1tracked by
chirho.indexed. Withshared_noise=Trueit is created lazily bythe first intervention and propagates downstream, leaving upstream latents
shared/scalar. With
shared_noise=Falseit is created eagerly on every latentsite by
BatchedLatents(N+1, name=...), which handles event dimensions correctly(a
to_event(1)latent is placed at the registered plate position).For each intervened site, the handler computes
where
actandmaskare first moved onto the named batch dimension withevent_dimis the site's event dimensionality, which flows throughmsg["kwargs"]from
Interventions._pyro_post_sample. This ensures the leftmost axis coincideswith the named plate dimension.
We reserve index 0 as the factual world: an extra batch slice whose
maskisFalsefor every site, so the factual value propagates through unchanged. Scenariosoccupy indices
1..N. This matches the TWC/MWC convention(
gather(value, IndexSet(name={0}))is the factual world).5. API
BatchedWorldCounterfactual, the handler, subclassesIndexPlatesMessenger.batched_dois a thin context-manager wrapper exposing the sameparameters.
Interventions are specified as a list of per-scenario
{site: tensor}dicts;see
BatchedInterventionsfor the full accepted forms.Results are read with the standard indexed operations:
where
k=0is the factual world andk=1..Nare the scenarios.6. Benchmarks
Two benchmarks were run locally (single thread, skipped in CI). They live skipped in the
corresponding test module.
The first measures wall-clock time for N=300 single-site interventions on a
three-site model with event size 300, comparing four approaches:
batched_dodoThe second measures the size in bytes of the output tensor as the number of
intervened sites K grows. MWC allocates one axis per site, so the output covers
2^K worlds; the batched handler keeps a single axis of size K+1.
7. Inspiration
PR #594 by @jfeser
was a first push in this direction: batch interventions along a shared leading dimension
and use
torch.whereto select between the action and the observed value. Thiswork picks up from that starting point.
Getting from prototype to a first-class handler required a few extensions and modifications.
In the collection form, the loop that accumulated per-site action tensors into
actswascorrect, but the stacking step at the end accidentally read from
intv(the loopvariable, left pointing at the last scenario) rather than from
acts. This meantonly the last scenario's values were used.
A subtler issue was that the batch dimension was left as a plain leading axis on
act, not registered as a named index-plate. This made it invisible tochirho.indexed, sogatherandindices_ofcould not address it, and siteswith non-trivial event dimensions did not work. Routing through
unbind_leftmost_dimplaces the axis on the named plate so the rest of theindexed machinery sees it normally.
The original prototype always drew independent noise per world, which is one
useful mode but not the only one. A
shared_noiseflag (defaultTrue) letsusers hold noise fixed across worlds and recover exact MWC-equivalent
counterfactuals. Similarly, the original had no reserved factual world: adding
index 0 as the unintervened run and composing
FactualConditioningMessengerbrings the handler in line with how TWC and MWC work. For the same reason,
BatchedWorldCounterfactualsubclassesIndexPlatesMessengerdirectly ratherthan requiring users to add the wrapper themselves.