Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions examples/deepscaler/math_eval_nb.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(
if mesh_config is None:
# Default: 4-way tensor parallelism
mesh_config = [[1, 4], ["fsdp", "tp"]]
self.mesh = jax.make_mesh(*mesh_config, axis_types=(jax.sharding.AxisType.Auto,) * len(mesh_config[0]))
self.mesh = jax.make_mesh(*mesh_config, axis_types=(jax.sharding.AxisType.Auto,) * len(mesh_config[0])) # pyrefly: ignore[bad-argument-type]
self.tokenizer = None
self.model = None
self.sampler = None
Expand Down Expand Up @@ -259,7 +259,7 @@ def model_from_orbax_ckpt(self):
),
)
graphdef, _ = nnx.split(abs_model)
new_state = nnx.State(ckpt.model_params)
new_state = nnx.State(ckpt.model_params) # pyrefly: ignore[missing-attribute]
self.model = nnx.merge(graphdef, new_state)

def load_model(self):
Expand Down Expand Up @@ -292,7 +292,7 @@ def load_model(self):

if self.sampler_type == "vanilla":
self.sampler_vanilla = sampler_lib.Sampler(
transformer=self.model,
transformer=self.model, # pyrefly: ignore[bad-argument-type]
tokenizer=self.tokenizer,
cache_config=cache_config,
)
Expand All @@ -304,7 +304,7 @@ def load_model(self):
model=self.model,
backend="sglang_jax",
)
self.sampler_sglang = sglang_jax_sampler.SglangJaxSampler(
self.sampler_sglang = sglang_jax_sampler.SglangJaxSampler( # pyrefly: ignore[bad-instantiation]
tokenizer=self.tokenizer,
config=sglang_jax_sampler.SglangJaxConfig(
mesh=self.mesh,
Expand All @@ -330,7 +330,7 @@ def load_model(self):
model=self.model,
backend="vllm_jax",
)
self.sampler_vllm = vllm_sampler.VllmSampler(
self.sampler_vllm = vllm_sampler.VllmSampler( # pyrefly: ignore[bad-instantiation]
tokenizer=self.tokenizer,
config=vllm_sampler.VllmConfig(
mesh=self.mesh,
Expand Down Expand Up @@ -396,7 +396,7 @@ def process_item(item):
else:
instruction = "Please reason step by step. Your final answer must appear inside \\boxed{...} and nothing else."
prompt = f"{instruction} {question}"
prompt = self.tokenizer.apply_chat_template(
prompt = self.tokenizer.apply_chat_template( # pyrefly: ignore[missing-attribute]
[{"role": "user", "content": prompt}],
tokenize=False, add_generation_prompt=True)

Expand All @@ -410,7 +410,7 @@ def process_item(item):
print("\n" + "=" * 60)
print("DEBUG: First formatted prompt:")
first_item = dataset[0]
print(first_item["prompt"])
print(first_item["prompt"]) # pyrefly: ignore[unsupported-operation]
print("=" * 60 + "\n")

return dataset
Expand Down Expand Up @@ -451,7 +451,7 @@ def generate(
top_p=top_p,
echo=False,
eos_tokens=[stop_token_id],
seed=jax.random.PRNGKey(seed) if seed is not None else None,
seed=jax.random.PRNGKey(seed) if seed is not None else None, # pyrefly: ignore[bad-argument-type]
)
elif self.sampler_type == "sglang_jax":
out_data = self.sampler_sglang(
Expand Down Expand Up @@ -530,8 +530,8 @@ def evaluate(
batch_response = self.generate(
prompts=prompts,
temperature=temperature,
top_k=top_k,
top_p=top_p,
top_k=top_k, # pyrefly: ignore[bad-argument-type]
top_p=top_p, # pyrefly: ignore[bad-argument-type]
seed=pass_idx
if self.sampler_type != "vllm"
else None, # vllm handles seeding differently
Expand Down
10 changes: 5 additions & 5 deletions examples/deepscaler/train_deepscaler_nb.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@
trainer_devices = math.prod(TRAINER_MESH[0])
rollout_devices = math.prod(ROLLOUT_MESH[0])

if trainer_devices + rollout_devices > jax.device_count():
if trainer_devices + rollout_devices > jax.device_count(): # pyrefly: ignore[unsupported-operation]
raise ValueError(
"Trainer devices must be less than or equal to the number of devices"
" available."
Expand All @@ -237,7 +237,7 @@

if ROLLOUT_ENGINE in ("sglang_jax", "vllm"):
rollout_device_list = jax._src.mesh_utils.create_device_mesh(
ROLLOUT_MESH[0], jax.devices()[:rollout_devices]
ROLLOUT_MESH[0], jax.devices()[:rollout_devices] # pyrefly: ignore[bad-argument-type, bad-index]
)

rollout_mesh = jax.sharding.Mesh(
Expand All @@ -252,7 +252,7 @@
# )
print(f"YY {rollout_device_list=} {rollout_mesh.devices=}")
trainer_devices_list = jax._src.mesh_utils.create_device_mesh(
TRAINER_MESH[0], jax.devices()[-trainer_devices:]
TRAINER_MESH[0], jax.devices()[-trainer_devices:] # pyrefly: ignore[bad-argument-type, unsupported-operation]
)
# trainer_mesh = jax.make_mesh(
# *TRAINER_MESH,
Expand Down Expand Up @@ -545,15 +545,15 @@ def get_lora_model(base_model, model_mesh):
)
elif ROLLOUT_ENGINE == "vllm":
rollout_engine_config = base_rollout.RolloutConfig(
**base_rollout_dict, **vllm_rollout_dict
**base_rollout_dict, **vllm_rollout_dict # pyrefly: ignore[bad-argument-type]
)
elif ROLLOUT_ENGINE == "vanilla":
rollout_engine_config = base_rollout.RolloutConfig(**base_rollout_dict)
else:
raise ValueError(f"Unsupported rollout engine: {ROLLOUT_ENGINE}")

cluster_config = rl_cluster_lib.ClusterConfig(
role_to_mesh={
role_to_mesh={ # pyrefly: ignore[bad-argument-type]
rl_cluster_lib.Role.ACTOR: trainer_mesh,
rl_cluster_lib.Role.REFERENCE: trainer_mesh,
rl_cluster_lib.Role.ROLLOUT: rollout_mesh,
Expand Down
2 changes: 1 addition & 1 deletion examples/deepswe/swe_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def close(self) -> None:
os.system(f"docker rmi {docker_image}")

@staticmethod
def from_dict(extra_info: dict | str) -> "SWEEnv":
def from_dict(extra_info: dict | str) -> "SWEEnv": # pyrefly: ignore[bad-override]
"""Create an environment instance from JSON configuration.

Args:
Expand Down
8 changes: 4 additions & 4 deletions examples/deepswe/train_deepswe_nb.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@
except ImportError as e:
print(f"❌ Still missing a module: {e}")

if pathwaysutils is not None and os.getenv("JAX_PLATFORMS", None) == "proxy":
if pathwaysutils is not None and os.getenv("JAX_PLATFORMS", None) == "proxy": # pyrefly: ignore[unbound-name]
pathwaysutils.initialize()


Expand Down Expand Up @@ -356,7 +356,7 @@
os.makedirs(MODEL_PATH, exist_ok=True)

# Assumes "Qwen/" organization prefix for HF download. Adjust if using other models.
snapshot_download(
snapshot_download( # pyrefly: ignore[no-matching-overload]
repo_id=f"Qwen/{MODEL_VERSION}",
local_dir=MODEL_PATH,
local_dir_use_symlinks=False,
Expand Down Expand Up @@ -651,7 +651,7 @@ def transform(entry):

dataset = dataset.map(
transform,
keep_in_memory=True,
keep_in_memory=True, # pyrefly: ignore[unexpected-keyword]
)

# %%
Expand Down Expand Up @@ -822,7 +822,7 @@ def transform(entry):
# ==========================================

dataset = dataset.shuffle(seed=SEED)
grain_dataset = grain.MapDataset.source(dataset)
grain_dataset = grain.MapDataset.source(dataset) # pyrefly: ignore[bad-argument-type]

def mixed_type_batch_fn(elements):
"""elements: A list of dicts."""
Expand Down
2 changes: 1 addition & 1 deletion tests/cli/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def test_perf_metrics_validation(
overrides = [o.format(log_dir=log_dir) for o in overrides]
argv = [main_command, "base_config.yaml"] + overrides
if expected_error:
with self.assertRaisesRegex(expected_error, error_regex):
with self.assertRaisesRegex(expected_error, error_regex): # pyrefly: ignore[bad-argument-type]
config.initialize(argv)
else:
config.initialize(argv)
Expand Down
34 changes: 17 additions & 17 deletions tests/distillation/feature_extraction/sowed_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_wrap_model_with_sowed_modules_recursive(self):

# Target SimpleLayer for wrapping
target_types = [SimpleLayer]
sowed_module.wrap_model_with_sowed_modules(model, target_types)
sowed_module.wrap_model_with_sowed_modules(model, target_types) # pyrefly: ignore[bad-argument-type]

# Assert all SimpleLayers are wrapped
self.assertIsInstance(
Expand Down Expand Up @@ -166,39 +166,39 @@ def test_wrap_model_with_multiple_target_types(self):

# Target SimpleLayer for wrapping
target_types = [SimpleLayer, Block]
sowed_module.wrap_model_with_sowed_modules(model, target_types)
sowed_module.wrap_model_with_sowed_modules(model, target_types) # pyrefly: ignore[bad-argument-type]

# Assert all SimpleLayers are wrapped
self.assertIsInstance(
model.block1.wrapped_model.layer1,
model.block1.wrapped_model.layer1, # pyrefly: ignore[missing-attribute]
sowed_module.SowedModule,
)
self.assertIsInstance(
model.block1.wrapped_model.layer2,
model.block1.wrapped_model.layer2, # pyrefly: ignore[missing-attribute]
sowed_module.SowedModule,
)
self.assertIsInstance(
model.block1.wrapped_model.other_layers[0],
model.block1.wrapped_model.other_layers[0], # pyrefly: ignore[missing-attribute]
sowed_module.SowedModule,
)
self.assertIsInstance(
model.block1.wrapped_model.other_layers[1],
model.block1.wrapped_model.other_layers[1], # pyrefly: ignore[missing-attribute]
sowed_module.SowedModule,
)
self.assertIsInstance(
model.block2.wrapped_model.layer1,
model.block2.wrapped_model.layer1, # pyrefly: ignore[missing-attribute]
sowed_module.SowedModule,
)
self.assertIsInstance(
model.block2.wrapped_model.layer2,
model.block2.wrapped_model.layer2, # pyrefly: ignore[missing-attribute]
sowed_module.SowedModule,
)
self.assertIsInstance(
model.block2.wrapped_model.other_layers[0],
model.block2.wrapped_model.other_layers[0], # pyrefly: ignore[missing-attribute]
sowed_module.SowedModule,
)
self.assertIsInstance(
model.block2.wrapped_model.other_layers[1],
model.block2.wrapped_model.other_layers[1], # pyrefly: ignore[missing-attribute]
sowed_module.SowedModule,
)
self.assertIsInstance(
Expand All @@ -215,11 +215,11 @@ def test_wrap_model_with_multiple_target_types(self):
model.block1.wrapped_model.layer2.wrapped_model, original_b1_l2
)
self.assertIs(
model.block1.wrapped_model.other_layers[0].wrapped_model,
model.block1.wrapped_model.other_layers[0].wrapped_model, # pyrefly: ignore[missing-attribute]
original_b1_other_layers_0,
)
self.assertIs(
model.block1.wrapped_model.other_layers[1].wrapped_model,
model.block1.wrapped_model.other_layers[1].wrapped_model, # pyrefly: ignore[missing-attribute]
original_b1_other_layers_1,
)
self.assertIs(
Expand All @@ -229,11 +229,11 @@ def test_wrap_model_with_multiple_target_types(self):
model.block2.wrapped_model.layer2.wrapped_model, original_b2_l2
)
self.assertIs(
model.block2.wrapped_model.other_layers[0].wrapped_model,
model.block2.wrapped_model.other_layers[0].wrapped_model, # pyrefly: ignore[missing-attribute]
original_b2_other_layers_0,
)
self.assertIs(
model.block2.wrapped_model.other_layers[1].wrapped_model,
model.block2.wrapped_model.other_layers[1].wrapped_model, # pyrefly: ignore[missing-attribute]
original_b2_other_layers_1,
)
self.assertIs(model.final_layer.wrapped_model, original_final)
Expand Down Expand Up @@ -316,11 +316,11 @@ def test_capture_avoids_double_wrapping(self):
original_l2 = model.layer2
# Manually wrap one layer first
manual_wrapper = sowed_module.SowedModule(original_l1)
model.layer1 = manual_wrapper
model.layer1 = manual_wrapper # pyrefly: ignore[bad-assignment]

# Run the utility function
target_types = [SimpleLayer]
sowed_module.wrap_model_with_sowed_modules(model, target_types)
sowed_module.wrap_model_with_sowed_modules(model, target_types) # pyrefly: ignore[bad-argument-type]

# Assert manually wrapped layer wasn't re-wrapped
self.assertIs(model.layer1, manual_wrapper)
Expand All @@ -347,7 +347,7 @@ def test_unwrap_sowed_modules_recursive(self):

# Wrap the model (all SimpleLayers)
target_types = [SimpleLayer]
sowed_module.wrap_model_with_sowed_modules(model, target_types)
sowed_module.wrap_model_with_sowed_modules(model, target_types) # pyrefly: ignore[bad-argument-type]

# Unwrap the sowed modules
sowed_module.unwrap_sowed_modules(model)
Expand Down
2 changes: 1 addition & 1 deletion tests/distillation/strategies/logit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_top_k_distillation_logic(self):
)

# The losses should be different because the distribution is truncated
self.assertNotAlmostEqual(loss_full, loss_k, places=4)
self.assertNotAlmostEqual(loss_full, loss_k, places=4) # pyrefly: ignore[no-matching-overload]

@parameterized.named_parameters(
("alpha_one", 1.0),
Expand Down
10 changes: 5 additions & 5 deletions tests/rl/ppo/ppo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def test_ppo_learner(
actor=model,
critic=value_model,
reference=ref_model,
reward=reward_model if use_reward_model else None, # pylint: disable=undefined-variable
reward=reward_model if use_reward_model else None, # pylint: disable=undefined-variable # pyrefly: ignore[unbound-name]
tokenizer=vocab,
cluster_config=cluster_config,
)
Expand Down Expand Up @@ -421,11 +421,11 @@ def test_ppo_learner(
expected_metrics.append('pg_clipfrac_lower')
for metric_name in expected_metrics:
self.assertLen(
actor_metric_logger.get_metric_history('actor', metric_name, 'train'),
actor_metric_logger.get_metric_history('actor', metric_name, 'train'), # pyrefly: ignore[missing-attribute]
ppo_learner._iter_steps,
)
self.assertLen(
actor_metric_logger.get_metric_history('actor', metric_name, 'eval'),
actor_metric_logger.get_metric_history('actor', metric_name, 'eval'), # pyrefly: ignore[missing-attribute]
ppo_learner.rl_cluster.actor_trainer.train_steps
/ cluster_config.training_config.eval_every_n_steps,
msg=f'metric_name: {metric_name}',
Expand All @@ -434,13 +434,13 @@ def test_ppo_learner(
critic_metric_logger = ppo_learner.rl_cluster.critic_trainer.metrics_logger
for metric_name in ['loss', 'vpred_mean', 'vf_clipfrac']:
self.assertLen(
critic_metric_logger.get_metric_history(
critic_metric_logger.get_metric_history( # pyrefly: ignore[missing-attribute]
'critic', metric_name, 'train'
),
ppo_learner.rl_cluster.critic_trainer.train_steps,
)
self.assertLen(
critic_metric_logger.get_metric_history(
critic_metric_logger.get_metric_history( # pyrefly: ignore[missing-attribute]
'critic', metric_name, 'eval'
),
ppo_learner.rl_cluster.critic_trainer.train_steps
Expand Down
Loading
Loading