diff --git a/examples/deepscaler/math_eval_nb.py b/examples/deepscaler/math_eval_nb.py index 9cfd3198c..51188d03f 100644 --- a/examples/deepscaler/math_eval_nb.py +++ b/examples/deepscaler/math_eval_nb.py @@ -209,7 +209,7 @@ def __init__( if mesh_config is None: # Default: 4-way tensor parallelism mesh_config = [[1, 4], ["fsdp", "tp"]] - self.mesh = jax.make_mesh(*mesh_config, axis_types=(jax.sharding.AxisType.Auto,) * len(mesh_config[0])) + self.mesh = jax.make_mesh(*mesh_config, axis_types=(jax.sharding.AxisType.Auto,) * len(mesh_config[0])) # pyrefly: ignore[bad-argument-type] self.tokenizer = None self.model = None self.sampler = None @@ -259,7 +259,7 @@ def model_from_orbax_ckpt(self): ), ) graphdef, _ = nnx.split(abs_model) - new_state = nnx.State(ckpt.model_params) + new_state = nnx.State(ckpt.model_params) # pyrefly: ignore[missing-attribute] self.model = nnx.merge(graphdef, new_state) def load_model(self): @@ -292,7 +292,7 @@ def load_model(self): if self.sampler_type == "vanilla": self.sampler_vanilla = sampler_lib.Sampler( - transformer=self.model, + transformer=self.model, # pyrefly: ignore[bad-argument-type] tokenizer=self.tokenizer, cache_config=cache_config, ) @@ -304,7 +304,7 @@ def load_model(self): model=self.model, backend="sglang_jax", ) - self.sampler_sglang = sglang_jax_sampler.SglangJaxSampler( + self.sampler_sglang = sglang_jax_sampler.SglangJaxSampler( # pyrefly: ignore[bad-instantiation] tokenizer=self.tokenizer, config=sglang_jax_sampler.SglangJaxConfig( mesh=self.mesh, @@ -330,7 +330,7 @@ def load_model(self): model=self.model, backend="vllm_jax", ) - self.sampler_vllm = vllm_sampler.VllmSampler( + self.sampler_vllm = vllm_sampler.VllmSampler( # pyrefly: ignore[bad-instantiation] tokenizer=self.tokenizer, config=vllm_sampler.VllmConfig( mesh=self.mesh, @@ -396,7 +396,7 @@ def process_item(item): else: instruction = "Please reason step by step. Your final answer must appear inside \\boxed{...} and nothing else." prompt = f"{instruction} {question}" - prompt = self.tokenizer.apply_chat_template( + prompt = self.tokenizer.apply_chat_template( # pyrefly: ignore[missing-attribute] [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True) @@ -410,7 +410,7 @@ def process_item(item): print("\n" + "=" * 60) print("DEBUG: First formatted prompt:") first_item = dataset[0] - print(first_item["prompt"]) + print(first_item["prompt"]) # pyrefly: ignore[unsupported-operation] print("=" * 60 + "\n") return dataset @@ -451,7 +451,7 @@ def generate( top_p=top_p, echo=False, eos_tokens=[stop_token_id], - seed=jax.random.PRNGKey(seed) if seed is not None else None, + seed=jax.random.PRNGKey(seed) if seed is not None else None, # pyrefly: ignore[bad-argument-type] ) elif self.sampler_type == "sglang_jax": out_data = self.sampler_sglang( @@ -530,8 +530,8 @@ def evaluate( batch_response = self.generate( prompts=prompts, temperature=temperature, - top_k=top_k, - top_p=top_p, + top_k=top_k, # pyrefly: ignore[bad-argument-type] + top_p=top_p, # pyrefly: ignore[bad-argument-type] seed=pass_idx if self.sampler_type != "vllm" else None, # vllm handles seeding differently diff --git a/examples/deepscaler/train_deepscaler_nb.py b/examples/deepscaler/train_deepscaler_nb.py index 8cc903077..3a43fef6b 100644 --- a/examples/deepscaler/train_deepscaler_nb.py +++ b/examples/deepscaler/train_deepscaler_nb.py @@ -228,7 +228,7 @@ trainer_devices = math.prod(TRAINER_MESH[0]) rollout_devices = math.prod(ROLLOUT_MESH[0]) -if trainer_devices + rollout_devices > jax.device_count(): +if trainer_devices + rollout_devices > jax.device_count(): # pyrefly: ignore[unsupported-operation] raise ValueError( "Trainer devices must be less than or equal to the number of devices" " available." @@ -237,7 +237,7 @@ if ROLLOUT_ENGINE in ("sglang_jax", "vllm"): rollout_device_list = jax._src.mesh_utils.create_device_mesh( - ROLLOUT_MESH[0], jax.devices()[:rollout_devices] + ROLLOUT_MESH[0], jax.devices()[:rollout_devices] # pyrefly: ignore[bad-argument-type, bad-index] ) rollout_mesh = jax.sharding.Mesh( @@ -252,7 +252,7 @@ # ) print(f"YY {rollout_device_list=} {rollout_mesh.devices=}") trainer_devices_list = jax._src.mesh_utils.create_device_mesh( - TRAINER_MESH[0], jax.devices()[-trainer_devices:] + TRAINER_MESH[0], jax.devices()[-trainer_devices:] # pyrefly: ignore[bad-argument-type, unsupported-operation] ) # trainer_mesh = jax.make_mesh( # *TRAINER_MESH, @@ -545,7 +545,7 @@ def get_lora_model(base_model, model_mesh): ) elif ROLLOUT_ENGINE == "vllm": rollout_engine_config = base_rollout.RolloutConfig( - **base_rollout_dict, **vllm_rollout_dict + **base_rollout_dict, **vllm_rollout_dict # pyrefly: ignore[bad-argument-type] ) elif ROLLOUT_ENGINE == "vanilla": rollout_engine_config = base_rollout.RolloutConfig(**base_rollout_dict) @@ -553,7 +553,7 @@ def get_lora_model(base_model, model_mesh): raise ValueError(f"Unsupported rollout engine: {ROLLOUT_ENGINE}") cluster_config = rl_cluster_lib.ClusterConfig( - role_to_mesh={ + role_to_mesh={ # pyrefly: ignore[bad-argument-type] rl_cluster_lib.Role.ACTOR: trainer_mesh, rl_cluster_lib.Role.REFERENCE: trainer_mesh, rl_cluster_lib.Role.ROLLOUT: rollout_mesh, diff --git a/examples/deepswe/swe_env.py b/examples/deepswe/swe_env.py index 6164c4991..d054c449d 100644 --- a/examples/deepswe/swe_env.py +++ b/examples/deepswe/swe_env.py @@ -153,7 +153,7 @@ def close(self) -> None: os.system(f"docker rmi {docker_image}") @staticmethod - def from_dict(extra_info: dict | str) -> "SWEEnv": + def from_dict(extra_info: dict | str) -> "SWEEnv": # pyrefly: ignore[bad-override] """Create an environment instance from JSON configuration. Args: diff --git a/examples/deepswe/train_deepswe_nb.py b/examples/deepswe/train_deepswe_nb.py index 3357367a0..7f92333bd 100644 --- a/examples/deepswe/train_deepswe_nb.py +++ b/examples/deepswe/train_deepswe_nb.py @@ -280,7 +280,7 @@ except ImportError as e: print(f"❌ Still missing a module: {e}") -if pathwaysutils is not None and os.getenv("JAX_PLATFORMS", None) == "proxy": +if pathwaysutils is not None and os.getenv("JAX_PLATFORMS", None) == "proxy": # pyrefly: ignore[unbound-name] pathwaysutils.initialize() @@ -356,7 +356,7 @@ os.makedirs(MODEL_PATH, exist_ok=True) # Assumes "Qwen/" organization prefix for HF download. Adjust if using other models. - snapshot_download( + snapshot_download( # pyrefly: ignore[no-matching-overload] repo_id=f"Qwen/{MODEL_VERSION}", local_dir=MODEL_PATH, local_dir_use_symlinks=False, @@ -651,7 +651,7 @@ def transform(entry): dataset = dataset.map( transform, - keep_in_memory=True, + keep_in_memory=True, # pyrefly: ignore[unexpected-keyword] ) # %% @@ -822,7 +822,7 @@ def transform(entry): # ========================================== dataset = dataset.shuffle(seed=SEED) -grain_dataset = grain.MapDataset.source(dataset) +grain_dataset = grain.MapDataset.source(dataset) # pyrefly: ignore[bad-argument-type] def mixed_type_batch_fn(elements): """elements: A list of dicts.""" diff --git a/tests/cli/config_test.py b/tests/cli/config_test.py index 2c142dd59..c7d1788d3 100644 --- a/tests/cli/config_test.py +++ b/tests/cli/config_test.py @@ -688,7 +688,7 @@ def test_perf_metrics_validation( overrides = [o.format(log_dir=log_dir) for o in overrides] argv = [main_command, "base_config.yaml"] + overrides if expected_error: - with self.assertRaisesRegex(expected_error, error_regex): + with self.assertRaisesRegex(expected_error, error_regex): # pyrefly: ignore[bad-argument-type] config.initialize(argv) else: config.initialize(argv) diff --git a/tests/distillation/feature_extraction/sowed_module_test.py b/tests/distillation/feature_extraction/sowed_module_test.py index 88c041712..b7169026b 100644 --- a/tests/distillation/feature_extraction/sowed_module_test.py +++ b/tests/distillation/feature_extraction/sowed_module_test.py @@ -87,7 +87,7 @@ def test_wrap_model_with_sowed_modules_recursive(self): # Target SimpleLayer for wrapping target_types = [SimpleLayer] - sowed_module.wrap_model_with_sowed_modules(model, target_types) + sowed_module.wrap_model_with_sowed_modules(model, target_types) # pyrefly: ignore[bad-argument-type] # Assert all SimpleLayers are wrapped self.assertIsInstance( @@ -166,39 +166,39 @@ def test_wrap_model_with_multiple_target_types(self): # Target SimpleLayer for wrapping target_types = [SimpleLayer, Block] - sowed_module.wrap_model_with_sowed_modules(model, target_types) + sowed_module.wrap_model_with_sowed_modules(model, target_types) # pyrefly: ignore[bad-argument-type] # Assert all SimpleLayers are wrapped self.assertIsInstance( - model.block1.wrapped_model.layer1, + model.block1.wrapped_model.layer1, # pyrefly: ignore[missing-attribute] sowed_module.SowedModule, ) self.assertIsInstance( - model.block1.wrapped_model.layer2, + model.block1.wrapped_model.layer2, # pyrefly: ignore[missing-attribute] sowed_module.SowedModule, ) self.assertIsInstance( - model.block1.wrapped_model.other_layers[0], + model.block1.wrapped_model.other_layers[0], # pyrefly: ignore[missing-attribute] sowed_module.SowedModule, ) self.assertIsInstance( - model.block1.wrapped_model.other_layers[1], + model.block1.wrapped_model.other_layers[1], # pyrefly: ignore[missing-attribute] sowed_module.SowedModule, ) self.assertIsInstance( - model.block2.wrapped_model.layer1, + model.block2.wrapped_model.layer1, # pyrefly: ignore[missing-attribute] sowed_module.SowedModule, ) self.assertIsInstance( - model.block2.wrapped_model.layer2, + model.block2.wrapped_model.layer2, # pyrefly: ignore[missing-attribute] sowed_module.SowedModule, ) self.assertIsInstance( - model.block2.wrapped_model.other_layers[0], + model.block2.wrapped_model.other_layers[0], # pyrefly: ignore[missing-attribute] sowed_module.SowedModule, ) self.assertIsInstance( - model.block2.wrapped_model.other_layers[1], + model.block2.wrapped_model.other_layers[1], # pyrefly: ignore[missing-attribute] sowed_module.SowedModule, ) self.assertIsInstance( @@ -215,11 +215,11 @@ def test_wrap_model_with_multiple_target_types(self): model.block1.wrapped_model.layer2.wrapped_model, original_b1_l2 ) self.assertIs( - model.block1.wrapped_model.other_layers[0].wrapped_model, + model.block1.wrapped_model.other_layers[0].wrapped_model, # pyrefly: ignore[missing-attribute] original_b1_other_layers_0, ) self.assertIs( - model.block1.wrapped_model.other_layers[1].wrapped_model, + model.block1.wrapped_model.other_layers[1].wrapped_model, # pyrefly: ignore[missing-attribute] original_b1_other_layers_1, ) self.assertIs( @@ -229,11 +229,11 @@ def test_wrap_model_with_multiple_target_types(self): model.block2.wrapped_model.layer2.wrapped_model, original_b2_l2 ) self.assertIs( - model.block2.wrapped_model.other_layers[0].wrapped_model, + model.block2.wrapped_model.other_layers[0].wrapped_model, # pyrefly: ignore[missing-attribute] original_b2_other_layers_0, ) self.assertIs( - model.block2.wrapped_model.other_layers[1].wrapped_model, + model.block2.wrapped_model.other_layers[1].wrapped_model, # pyrefly: ignore[missing-attribute] original_b2_other_layers_1, ) self.assertIs(model.final_layer.wrapped_model, original_final) @@ -316,11 +316,11 @@ def test_capture_avoids_double_wrapping(self): original_l2 = model.layer2 # Manually wrap one layer first manual_wrapper = sowed_module.SowedModule(original_l1) - model.layer1 = manual_wrapper + model.layer1 = manual_wrapper # pyrefly: ignore[bad-assignment] # Run the utility function target_types = [SimpleLayer] - sowed_module.wrap_model_with_sowed_modules(model, target_types) + sowed_module.wrap_model_with_sowed_modules(model, target_types) # pyrefly: ignore[bad-argument-type] # Assert manually wrapped layer wasn't re-wrapped self.assertIs(model.layer1, manual_wrapper) @@ -347,7 +347,7 @@ def test_unwrap_sowed_modules_recursive(self): # Wrap the model (all SimpleLayers) target_types = [SimpleLayer] - sowed_module.wrap_model_with_sowed_modules(model, target_types) + sowed_module.wrap_model_with_sowed_modules(model, target_types) # pyrefly: ignore[bad-argument-type] # Unwrap the sowed modules sowed_module.unwrap_sowed_modules(model) diff --git a/tests/distillation/strategies/logit_test.py b/tests/distillation/strategies/logit_test.py index d79d5ce89..efafca175 100644 --- a/tests/distillation/strategies/logit_test.py +++ b/tests/distillation/strategies/logit_test.py @@ -156,7 +156,7 @@ def test_top_k_distillation_logic(self): ) # The losses should be different because the distribution is truncated - self.assertNotAlmostEqual(loss_full, loss_k, places=4) + self.assertNotAlmostEqual(loss_full, loss_k, places=4) # pyrefly: ignore[no-matching-overload] @parameterized.named_parameters( ("alpha_one", 1.0), diff --git a/tests/rl/ppo/ppo_learner_test.py b/tests/rl/ppo/ppo_learner_test.py index c9a59de23..11c9c1c13 100644 --- a/tests/rl/ppo/ppo_learner_test.py +++ b/tests/rl/ppo/ppo_learner_test.py @@ -360,7 +360,7 @@ def test_ppo_learner( actor=model, critic=value_model, reference=ref_model, - reward=reward_model if use_reward_model else None, # pylint: disable=undefined-variable + reward=reward_model if use_reward_model else None, # pylint: disable=undefined-variable # pyrefly: ignore[unbound-name] tokenizer=vocab, cluster_config=cluster_config, ) @@ -421,11 +421,11 @@ def test_ppo_learner( expected_metrics.append('pg_clipfrac_lower') for metric_name in expected_metrics: self.assertLen( - actor_metric_logger.get_metric_history('actor', metric_name, 'train'), + actor_metric_logger.get_metric_history('actor', metric_name, 'train'), # pyrefly: ignore[missing-attribute] ppo_learner._iter_steps, ) self.assertLen( - actor_metric_logger.get_metric_history('actor', metric_name, 'eval'), + actor_metric_logger.get_metric_history('actor', metric_name, 'eval'), # pyrefly: ignore[missing-attribute] ppo_learner.rl_cluster.actor_trainer.train_steps / cluster_config.training_config.eval_every_n_steps, msg=f'metric_name: {metric_name}', @@ -434,13 +434,13 @@ def test_ppo_learner( critic_metric_logger = ppo_learner.rl_cluster.critic_trainer.metrics_logger for metric_name in ['loss', 'vpred_mean', 'vf_clipfrac']: self.assertLen( - critic_metric_logger.get_metric_history( + critic_metric_logger.get_metric_history( # pyrefly: ignore[missing-attribute] 'critic', metric_name, 'train' ), ppo_learner.rl_cluster.critic_trainer.train_steps, ) self.assertLen( - critic_metric_logger.get_metric_history( + critic_metric_logger.get_metric_history( # pyrefly: ignore[missing-attribute] 'critic', metric_name, 'eval' ), ppo_learner.rl_cluster.critic_trainer.train_steps diff --git a/tunix/cli/config.py b/tunix/cli/config.py index 5b0aeb631..b9b6ed568 100644 --- a/tunix/cli/config.py +++ b/tunix/cli/config.py @@ -201,7 +201,7 @@ def __init__(self, argv: list[str], **kwargs): base_config_file = pathlib.Path(__file__).parent / argv[1] else: base_config_file = argv[1] - raw_data_from_yaml = self._load_config_from_yaml(base_config_file) + raw_data_from_yaml = self._load_config_from_yaml(base_config_file) # pyrefly: ignore[bad-argument-type] self._validate_env_variable(raw_data_from_yaml) base_model_config = raw_data_from_yaml.get("model_config", {}) @@ -649,7 +649,7 @@ def create_optimizer( # Handle learning rate, potentially creating a schedule learning_rate_val = self._create_learning_rate( - optimizer_config, config_path_info + optimizer_config, config_path_info # pyrefly: ignore[bad-argument-type] ) if learning_rate_val is None and ( "learning_rate" in inspect.signature(opt_func).parameters @@ -664,7 +664,7 @@ def create_optimizer( ) opt_kwargs = self._extract_kwargs( - opt_func, optimizer_config, config_path_info, learning_rate_val + opt_func, optimizer_config, config_path_info, learning_rate_val # pyrefly: ignore[bad-argument-type] ) # Wrap the optimizer function with inject_hyperparams so that # the learning rate can be tracked and logged during training. @@ -870,7 +870,7 @@ def _update_from_env_and_command_line( updated_keys = [] # Check for conflicts and unknown keys. - for k in raw_data_from_cmd_line: + for k in raw_data_from_cmd_line: # pyrefly: ignore[not-iterable] if not k: continue if k not in raw_data_from_yaml: @@ -882,7 +882,7 @@ def _update_from_env_and_command_line( for k in raw_data_from_yaml: # Error out if same key defined in cmd line and environment - if k in raw_data_from_cmd_line and yaml_key_to_env_key(k) in os.environ: + if k in raw_data_from_cmd_line and yaml_key_to_env_key(k) in os.environ: # pyrefly: ignore[not-iterable] raise ValueError( f"You are passing overrides by both CLI and ENV for `{k}`. This" " isn't allowed." @@ -891,7 +891,7 @@ def _update_from_env_and_command_line( # Take value from base config yaml if key is not specified in command line # or environment. if ( - k not in raw_data_from_cmd_line + k not in raw_data_from_cmd_line # pyrefly: ignore[not-iterable] and yaml_key_to_env_key(k) not in os.environ ): # take the config value from the YAML file. @@ -902,8 +902,8 @@ def _update_from_env_and_command_line( updated_keys.append(k) # take updated value from command line or enviornment - if k in raw_data_from_cmd_line: - new_proposal = raw_data_from_cmd_line[k] + if k in raw_data_from_cmd_line: # pyrefly: ignore[not-iterable] + new_proposal = raw_data_from_cmd_line[k] # pyrefly: ignore[bad-index, unsupported-operation] else: new_proposal = os.environ.get(yaml_key_to_env_key(k)) @@ -991,7 +991,7 @@ def update_dict(self, schema: dict[str, Any], source: dict[str, Any]): ): # Both are dictionaries, so recurse. # The recursive call uses the same self.replace_keys instance. - output[key] = self.update_dict(output_val, source_val) + output[key] = self.update_dict(output_val, source_val) # pyrefly: ignore[bad-argument-type] else: # Otherwise (not both dictionaries), the source value overwrites. output[key] = copy.deepcopy(source_val) @@ -1103,7 +1103,7 @@ def reward_fn( for name, member in inspect.getmembers(module): if inspect.isfunction(member) and not name.startswith("_"): # Check if the function was defined in this module - if member.__module__ == module_name: + if member.__module__ == module_name: # pyrefly: ignore[unbound-name] defined_functions.append(member) reward_fns.extend(defined_functions) return reward_fns diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py index 3e9a50e01..b574ac7ba 100644 --- a/tunix/cli/grpo_main.py +++ b/tunix/cli/grpo_main.py @@ -375,7 +375,7 @@ def _rollout_engine_extra( "rollout_vllm_max_num_batched_tokens", vllm.get( "max_num_batched_tokens", - (max_num_seqs * kv_cache_size) // 4, + (max_num_seqs * kv_cache_size) // 4, # pyrefly: ignore[unsupported-operation] ), ) submission_threshold = rollout_cfg.get( @@ -655,7 +655,7 @@ def _get_dataset(self, tokenizer): data_source=self.config["data_source"], dataset=self.config["dataset_name"], tfds_download=self.config["tfds_download"], - split=self.config.get( + split=self.config.get( # pyrefly: ignore[bad-argument-type] "train_split", self.config.get("split", "train") ), apply_chat_template_to_dataset=apply_chat_template_to_dataset, @@ -665,7 +665,7 @@ def _get_dataset(self, tokenizer): data_source=self.config["data_source"], dataset=self.config["dataset_name"], tokenizer=tokenizer, - split=self.config.get( + split=self.config.get( # pyrefly: ignore[bad-argument-type] "train_split", self.config.get("split", "train") ), apply_chat_template_to_dataset=apply_chat_template_to_dataset, diff --git a/tunix/cli/utils/data.py b/tunix/cli/utils/data.py index 9aa62fb9d..b05909a6c 100644 --- a/tunix/cli/utils/data.py +++ b/tunix/cli/utils/data.py @@ -120,7 +120,7 @@ def get_dataset_from_module( if os.path.exists(specifier) and specifier.endswith(".py"): module_name = os.path.splitext(os.path.basename(specifier))[0] spec = importlib.util.spec_from_file_location(module_name, specifier) - module = importlib.util.module_from_spec(spec) + module = importlib.util.module_from_spec(spec) # pyrefly: ignore[bad-argument-type] if spec is None: raise ImportError(f"Failed to create spec for {specifier}") diff --git a/tunix/cli/utils/model.py b/tunix/cli/utils/model.py index 38749257c..04b21a903 100644 --- a/tunix/cli/utils/model.py +++ b/tunix/cli/utils/model.py @@ -70,10 +70,10 @@ def apply_lora_to_model(base_model, mesh, lora_config, rng_seed=0): ) if original_remat is not None: - lora_model.config.remat_config = original_remat + lora_model.config.remat_config = original_remat # pyrefly: ignore[missing-attribute] if mesh is not None: - lora_model = reshard.reshard_model_to_mesh(lora_model, mesh) + lora_model = reshard.reshard_model_to_mesh(lora_model, mesh) # pyrefly: ignore[bad-argument-type] return lora_model diff --git a/tunix/distillation/distillation_trainer.py b/tunix/distillation/distillation_trainer.py index 57c9b3944..d057ff95a 100644 --- a/tunix/distillation/distillation_trainer.py +++ b/tunix/distillation/distillation_trainer.py @@ -64,8 +64,8 @@ def __init__( super().__init__(student_model, optimizer, training_config) self.strategy = strategy self.teacher_model = teacher_model - self.loss_fn = self.get_train_loss - self.eval_loss_fn = self.get_eval_loss + self.loss_fn = self.get_train_loss # pyrefly: ignore[bad-assignment] + self.eval_loss_fn = self.get_eval_loss # pyrefly: ignore[bad-assignment] # Since distillation strategies return (loss, metrics), # we explicitly enable auxiliary metric handling. @@ -147,7 +147,7 @@ def get_train_loss( teacher_output: Any, inputs: dict[str, ArrayLike], ) -> tuple[ArrayLike, dict[str, Any]]: - output = self.strategy.get_train_loss(model, teacher_output, inputs) + output = self.strategy.get_train_loss(model, teacher_output, inputs) # pyrefly: ignore[bad-argument-type] return self._standardize_loss_output(output) def get_eval_loss( @@ -157,7 +157,7 @@ def get_eval_loss( inputs: dict[str, ArrayLike], ) -> tuple[ArrayLike, dict[str, Any]]: del teacher_output # Not computed in eval. - output = self.strategy.get_eval_loss(model, inputs) + output = self.strategy.get_eval_loss(model, inputs) # pyrefly: ignore[bad-argument-type] return self._standardize_loss_output(output) def close(self): diff --git a/tunix/distillation/feature_extraction/projection.py b/tunix/distillation/feature_extraction/projection.py index d4c1ad44d..5ec203e24 100644 --- a/tunix/distillation/feature_extraction/projection.py +++ b/tunix/distillation/feature_extraction/projection.py @@ -52,7 +52,7 @@ def __init__( self.projection_layer = nnx.LinearGeneral( feature_shape, feature_target_shape, - axis=np.arange(len(feature_shape)), + axis=np.arange(len(feature_shape)), # pyrefly: ignore[bad-argument-type] rngs=rngs, ) diff --git a/tunix/distillation/feature_extraction/sowed_module.py b/tunix/distillation/feature_extraction/sowed_module.py index 4cb7796ff..5754d4bdd 100644 --- a/tunix/distillation/feature_extraction/sowed_module.py +++ b/tunix/distillation/feature_extraction/sowed_module.py @@ -96,7 +96,7 @@ def wrap_model_with_sowed_modules( if isinstance(part, str): current_parent_obj = getattr(current_parent_obj, part) elif isinstance(part, int): - current_parent_obj = current_parent_obj[part] + current_parent_obj = current_parent_obj[part] # pyrefly: ignore[bad-index] else: raise TypeError( f"Unsupported path part type: {type(part)}. Path: {path}" @@ -112,7 +112,7 @@ def wrap_model_with_sowed_modules( setattr(current_parent_obj, last_key, wrapped_instance) elif isinstance(last_key, int): # If the parent is a sequence, try to modify in-place - current_parent_obj[last_key] = wrapped_instance + current_parent_obj[last_key] = wrapped_instance # pyrefly: ignore[unsupported-operation] else: raise TypeError( f"Unsupported key type for replacement: {type(last_key)}. Path:" @@ -160,7 +160,7 @@ def unwrap_sowed_modules(model: nnx.Module): if isinstance(part, str): current_parent_obj = getattr(current_parent_obj, part) elif isinstance(part, int): - current_parent_obj = current_parent_obj[part] + current_parent_obj = current_parent_obj[part] # pyrefly: ignore[bad-index] else: raise TypeError( f"Unsupported path part type: {type(part)}. Path: {path}" @@ -171,7 +171,7 @@ def unwrap_sowed_modules(model: nnx.Module): if isinstance(last_key, str): setattr(current_parent_obj, last_key, original_module) elif isinstance(last_key, int): - current_parent_obj[last_key] = original_module + current_parent_obj[last_key] = original_module # pyrefly: ignore[unsupported-operation] else: raise TypeError( f"Unsupported key type for replacement: {type(last_key)}. Path:" diff --git a/tunix/distillation/strategies/feature_pooling.py b/tunix/distillation/strategies/feature_pooling.py index a522a6203..6b35b3cb9 100644 --- a/tunix/distillation/strategies/feature_pooling.py +++ b/tunix/distillation/strategies/feature_pooling.py @@ -183,7 +183,7 @@ def compute_loss( teacher_features = feature_extraction.avg_pool_array_to_target_shape( teacher_output, student_features.shape ) - feature_loss = self.feature_loss_fn(student_features, teacher_features) + feature_loss = self.feature_loss_fn(student_features, teacher_features) # pyrefly: ignore[not-callable] # Calculate Task Loss (Cross-Entropy) ce_loss_per_example = optax.softmax_cross_entropy( diff --git a/tunix/distillation/strategies/feature_projection.py b/tunix/distillation/strategies/feature_projection.py index da5dfa43e..7a11af5cb 100644 --- a/tunix/distillation/strategies/feature_projection.py +++ b/tunix/distillation/strategies/feature_projection.py @@ -182,7 +182,7 @@ def compute_loss( student_logits = student_output["logits"] student_features = student_output["features"] - feature_loss = self.feature_loss_fn(student_features, teacher_output) + feature_loss = self.feature_loss_fn(student_features, teacher_output) # pyrefly: ignore[not-callable] # Calculate Task Loss (Cross-Entropy) ce_loss_per_example = optax.softmax_cross_entropy( diff --git a/tunix/examples/data/math_dataset.py b/tunix/examples/data/math_dataset.py index b32016f07..593415447 100644 --- a/tunix/examples/data/math_dataset.py +++ b/tunix/examples/data/math_dataset.py @@ -150,7 +150,7 @@ def get_huggingface_dataset( split=split, ) data = data.shuffle(seed=shuffle_seed) - return grain.MapDataset.source(data) + return grain.MapDataset.source(data) # pyrefly: ignore[bad-argument-type] def create_dataset( diff --git a/tunix/examples/data/translation_dataset.py b/tunix/examples/data/translation_dataset.py index ad19d1aa5..2e1bf1bbb 100644 --- a/tunix/examples/data/translation_dataset.py +++ b/tunix/examples/data/translation_dataset.py @@ -70,7 +70,7 @@ def create_datasets( ) elif dataset_name == "Helsinki-NLP/opus-100": # Hugging Face dataloader train_ds, eval_ds = datasets.load_dataset( - dataset_name, data_dir="en-fr", split=("train", "validation") + dataset_name, data_dir="en-fr", split=("train", "validation") # pyrefly: ignore[bad-argument-type] ) else: raise ValueError(f"Unsupported dataset: {dataset_name}") @@ -78,7 +78,7 @@ def create_datasets( input_template = INPUT_TEMPLATE_IT if instruct_tuned else INPUT_TEMPLATE train_loader = _build_data_loader( - data_source=train_ds, + data_source=train_ds, # pyrefly: ignore[bad-argument-type] batch_size=global_batch_size, num_epochs=num_train_epochs, max_seq_len=max_target_length, @@ -86,7 +86,7 @@ def create_datasets( input_template=input_template, ) eval_loader = _build_data_loader( - data_source=eval_ds, + data_source=eval_ds, # pyrefly: ignore[bad-argument-type] batch_size=global_batch_size, num_epochs=1, max_seq_len=max_target_length, @@ -168,7 +168,7 @@ def map(self, tokens: tuple[np.ndarray, np.ndarray]) -> TrainingInput: # The input sequence fed to the model is simply the concatenation of the # source and the destination. - tokens = np.concat([src_tokens, dst_tokens], axis=0) + tokens = np.concat([src_tokens, dst_tokens], axis=0) # pyrefly: ignore[bad-assignment] # To prevent the model from updating based on the source (input) # tokens, add a target mask to each input. @@ -178,12 +178,12 @@ def map(self, tokens: tuple[np.ndarray, np.ndarray]) -> TrainingInput: # If the input tokens sequence is smaller than the target sequence size, # then pad it with pad tokens. - tokens = self._pad_up_to_max_len(tokens, self._pad_value) + tokens = self._pad_up_to_max_len(tokens, self._pad_value) # pyrefly: ignore[bad-argument-type, bad-assignment] # Don't want to perform the backward pass on the pad tokens. mask = self._pad_up_to_max_len(mask, 0) - return TrainingInput(input_tokens=tokens, input_mask=mask) + return TrainingInput(input_tokens=tokens, input_mask=mask) # pyrefly: ignore[bad-argument-type] def _pad_up_to_max_len( self, input_tensor: np.ndarray, pad_value: int diff --git a/tunix/models/automodel.py b/tunix/models/automodel.py index 11c8cf405..038dd7893 100644 --- a/tunix/models/automodel.py +++ b/tunix/models/automodel.py @@ -95,13 +95,13 @@ def call_model_config(model_name: str) -> Any: model_lib_module = get_model_module(model_name, ModelModule.MODEL) target_obj = model_lib_module.ModelConfig - if not hasattr(target_obj, config_id): + if not hasattr(target_obj, config_id): # pyrefly: ignore[bad-argument-type] raise AttributeError( f"Error: Function '{config_id}' not found on the target object " f"for model '{model_name}'. Target object type: {type(target_obj)}" ) - method_to_call = getattr(target_obj, config_id) + method_to_call = getattr(target_obj, config_id) # pyrefly: ignore[bad-argument-type] if not callable(method_to_call): raise TypeError( @@ -198,7 +198,7 @@ def _nnx_convert_and_reload() -> tuple[nnx.Module, Any]: else: # gemma dir_name = version_dashed - params_path = os.path.join(ckpt_path, dir_name) + params_path = os.path.join(ckpt_path, dir_name) # pyrefly: ignore[no-matching-overload] model, params = create_gemma_model_from_params(params_path, model_name) @@ -310,11 +310,11 @@ def download_model( if model_source == ModelSource.KAGGLE: from tunix.oss import utils as oss_utils # pylint: disable=g-import-not-at-top - return oss_utils.kaggle_pipeline(model_id_or_path, model_download_path) + return oss_utils.kaggle_pipeline(model_id_or_path, model_download_path) # pyrefly: ignore[bad-argument-type] elif model_source == ModelSource.HUGGINGFACE: from tunix.oss import utils as oss_utils # pylint: disable=g-import-not-at-top - return oss_utils.hf_pipeline(model_id_or_path, model_download_path) + return oss_utils.hf_pipeline(model_id_or_path, model_download_path) # pyrefly: ignore[bad-argument-type] elif model_source in (ModelSource.GCS, ModelSource.MAXTEXT): return model_id_or_path elif model_source == ModelSource.INTERNAL: @@ -428,7 +428,7 @@ def from_pretrained( The path where the model was downloaded to. """ # TODO(b/477915179): Allow model_id to be config_id or a Kaggle_id - model: nnx.Module = None + model: nnx.Module = None # pyrefly: ignore[bad-assignment] model_params: Any = None naming_info = naming.ModelNaming(model_id=model_id) @@ -499,7 +499,7 @@ def from_pretrained( if model_source in (ModelSource.GCS, ModelSource.INTERNAL): model, model_params = create_gemma3_model_from_checkpoint( ckpt_path=resolved_model_path, - model_name=naming_info.model_name, + model_name=naming_info.model_name, # pyrefly: ignore[bad-argument-type] mesh=mesh, **kwargs, ) @@ -516,7 +516,7 @@ def from_pretrained( # Name is legacy — dynamically resolves to gemma4 via ModelNaming. model, model_params = create_gemma3_model_from_checkpoint( ckpt_path=resolved_model_path, - model_name=naming_info.model_name, + model_name=naming_info.model_name, # pyrefly: ignore[bad-argument-type] mesh=mesh, **kwargs, ) @@ -539,16 +539,16 @@ def from_pretrained( intermediate_ckpt_dir = kwargs.get('intermediate_ckpt_dir') rng_seed = kwargs.get('rng_seed', 0) model, model_params = create_gemma_model_with_nnx_conversion( - model_name=naming_info.model_name, + model_name=naming_info.model_name, # pyrefly: ignore[bad-argument-type] ckpt_path=resolved_model_path, - intermediate_ckpt_dir=intermediate_ckpt_dir, + intermediate_ckpt_dir=intermediate_ckpt_dir, # pyrefly: ignore[bad-argument-type] rng_seed=rng_seed, mesh=mesh, model_path=model_path, ) elif model_source == ModelSource.INTERNAL: model, model_params = create_gemma_model_from_params( - params_path=resolved_model_path, model_name=naming_info.model_name + params_path=resolved_model_path, model_name=naming_info.model_name # pyrefly: ignore[bad-argument-type] ) else: raise NotImplementedError( @@ -566,7 +566,7 @@ def from_pretrained( # Common path for all other native Tunix models -- create model from safe tensors if not model_params: # pick corresponding config based on model version - model_params = call_model_config(naming_info.model_name) + model_params = call_model_config(naming_info.model_name) # pyrefly: ignore[bad-argument-type] # Apply any model config field overrides passed via kwargs (e.g. # use_flash_attention, flash_attention_block_size). @@ -579,7 +579,7 @@ def from_pretrained( with mesh: model = create_model_from_safe_tensors( - naming_info.model_name, + naming_info.model_name, # pyrefly: ignore[bad-argument-type] resolved_model_path, model_params, mesh, diff --git a/tunix/models/gemma/model.py b/tunix/models/gemma/model.py index 97bd94fc7..6e8c5b40f 100644 --- a/tunix/models/gemma/model.py +++ b/tunix/models/gemma/model.py @@ -243,7 +243,7 @@ def __init__( def encode(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: x = self.input_embedding[(x,)] x *= jnp.sqrt(x.shape[-1]).astype(x.dtype) - x = shard(x, self.shd_config.act_btd) + x = shard(x, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] return x @jax.named_scope('embedder_decode') @@ -409,9 +409,9 @@ def block( query_proj = self.q_einsum(x) key_proj, value_proj = self.kv_einsum(x) - query_proj = shard(query_proj, self.shd_config.act_btnh) - key_proj = shard(key_proj, self.shd_config.act_btnh) - value_proj = shard(value_proj, self.shd_config.act_btnh) + query_proj = shard(query_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] + key_proj = shard(key_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] + value_proj = shard(value_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] query_proj = apply_rope( query_proj, @@ -460,7 +460,7 @@ def block( sliding_mask = _create_sliding_mask( segment_pos, cache_len=attn_mask.shape[-1], - sliding_window_size=self.sliding_window_size, + sliding_window_size=self.sliding_window_size, # pyrefly: ignore[bad-argument-type] ) attn_mask = sliding_mask * attn_mask @@ -482,7 +482,7 @@ def block( encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) attn_output = self.attn_vec_einsum(encoded) - attn_output = shard(attn_output, self.shd_config.act_btd) + attn_output = shard(attn_output, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] if cache is not None: new_cache = { @@ -585,12 +585,12 @@ def __init__( @jax.named_scope('feed_forward') def block(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: - ff_gate = self.gate_proj(x) + ff_gate = self.gate_proj(x) # pyrefly: ignore[bad-argument-type] gate_value = nnx.gelu(ff_gate) - ff1 = self.up_proj(x) + ff1 = self.up_proj(x) # pyrefly: ignore[bad-argument-type] activations = gate_value * ff1 - activations = shard(activations, self.config.shd_config.act_btf) + activations = shard(activations, self.config.shd_config.act_btf) # pyrefly: ignore[bad-argument-type] outputs = self.down_proj(activations) return outputs @@ -700,7 +700,7 @@ def __init__( shd_config: ShardingConfig = ShardingConfig.get_default_sharding(), ): self.scale = nnx.Param( - nnx.initializers.zeros_init()(rngs.params(), dim), + nnx.initializers.zeros_init()(rngs.params(), dim), # pyrefly: ignore[bad-argument-type] sharding=shd_config.rms_norm_weight, ) @@ -776,7 +776,7 @@ def assign_val_fn( mapped_path: tuple[str | int, ...], val: Any, ) -> dict[tuple[str, ...], Any]: - state[mapped_path].value = val + state[mapped_path].value = val # pyrefly: ignore[bad-index] return state mdl: nnx.Module = nnx.eval_shape(module_factory) @@ -828,10 +828,10 @@ def _assign_linen_params_to_nnx_state( val: Any, ) -> dict[tuple[str, ...], Any]: if 'gate_proj' in mapped_path: - state[mapped_path].value = val[0] - state[mapped_path[:-2] + ('up_proj', 'kernel')].value = val[1] + state[mapped_path].value = val[0] # pyrefly: ignore[bad-index] + state[mapped_path[:-2] + ('up_proj', 'kernel')].value = val[1] # pyrefly: ignore[bad-index] else: - state[mapped_path].value = val + state[mapped_path].value = val # pyrefly: ignore[bad-index] return state @@ -855,7 +855,7 @@ def from_params(cls, params: params_lib.Params, version: str) -> 'Gemma': except AttributeError as exc: raise ValueError(f'Unsupported version: {version}') from exc - return module_from_linen_variables( + return module_from_linen_variables( # pyrefly: ignore[bad-return] module_factory=lambda: cls(config, rngs=nnx.Rngs(params=0)), variables=params['transformer'], map_key_fn=_map_linen_var_names, diff --git a/tunix/models/gemma/sampler.py b/tunix/models/gemma/sampler.py index 5d6771b15..ce568319b 100644 --- a/tunix/models/gemma/sampler.py +++ b/tunix/models/gemma/sampler.py @@ -159,7 +159,7 @@ def __init__( """ self.vocab = vocab self.cache_size = cache_size - self._transformer_graphdef: graph.NodeDef = nnx.graphdef(transformer) + self._transformer_graphdef: graph.NodeDef = nnx.graphdef(transformer) # pyrefly: ignore[bad-assignment] self._transformer_state: list[statelib.State] = nnx.variables(transformer) self._flattened_transformer_state: list[statelib.State] = jax.tree.leaves( self._transformer_state, @@ -173,7 +173,7 @@ def __init__( @property def transformer(self) -> gemma_lib.Gemma: - return nnx.merge( + return nnx.merge( # pyrefly: ignore[no-matching-overload] self._transformer_graphdef, self._flattened_transformer_state ) @@ -359,7 +359,7 @@ def _prefill_fn( input_mask = tokens != self.vocab.pad_id() attention_mask = make_causal_attn_mask(input_mask, self.cache_size) - transformer = nnx.merge(self._transformer_graphdef, params) + transformer = nnx.merge(self._transformer_graphdef, params) # pyrefly: ignore[no-matching-overload] logits, cache = transformer( tokens, step_positions, @@ -377,7 +377,7 @@ def _prefill_fn( key, sampler_state.temperature, sampler_state.top_p, - sampler_state.top_k, + sampler_state.top_k, # pyrefly: ignore[bad-argument-type] ) else: next_token_candidate = sample_best(logits) @@ -450,7 +450,7 @@ def _sample_step( decoding_step, self.cache_size, input_mask ) - transformer = nnx.merge(self._transformer_graphdef, params) + transformer = nnx.merge(self._transformer_graphdef, params) # pyrefly: ignore[no-matching-overload] logits, cache = transformer( last_token, step_positions, @@ -467,7 +467,7 @@ def _sample_step( key, sampler_state.temperature, sampler_state.top_p, - sampler_state.top_k, + sampler_state.top_k, # pyrefly: ignore[bad-argument-type] ) else: next_token_candidate = sample_best(logits) @@ -571,9 +571,9 @@ def __call__( total_sampling_steps = max_prompt_length + max_generation_steps if seed is None: - seed = jax.random.PRNGKey(0) + seed = jax.random.PRNGKey(0) # pyrefly: ignore[bad-assignment] elif isinstance(seed, int): - seed = jax.random.PRNGKey(seed) + seed = jax.random.PRNGKey(seed) # pyrefly: ignore[bad-assignment] sampling_state = self.init_sample_state( all_input_ids, include_logits=return_logits, @@ -581,8 +581,8 @@ def __call__( forbidden_token_ids=forbidden_token_ids, temperature=temperature, top_p=top_p, - top_k=top_k, - seed=seed, + top_k=top_k, # pyrefly: ignore[bad-argument-type] + seed=seed, # pyrefly: ignore[bad-argument-type] ) sampling_state = self._compiled_prefill_fn( self._flattened_transformer_state, sampling_state diff --git a/tunix/models/llama3/model.py b/tunix/models/llama3/model.py index b74f5510a..384ca7656 100644 --- a/tunix/models/llama3/model.py +++ b/tunix/models/llama3/model.py @@ -235,7 +235,7 @@ def __init__( @jax.named_scope('embedder_encode') def encode(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: x = self.input_embedding[(x,)] - x = shard(x, self.shd_config.act_btd) + x = shard(x, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] return x @jax.named_scope('embedder_decode') @@ -279,7 +279,7 @@ def __init__( shd_config: ShardingConfig = ShardingConfig.get_default_sharding(), ): self.w = nnx.Param( - nnx.initializers.ones_init()(rngs.params(), dim), + nnx.initializers.ones_init()(rngs.params(), dim), # pyrefly: ignore[bad-argument-type] sharding=shd_config.rms_norm_weight, ) self.norm_eps = norm_eps @@ -346,9 +346,9 @@ def block( key_proj = self.k_proj(x) value_proj = self.v_proj(x) - query_proj = shard(query_proj, self.shd_config.act_btnh) - key_proj = shard(key_proj, self.shd_config.act_btnh) - value_proj = shard(value_proj, self.shd_config.act_btnh) + query_proj = shard(query_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] + key_proj = shard(key_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] + value_proj = shard(value_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] query_proj = apply_rope( query_proj, @@ -393,7 +393,7 @@ def block( qkv = qkv.reshape((b, t, qh, d)) outputs = self.o_proj(qkv) - outputs = shard(outputs, self.shd_config.act_btd) + outputs = shard(outputs, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] if cache is not None: new_cache = { @@ -484,7 +484,7 @@ def block( x: jaxtyping.Array, ) -> jaxtyping.Array: activations = nnx.silu(self.gate_proj(x)) * self.up_proj(x) - activations = shard(activations, self.shd_config.act_btf) + activations = shard(activations, self.shd_config.act_btf) # pyrefly: ignore[bad-argument-type] outputs = self.down_proj(activations) return outputs @@ -496,7 +496,7 @@ def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: ): return nnx.remat(self.block.__func__, graph_updates=False)(self, x) else: - return self.block(x) + return self.block(x) # pyrefly: ignore[bad-argument-type] class DecoderLayer(nnx.Module): diff --git a/tunix/models/qwen2/model.py b/tunix/models/qwen2/model.py index 4eecdf382..620b75e28 100644 --- a/tunix/models/qwen2/model.py +++ b/tunix/models/qwen2/model.py @@ -73,13 +73,13 @@ def get_default_sharding(is_sampling: bool = False, enable_sp: bool = False): fsdp = (fsdp, sp) if fsdp and sp else fsdp return ShardingConfig( - emb_vd=('tp', fsdp), - emb_dv=(fsdp, 'tp'), - q_weight_dnh=(fsdp, 'tp', None), - kv_weight_dnh=(fsdp, 'tp', None), - o_weight_nhd=('tp', None, fsdp), - ffw_weight_df=(fsdp, 'tp'), - ffw_weight_fd=('tp', fsdp), + emb_vd=('tp', fsdp), # pyrefly: ignore[bad-argument-type] + emb_dv=(fsdp, 'tp'), # pyrefly: ignore[bad-argument-type] + q_weight_dnh=(fsdp, 'tp', None), # pyrefly: ignore[bad-argument-type] + kv_weight_dnh=(fsdp, 'tp', None), # pyrefly: ignore[bad-argument-type] + o_weight_nhd=('tp', None, fsdp), # pyrefly: ignore[bad-argument-type] + ffw_weight_df=(fsdp, 'tp'), # pyrefly: ignore[bad-argument-type] + ffw_weight_fd=('tp', fsdp), # pyrefly: ignore[bad-argument-type] rms_norm_weight=('tp',), act_btd=('fsdp', sp, None if is_sampling else 'tp'), act_btf=('fsdp', sp, 'tp'), @@ -300,7 +300,7 @@ def __init__( def encode(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: x = self.input_embedding[(x,)] x = jnp.astype(x, self.dtype) - x = shard(x, self.shd_config.act_btd) + x = shard(x, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] return x @jax.named_scope('embedder_decode') @@ -485,9 +485,9 @@ def block( ) value_proj = jnp.reshape(value_proj, (b, s, k, h)) - query_proj = shard(query_proj, self.shd_config.act_btnh) - key_proj = shard(key_proj, self.shd_config.act_btnh) - value_proj = shard(value_proj, self.shd_config.act_btnh) + query_proj = shard(query_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] + key_proj = shard(key_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] + value_proj = shard(value_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] query_proj = apply_rotary_embedding( query_proj, @@ -641,7 +641,7 @@ def sharded_splash_attn(kernel, q_block, k_block, v_block): qkv = qkv.reshape((b, t, qh, d)) outputs = self.o_proj(qkv) - outputs = shard(outputs, self.shd_config.act_btd) + outputs = shard(outputs, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] if cache is not None: new_cache = { @@ -740,7 +740,7 @@ def block( x: jaxtyping.Array, ) -> jaxtyping.Array: activations = nnx.silu(self.gate_proj(x)) * self.up_proj(x) - activations = shard(activations, self.shd_config.act_btf) + activations = shard(activations, self.shd_config.act_btf) # pyrefly: ignore[bad-argument-type] outputs = self.down_proj(activations) return outputs @@ -752,7 +752,7 @@ def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: ): return nnx.remat(self.block.__func__, graph_updates=False)(self, x) else: - return self.block(x) + return self.block(x) # pyrefly: ignore[bad-argument-type] class DecoderLayer(nnx.Module): @@ -926,7 +926,7 @@ def __call__( layer_cache, x = layer( x, layer_cache, - attention_mask, + attention_mask, # pyrefly: ignore[bad-argument-type] sin, cos, segment_ids=segment_ids, diff --git a/tunix/models/qwen3/model.py b/tunix/models/qwen3/model.py index 078acfc26..4011644df 100644 --- a/tunix/models/qwen3/model.py +++ b/tunix/models/qwen3/model.py @@ -114,13 +114,13 @@ def get_default_sharding(is_sampling: bool = False, enable_sp: bool = False): fsdp = (fsdp, sp) if fsdp and sp else fsdp return ShardingConfig( - emb_vd=('tp', fsdp), - emb_dv=(fsdp, 'tp'), - q_weight_dnh=(fsdp, 'tp', None), - kv_weight_dnh=(fsdp, 'tp', None), - o_weight_nhd=('tp', None, fsdp), - ffw_weight_df=(fsdp, 'tp'), - ffw_weight_fd=('tp', fsdp), + emb_vd=('tp', fsdp), # pyrefly: ignore[bad-argument-type] + emb_dv=(fsdp, 'tp'), # pyrefly: ignore[bad-argument-type] + q_weight_dnh=(fsdp, 'tp', None), # pyrefly: ignore[bad-argument-type] + kv_weight_dnh=(fsdp, 'tp', None), # pyrefly: ignore[bad-argument-type] + o_weight_nhd=('tp', None, fsdp), # pyrefly: ignore[bad-argument-type] + ffw_weight_df=(fsdp, 'tp'), # pyrefly: ignore[bad-argument-type] + ffw_weight_fd=('tp', fsdp), # pyrefly: ignore[bad-argument-type] rms_norm_weight=('tp',), act_btd=('fsdp', sp, None if is_sampling else 'tp'), act_btf=('fsdp', sp, 'tp'), @@ -356,7 +356,7 @@ def __init__( def encode(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: x = self.input_embedding[(x,)] x = jnp.astype(x, self.dtype) - x = shard(x, self.shd_config.act_btd) + x = shard(x, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] return x @jax.named_scope('embedder_decode') @@ -404,7 +404,7 @@ def __init__( param_dtype: jnp.dtype, ): self.w = nnx.Param( - nnx.initializers.ones_init()(rngs.params(), dim, param_dtype), + nnx.initializers.ones_init()(rngs.params(), dim, param_dtype), # pyrefly: ignore[bad-argument-type] sharding=shd_config.rms_norm_weight, ) self.norm_eps = norm_eps @@ -495,9 +495,9 @@ def block( key_proj = self.k_norm(self.k_proj(x)) value_proj = self.v_proj(x) - query_proj = shard(query_proj, self.shd_config.act_btnh) - key_proj = shard(key_proj, self.shd_config.act_btnh) - value_proj = shard(value_proj, self.shd_config.act_btnh) + query_proj = shard(query_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] + key_proj = shard(key_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] + value_proj = shard(value_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] query_proj = apply_rope( query_proj, @@ -643,7 +643,7 @@ def sharded_splash_attn(kernel, q_block, k_block, v_block): qkv = qkv.reshape((b, t, qh, d)) outputs = self.o_proj(qkv) - outputs = shard(outputs, self.shd_config.act_btd) + outputs = shard(outputs, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] if cache is not None: new_cache = { @@ -704,7 +704,7 @@ def __init__( self.num_experts = config.num_experts self.router = nnx.Linear( in_features=config.embed_dim, - out_features=config.num_experts, + out_features=config.num_experts, # pyrefly: ignore[bad-argument-type] use_bias=False, rngs=rngs, dtype=config.dtype, @@ -736,7 +736,7 @@ def __init__( def __call__(self, x, use_megablox=True): scores = self.router(x).astype(jnp.float32) # [B,T,E] routing_weights, routing_idx = jax.lax.top_k( - jax.nn.softmax(scores, axis=-1), self.experts_per_tok + jax.nn.softmax(scores, axis=-1), self.experts_per_tok # pyrefly: ignore[bad-argument-type] ) routing_weights = ( routing_weights / jnp.sum(routing_weights, axis=-1, keepdims=True) @@ -751,7 +751,7 @@ def __call__(self, x, use_megablox=True): # ------------------------------------------------------------- if not use_megablox or (mesh.empty or jax.devices()[0].platform == 'cpu'): dispatch_mask = jax.nn.one_hot( - routing_idx, num_classes=self.num_experts, dtype=self.dtype + routing_idx, num_classes=self.num_experts, dtype=self.dtype # pyrefly: ignore[bad-argument-type] ) # [B, T, K, E] dispatch_mask = jnp.swapaxes(dispatch_mask, -1, -2) # [B, T, E, K] dispatched_input = jnp.einsum( @@ -759,7 +759,7 @@ def __call__(self, x, use_megablox=True): ).astype(self.dtype) expert_outputs = [] - for i in range(self.num_experts): + for i in range(self.num_experts): # pyrefly: ignore[bad-argument-type] expert_input = dispatched_input[:, :, i, :] gate_proj = jnp.astype(self.gate_proj.value[i], self.dtype) up_proj = jnp.astype(self.up_proj.value[i], self.dtype) @@ -824,10 +824,10 @@ def sharded_megablox_moe(inputs, weights, indices, gate_w, up_w, down_w): num_ep = 1 ep_shard_idx = 0 - num_local_experts = self.num_experts // num_ep + num_local_experts = self.num_experts // num_ep # pyrefly: ignore[unsupported-operation] flat_repeated_inputs = jnp.repeat( - inputs.reshape(B * T, D_global), self.experts_per_tok, axis=0 + inputs.reshape(B * T, D_global), self.experts_per_tok, axis=0 # pyrefly: ignore[bad-argument-type] ) flat_selected_indices = indices.reshape(-1) @@ -860,7 +860,7 @@ def sharded_megablox_moe(inputs, weights, indices, gate_w, up_w, down_w): local_output_offsets = global_out_offsets[ep_shard_idx] output_buffer_size = ( - min(self.experts_per_tok, num_local_experts) * B * T * num_ep + min(self.experts_per_tok, num_local_experts) * B * T * num_ep # pyrefly: ignore[bad-specialization] ) output_buffer = jax.lax.empty( shape=(output_buffer_size, D_global), dtype=inputs.dtype @@ -939,9 +939,9 @@ def sharded_megablox_moe(inputs, weights, indices, gate_w, up_w, down_w): ) global_in_offsets, global_out_offsets = get_global_input_output_offsets( - global_send_sizes.T, num_ep # pylint: disable=undefined-variable + global_send_sizes.T, num_ep # pylint: disable=undefined-variable # pyrefly: ignore[unbound-name] ) - local_send_sizes, local_recv_sizes = local_recv_sizes, local_send_sizes # pylint: disable=undefined-variable + local_send_sizes, local_recv_sizes = local_recv_sizes, local_send_sizes # pylint: disable=undefined-variable # pyrefly: ignore[unbound-name] output_buffer = jax.lax.empty( shape=(B * T * self.experts_per_tok, D_global), dtype=inputs.dtype @@ -1033,7 +1033,7 @@ def block( x: jaxtyping.Array, ) -> jaxtyping.Array: activations = nnx.silu(self.gate_proj(x)) * self.up_proj(x) - activations = shard(activations, self.shd_config.act_btf) + activations = shard(activations, self.shd_config.act_btf) # pyrefly: ignore[bad-argument-type] outputs = self.down_proj(activations) return outputs @@ -1045,7 +1045,7 @@ def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: ): return nnx.remat(self.block.__func__, graph_updates=False)(self, x) else: - return self.block(x) + return self.block(x) # pyrefly: ignore[bad-argument-type] class DecoderLayer(nnx.Module): diff --git a/tunix/models/qwen3/params.py b/tunix/models/qwen3/params.py index d488f332f..d8b61336a 100644 --- a/tunix/models/qwen3/params.py +++ b/tunix/models/qwen3/params.py @@ -178,5 +178,5 @@ def save_lora_merged_model_as_safetensors( rank=rank, alpha=alpha, state_key_transform_fn=_qwen3_state_key_to_safetensors_key, - transpose_rules=_QWEN3_HUGGINGFACE_TRANSPOSE_RULES, + transpose_rules=_QWEN3_HUGGINGFACE_TRANSPOSE_RULES, # pyrefly: ignore[bad-argument-type] ) diff --git a/tunix/models/safetensors_loader.py b/tunix/models/safetensors_loader.py index 0f429623a..15efa910a 100644 --- a/tunix/models/safetensors_loader.py +++ b/tunix/models/safetensors_loader.py @@ -246,7 +246,7 @@ def process_key(k_name, f, sf_file, file_loaded_tensors): for future in concurrent.futures.as_completed(futures): if future.exception(): - raise future.exception() + raise future.exception() # pyrefly: ignore[bad-raise] # Apply preprocessing if provided (e.g., for MoE expert stacking) if preprocess_fn is not None: diff --git a/tunix/models/safetensors_saver.py b/tunix/models/safetensors_saver.py index c08049818..5acf88918 100644 --- a/tunix/models/safetensors_saver.py +++ b/tunix/models/safetensors_saver.py @@ -69,12 +69,12 @@ def save_lora_merged_model_as_safetensors( path_str = join_path(path[:-1]) if path_str in lora_layers: assert ( - 'lora_b' in path[-1] + 'lora_b' in path[-1] # pyrefly: ignore[not-iterable] ), f'Expect second LoRAParam to be lora_b, got {path[-1]}' lora_layers[path_str].append(value) else: assert ( - 'lora_a' in path[-1] + 'lora_a' in path[-1] # pyrefly: ignore[not-iterable] ), f'Expect first LoRAParam to be lora_a, got {path[-1]}' lora_layers[path_str] = [value] diff --git a/tunix/perf/experimental/timeline.py b/tunix/perf/experimental/timeline.py index b0967dc54..e19657d37 100644 --- a/tunix/perf/experimental/timeline.py +++ b/tunix/perf/experimental/timeline.py @@ -61,17 +61,17 @@ def add_tag(self, key: str, value: Any) -> None: key: The tag key. value: The tag value. """ - if key in self.tags: + if key in self.tags: # pyrefly: ignore[not-iterable] logging.warning( "Span '%s' (id=%s): Tag %r already exists with value %r." " Overwriting with %r.", self.name, self.id, key, - self.tags[key], + self.tags[key], # pyrefly: ignore[unsupported-operation] value, ) - self.tags[key] = value + self.tags[key] = value # pyrefly: ignore[unsupported-operation] def _format_relative(self, born_at: float) -> str: """Returns a string representation of the span with relative times. diff --git a/tunix/perf/experimental/trace_writer.py b/tunix/perf/experimental/trace_writer.py index a53fe6f3a..35b10044f 100644 --- a/tunix/perf/experimental/trace_writer.py +++ b/tunix/perf/experimental/trace_writer.py @@ -297,7 +297,7 @@ def write_timelines(self, timelines: Mapping[str, Timeline]) -> None: # Write track descriptors for parent tracks. for track_info in self._track_info.values(): packet = builder.add_packet() - packet.track_descriptor.uuid = track_info.uuid + packet.track_descriptor.uuid = track_info.uuid # pyrefly: ignore[bad-assignment] packet.track_descriptor.name = track_info.name # Sort timelines by ID to ensure consistent track ordering. @@ -325,7 +325,7 @@ def write_timelines(self, timelines: Mapping[str, Timeline]) -> None: if tl_id in self._timeline_tracks: track_info = self._track_info[self._timeline_tracks[tl_id]] - packet.track_descriptor.parent_uuid = track_info.uuid + packet.track_descriptor.parent_uuid = track_info.uuid # pyrefly: ignore[bad-assignment] # TODO: noghabi - limit processing to last steps. we don't need to start # from the beginning every time. @@ -355,7 +355,7 @@ def write_timelines(self, timelines: Mapping[str, Timeline]) -> None: "timestamp": start_ns, "type": TrackEvent.Type.TYPE_SLICE_BEGIN, "uuid": lane_uuid, - "name": _create_span_name(s.name, s.tags), + "name": _create_span_name(s.name, s.tags), # pyrefly: ignore[bad-argument-type] }) if s.ended: diff --git a/tunix/perf/export.py b/tunix/perf/export.py index 9e2f424d6..a3bd3e251 100644 --- a/tunix/perf/export.py +++ b/tunix/perf/export.py @@ -47,7 +47,7 @@ def __call__( ) -> tuple[ bool, list[SpanGroup], list[Span], list[Span], list[SpanGroup], list[Span] ]: - return (False, None, None, None, None, None) + return (False, None, None, None, None, None) # pyrefly: ignore[bad-return] class PerfMetricsExport: diff --git a/tunix/rl/agentic/agentic_grpo_learner.py b/tunix/rl/agentic/agentic_grpo_learner.py index e0544f060..287c2953c 100644 --- a/tunix/rl/agentic/agentic_grpo_learner.py +++ b/tunix/rl/agentic/agentic_grpo_learner.py @@ -231,7 +231,7 @@ def __init__( else: logging.warning("Metrics log dir is None, skipping trajectory logging.") - self.algo_config.temperature = self.rl_cluster.get_rollout_config( + self.algo_config.temperature = self.rl_cluster.get_rollout_config( # pyrefly: ignore[missing-attribute] mode=rl_cluster_lib.Mode.TRAIN ).temperature @@ -253,7 +253,7 @@ def __init__( has_aux=True, ) self.rl_cluster.actor_trainer.with_gen_model_input_fn( - lambda x: { + lambda x: { # pyrefly: ignore[bad-argument-type] "train_example": x, "algo_config": self.algo_config, } @@ -278,7 +278,7 @@ def __init__( "sampler_is/weight_mean": np.mean, "sampler_is/weight_min": np.min, }) - self.rl_cluster.actor_trainer.with_tqdm_metrics_to_display([ + self.rl_cluster.actor_trainer.with_tqdm_metrics_to_display([ # pyrefly: ignore[bad-argument-type] lambda: "kl" if self.algo_config.force_compute_kl or self.algo_config.beta != 0.0 else None, @@ -386,8 +386,8 @@ def _process_results( clipped_completion_count += 1 padded_prompt, padded_completion, _ = ( agentic_utils.pad_prompt_and_completion( - prompt_tokens, - completion_tokens, + prompt_tokens, # pyrefly: ignore[bad-argument-type] + completion_tokens, # pyrefly: ignore[bad-argument-type] rollout_config.max_prompt_length, max_response_length, pad_value, @@ -529,7 +529,7 @@ def _process_results( "generation/prompts/min_length": (prompt_token_len, np.min), }, mode=mode, - step=expected_step, + step=expected_step, # pyrefly: ignore[bad-argument-type] ) reward_kwargs = { @@ -707,9 +707,9 @@ def _process_results( f"{prefix}/{sub_key}/min": (np.min(vals), np.min), }) self.rl_cluster.buffer_metrics_async( - metrics_to_log, + metrics_to_log, # pyrefly: ignore[bad-argument-type] mode=mode, - step=expected_step, + step=expected_step, # pyrefly: ignore[bad-argument-type] ) for metric_fn in self.metric_fns: @@ -725,7 +725,7 @@ def _process_results( }, ) self.rl_cluster.buffer_metrics_async( - user_defined_metric, mode=mode, step=expected_step + user_defined_metric, mode=mode, step=expected_step # pyrefly: ignore[bad-argument-type] ) combined_batch = TrainExample( diff --git a/tunix/rl/agentic/agentic_rl_learner.py b/tunix/rl/agentic/agentic_rl_learner.py index 164eeeeee..e79e924d8 100644 --- a/tunix/rl/agentic/agentic_rl_learner.py +++ b/tunix/rl/agentic/agentic_rl_learner.py @@ -421,7 +421,7 @@ def _create_agent_env_pair( assert "pair_index" not in self.env_kwargs env = self.env_class( single_example, - **{"group_id": group_id, "pair_index": pair_index, **self.env_kwargs}, + **{"group_id": group_id, "pair_index": pair_index, **self.env_kwargs}, # pyrefly: ignore[bad-argument-type] ) return agent, env @@ -454,7 +454,7 @@ def _model_call( tags[perf_constants.PAIR_INDEX] = env.extra_kwargs["pair_index"] result = self.rl_cluster.generate( - prompts=chat_lists, + prompts=chat_lists, # pyrefly: ignore[bad-argument-type] apply_chat_template=False if self.chat_parser else True, mode=rl_cluster_lib.Mode.TRAIN, trace_tags=tags, @@ -508,7 +508,7 @@ async def pairs_stream_generator(): # with mini-batch. group_id = self.rl_cluster.global_steps * self._full_batch_size if is_async_iterator: - async for single_example in prompt_iterator: + async for single_example in prompt_iterator: # pyrefly: ignore[not-iterable] # Create agent-env pairs in parallel for a group to handle potential # cold start latency on env creation. agent_env_pairs = await asyncio.gather(*[ @@ -525,7 +525,7 @@ async def pairs_stream_generator(): yield agent, env group_id += 1 else: - for single_example in prompt_iterator: + for single_example in prompt_iterator: # pyrefly: ignore[not-iterable] agent_env_pairs = await asyncio.gather(*[ self.loop.run_in_executor( None, @@ -725,7 +725,7 @@ def train( self.rl_cluster.close() return - full_batch_size = len(next(iter(first_item.values()))) + full_batch_size = len(next(iter(first_item.values()))) # pyrefly: ignore[bad-argument-type] self._full_batch_size = full_batch_size # Initialize batch sizes. mini_batch_size = self._training_config.mini_batch_size or full_batch_size @@ -1158,7 +1158,7 @@ def _filter_outdated_offpolicy_examples( len(train_micro_batch), self.policy_version, str([ - train_example.policy_version[0] + train_example.policy_version[0] # pyrefly: ignore[unsupported-operation] for train_example in train_micro_batch ]), self.algo_config.off_policy_steps, diff --git a/tunix/rl/agentic/agents/base_agent.py b/tunix/rl/agentic/agents/base_agent.py index 4f89b8194..327b6bd1f 100644 --- a/tunix/rl/agentic/agents/base_agent.py +++ b/tunix/rl/agentic/agents/base_agent.py @@ -218,7 +218,7 @@ def update_from_env( # Let subclass / default handler convert observation into messages. if observation is not None: - self._observation_to_messages(observation, reward, done, info) + self._observation_to_messages(observation, reward, done, info) # pyrefly: ignore[bad-argument-type] def reset(self) -> None: """Reset trajectory, cache, and conversation history.""" diff --git a/tunix/rl/agentic/pipeline/rollout_orchestrator.py b/tunix/rl/agentic/pipeline/rollout_orchestrator.py index 1c607e1fa..9b4a66c22 100644 --- a/tunix/rl/agentic/pipeline/rollout_orchestrator.py +++ b/tunix/rl/agentic/pipeline/rollout_orchestrator.py @@ -249,7 +249,7 @@ async def run_producers_from_stream( if is_async_stream: pairs_iterator = aiter(pairs_stream) # pytype: disable=wrong-arg-types else: - pairs_iterator = iter(pairs_stream) + pairs_iterator = iter(pairs_stream) # pyrefly: ignore[no-matching-overload] active_tasks: set[asyncio.Task] = set() stream_exhausted = False @@ -271,7 +271,7 @@ async def run_producers_from_stream( if is_async_stream: agent, env = await anext(pairs_iterator) # pytype: disable=name-error else: - agent, env = next(pairs_iterator) + agent, env = next(pairs_iterator) # pyrefly: ignore[bad-argument-type] task = asyncio.create_task( self._runner( agent=agent, diff --git a/tunix/rl/agentic/rewards/reward.py b/tunix/rl/agentic/rewards/reward.py index d7174bbf1..115a61041 100644 --- a/tunix/rl/agentic/rewards/reward.py +++ b/tunix/rl/agentic/rewards/reward.py @@ -187,12 +187,12 @@ def _eval(node): op_type = type(node.op) if op_type not in _OP_MAP: raise ValueError(f"Unsupported operator: {op_type}") - return _OP_MAP[op_type](_eval(node.left), _eval(node.right)) + return _OP_MAP[op_type](_eval(node.left), _eval(node.right)) # pyrefly: ignore[bad-argument-count] elif isinstance(node, ast.UnaryOp): op_type = type(node.op) if op_type not in _OP_MAP: raise ValueError(f"Unsupported operator: {op_type}") - return _OP_MAP[op_type](_eval(node.operand)) + return _OP_MAP[op_type](_eval(node.operand)) # pyrefly: ignore[bad-argument-count, no-matching-overload] raise ValueError(f"Unsupported node: {type(node)}") tree = ast.parse(expr_str, mode="eval") diff --git a/tunix/rl/agentic/trajectory/trajectory_collect_engine.py b/tunix/rl/agentic/trajectory/trajectory_collect_engine.py index 995f22d7d..93c3a77ce 100644 --- a/tunix/rl/agentic/trajectory/trajectory_collect_engine.py +++ b/tunix/rl/agentic/trajectory/trajectory_collect_engine.py @@ -431,7 +431,7 @@ async def _reset(self): contains_first_msg=True, contains_generation_msg=True, ) - self.agent.trajectory.prompt_tokens = prompt_tokens + self.agent.trajectory.prompt_tokens = prompt_tokens # pyrefly: ignore[missing-attribute] @property def _debug_prefix(self) -> str: @@ -526,7 +526,7 @@ def _safe_model_call(): not self.agent.trajectory.steps and rollout_output.left_padded_prompt_tokens is not None ): - self.agent.trajectory.prompt_tokens = ( + self.agent.trajectory.prompt_tokens = ( # pyrefly: ignore[missing-attribute] rollout_output.left_padded_prompt_tokens[0] ) @@ -653,7 +653,7 @@ def _safe_model_call(): self.agent.trajectory.status = agent_types.TrajectoryStatus.TIMEOUT logging.warning("Episode timed out after %d seconds.", self.timeout) self._log_trajectory_clip("TIMEOUT") - self.agent.get_current_step().done = True + self.agent.get_current_step().done = True # pyrefly: ignore[missing-attribute] return True return done diff --git a/tunix/rl/inference/inference_worker.py b/tunix/rl/inference/inference_worker.py index 4a07a58fd..984493579 100644 --- a/tunix/rl/inference/inference_worker.py +++ b/tunix/rl/inference/inference_worker.py @@ -57,7 +57,7 @@ def get_ref_per_token_logps( eos_id: int, temperature: float = 1.0, ) -> jax.Array: - graphdef, state = self._model_states.get("reference") + graphdef, state = self._model_states.get("reference") # pyrefly: ignore[not-iterable] if graphdef is None: raise ValueError("Reference model is not available.") return common.compute_per_token_logps( @@ -78,7 +78,7 @@ def get_values( pad_id: int, eos_id: int, ) -> jax.Array: - graphdef, state = self._model_states.get("critic") + graphdef, state = self._model_states.get("critic") # pyrefly: ignore[not-iterable] critic_model = nnx.merge(graphdef, state) if critic_model is None: raise ValueError("Critic model is not available.") diff --git a/tunix/rl/ppo/ppo_learner.py b/tunix/rl/ppo/ppo_learner.py index ae69bde05..88a230565 100644 --- a/tunix/rl/ppo/ppo_learner.py +++ b/tunix/rl/ppo/ppo_learner.py @@ -159,7 +159,7 @@ def __init__( super().__init__( rl_cluster=rl_cluster, algo_config=ppo_config, - reward_fns=reward_fns, + reward_fns=reward_fns, # pyrefly: ignore[bad-argument-type] metric_fns=metric_fns, data_shuffle_seed=data_shuffle_seed, ) @@ -198,7 +198,7 @@ def __init__( ) self.rl_cluster.actor_trainer.with_loss_fn(loss_fn, has_aux=True) self.rl_cluster.actor_trainer.with_gen_model_input_fn( - lambda x: { + lambda x: { # pyrefly: ignore[bad-argument-type] "train_example": x, "algo_config": self.algo_config, } @@ -230,7 +230,7 @@ def __init__( ): actor_rl_metrics_to_log["loss/entropy"] = np.mean self.rl_cluster.actor_trainer.with_rl_metrics_to_log( - actor_rl_metrics_to_log + actor_rl_metrics_to_log # pyrefly: ignore[bad-argument-type] ) self.rl_cluster.critic_trainer.with_rl_metrics_to_log({ @@ -266,7 +266,7 @@ def _generate_and_compute_advantage( # Generate. We use `model`, i.e., the policy model for generating the # "experiences". rollout_output = self.rl_cluster.generate( - prompts=training_input["prompts"], + prompts=training_input["prompts"], # pyrefly: ignore[bad-argument-type] micro_batch_size=self._rollout_micro_batch_size, ) padded_completion_ids = np.array([ @@ -343,10 +343,10 @@ def _generate_and_compute_advantage( last_token_scores = jax.device_get(jax_last_token_scores) else: last_token_scores = self._compute_rewards( - prompts=training_input["prompts"], + prompts=training_input["prompts"], # pyrefly: ignore[bad-argument-type] completions=rollout_output.text, mode=mode, - **{k: v for k, v in training_input.items() if k != "prompts"}, + **{k: v for k, v in training_input.items() if k != "prompts"}, # pyrefly: ignore[bad-argument-type] ) jax_last_token_scores = jax.device_put(last_token_scores) @@ -365,7 +365,7 @@ def _generate_and_compute_advantage( # rewards or computed in the loss function. kl = common.compute_kl_divergence( old_per_token_logps, - ref_per_token_logps, + ref_per_token_logps, # pyrefly: ignore[bad-argument-type] method=self.algo_config.kl_method, clamp_value=self.algo_config.kl_clamp_value, ) @@ -408,7 +408,7 @@ def _generate_and_compute_advantage( if self.algo_config.beta != 0.0: # Average of the per-sequence mean KL per_sequence_mean_kl = ppo_helpers.masked_mean( - kl, jax_completion_mask, axis=-1 # pylint: disable=undefined-variable + kl, jax_completion_mask, axis=-1 # pylint: disable=undefined-variable # pyrefly: ignore[unbound-name] ) self.rl_cluster.buffer_metrics( { @@ -516,7 +516,7 @@ def _compute_trajectory_ids( Returns: A list of trajectory IDs, one for each prompt in the batch. """ - batch_size = len(example["prompts"]) // self._num_generations() + batch_size = len(example["prompts"]) // self._num_generations() # pyrefly: ignore[bad-argument-type] row_offset = steps * batch_size row_offsets = np.arange(row_offset, row_offset + batch_size) return row_offsets.astype(str).tolist() diff --git a/tunix/rl/rollout/mock_rollout.py b/tunix/rl/rollout/mock_rollout.py index 9257a30f8..043ab5e10 100644 --- a/tunix/rl/rollout/mock_rollout.py +++ b/tunix/rl/rollout/mock_rollout.py @@ -314,7 +314,7 @@ def get_per_token_logps( """ batch_size, length = completion_tokens.shape # Use numpy to keep it on host memory. - return np.zeros((batch_size, length), dtype=np.float32) + return np.zeros((batch_size, length), dtype=np.float32) # pyrefly: ignore[bad-return] def update_params( self, @@ -332,13 +332,13 @@ def update_params( def pad_id(self) -> int: if self._tokenizer is not None and hasattr(self._tokenizer, "pad_id"): pad_id_attr = self._tokenizer.pad_id - return pad_id_attr() if callable(pad_id_attr) else pad_id_attr + return pad_id_attr() if callable(pad_id_attr) else pad_id_attr # pyrefly: ignore[bad-return] return self._pad_id def eos_id(self) -> int: if self._tokenizer is not None and hasattr(self._tokenizer, "eos_id"): eos_id_attr = self._tokenizer.eos_id - return eos_id_attr() if callable(eos_id_attr) else eos_id_attr + return eos_id_attr() if callable(eos_id_attr) else eos_id_attr # pyrefly: ignore[bad-return] return self._eos_id def model(self) -> Any: diff --git a/tunix/rl/rollout/vanilla_rollout.py b/tunix/rl/rollout/vanilla_rollout.py index 1e7be3803..cf092fd85 100644 --- a/tunix/rl/rollout/vanilla_rollout.py +++ b/tunix/rl/rollout/vanilla_rollout.py @@ -59,17 +59,17 @@ def generate( temperature=rollout_config.temperature, top_p=rollout_config.top_p, top_k=rollout_config.top_k, - seed=rollout_config.seed, + seed=rollout_config.seed, # pyrefly: ignore[bad-argument-type] pad_output=False, eos_tokens=rollout_config.eos_tokens, return_logprobs=rollout_config.return_logprobs, ) return base_rollout.RolloutOutput( text=output.text, - logits=output.logits, - tokens=output.tokens, + logits=output.logits, # pyrefly: ignore[bad-argument-type] + tokens=output.tokens, # pyrefly: ignore[bad-argument-type] left_padded_prompt_tokens=output.padded_prompt_tokens, - logprobs=output.logprobs, + logprobs=output.logprobs, # pyrefly: ignore[bad-argument-type] ) def get_per_token_logps( @@ -117,7 +117,7 @@ def update_params( operator.ior, [flat_old_params, flat_new_params], {} ) merged_params = jax.tree.unflatten(tree_def, merged_params.values()) - new_model = nnx.merge(self._sampler._transformer_graphdef, merged_params) # pylint: disable=protected-access + new_model = nnx.merge(self._sampler._transformer_graphdef, merged_params) # pylint: disable=protected-access # pyrefly: ignore[no-matching-overload] self._sampler.transformer_state = nnx.variables(new_model, nnx.Param) def pad_id(self) -> int: diff --git a/tunix/tests/test_common.py b/tunix/tests/test_common.py index 45bdca740..bef6b27c2 100644 --- a/tunix/tests/test_common.py +++ b/tunix/tests/test_common.py @@ -254,8 +254,8 @@ def get_lora_model( model, lora_provider, **dummy_model_input ) if mesh is not None: - lora_model = reshard.reshard_model_to_mesh(lora_model, mesh) - return lora_model + lora_model = reshard.reshard_model_to_mesh(lora_model, mesh) # pyrefly: ignore[bad-argument-type] + return lora_model # pyrefly: ignore[bad-return] class MockVocab(spm.SentencePieceProcessor): @@ -342,7 +342,7 @@ def __init__(self, transformer: nnx.Module, rngs: nnx.Rngs): self.transformer = transformer self.score = nnx.Linear( - in_features=transformer.head_dim, + in_features=transformer.head_dim, # pyrefly: ignore[missing-attribute] out_features=1, use_bias=False, rngs=rngs,