Fix JaxArrayWrapper TypeError with JAX tracers during JIT#211
Open
Chessing234 wants to merge 1 commit intogoogle-deepmind:mainfrom
Open
Fix JaxArrayWrapper TypeError with JAX tracers during JIT#211Chessing234 wants to merge 1 commit intogoogle-deepmind:mainfrom
Chessing234 wants to merge 1 commit intogoogle-deepmind:mainfrom
Conversation
Add jax.core.Tracer to the isinstance checks in JaxArrayWrapper's __array_ufunc__ method, the unwrap function, and _WRAPPED_TYPES. During JIT-traced execution (e.g. GenCast autoregressive rollout), JAX uses tracer objects (DynamicJaxprTracer) instead of concrete jax.Array instances. These tracers were not recognized by the type checks in xarray_jax.py, causing __array_ufunc__ to return NotImplemented and raising a TypeError when numpy ufuncs like multiply were applied to a mix of raw tracers and JaxArrayWrapper- wrapped tracers. Fixes google-deepmind#203 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
TypeError: operand type(s) all returned NotImplemented from __array_ufunc__that occurs during GenCast autoregressive rollout underjax.jittracingjax.core.Tracerto three locations inxarray_jax.pywhere JAX array types are checked at runtime:_WRAPPED_TYPEStuple, so tracers are properly wrapped when constructing xarray structures during tracingJaxArrayWrapper.__array_ufunc__isinstance check, so numpy ufuncs (e.g.multiply) correctly dispatch when one operand is a raw tracer and the other is aJaxArrayWrapper-wrapped tracerunwrap()function, so tracers are passed through (likejax.Array) instead of falling to the error/passthrough branchRoot cause
During JIT-traced execution, JAX replaces concrete
jax.Arrayobjects withDynamicJaxprTracerinstances. These inherit fromjax.core.Tracer, notjax.Array. The existing type checks inxarray_jax.pydid not account for tracer types, causing__array_ufunc__to returnNotImplementedwhen a raw tracer scalar was combined with aJaxArrayWrapper-wrapped tracer array.Test plan
xarray_jax_test.pytests passTypeErrorno longer occursFixes #203
🤖 Generated with Claude Code