From 3efd0d3f7ba165d9b16bfef7b34b145357481bbd Mon Sep 17 00:00:00 2001 From: Hana Joo Date: Fri, 3 Jul 2026 01:51:36 -0700 Subject: [PATCH] Code update PiperOrigin-RevId: 942021390 --- tests/generate/beam_search_test.py | 16 ++++---- tests/generate/sampler_test.py | 18 ++++---- tests/generate/utils_test.py | 4 +- tests/processors/audio_processor_test.py | 12 +++--- tests/processors/image_processor_test.py | 8 ++-- tests/rl/grpo/dapo_learner_test.py | 4 +- tests/rl/grpo/drgrpo_learner_test.py | 2 +- tests/rl/grpo/grpo_learner_test.py | 10 ++--- tests/sft/checkpoint_manager_test.py | 8 ++-- tests/sft/checkpoint_options_test.py | 12 +++--- tests/sft/dpo/dpo_trainer_test.py | 6 +-- tests/sft/dpo/orpo_trainer_test.py | 6 +-- tests/utils/mesh_test.py | 24 +++++------ tests/utils/token_sanitization_test.py | 2 +- tunix/generate/mappings.py | 2 +- tunix/generate/sampler.py | 38 ++++++++--------- tunix/generate/tokenizer_adapter.py | 8 ++-- tunix/generate/utils.py | 18 ++++---- tunix/models/gemma3/merge_embeddings.py | 4 +- tunix/models/gemma3/model.py | 40 +++++++++--------- tunix/models/gemma3/params.py | 2 +- tunix/models/gemma3/params_safetensors.py | 2 +- tunix/models/gemma3/utils.py | 6 +-- tunix/models/gemma3/vision.py | 22 +++++----- tunix/models/gemma4/audio.py | 16 ++++---- tunix/models/gemma4/model.py | 50 +++++++++++------------ tunix/models/gemma4/params_safetensors.py | 2 +- tunix/models/gemma4/vision.py | 34 +++++++-------- tunix/processors/audio_processor.py | 4 +- tunix/processors/image_processor.py | 6 +-- tunix/rl/algo_core.py | 6 +-- tunix/rl/common.py | 12 +++--- tunix/rl/grpo/grpo_learner.py | 18 ++++---- tunix/rl/rl_cluster.py | 8 ++-- tunix/rl/rl_learner.py | 6 +-- tunix/rl/utils.py | 12 +++--- tunix/sft/checkpoint_manager.py | 10 ++--- tunix/sft/checkpoint_options.py | 10 ++--- tunix/sft/dpo/dpo_trainer.py | 24 +++++------ tunix/sft/hooks.py | 16 ++++---- tunix/sft/metrics_logger.py | 8 ++-- tunix/sft/peft_trainer.py | 30 +++++++------- tunix/sft/profiler.py | 2 +- tunix/utils/math_utils.py | 10 ++--- tunix/utils/mesh.py | 2 +- tunix/utils/script_utils.py | 2 +- tunix/utils/trajectory_logger.py | 14 +++---- 47 files changed, 288 insertions(+), 288 deletions(-) diff --git a/tests/generate/beam_search_test.py b/tests/generate/beam_search_test.py index 0712908e1..24640798c 100644 --- a/tests/generate/beam_search_test.py +++ b/tests/generate/beam_search_test.py @@ -109,10 +109,10 @@ def test_beam_search_step_without_stop(self) -> None: new_scores1, tokens1 = jax.lax.top_k( jax.nn.log_softmax(jnp.array([1, 2, 1.1])), 2 ) - self.assertAlmostEqual(state.scores[0][0], new_scores0[0], places=6) - self.assertAlmostEqual(state.scores[0][1], new_scores0[1], places=6) - self.assertAlmostEqual(state.scores[1][0], new_scores1[0], places=6) - self.assertAlmostEqual(state.scores[1][1], new_scores1[1], places=6) + self.assertAlmostEqual(state.scores[0][0], new_scores0[0], places=6) # pyrefly: ignore[no-matching-overload] + self.assertAlmostEqual(state.scores[0][1], new_scores0[1], places=6) # pyrefly: ignore[no-matching-overload] + self.assertAlmostEqual(state.scores[1][0], new_scores1[0], places=6) # pyrefly: ignore[no-matching-overload] + self.assertAlmostEqual(state.scores[1][1], new_scores1[1], places=6) # pyrefly: ignore[no-matching-overload] updated_token_buffer = updated_params['token_buffer'] expected = token_buffer @@ -162,10 +162,10 @@ def _check(x, y) -> None: ).ravel(), 2, ) - self.assertAlmostEqual(state.scores[0][0], new_scores0[0], places=6) - self.assertAlmostEqual(state.scores[0][1], new_scores0[1], places=6) - self.assertAlmostEqual(state.scores[1][0], new_scores1[0], places=6) - self.assertAlmostEqual(state.scores[1][1], new_scores1[1], places=6) + self.assertAlmostEqual(state.scores[0][0], new_scores0[0], places=6) # pyrefly: ignore[no-matching-overload] + self.assertAlmostEqual(state.scores[0][1], new_scores0[1], places=6) # pyrefly: ignore[no-matching-overload] + self.assertAlmostEqual(state.scores[1][0], new_scores1[0], places=6) # pyrefly: ignore[no-matching-overload] + self.assertAlmostEqual(state.scores[1][1], new_scores1[1], places=6) # pyrefly: ignore[no-matching-overload] # before the beam search, the token buffer[:][0] is [2, 1, 1, 2] # token_buffer[0] should be [1, 1, -1, ...] # token_buffer[1] should be [2, 1, -1, ...] diff --git a/tests/generate/sampler_test.py b/tests/generate/sampler_test.py index 7a69033f8..a05def07b 100644 --- a/tests/generate/sampler_test.py +++ b/tests/generate/sampler_test.py @@ -139,8 +139,8 @@ def test_samples_padding_output(self, max_prompt_length, echo, return_logits): + 1 ) np.testing.assert_allclose( - result_not_padded.logits[i], - result_padded.logits[i][:valid_length], + result_not_padded.logits[i], # pyrefly: ignore[unsupported-operation] + result_padded.logits[i][:valid_length], # pyrefly: ignore[unsupported-operation] ) np.testing.assert_allclose( result_not_padded.tokens[i], @@ -197,7 +197,7 @@ def __call__(self, images): return_logits=True, max_prompt_length=8, echo=True, - images=images, + images=images, # pyrefly: ignore[bad-argument-type] ) self.assertIsNotNone(result) @@ -259,9 +259,9 @@ def test_samples(self, max_prompt_length, echo): self.assertIsNotNone(result) self.assertLen(result.logits, 2) if echo: - self.assertEqual(result.logits[0].shape, (13, vocab.GetPieceSize())) + self.assertEqual(result.logits[0].shape, (13, vocab.GetPieceSize())) # pyrefly: ignore[unsupported-operation] else: - self.assertEqual(result.logits[0].shape, (10, vocab.GetPieceSize())) + self.assertEqual(result.logits[0].shape, (10, vocab.GetPieceSize())) # pyrefly: ignore[unsupported-operation] # With 1 beam, the beam search result should be the # same as the greedy output @@ -438,7 +438,7 @@ def test_state_update(self): input_strings, max_generation_steps=10, return_logits=True ).logits with self.assertRaises(AssertionError): - for orig, new in zip(original_logits, new_logits): + for orig, new in zip(original_logits, new_logits): # pyrefly: ignore[bad-argument-type] np.testing.assert_allclose(orig, new, atol=1e-1, rtol=1e-1) def test_lora_state_update(self): @@ -482,7 +482,7 @@ def test_lora_state_update(self): input_strings, max_generation_steps=10, return_logits=True ).logits with self.assertRaises(AssertionError): - for orig, new in zip(original_logits, new_logits): + for orig, new in zip(original_logits, new_logits): # pyrefly: ignore[bad-argument-type] np.testing.assert_allclose(orig, new, atol=1e-1, rtol=1e-1) def test_invalid_state_update(self): @@ -721,8 +721,8 @@ def test_gemma4_decode_only_last_token_consistency(self): self.assertEqual(len(res_opt.tokens), len(res_unopt.tokens)) for t_opt, t_unopt in zip(res_opt.tokens, res_unopt.tokens): np.testing.assert_array_equal(t_opt, t_unopt) - self.assertEqual(len(res_opt.logits), len(res_unopt.logits)) - for l_opt, l_unopt in zip(res_opt.logits, res_unopt.logits): + self.assertEqual(len(res_opt.logits), len(res_unopt.logits)) # pyrefly: ignore[bad-argument-type] + for l_opt, l_unopt in zip(res_opt.logits, res_unopt.logits): # pyrefly: ignore[bad-argument-type] self.assertEqual(l_opt.shape, l_unopt.shape) np.testing.assert_allclose(l_opt, l_unopt, atol=1e-5, rtol=1e-5) diff --git a/tests/generate/utils_test.py b/tests/generate/utils_test.py index f14ff5941..c1582b30c 100644 --- a/tests/generate/utils_test.py +++ b/tests/generate/utils_test.py @@ -176,7 +176,7 @@ def test_logprobs_basic_extraction(self): ] expected = [-1.71, -0.37, 0.0] self.assertEqual( - utils.get_logprobs_from_vllm_output(token_ids, logprobs), + utils.get_logprobs_from_vllm_output(token_ids, logprobs), # pyrefly: ignore[bad-argument-type] expected, ) @@ -184,7 +184,7 @@ def test_logprobs_extraction_with_missing_token(self): token_ids = [100, 200] logprobs = [{101: Logprob(-0.5)}, {200: Logprob(-1.2)}] with self.assertRaises(ValueError): - utils.get_logprobs_from_vllm_output(token_ids, logprobs) + utils.get_logprobs_from_vllm_output(token_ids, logprobs) # pyrefly: ignore[bad-argument-type] @parameterized.named_parameters( ("none_logprobs", [], None), diff --git a/tests/processors/audio_processor_test.py b/tests/processors/audio_processor_test.py index 7df1feacb..c66f9f5b7 100644 --- a/tests/processors/audio_processor_test.py +++ b/tests/processors/audio_processor_test.py @@ -74,7 +74,7 @@ def test_process_inputs_batch_list_ndarray(self): ] processed_audios, new_tokens = audio_processor.process_gemma4_inputs( - audios=audios, + audios=audios, # pyrefly: ignore[bad-argument-type] tokens=tokens, audio_encoder=self.audio_encoder, ) @@ -120,7 +120,7 @@ def test_process_inputs_batch_nested_list(self): ] proc_audios, new_tokens = audio_processor.process_gemma4_inputs( - audios=audios, + audios=audios, # pyrefly: ignore[bad-argument-type] tokens=tokens, audio_encoder=self.audio_encoder, ) @@ -170,7 +170,7 @@ def test_process_inputs_no_audio_clips(self): ] processed_audios, new_tokens = audio_processor.process_gemma4_inputs( - audios=audios, + audios=audios, # pyrefly: ignore[bad-argument-type] tokens=tokens, audio_encoder=self.audio_encoder, ) @@ -225,7 +225,7 @@ def test_batch_mismatch_error(self): ValueError, "Batch size of tokens.*does not match" ): audio_processor.process_gemma4_inputs( - audios=audios, + audios=audios, # pyrefly: ignore[bad-argument-type] tokens=tokens, audio_encoder=self.audio_encoder, ) @@ -239,7 +239,7 @@ def test_placeholder_mismatch_error(self): ValueError, "Placeholders provided for 2 clips, but only 1 provided" ): audio_processor.process_gemma4_inputs( - audios=audios, + audios=audios, # pyrefly: ignore[bad-argument-type] tokens=tokens, audio_encoder=self.audio_encoder, ) @@ -255,7 +255,7 @@ def test_max_audio_clips_exceeded_error(self): ValueError, "A batch entry has more clips than the specified" ): audio_processor.process_gemma4_inputs( - audios=audios, + audios=audios, # pyrefly: ignore[bad-argument-type] tokens=tokens, audio_encoder=self.audio_encoder, max_audio_clips=1, # Limit to 1 clip diff --git a/tests/processors/image_processor_test.py b/tests/processors/image_processor_test.py index d6069e77f..1be9cf9e2 100644 --- a/tests/processors/image_processor_test.py +++ b/tests/processors/image_processor_test.py @@ -108,7 +108,7 @@ def test_multiple_images_dim_1(self): img1 = np.zeros((100, 100, 3), dtype=np.uint8) img2 = np.zeros((50, 50, 3), dtype=np.uint8) images = [img1, img2] - processed_images = self.processor(images=images) + processed_images = self.processor(images=images) # pyrefly: ignore[bad-argument-type] np.testing.assert_allclose( processed_images[0][0], -1.0 * np.ones((self.height, self.width, 3)) ) @@ -129,7 +129,7 @@ def test_padding(self, input_type): else: images = [img1, [img1, img2]] - processed_images = self.processor(images=images) + processed_images = self.processor(images=images) # pyrefly: ignore[bad-argument-type] np.testing.assert_allclose( processed_images[0][0], -1.0 * np.ones((self.height, self.width, 3)) ) @@ -148,7 +148,7 @@ def test_mixed_inputs(self): img1 = np.zeros((100, 100, 3), dtype=np.uint8) img2 = self._create_dummy_image_file() images = [img1, [img1, img2]] - processed_images = self.processor(images=images) + processed_images = self.processor(images=images) # pyrefly: ignore[bad-argument-type] np.testing.assert_allclose( processed_images[0][0], -1.0 * np.ones((self.height, self.width, 3)) ) @@ -164,7 +164,7 @@ def test_mixed_inputs(self): def test_call_with_none_in_batch(self): images = [None, [np.zeros((100, 100, 3), dtype=np.uint8)]] - processed_images = self.processor(images=images) + processed_images = self.processor(images=images) # pyrefly: ignore[bad-argument-type] np.testing.assert_allclose( processed_images[0][0], np.zeros((self.height, self.width, 3)) ) diff --git a/tests/rl/grpo/dapo_learner_test.py b/tests/rl/grpo/dapo_learner_test.py index 4e45d5438..eec619bb0 100644 --- a/tests/rl/grpo/dapo_learner_test.py +++ b/tests/rl/grpo/dapo_learner_test.py @@ -67,8 +67,8 @@ def create_train_example(self): def test_diff_loss(self): dapo_config = dapo_lib.DAPOConfig() grpo_config = grpo_lib.GRPOConfig() - dapo_config.temperature = 1.0 - grpo_config.temperature = 1.0 + dapo_config.temperature = 1.0 # pyrefly: ignore[missing-attribute] + grpo_config.temperature = 1.0 # pyrefly: ignore[missing-attribute] dapo_loss_fn_impl = fr.default_registry.get( "policy_loss_fn", dapo_config.policy_loss_fn diff --git a/tests/rl/grpo/drgrpo_learner_test.py b/tests/rl/grpo/drgrpo_learner_test.py index 6a3bc2edf..d88f7ea96 100644 --- a/tests/rl/grpo/drgrpo_learner_test.py +++ b/tests/rl/grpo/drgrpo_learner_test.py @@ -109,7 +109,7 @@ def test_drgrpo_advantage_estimator(self): def test_drgrpo_loss_fn(self): drgrpo_config = drgrpo_lib.DrGRPOConfig() - drgrpo_config.temperature = 1.0 + drgrpo_config.temperature = 1.0 # pyrefly: ignore[missing-attribute] drgrpo_loss_fn_impl = fr.default_registry.get( "policy_loss_fn", drgrpo_config.policy_loss_fn diff --git a/tests/rl/grpo/grpo_learner_test.py b/tests/rl/grpo/grpo_learner_test.py index d21eef20d..6738d823f 100644 --- a/tests/rl/grpo/grpo_learner_test.py +++ b/tests/rl/grpo/grpo_learner_test.py @@ -148,7 +148,7 @@ def __init__(self, grpo_config): self._last_iter_step = 0 self.algo_config = grpo_config self._data_shuffle_seed = None - self.rl_cluster = types.SimpleNamespace( + self.rl_cluster = types.SimpleNamespace( # pyrefly: ignore[bad-assignment] global_steps=0, cluster_config=types.SimpleNamespace( training_config=types.SimpleNamespace( @@ -167,7 +167,7 @@ def _generate_and_compute_advantage(self, example, mode='train'): if 'trajectory_ids' in example: del example['trajectory_ids'] prompts = example['prompts'] - num_samples = len(prompts) + num_samples = len(prompts) # pyrefly: ignore[bad-argument-type] # Return a SimpleNamespace to mimic TrainExample attributes return types.SimpleNamespace( prompt_ids=np.array(prompts), @@ -348,12 +348,12 @@ def wrapper(*args, **kwargs): metric_logger = grpo_learner.rl_cluster.actor_trainer.metrics_logger for metric_name in ['loss', 'kl']: self.assertLen( - metric_logger.get_metric_history('actor', metric_name, 'train'), + metric_logger.get_metric_history('actor', metric_name, 'train'), # pyrefly: ignore[missing-attribute] grpo_learner.rl_cluster.actor_trainer.train_steps, msg=f'metric_name: {metric_name}', ) self.assertLen( - metric_logger.get_metric_history('actor', metric_name, 'eval'), + metric_logger.get_metric_history('actor', metric_name, 'eval'), # pyrefly: ignore[missing-attribute] grpo_learner.rl_cluster.actor_trainer.train_steps / kwargs['eval_every_n_steps'], msg=f'metric_name: {metric_name}', @@ -546,7 +546,7 @@ def wrapper(*args, **kwargs): // (kwargs.get('gradient_accumulation_steps') or 1), ) self.assertLen( - grpo_learner.rl_cluster.actor_trainer.metrics_logger.get_metric_history( + grpo_learner.rl_cluster.actor_trainer.metrics_logger.get_metric_history( # pyrefly: ignore[missing-attribute] 'actor', 'kl', 'train' ), grpo_learner.rl_cluster.actor_trainer.train_steps, diff --git a/tests/sft/checkpoint_manager_test.py b/tests/sft/checkpoint_manager_test.py index 63d120744..b3da70e27 100644 --- a/tests/sft/checkpoint_manager_test.py +++ b/tests/sft/checkpoint_manager_test.py @@ -104,8 +104,8 @@ def setUp(self): def test_empty_root_directory(self): cp_manager = checkpoint_manager.CheckpointManager(root_directory=None) self.assertIsNone(cp_manager.latest_step()) - self.assertFalse(cp_manager.save(1, None)) - self.assertEqual(cp_manager.maybe_restore(None), (0, {})) + self.assertFalse(cp_manager.save(1, None)) # pyrefly: ignore[bad-argument-type] + self.assertEqual(cp_manager.maybe_restore(None), (0, {})) # pyrefly: ignore[bad-argument-type] def test_checkpoint_manager_options_none_sets_default(self): cp_path = f'{self.temp_path}/{self.id()}' @@ -210,7 +210,7 @@ def test_restore_with_lora(self): ) # Save the model params. - self.assertTrue(cp_manager.save(1, model, save_only_lora_params=True)) + self.assertTrue(cp_manager.save(1, model, save_only_lora_params=True)) # pyrefly: ignore[bad-argument-type] cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error # Change the model state. @@ -219,7 +219,7 @@ def test_restore_with_lora(self): # Restore the model lora params. self.assertEqual( - cp_manager.maybe_restore(model, restore_only_lora_params=True), + cp_manager.maybe_restore(model, restore_only_lora_params=True), # pyrefly: ignore[bad-argument-type] (1, {}), ) # Check the model lora params are restored correctly. diff --git a/tests/sft/checkpoint_options_test.py b/tests/sft/checkpoint_options_test.py index dbf713dd0..6e6389cd6 100644 --- a/tests/sft/checkpoint_options_test.py +++ b/tests/sft/checkpoint_options_test.py @@ -39,7 +39,7 @@ def test_resolve_checkpointing_defaults_with_deprecated_options(self): with self.assertLogs(level='WARNING') as log: opts = checkpoint_options.resolve_checkpointing_defaults( - legacy_opts + legacy_opts # pyrefly: ignore[bad-argument-type] ) # Verify deprecation warnings were logged @@ -64,7 +64,7 @@ def test_resolve_checkpointing_defaults_with_legacy_options_dataclass(self): ), ) opts = checkpoint_options.resolve_checkpointing_defaults( - legacy_opts + legacy_opts # pyrefly: ignore[bad-argument-type] ) self.assertIsInstance( opts.save_decision_policy, @@ -92,10 +92,10 @@ def test_resolve_checkpointing_defaults_with_async_timeout(self): def test_resolve_checkpointing_defaults_with_modern_options(self): modern_opts = checkpoint_options.TunixCheckpointingOptions( - save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy( + save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy( # pyrefly: ignore[bad-argument-type] 50 ), - preservation_policy=ocp.training.preservation_policies.LatestN(10), + preservation_policy=ocp.training.preservation_policies.LatestN(10), # pyrefly: ignore[bad-argument-type] enable_async_checkpointing=False, ) opts = checkpoint_options.resolve_checkpointing_defaults( @@ -111,10 +111,10 @@ def test_resolve_checkpointing_defaults_with_modern_options(self): def test_create_checkpointing_options(self): opts = checkpoint_options.create_checkpointing_options( - save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy( + save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy( # pyrefly: ignore[bad-argument-type] 50 ), - preservation_policy=ocp.training.preservation_policies.LatestN(10), + preservation_policy=ocp.training.preservation_policies.LatestN(10), # pyrefly: ignore[bad-argument-type] enable_async_checkpointing=False, ) self.assertIsInstance(opts, checkpoint_options.TunixCheckpointingOptions) diff --git a/tests/sft/dpo/dpo_trainer_test.py b/tests/sft/dpo/dpo_trainer_test.py index ae2ff71ad..7c73a13b6 100644 --- a/tests/sft/dpo/dpo_trainer_test.py +++ b/tests/sft/dpo/dpo_trainer_test.py @@ -169,13 +169,13 @@ def test_dpo_trainer( "log_probs/rejected", ]: self.assertLen( - dpo_trainer.metrics_logger.get_metric_history( + dpo_trainer.metrics_logger.get_metric_history( # pyrefly: ignore[missing-attribute] "", metric_name, "train" ), dpo_trainer._train_steps, ) self.assertLen( - dpo_trainer.metrics_logger.get_metric_history( + dpo_trainer.metrics_logger.get_metric_history( # pyrefly: ignore[missing-attribute] "", metric_name, "eval" ), 3, @@ -246,7 +246,7 @@ def test_dpo_trainer_with_string_inputs(self, train_ds): "rewards/accuracy", ]: self.assertLen( - dpo_trainer.metrics_logger.get_metric_history( + dpo_trainer.metrics_logger.get_metric_history( # pyrefly: ignore[missing-attribute] "", metric_name, "train" ), dpo_trainer._train_steps, diff --git a/tests/sft/dpo/orpo_trainer_test.py b/tests/sft/dpo/orpo_trainer_test.py index b42b769f3..f767a106b 100644 --- a/tests/sft/dpo/orpo_trainer_test.py +++ b/tests/sft/dpo/orpo_trainer_test.py @@ -155,13 +155,13 @@ def test_orpo_trainer( "odds_ratio", ]: self.assertLen( - orpo_trainer.metrics_logger.get_metric_history( + orpo_trainer.metrics_logger.get_metric_history( # pyrefly: ignore[missing-attribute] "", metric_name, "train" ), orpo_trainer._train_steps, ) self.assertLen( - orpo_trainer.metrics_logger.get_metric_history( + orpo_trainer.metrics_logger.get_metric_history( # pyrefly: ignore[missing-attribute] "", metric_name, "eval" ), 3, @@ -221,7 +221,7 @@ def test_orpo_trainer_with_string_inputs(self, train_ds): "rewards/accuracy", ]: self.assertLen( - orpo_trainer.metrics_logger.get_metric_history( + orpo_trainer.metrics_logger.get_metric_history( # pyrefly: ignore[missing-attribute] "", metric_name, "train" ), orpo_trainer._train_steps, diff --git a/tests/utils/mesh_test.py b/tests/utils/mesh_test.py index f42bb347b..b988ef586 100644 --- a/tests/utils/mesh_test.py +++ b/tests/utils/mesh_test.py @@ -106,7 +106,7 @@ def __init__(self, device_id, slice_index): ]) self.assertEqual( - [[device.id for device in group] for group in grouped], [[0, 1], [2, 3]] + [[device.id for device in group] for group in grouped], [[0, 1], [2, 3]] # pyrefly: ignore[not-iterable] ) def test_group_devices_by_slice_treats_missing_metadata_as_one_slice(self): @@ -121,7 +121,7 @@ def __init__(self, device_id): ]) self.assertEqual( - [[device.id for device in group] for group in grouped], [[0, 1]] + [[device.id for device in group] for group in grouped], [[0, 1]] # pyrefly: ignore[not-iterable] ) def test_partition_devices_by_host_groups_and_sorts_by_slice_then_host(self): @@ -140,7 +140,7 @@ def __init__(self, device_id, slice_index, process_index): ]) self.assertEqual( - [[device.id for device in group] for group in groups], + [[device.id for device in group] for group in groups], # pyrefly: ignore[not-iterable] [[0], [1], [2], [3]], ) @@ -167,13 +167,13 @@ def __init__(self, device_id, coords, core_on_chip): self.assertFalse( mesh.candidate_uses_whole_chips( - topology, + topology, # pyrefly: ignore[bad-argument-type] [(0, 0, 0, 0), (1, 0, 0, 0)], ) ) self.assertTrue( mesh.candidate_uses_whole_chips( - topology, + topology, # pyrefly: ignore[bad-argument-type] [(0, 0, 0, 0), (0, 0, 0, 1), (1, 0, 0, 0), (1, 0, 0, 1)], ) ) @@ -194,7 +194,7 @@ def __init__(self, device_id, coords): self.assertTrue( mesh.candidate_uses_whole_chips( - topology, + topology, # pyrefly: ignore[bad-argument-type] [(0, 0, 0)], ) ) @@ -477,7 +477,7 @@ def __init__(self, device_id, coords, process_index): allocated, _ = mesh._allocate_devices_by_coords(fake_devices, 8) self.assertEqual( - [device.id for device in allocated], [0, 1, 4, 5, 8, 9, 12, 13] + [device.id for device in allocated], [0, 1, 4, 5, 8, 9, 12, 13] # pyrefly: ignore[not-iterable] ) def test_allocate_named_mesh_device_slices_prefers_coord_boxes(self): @@ -537,7 +537,7 @@ def __init__(self, device_id, coords, core_on_chip): allocated, _ = mesh._allocate_devices_by_coords(fake_devices, 8) self.assertEqual( - [device.id for device in allocated], + [device.id for device in allocated], # pyrefly: ignore[not-iterable] [0, 1, 4, 5, 16, 17, 20, 21], ) @@ -573,7 +573,7 @@ def __init__(self, device_id, coords, process_index): allocated, _ = mesh._allocate_devices_by_coords(fake_devices, 4) - self.assertEqual([device.id for device in allocated], [0, 1, 2, 3]) + self.assertEqual([device.id for device in allocated], [0, 1, 2, 3]) # pyrefly: ignore[not-iterable] def test_allocate_devices_by_coords_prefers_more_cubical_supported_shape( self, @@ -595,7 +595,7 @@ def __init__(self, device_id, coords): allocated, _ = mesh._allocate_devices_by_coords(fake_devices, 256) - allocated_coords = [device.coords for device in allocated] + allocated_coords = [device.coords for device in allocated] # pyrefly: ignore[not-iterable] mins = tuple( min(coords[dim] for coords in allocated_coords) for dim in range(3) ) @@ -687,7 +687,7 @@ def __init__(self, device_id, coords, core_on_chip, process_index): self.assertEqual([device.id for device in actor_devices], [0, 1, 2, 3]) self.assertEqual( - next_state.remaining_coord_regions_by_slice[0][0], + next_state.remaining_coord_regions_by_slice[0][0], # pyrefly: ignore[unsupported-operation] mesh.CoordRegion((0, 1, 0, 0), (2, 1, 1, 2)), ) @@ -701,7 +701,7 @@ def __init__(self, device_id, coords, core_on_chip, process_index): self.assertEqual([device.id for device in rollout_devices], [8, 9, 10, 11]) self.assertNotIn( mesh.CoordRegion((0, 1, 0, 0), (2, 1, 1, 2)), - next_state.remaining_coord_regions_by_slice[0], + next_state.remaining_coord_regions_by_slice[0], # pyrefly: ignore[unsupported-operation] ) def test_allocate_devices_prefers_smallest_remaining_coord_region_first(self): diff --git a/tests/utils/token_sanitization_test.py b/tests/utils/token_sanitization_test.py index 5b70543cc..f0ea530c1 100644 --- a/tests/utils/token_sanitization_test.py +++ b/tests/utils/token_sanitization_test.py @@ -65,7 +65,7 @@ def test_sanitize_control_tokens_with_empty_extra(self): expected = 'hello world' self.assertEqual( token_sanitization.sanitize_control_tokens( - content, extra_tokens=extra_tokens + content, extra_tokens=extra_tokens # pyrefly: ignore[bad-argument-type] ), expected, ) diff --git a/tunix/generate/mappings.py b/tunix/generate/mappings.py index 7a820e0ee..93418b4d4 100644 --- a/tunix/generate/mappings.py +++ b/tunix/generate/mappings.py @@ -99,7 +99,7 @@ def build( return mapping_obj if mapping_obj is None: - return cls.from_model(model, backend) + return cls.from_model(model, backend) # pyrefly: ignore[bad-argument-type] keys = ( 'to_hf_mappings', diff --git a/tunix/generate/sampler.py b/tunix/generate/sampler.py index 663070d25..8ded0bf50 100644 --- a/tunix/generate/sampler.py +++ b/tunix/generate/sampler.py @@ -134,7 +134,7 @@ def sample_top_p( return next_token, logp_sampled k = next_token_logits.shape[-1] if _no_topk else top_k - logits_sorted, indices = jax.lax.top_k(next_token_logits, k=k) + logits_sorted, indices = jax.lax.top_k(next_token_logits, k=k) # pyrefly: ignore[bad-argument-type] probs_sorted = jax.nn.softmax(logits_sorted, axis=-1) cumsum_probs = jnp.cumsum(probs_sorted, axis=-1) @@ -225,7 +225,7 @@ def __init__( self.tokenizer = tok_adapter.TokenizerAdapter(tokenizer) self.cache_config = cache_config self.image_processor = image_processor - self._transformer_graphdef: graph.NodeDef = nnx.graphdef(transformer) + self._transformer_graphdef: graph.NodeDef = nnx.graphdef(transformer) # pyrefly: ignore[bad-assignment] self._transformer_state: list[statelib.State] = nnx.variables(transformer) self._flattened_transformer_state: list[statelib.State] = jax.tree.leaves( self._transformer_state, @@ -260,7 +260,7 @@ def model_def_and_state(self) -> tuple[graph.NodeDef, statelib.State]: @property def transformer(self) -> nnx.Module: - return nnx.merge( + return nnx.merge( # pyrefly: ignore[no-matching-overload] self._transformer_graphdef, self._flattened_transformer_state ) @@ -421,7 +421,7 @@ def init_sample_state( if include_logits: logits_buffer = jnp.zeros( - (batch_size, total_sampling_steps, self.transformer.num_embed), + (batch_size, total_sampling_steps, self.transformer.num_embed), # pyrefly: ignore[missing-attribute] dtype=jnp.float32, ) else: @@ -438,16 +438,16 @@ def init_sample_state( sampling_mode = [None] if beam_size is not None: - utils.check_sampling_mode_conflict(sampling_mode, 'beam_search') + utils.check_sampling_mode_conflict(sampling_mode, 'beam_search') # pyrefly: ignore[bad-argument-type] sampling_parameters['beam_size'] = beam_size if top_p is not None: - utils.check_sampling_mode_conflict(sampling_mode, 'top_p') + utils.check_sampling_mode_conflict(sampling_mode, 'top_p') # pyrefly: ignore[bad-argument-type] sampling_parameters['top_p'] = top_p sampling_parameters['top_k'] = top_k if sampling_mode[0] is None: - sampling_mode[0] = 'greedy' + sampling_mode[0] = 'greedy' # pyrefly: ignore[unsupported-operation] logging.debug('Using sampling mode: %s', sampling_mode[0]) @@ -504,7 +504,7 @@ def _sample( token_buffer=token_buffer, cache=cache, logits_buffer=logits_buffer, - state=beam_search_state, + state=beam_search_state, # pyrefly: ignore[bad-argument-type] pad_token_id=eos[0], decoding_step=decoding_step, logprobs_buffer=logprobs_buffer, @@ -526,7 +526,7 @@ def _sample( key, sampler_state.temperature, sampler_state.sampling_parameters['top_p'], - sampler_state.sampling_parameters['top_k'], + sampler_state.sampling_parameters['top_k'], # pyrefly: ignore[bad-argument-type] return_logprobs=(logprobs_buffer is not None), ) else: @@ -601,7 +601,7 @@ def _prefill_fn( input_mask, self.cache_config.cache_size ) - transformer = nnx.merge(self._transformer_graphdef, params) + transformer = nnx.merge(self._transformer_graphdef, params) # pyrefly: ignore[no-matching-overload] kwargs = {} if images is not None: kwargs['images'] = images @@ -714,7 +714,7 @@ def _sample_step( decoding_step, self.cache_config.cache_size, input_mask ) - transformer = nnx.merge(self._transformer_graphdef, params) + transformer = nnx.merge(self._transformer_graphdef, params) # pyrefly: ignore[no-matching-overload] logits, cache = transformer( last_token, positions=step_positions, @@ -834,13 +834,13 @@ def __call__( assert self.transformer.vision_encoder is not None processed_images, tokens = image_processor.process_gemma4_inputs( images, - tokens, + tokens, # pyrefly: ignore[bad-argument-type] self.transformer.vision_encoder, self.tokenizer.pad_id(), ) elif images is not None and self.image_processor is not None: - processed_images = self.image_processor(images) + processed_images = self.image_processor(images) # pyrefly: ignore[bad-argument-type] processed_images = jnp.array(processed_images) processed_audios = None @@ -849,8 +849,8 @@ def __call__( assert hasattr(self.transformer, 'audio_encoder') assert self.transformer.audio_encoder is not None processed_audios, tokens = audio_processor.process_gemma4_inputs( - audios=audios, - tokens=tokens, + audios=audios, # pyrefly: ignore[bad-argument-type] + tokens=tokens, # pyrefly: ignore[bad-argument-type] audio_encoder=self.transformer.audio_encoder, max_audio_length=max_audio_length, max_audio_clips=max_audio_clips, @@ -864,7 +864,7 @@ def __call__( all_input_ids = np.array([ utils.pad_to_length( - x, + x, # pyrefly: ignore[bad-argument-type] target_length=max_prompt_length, pad_value=self.tokenizer.pad_id(), left=True, @@ -880,9 +880,9 @@ def __call__( ) if seed is None: - seed = jax.random.PRNGKey(0) + seed = jax.random.PRNGKey(0) # pyrefly: ignore[bad-assignment] elif isinstance(seed, int): - seed = jax.random.PRNGKey(seed) + seed = jax.random.PRNGKey(seed) # pyrefly: ignore[bad-assignment] sampling_state = self.init_sample_state( jnp.array(all_input_ids), include_logits=return_logits, @@ -891,7 +891,7 @@ def __call__( temperature=temperature, top_p=top_p, top_k=top_k, - seed=seed, + seed=seed, # pyrefly: ignore[bad-argument-type] beam_size=beam_size, include_logprobs=return_logprobs, ) diff --git a/tunix/generate/tokenizer_adapter.py b/tunix/generate/tokenizer_adapter.py index cbe2f4a19..5d8e0cae2 100644 --- a/tunix/generate/tokenizer_adapter.py +++ b/tunix/generate/tokenizer_adapter.py @@ -26,9 +26,9 @@ class TokenizerType(enum.Enum): - SP: str = 'sp' # sentencepiece tokenizer - HF: str = 'hf' # huggingface tokenizer - NONE: str = 'none' # Represents no tokenizer + SP: str = 'sp' # sentencepiece tokenizer # pyrefly: ignore[invalid-annotation] + HF: str = 'hf' # huggingface tokenizer # pyrefly: ignore[invalid-annotation] + NONE: str = 'none' # Represents no tokenizer # pyrefly: ignore[invalid-annotation] class TokenizerAdapter: @@ -98,7 +98,7 @@ def pad_id(self) -> int: # e.g. llama3 HF tokenizers do not have pad_id if self._tokenizer.pad_token_id is None: self._tokenizer.pad_token = self._tokenizer.eos_token - return self._tokenizer.pad_token_id + return self._tokenizer.pad_token_id # pyrefly: ignore[bad-return] else: return self._tokenizer.pad_id() diff --git a/tunix/generate/utils.py b/tunix/generate/utils.py index 1c2714f17..056be4740 100644 --- a/tunix/generate/utils.py +++ b/tunix/generate/utils.py @@ -162,7 +162,7 @@ def np_find_first_eos_idx( """Numpy version of find_first_eos_idx. Works on CPU arrays.""" assert ids.ndim == 1, f'ids should be a 1d array. Got: {ids.shape}' if isinstance(eos_id, int): - eos_id = np.array([eos_id]) + eos_id = np.array([eos_id]) # pyrefly: ignore[bad-assignment] mask = np.isin(ids, eos_id) return int(np.argmax(mask)) if mask.any() else len(ids) @@ -334,8 +334,8 @@ def get_logprobs_from_vllm_output( extracted = [] for tok_id, tok_logprobs in zip(token_ids, logprobs): - if tok_id in tok_logprobs: - extracted.append(tok_logprobs[tok_id].logprob) + if tok_id in tok_logprobs: # pyrefly: ignore[not-iterable] + extracted.append(tok_logprobs[tok_id].logprob) # pyrefly: ignore[unsupported-operation] else: raise ValueError( f'The selected token id {tok_id} not in the return log probs list' @@ -896,14 +896,14 @@ def transfer_state_with_mappings( if kwargs.get('reshard_chunk_size', None) is not None: resharded_values_flat_dict = _reshard_in_chunks( src_flat=tgt_flat_dict, - spec_flat=sharding_dict, + spec_flat=sharding_dict, # pyrefly: ignore[bad-argument-type] reshard_fn=reshard_fn, chunk_size=kwargs['reshard_chunk_size'], delete_spec_buffers=kwargs.get('delete_dst_buffers', False), ) else: if kwargs.get('delete_dst_buffers', False): - _delete_target_buffers(sharding_dict, tgt_flat_dict) + _delete_target_buffers(sharding_dict, tgt_flat_dict) # pyrefly: ignore[bad-argument-type] resharded_values_flat_dict = reshard_fn(tgt_flat_dict, sharding_dict) for tgt_key, tgt_param in tgt_flat_list: @@ -992,7 +992,7 @@ def _unstack_scanned_param( return jnp.unstack(src_val) else: # Fallback for older JAX versions - return [src_val[i] for i in range(src_val.shape[0])] + return [src_val[i] for i in range(src_val.shape[0])] # pyrefly: ignore[bad-return] except Exception as e: logging.debug( "Failed to unstack parameter '%s'. Error: %s. Using original.", @@ -1041,7 +1041,7 @@ def _get_n_shards(arr: jax.Array | np.ndarray, axis: int) -> int: """Returns the number of shards for a given axis of an array.""" sharding = getattr(arr, 'sharding', None) if isinstance(sharding, jax.sharding.NamedSharding): - return _partition_size(_spec_at_axis(sharding, axis), sharding.mesh) + return _partition_size(_spec_at_axis(sharding, axis), sharding.mesh) # pyrefly: ignore[bad-argument-type] return 1 @@ -1133,7 +1133,7 @@ def _align_per_axis( mesh = tgt_sharding.mesh pad_specs = [] for axis, s, t in mismatches: - n_shards = _partition_size(_spec_at_axis(tgt_sharding, axis), mesh) + n_shards = _partition_size(_spec_at_axis(tgt_sharding, axis), mesh) # pyrefly: ignore[bad-argument-type] if t % n_shards != 0: raise ValueError( f"Target dimension {t} on axis {axis} for {key_path} is not " @@ -1512,7 +1512,7 @@ def _reshard_in_chunks( jax.block_until_ready(chunk_resharded) resharded.update(traverse_util.flatten_dict(chunk_resharded)) - del ( + del ( # pyrefly: ignore[unsupported-delete] chunk_src, chunk_dst_shardings, chunk_resharded, diff --git a/tunix/models/gemma3/merge_embeddings.py b/tunix/models/gemma3/merge_embeddings.py index 441c93dd7..86efbadb0 100644 --- a/tunix/models/gemma3/merge_embeddings.py +++ b/tunix/models/gemma3/merge_embeddings.py @@ -45,10 +45,10 @@ def _merge_embeddings_inner( ) # len(vision_embeddings) == max_num_images * num_tokens_per_image - target_pos = jnp.nonzero(mask, size=len(vision_embeddings)) + target_pos = jnp.nonzero(mask, size=len(vision_embeddings)) # pyrefly: ignore[bad-argument-type] # Save and restore the first position overwritten if there's no MM tokens. - first_pos = text_embeddings[0] + first_pos = text_embeddings[0] # pyrefly: ignore[bad-index] merged = text_embeddings.at[target_pos, :].set(vision_embeddings) # pytype: disable=attribute-error # jax-arraylike diff --git a/tunix/models/gemma3/model.py b/tunix/models/gemma3/model.py index 3e877bdaf..dedc9e8f6 100644 --- a/tunix/models/gemma3/model.py +++ b/tunix/models/gemma3/model.py @@ -336,7 +336,7 @@ def __init__( vision_proj_dim, rngs=rngs, param_dtype=param_dtype, - sharding=shd_config.vision_soft_emb_norm_weight, + sharding=shd_config.vision_soft_emb_norm_weight, # pyrefly: ignore[bad-argument-type] ) self.mm_input_projection = Einsum( einsum_str='...TM,MD->...TD', @@ -351,7 +351,7 @@ def __init__( def encode(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: x = self.input_embedding[(x,)] x *= jnp.sqrt(x.shape[-1]).astype(x.dtype) - x = sharding_utils.shard(x, self.shd_config.act_btd) + x = sharding_utils.shard(x, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] return x @jax.named_scope('embedder_decode') @@ -360,7 +360,7 @@ def decode(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: @jax.named_scope('embedder_encode_vision') def encode_vision(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: - x = self.mm_soft_embedding_norm(x) + x = self.mm_soft_embedding_norm(x) # pyrefly: ignore[bad-argument-type] x = self.mm_input_projection(x) return x @@ -576,9 +576,9 @@ def block( query_proj = self.q_einsum(x) key_proj, value_proj = self.kv_einsum(x) - query_proj = sharding_utils.shard(query_proj, self.shd_config.act_btnh) - key_proj = sharding_utils.shard(key_proj, self.shd_config.act_btnh) - value_proj = sharding_utils.shard(value_proj, self.shd_config.act_btnh) + query_proj = sharding_utils.shard(query_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] + key_proj = sharding_utils.shard(key_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] + value_proj = sharding_utils.shard(value_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] query_proj = self._query_norm(query_proj) key_proj = self._key_norm(key_proj) @@ -630,13 +630,13 @@ def block( if segment_pos.shape[1] == 1: # for decoding sliding_mask = create_sliding_window_mask( attn_mask, - sliding_window_size=self.sliding_window_size, + sliding_window_size=self.sliding_window_size, # pyrefly: ignore[bad-argument-type] ) else: # for prefill all_ones = jnp.ones_like(attn_mask) sliding_mask = jnp.triu( - all_ones, -1 * self.sliding_window_size + 1 - ) * jnp.tril(all_ones, self.sliding_window_size - 1) + all_ones, -1 * self.sliding_window_size + 1 # pyrefly: ignore[unsupported-operation] + ) * jnp.tril(all_ones, self.sliding_window_size - 1) # pyrefly: ignore[unsupported-operation] attn_mask = sliding_mask * attn_mask padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK) @@ -657,7 +657,7 @@ def block( encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) attn_output = self.attn_vec_einsum(encoded) - attn_output = sharding_utils.shard(attn_output, self.shd_config.act_btd) + attn_output = sharding_utils.shard(attn_output, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] if cache is not None: new_cache = { @@ -784,7 +784,7 @@ def block( ff1 = self.up_proj(x) activations = gate_value * ff1 activations = sharding_utils.shard( - activations, self.config.shd_config.act_btf + activations, self.config.shd_config.act_btf # pyrefly: ignore[bad-argument-type] ) outputs = self.down_proj(activations) return outputs @@ -794,7 +794,7 @@ def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: if self.config.remat_config == RematConfig.BLOCK: return nnx.remat(self.block.__func__, graph_updates=False)(self, x) else: - return self.block(x) + return self.block(x) # pyrefly: ignore[bad-argument-type] class DecoderLayer(nnx.Module): @@ -811,7 +811,7 @@ def __init__( self.pre_attention_norm = RMSNorm( config.embed_dim, rngs=rngs, - sharding=config.shd_config.rms_norm_weight, + sharding=config.shd_config.rms_norm_weight, # pyrefly: ignore[bad-argument-type] param_dtype=config.param_dtype, ) self.attn = Attention( @@ -836,13 +836,13 @@ def __init__( self.post_attention_norm = RMSNorm( config.embed_dim, rngs=rngs, - sharding=config.shd_config.rms_norm_weight, + sharding=config.shd_config.rms_norm_weight, # pyrefly: ignore[bad-argument-type] param_dtype=config.param_dtype, ) self.pre_ffw_norm = RMSNorm( config.embed_dim, rngs=rngs, - sharding=config.shd_config.rms_norm_weight, + sharding=config.shd_config.rms_norm_weight, # pyrefly: ignore[bad-argument-type] param_dtype=config.param_dtype, ) self.mlp = FeedForward( @@ -852,7 +852,7 @@ def __init__( self.post_ffw_norm = RMSNorm( config.embed_dim, rngs=rngs, - sharding=config.shd_config.rms_norm_weight, + sharding=config.shd_config.rms_norm_weight, # pyrefly: ignore[bad-argument-type] param_dtype=config.param_dtype, ) @@ -908,7 +908,7 @@ def __init__( param_dtype: jnp.dtype = jnp.bfloat16, ): self.scale = nnx.Param( - nnx.initializers.zeros_init()(rngs.params(), dim).astype(param_dtype), + nnx.initializers.zeros_init()(rngs.params(), dim).astype(param_dtype), # pyrefly: ignore[bad-argument-type] sharding=sharding, ) @@ -965,7 +965,7 @@ def __init__(self, config: ModelConfig, *, rngs: nnx.Rngs): self.final_norm = RMSNorm( config.embed_dim, rngs=rngs, - sharding=config.shd_config.rms_norm_weight, + sharding=config.shd_config.rms_norm_weight, # pyrefly: ignore[bad-argument-type] param_dtype=config.param_dtype, ) @@ -1012,9 +1012,9 @@ def __call__( with jax.named_scope(layer_name): layer_cache, x = layer( x, - positions, + positions, # pyrefly: ignore[bad-argument-type] layer_cache, - attention_mask, + attention_mask, # pyrefly: ignore[bad-argument-type] ) if cache is not None: new_cache[layer_name] = layer_cache # pytype: disable=container-type-mismatch diff --git a/tunix/models/gemma3/params.py b/tunix/models/gemma3/params.py index e65a5aafc..21bd40a5f 100644 --- a/tunix/models/gemma3/params.py +++ b/tunix/models/gemma3/params.py @@ -371,5 +371,5 @@ def save_lora_merged_model_as_safetensors( alpha=alpha, state_key_transform_fn=_gemma3_state_key_to_safetensors_key, custom_layer_extractor_fn=_extract_gemma3_lora_layers, - transpose_rules=_GEMMA3_HUGGINGFACE_TRANSPOSE_RULES, + transpose_rules=_GEMMA3_HUGGINGFACE_TRANSPOSE_RULES, # pyrefly: ignore[bad-argument-type] ) diff --git a/tunix/models/gemma3/params_safetensors.py b/tunix/models/gemma3/params_safetensors.py index e0ba68fcf..55bf98cbb 100644 --- a/tunix/models/gemma3/params_safetensors.py +++ b/tunix/models/gemma3/params_safetensors.py @@ -114,7 +114,7 @@ def _get_key_and_transform_mapping(cfg: model_lib.ModelConfig): # Vision Tower (SigLIP). if cfg.vision_config is not None: - mapping.update({ + mapping.update({ # pyrefly: ignore[no-matching-overload] r"vision_tower\.vision_model\.embeddings\.patch_embedding\.weight": ( "vision_encoder.siglip_encoder.embedding.kernel", ((2, 3, 1, 0), None), diff --git a/tunix/models/gemma3/utils.py b/tunix/models/gemma3/utils.py index 07a14b9b6..d0dbf0e4d 100644 --- a/tunix/models/gemma3/utils.py +++ b/tunix/models/gemma3/utils.py @@ -116,7 +116,7 @@ def _make_causal_mask( ) seq_len = input_mask.shape[-1] # pytype: disable=attribute-error # jax-arraylike causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool)) - attn_mask = input_mask[..., None, :] + attn_mask = input_mask[..., None, :] # pyrefly: ignore[bad-index] attn_mask *= causal_mask[None, ...] return attn_mask @@ -147,7 +147,7 @@ def _add_bidirectional_mask( q_block_indices = _make_block_mask_indices(bidirectional_mask) kv_block_indices = q_block_indices attn_mask = attn_mask | ( - (kv_block_indices[:, None, :] == q_block_indices[..., None]) - & (q_block_indices[..., None] > 0) + (kv_block_indices[:, None, :] == q_block_indices[..., None]) # pyrefly: ignore[bad-index] + & (q_block_indices[..., None] > 0) # pyrefly: ignore[bad-index] ) return attn_mask diff --git a/tunix/models/gemma3/vision.py b/tunix/models/gemma3/vision.py index 2d8b850a3..8d775edf1 100644 --- a/tunix/models/gemma3/vision.py +++ b/tunix/models/gemma3/vision.py @@ -185,9 +185,9 @@ def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array: k = self.key_proj(x) v = self.value_proj(x) if self.shd_config: - q = sharding_utils.shard(q, self.shd_config.act_btd) - k = sharding_utils.shard(k, self.shd_config.act_btd) - v = sharding_utils.shard(v, self.shd_config.act_btd) + q = sharding_utils.shard(q, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] + k = sharding_utils.shard(k, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] + v = sharding_utils.shard(v, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] q = q.reshape(desired_shape) k = k.reshape(desired_shape) @@ -195,7 +195,7 @@ def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array: logits = jnp.einsum("BTNH,BSNH->BNTS", q, k) if self.shd_config: - logits = sharding_utils.shard(logits, self.shd_config.act_bnts) + logits = sharding_utils.shard(logits, self.shd_config.act_bnts) # pyrefly: ignore[bad-argument-type] logits = logits / jnp.sqrt(self.head_dim).astype(logits.dtype) @@ -206,12 +206,12 @@ def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array: batch_size, seq_length, self.hidden_dim ) if self.shd_config: - out = sharding_utils.shard(out, self.shd_config.act_btd) + out = sharding_utils.shard(out, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] # 5. Final Output Projection out = self.out_proj(out) if self.shd_config: - out = sharding_utils.shard(out, self.shd_config.act_btd) + out = sharding_utils.shard(out, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] return out @@ -280,19 +280,19 @@ def __call__( Returns: The output tensor. """ - x = self.fc1(x) + x = self.fc1(x) # pyrefly: ignore[bad-argument-type] x = nnx.gelu(x, approximate=True) if self.shd_config: - x = sharding_utils.shard(x, self.shd_config.act_btd) + x = sharding_utils.shard(x, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] x = self.dropout(x, deterministic=deterministic) if self.shd_config: - x = sharding_utils.shard(x, self.shd_config.act_btd) + x = sharding_utils.shard(x, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] x = self.fc2(x) if self.shd_config: - x = sharding_utils.shard(x, self.shd_config.act_btd) + x = sharding_utils.shard(x, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] return x @@ -558,7 +558,7 @@ def __call__( # Patch extraction x = self.embedding(image) if self.shd_config: - x = sharding_utils.shard(x, self.shd_config.act_bhwd) + x = sharding_utils.shard(x, self.shd_config.act_bhwd) # pyrefly: ignore[bad-argument-type] n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) diff --git a/tunix/models/gemma4/audio.py b/tunix/models/gemma4/audio.py index fa58e1028..cc5422a2a 100644 --- a/tunix/models/gemma4/audio.py +++ b/tunix/models/gemma4/audio.py @@ -390,13 +390,13 @@ def __init__( self.ffn_layer1 = ClippedEinsum( rngs=rngs, shape=(num_features, num_features * 4), - dtype=dtype, + dtype=dtype, # pyrefly: ignore[bad-argument-type] param_dtype=param_dtype, ) self.ffn_layer2 = ClippedEinsum( rngs=rngs, shape=(num_features * 4, num_features), - dtype=dtype, + dtype=dtype, # pyrefly: ignore[bad-argument-type] param_dtype=param_dtype, ) self.post_layer_norm = nnx.RMSNorm( @@ -447,7 +447,7 @@ def __init__( self.linear_start = ClippedEinsum( rngs=rngs, shape=(num_features, 2 * num_features), # feature expansion for GLU. - dtype=dtype, + dtype=dtype, # pyrefly: ignore[bad-argument-type] param_dtype=param_dtype, ) self.depthwise_conv1d = nnx.Conv( @@ -472,7 +472,7 @@ def __init__( self.linear_end = ClippedEinsum( rngs=rngs, shape=(num_features, num_features), - dtype=dtype, + dtype=dtype, # pyrefly: ignore[bad-argument-type] param_dtype=param_dtype, ) @@ -660,19 +660,19 @@ def __init__( rngs=rngs, shape=(model_dims, model_dims), param_dtype=param_dtype, - dtype=dtype, + dtype=dtype, # pyrefly: ignore[bad-argument-type] ) self.key = ClippedEinsum( rngs=rngs, shape=(model_dims, model_dims), param_dtype=param_dtype, - dtype=dtype, + dtype=dtype, # pyrefly: ignore[bad-argument-type] ) self.value = ClippedEinsum( rngs=rngs, shape=(model_dims, model_dims), param_dtype=param_dtype, - dtype=dtype, + dtype=dtype, # pyrefly: ignore[bad-argument-type] ) self.per_dim_scale = nnx.Param(jnp.ones(units_per_head, dtype=param_dtype)) @@ -951,7 +951,7 @@ def __init__( self.post = ClippedEinsum( rngs=rngs, shape=(atten_num_heads, units_per_head, model_dims), - dtype=dtype, + dtype=dtype, # pyrefly: ignore[bad-argument-type] param_dtype=param_dtype, ) diff --git a/tunix/models/gemma4/model.py b/tunix/models/gemma4/model.py index a224a0f81..989314112 100644 --- a/tunix/models/gemma4/model.py +++ b/tunix/models/gemma4/model.py @@ -428,11 +428,11 @@ def encode(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: x = self.input_embedding[(x,)] x *= jnp.sqrt(x.shape[-1]).astype(x.dtype) x = jnp.astype(x, self.config.dtype) - x = shard(x, self.config.shd_config.act_btd) + x = shard(x, self.config.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] return x def encode_vision(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: - x = self.mm_pre_projection_norm(x) + x = self.mm_pre_projection_norm(x) # pyrefly: ignore[bad-argument-type] x = self.mm_input_projection(x) return x @@ -446,7 +446,7 @@ def encode_per_layer_input( self, x: jaxtyping.ArrayLike, t: jaxtyping.ArrayLike ) -> jaxtyping.Array: t = jnp.where( - jnp.logical_and(t >= 0, t < self.vocab_size), t, jnp.zeros_like(t) + jnp.logical_and(t >= 0, t < self.vocab_size), t, jnp.zeros_like(t) # pyrefly: ignore[unsupported-operation] ) x = self.per_layer_model_projection(x) x = jnp.reshape( @@ -512,10 +512,10 @@ def _add_bidirectional_mask( pad_len = attn_kv_len - kv_shape[-1] kv_block_indices = jnp.pad(kv_block_indices, [(0, 0), (0, pad_len)]) else: - kv_block_indices = kv_block_indices[..., -attn_kv_len:] + kv_block_indices = kv_block_indices[..., -attn_kv_len:] # pyrefly: ignore[bad-index] - bidir_cond = (kv_block_indices[:, None, :] == q_block_indices[..., None]) & ( - q_block_indices[..., None] > 0 + bidir_cond = (kv_block_indices[:, None, :] == q_block_indices[..., None]) & ( # pyrefly: ignore[bad-index] + q_block_indices[..., None] > 0 # pyrefly: ignore[bad-index] ) if len(attn_shape) == 4: @@ -645,7 +645,7 @@ def __init__( self.with_scale = with_scale if with_scale: self.scale = nnx.Param( - nnx.initializers.ones_init()(rngs.params(), dim).astype(param_dtype), + nnx.initializers.ones_init()(rngs.params(), dim).astype(param_dtype), # pyrefly: ignore[bad-argument-type] sharding=sharding.rms_norm_weight, ) self.dtype = dtype @@ -874,7 +874,7 @@ def block( x = x.astype(self.config.dtype) seq_len = x.shape[1] query_proj = self.q_einsum(x) - query_proj = shard(query_proj, self.config.shd_config.act_btnh) + query_proj = shard(query_proj, self.config.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] query_proj = self._query_norm(query_proj) query_proj = apply_rope( query_proj, @@ -895,8 +895,8 @@ def block( else: key_proj, value_proj = self.kv_einsum(x) - key_proj = shard(key_proj, self.config.shd_config.act_btnh) - value_proj = shard(value_proj, self.config.shd_config.act_btnh) + key_proj = shard(key_proj, self.config.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] + value_proj = shard(value_proj, self.config.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] # Apply norms to computed KV value_var = jnp.mean(jnp.square(value_proj), axis=-1, keepdims=True) @@ -971,7 +971,7 @@ def block( if self.attn_type == AttentionType.LOCAL_SLIDING: mask = mask_lib.LocalMask( (seq_len, seq_len), - window_size=(self.config.sliding_window_size - 1, 0), + window_size=(self.config.sliding_window_size - 1, 0), # pyrefly: ignore[unsupported-operation] offset=0, ) else: @@ -1139,14 +1139,14 @@ def sharded_splash_attn(kernel, q_block, k_block, v_block): # for decoding without sliding window cache sliding_mask = create_sliding_window_mask( attn_mask, - sliding_window_size=self.config.sliding_window_size, + sliding_window_size=self.config.sliding_window_size, # pyrefly: ignore[bad-argument-type] ) attn_mask = sliding_mask * attn_mask else: # for prefill all_ones = jnp.ones_like(attn_mask) sliding_mask = jnp.triu( - all_ones, -1 * self.config.sliding_window_size + 1 - ) * jnp.tril(all_ones, self.config.sliding_window_size - 1) + all_ones, -1 * self.config.sliding_window_size + 1 # pyrefly: ignore[unsupported-operation] + ) * jnp.tril(all_ones, self.config.sliding_window_size - 1) # pyrefly: ignore[unsupported-operation] attn_mask = sliding_mask * attn_mask attn = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK) @@ -1165,7 +1165,7 @@ def sharded_splash_attn(kernel, q_block, k_block, v_block): encoded = jnp.einsum('BTNS,BSNH->BTNH', attn, value_proj) attn_output = self.attn_vec_einsum(encoded) - attn_output = shard(attn_output, self.config.shd_config.act_btd) + attn_output = shard(attn_output, self.config.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] return new_cache, attn_output, (key_proj, value_proj) @property @@ -1213,18 +1213,18 @@ def init_cache(self, batch_size, max_seq_len, dtype): cache_shape = (batch_size, cache_len, self.num_kv_heads, self.head_dim) k = shard( - np.zeros(cache_shape, dtype), - self.config.shd_config.act_btnh, + np.zeros(cache_shape, dtype), # pyrefly: ignore[bad-argument-type] + self.config.shd_config.act_btnh, # pyrefly: ignore[bad-argument-type] eager=True, ) v = shard( - np.zeros(cache_shape, dtype), - self.config.shd_config.act_btnh, + np.zeros(cache_shape, dtype), # pyrefly: ignore[bad-argument-type] + self.config.shd_config.act_btnh, # pyrefly: ignore[bad-argument-type] eager=True, ) end_index = shard( - np.zeros((batch_size,), np.int32), - self.config.shd_config.act_btnh[:1], + np.zeros((batch_size,), np.int32), # pyrefly: ignore[bad-argument-type] + self.config.shd_config.act_btnh[:1], # pyrefly: ignore[bad-argument-type] eager=True, ) return {'k': k, 'v': v, 'end_index': end_index} @@ -1688,7 +1688,7 @@ def _encode_vision(self, vision_input: PreprocessedVisionInput): else: soft_token_counts = vision_input.soft_token_counts - max_n_images = max((len(counts) for counts in soft_token_counts), default=0) + max_n_images = max((len(counts) for counts in soft_token_counts), default=0) # pyrefly: ignore[bad-argument-type] if max_n_images == 0: return jnp.zeros((batch_size, 0, self.config.embed_dim)) @@ -1713,11 +1713,11 @@ def _encode_vision(self, vision_input: PreprocessedVisionInput): for b in range(batch_size): per_image_tokens = [] counts = soft_token_counts[b] if b < len(soft_token_counts) else () - for i in range(len(counts)): + for i in range(len(counts)): # pyrefly: ignore[bad-argument-type] idx = b * max_n_images + i - expected_count = counts[i] + expected_count = counts[i] # pyrefly: ignore[bad-index] if mask is not None: - valid_indices = jnp.nonzero(mask[idx], size=expected_count)[0] + valid_indices = jnp.nonzero(mask[idx], size=expected_count)[0] # pyrefly: ignore[bad-argument-type] real_tokens = embeddings[idx][valid_indices] else: real_tokens = embeddings[idx][:expected_count] diff --git a/tunix/models/gemma4/params_safetensors.py b/tunix/models/gemma4/params_safetensors.py index e8d9c3b07..3ecf22c5e 100644 --- a/tunix/models/gemma4/params_safetensors.py +++ b/tunix/models/gemma4/params_safetensors.py @@ -423,7 +423,7 @@ def _get_key_and_transform_mapping(cfg: model_lib.ModelConfig): }) if cfg.audio_encoder is not None: - mapping.update({ + mapping.update({ # pyrefly: ignore[no-matching-overload] # Audio Embedder r"(?:model\.)?embed_audio\.embedding_projection\.weight": ( "embedder.audio_input_projection.w", diff --git a/tunix/models/gemma4/vision.py b/tunix/models/gemma4/vision.py index a26dec9cf..52f8e07bb 100644 --- a/tunix/models/gemma4/vision.py +++ b/tunix/models/gemma4/vision.py @@ -424,9 +424,9 @@ def __call__( value_proj = self.value_norm(value_proj) if self.shd_config: - query_proj = sharding_utils.shard(query_proj, self.shd_config.act_btnh) - key_proj = sharding_utils.shard(key_proj, self.shd_config.act_bskh) - value_proj = sharding_utils.shard(value_proj, self.shd_config.act_bskh) + query_proj = sharding_utils.shard(query_proj, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] + key_proj = sharding_utils.shard(key_proj, self.shd_config.act_bskh) # pyrefly: ignore[bad-argument-type] + value_proj = sharding_utils.shard(value_proj, self.shd_config.act_bskh) # pyrefly: ignore[bad-argument-type] query_proj = apply_multidimensional_rope( query_proj, @@ -446,11 +446,11 @@ def __call__( ) if self.shd_config: - attn_vec = sharding_utils.shard(attn_vec, self.shd_config.act_btnh) + attn_vec = sharding_utils.shard(attn_vec, self.shd_config.act_btnh) # pyrefly: ignore[bad-argument-type] attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', attn_vec) if self.shd_config: - attn_output = sharding_utils.shard(attn_output, self.shd_config.act_btd) + attn_output = sharding_utils.shard(attn_output, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] return attn_output def _compute_attn_vec( @@ -488,7 +488,7 @@ def _qkv( attn_logits += attn_mask[:, None, None, :, :] if self.shd_config: - attn_logits = sharding_utils.shard(attn_logits, self.shd_config.act_bkgts) + attn_logits = sharding_utils.shard(attn_logits, self.shd_config.act_bkgts) # pyrefly: ignore[bad-argument-type] attn_weights = jax.nn.softmax(attn_logits, axis=-1).astype(v.dtype) result = jnp.einsum('bkgts,bskh->btkgh', attn_weights, v) @@ -527,10 +527,10 @@ def __call__(self, x: jax.Array) -> jax.Array: gate = self.gating_einsum('btd,cfd->btcf', x) activations = nnx.gelu(gate[..., 0, :]) * gate[..., 1, :] if self.shd_config: - activations = sharding_utils.shard(activations, self.shd_config.act_btf) + activations = sharding_utils.shard(activations, self.shd_config.act_btf) # pyrefly: ignore[bad-argument-type] out = self.linear('btf,fd->btd', activations) if self.shd_config: - out = sharding_utils.shard(out, self.shd_config.act_btd) + out = sharding_utils.shard(out, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] return out @@ -604,11 +604,11 @@ def __call__( attn_mask: jax.Array | None, ) -> jax.Array: if self.shd_config: - inputs = sharding_utils.shard(inputs, self.shd_config.act_btd) + inputs = sharding_utils.shard(inputs, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] normed_inputs = self.pre_attention_norm(inputs) if self.shd_config: normed_inputs = sharding_utils.shard( - normed_inputs, self.shd_config.act_btd + normed_inputs, self.shd_config.act_btd # pyrefly: ignore[bad-argument-type] ) attn_output = self.attn( x=normed_inputs, @@ -617,15 +617,15 @@ def __call__( ) attn_output = self.post_attention_norm(attn_output) if self.shd_config: - attn_output = sharding_utils.shard(attn_output, self.shd_config.act_btd) + attn_output = sharding_utils.shard(attn_output, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] attn_output += inputs outputs = self.pre_ffw_norm(attn_output) if self.shd_config: - outputs = sharding_utils.shard(outputs, self.shd_config.act_btd) + outputs = sharding_utils.shard(outputs, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] outputs = self.mlp(outputs) outputs = self.post_ffw_norm(outputs) if self.shd_config: - outputs = sharding_utils.shard(outputs, self.shd_config.act_btd) + outputs = sharding_utils.shard(outputs, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] outputs += attn_output return outputs @@ -671,15 +671,15 @@ def __call__( patches = 2.0 * (patches - 0.5) x = self.input_projection('btm,md->btd', patches) if self.shd_config: - x = sharding_utils.shard(x, self.shd_config.act_btd) + x = sharding_utils.shard(x, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] pos_embed = factorized_posemb(self.pos_emb.value, positions_xy).astype( x.dtype ) if self.shd_config: - pos_embed = sharding_utils.shard(pos_embed, self.shd_config.act_btd) + pos_embed = sharding_utils.shard(pos_embed, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] out = x + pos_embed if self.shd_config: - out = sharding_utils.shard(out, self.shd_config.act_btd) + out = sharding_utils.shard(out, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] return out @@ -841,7 +841,7 @@ def __call__( if self.shd_config: sharded_outputs = [] for emb, mask in outputs: - emb = sharding_utils.shard(emb, self.shd_config.act_btd) + emb = sharding_utils.shard(emb, self.shd_config.act_btd) # pyrefly: ignore[bad-argument-type] sharded_outputs.append((emb, mask)) outputs = tuple(sharded_outputs) diff --git a/tunix/processors/audio_processor.py b/tunix/processors/audio_processor.py index ba3f9f467..ef29bfaf7 100644 --- a/tunix/processors/audio_processor.py +++ b/tunix/processors/audio_processor.py @@ -187,8 +187,8 @@ def process_gemma4_inputs( padded_audio_lengths[b, i] = len(clip) processed_audios = gemma4_model_lib.PreprocessedAudioInput( - audios=padded_audios, - sequence_lengths=padded_audio_lengths, + audios=padded_audios, # pyrefly: ignore[bad-argument-type] + sequence_lengths=padded_audio_lengths, # pyrefly: ignore[bad-argument-type] ) return processed_audios, expanded_token_batch diff --git a/tunix/processors/image_processor.py b/tunix/processors/image_processor.py index ba5529ecd..fae604814 100644 --- a/tunix/processors/image_processor.py +++ b/tunix/processors/image_processor.py @@ -439,11 +439,11 @@ def add_variable_extra_tokens_for_images( counts = soft_token_counts[b] if b < len(soft_token_counts) else () for token in row: - if token == placeholder_token and image_idx < len(counts): - count = counts[image_idx] + if token == placeholder_token and image_idx < len(counts): # pyrefly: ignore[bad-argument-type] + count = counts[image_idx] # pyrefly: ignore[bad-index] expanded.append(double_new_line_token) expanded.append(start_token) - expanded.extend([soft_token_placeholder] * count) + expanded.extend([soft_token_placeholder] * count) # pyrefly: ignore[unsupported-operation] expanded.append(end_token) expanded.append(double_new_line_token) image_idx += 1 diff --git a/tunix/rl/algo_core.py b/tunix/rl/algo_core.py index 8bca37943..0e56daa6d 100644 --- a/tunix/rl/algo_core.py +++ b/tunix/rl/algo_core.py @@ -239,7 +239,7 @@ def ppo_policy_loss_fn( loss = policy_loss if return_entropy: - entropy_loss = masked_mean(token_entropy, completion_mask) + entropy_loss = masked_mean(token_entropy, completion_mask) # pyrefly: ignore[unbound-name] loss = loss - entropy_coef * entropy_loss aux["loss/entropy"] = entropy_loss @@ -519,12 +519,12 @@ def grpo_loss_fn( clamp_value=algo_config.kl_clamp_value, ) # Log mean KL. - aux["kl"] = jnp.astype( + aux["kl"] = jnp.astype( # pyrefly: ignore[bad-assignment] (kl * completion_mask).sum() / jnp.clip(completion_mask.sum(), min=1), jnp.float32, ) kl_loss = common.aggregate_loss(kl, completion_mask, loss_aggregation_mode) - aux["kl_loss"] = kl_loss + aux["kl_loss"] = kl_loss # pyrefly: ignore[bad-assignment] if beta is not None and beta != 0.0: loss = loss + beta * kl_loss diff --git a/tunix/rl/common.py b/tunix/rl/common.py index 0329a06cc..ac8fc1330 100644 --- a/tunix/rl/common.py +++ b/tunix/rl/common.py @@ -79,8 +79,8 @@ def _shuffle_and_slice_one_batch(self, rollout_batch: Any): # Slice the rollout batch into mini-batches. for i in range(num_mini_batches): - start = i * self._mini_batch_size - end = start + self._mini_batch_size + start = i * self._mini_batch_size # pyrefly: ignore[unsupported-operation] + end = start + self._mini_batch_size # pyrefly: ignore[unsupported-operation] batch_indices = shuffled_indices[start:end] mini_batch = jtu.tree_map( @@ -414,16 +414,16 @@ def compute_per_token_logps( ) if return_entropy: per_token_entropy = jnp.pad( - per_token_entropy, ((0, 0), (1, 0)), constant_values=0.0 + per_token_entropy, ((0, 0), (1, 0)), constant_values=0.0 # pyrefly: ignore[unbound-name] ) if stop_gradient: per_token_logps = jax.lax.stop_gradient(per_token_logps) if return_entropy: - per_token_entropy = jax.lax.stop_gradient(per_token_entropy) + per_token_entropy = jax.lax.stop_gradient(per_token_entropy) # pyrefly: ignore[unbound-name] if return_entropy: - return per_token_logps, per_token_entropy + return per_token_logps, per_token_entropy # pyrefly: ignore[unbound-name] return per_token_logps else: logits = outputs[:, -logits_to_keep - 1 : -1, :] @@ -769,4 +769,4 @@ def _check_get_norm( raise ValueError( f"Invalid 'norm' value: {norm}. Must be a positive number." ) - return norm + return norm # pyrefly: ignore[bad-return] diff --git a/tunix/rl/grpo/grpo_learner.py b/tunix/rl/grpo/grpo_learner.py index 8c1ca04d4..45b35f28d 100644 --- a/tunix/rl/grpo/grpo_learner.py +++ b/tunix/rl/grpo/grpo_learner.py @@ -166,7 +166,7 @@ def __init__( data_shuffle_seed=data_shuffle_seed, ) - self.algo_config.temperature = self.rl_cluster.get_rollout_config( + self.algo_config.temperature = self.rl_cluster.get_rollout_config( # pyrefly: ignore[missing-attribute] mode=rl_cluster_lib.Mode.TRAIN ).temperature @@ -190,7 +190,7 @@ def __init__( has_aux=True, ) self.rl_cluster.actor_trainer.with_gen_model_input_fn( - lambda x: { + lambda x: { # pyrefly: ignore[bad-argument-type] "train_example": x, "algo_config": self.algo_config, } @@ -199,7 +199,7 @@ def __init__( "kl": np.mean, "pg_clipfrac": np.mean, }) - self.rl_cluster.actor_trainer.with_tqdm_metrics_to_display([ + self.rl_cluster.actor_trainer.with_tqdm_metrics_to_display([ # pyrefly: ignore[bad-argument-type] lambda: "kl" if self.algo_config.beta != 0.0 else None, ]) @@ -224,7 +224,7 @@ def _generate_and_compute_advantage( if isinstance(rollout_config, dict): rollout_config = rollout_config[mode] - training_input["prompts"] = list(training_input["prompts"]) + training_input["prompts"] = list(training_input["prompts"]) # pyrefly: ignore[bad-argument-type] pad_value = self.rl_cluster.rollout.pad_id() eos_value = self.rl_cluster.rollout.eos_id() @@ -237,7 +237,7 @@ def _generate_and_compute_advantage( prompts=training_input["prompts"], mode=mode, micro_batch_size=( - self._rollout_micro_batch_size * self.algo_config.num_generations + self._rollout_micro_batch_size * self.algo_config.num_generations # pyrefly: ignore[unsupported-operation] ), trace_tags=perf_tags, ) @@ -274,7 +274,7 @@ def _generate_and_compute_advantage( pad_id=pad_value, eos_id=eos_value, micro_batch_size=( - self._compute_logps_micro_batch_size + self._compute_logps_micro_batch_size # pyrefly: ignore[unsupported-operation] * self.algo_config.num_generations ), ) @@ -293,7 +293,7 @@ def _generate_and_compute_advantage( prompt_tokens=prompt_ids, completion_tokens=jax_completion_ids, micro_batch_size=( - self._compute_logps_micro_batch_size + self._compute_logps_micro_batch_size # pyrefly: ignore[unsupported-operation] * self.algo_config.num_generations ), ) @@ -312,7 +312,7 @@ def _generate_and_compute_advantage( prompts=training_input["prompts"], completions=rollout_output.text, mode=mode, - **{k: v for k, v in training_input.items() if k != "prompts"}, + **{k: v for k, v in training_input.items() if k != "prompts"}, # pyrefly: ignore[bad-argument-type] ) advantage_estimator = function_registry.get_advantage_estimator( self.algo_config.advantage_estimator @@ -388,7 +388,7 @@ def _compute_trajectory_ids( Returns: A list of trajectory IDs, one for each prompt in the batch. """ - batch_size = len(example["prompts"]) // self.algo_config.num_generations + batch_size = len(example["prompts"]) // self.algo_config.num_generations # pyrefly: ignore[bad-argument-type] row_offset = steps * batch_size row_offsets = np.repeat( np.arange(row_offset, row_offset + batch_size), diff --git a/tunix/rl/rl_cluster.py b/tunix/rl/rl_cluster.py index 64866ee05..cd07dc166 100644 --- a/tunix/rl/rl_cluster.py +++ b/tunix/rl/rl_cluster.py @@ -474,7 +474,7 @@ def _init_cluster(self): ) or ( isinstance(self.cluster_config.rollout_engine, functools.partial) and issubclass( - self.cluster_config.rollout_engine.func, + self.cluster_config.rollout_engine.func, # pyrefly: ignore[bad-argument-type] base_rollout.BaseRollout, ) ): @@ -556,7 +556,7 @@ def _init_cluster(self): with self._get_mesh_and_logical_axis_rules_cm(Role.CRITIC): self._critic_trainer = rl_trainer.Trainer( model=self.critic, - optimizer=self.cluster_config.training_config.critic_optimizer, + optimizer=self.cluster_config.training_config.critic_optimizer, # pyrefly: ignore[bad-argument-type] training_config=critic_config, custom_checkpoint_metadata_fn=lambda: { "global_step": self.global_steps + 1, @@ -748,7 +748,7 @@ def _log_metrics(self, metrics_buffer: MetricsBuffer) -> None: self._rl_metrics_logger.log( prefix, metric_name, - agg_value, + agg_value, # pyrefly: ignore[bad-argument-type] metrics_buffer.mode, metrics_buffer.global_steps, ) @@ -960,7 +960,7 @@ def generate( logprobs = None if outputs[0].logprobs is not None: logprobs = list( - itertools.chain.from_iterable(out.logprobs for out in outputs) + itertools.chain.from_iterable(out.logprobs for out in outputs) # pyrefly: ignore[bad-argument-type] ) logits = None diff --git a/tunix/rl/rl_learner.py b/tunix/rl/rl_learner.py index 0c0bfc7f8..ce12e7ce0 100644 --- a/tunix/rl/rl_learner.py +++ b/tunix/rl/rl_learner.py @@ -394,7 +394,7 @@ def _process_and_enqueue_tail(): ): # Fetch one training micro-batch example = next(iterator) - cur_batch_size = len(example["prompts"]) + cur_batch_size = len(example["prompts"]) # pyrefly: ignore[bad-argument-type] # Buffer the fetched micro-batch. We accumulate micro-batches and track # their sizes and the total number of samples. This allows us to form a @@ -527,7 +527,7 @@ def train( """Main entry point for the training loop.""" full_batch_iterator = iter(train_ds) first_item = next(full_batch_iterator) - full_batch_size = len(first_item["prompts"]) + full_batch_size = len(first_item["prompts"]) # pyrefly: ignore[bad-argument-type] full_batch_iterator = itertools.chain([first_item], full_batch_iterator) # Initialize batch sizes. mini_batch_size = self._training_config.mini_batch_size or full_batch_size @@ -622,7 +622,7 @@ def train( ) if ( - self.rl_cluster.actor_trainer.train_steps + self.rl_cluster.actor_trainer.train_steps # pyrefly: ignore[unsupported-operation] >= self.rl_cluster.cluster_config.training_config.max_steps ): break diff --git a/tunix/rl/utils.py b/tunix/rl/utils.py index 69fd1ef9a..9ea270220 100644 --- a/tunix/rl/utils.py +++ b/tunix/rl/utils.py @@ -90,7 +90,7 @@ def _get_mesh_info(leaf: jaxtyping.PyTree): def _is_same_state(s1: jaxtyping.PyTree, s2: jaxtyping.PyTree) -> bool: """Returns whether two states refer to the same Params.""" - return np.all( + return np.all( # pyrefly: ignore[bad-return] jax.tree.map( lambda x, y: x is y, jax.tree_util.tree_leaves(s1), @@ -321,11 +321,11 @@ def unpad_train_example(example: common.TrainExample) -> list[dict[str, Any]]: "completion_mask": c_mask[i, :c_len], "advantages": adv[i, :c_len] if adv_is_per_token else adv[i], "adv_is_per_token": adv_is_per_token, - "ref_per_token_logps": ref_logps[i, :c_len] if has_ref else None, - "old_per_token_logps": old_logps[i, :c_len] if has_old else None, - "returns": returns_np[i, :c_len] if has_returns else None, - "old_values": old_values_np[i, :c_len] if has_old_values else None, - "policy_version": policy_version_np if has_policy_version else None, + "ref_per_token_logps": ref_logps[i, :c_len] if has_ref else None, # pyrefly: ignore[unbound-name] + "old_per_token_logps": old_logps[i, :c_len] if has_old else None, # pyrefly: ignore[unbound-name] + "returns": returns_np[i, :c_len] if has_returns else None, # pyrefly: ignore[unbound-name] + "old_values": old_values_np[i, :c_len] if has_old_values else None, # pyrefly: ignore[unbound-name] + "policy_version": policy_version_np if has_policy_version else None, # pyrefly: ignore[unbound-name] } res.append(item) return res diff --git a/tunix/sft/checkpoint_manager.py b/tunix/sft/checkpoint_manager.py index c717e82c1..ae376352e 100644 --- a/tunix/sft/checkpoint_manager.py +++ b/tunix/sft/checkpoint_manager.py @@ -75,7 +75,7 @@ def __init__( 'model_params': ocp.PyTreeCheckpointHandler(), 'optimizer_state': ocp.PyTreeCheckpointHandler(), } - item_handlers['custom_metadata'] = ocp.JsonCheckpointHandler() + item_handlers['custom_metadata'] = ocp.JsonCheckpointHandler() # pyrefly: ignore[unsupported-operation] self._checkpoint_manager = ocp.CheckpointManager( root_directory, item_handlers=item_handlers, @@ -201,11 +201,11 @@ def fix_sharding(state): for s in jax.tree_util.tree_leaves(shardings) if isinstance(s, jax.sharding.NamedSharding) ) - return nnx.get_named_sharding(optimizer_state, named_sharding.mesh) + return nnx.get_named_sharding(optimizer_state, named_sharding.mesh) # pyrefly: ignore[bad-argument-type] except StopIteration: return shardings - if optimizer is not None and 'optimizer_state' in metadata.item_metadata: + if optimizer is not None and 'optimizer_state' in metadata.item_metadata: # pyrefly: ignore[not-iterable] optimizer_state = nnx.state(optimizer, nnx.optimizer.OptState) fixed_sharding = fix_sharding(optimizer_state) optimizer_cp_args = ocp.args.PyTreeRestore( @@ -221,7 +221,7 @@ def fix_sharding(state): optimizer_state=optimizer_cp_args, ), ) - nnx.update(optimizer, ckpt.optimizer_state) + nnx.update(optimizer, ckpt.optimizer_state) # pyrefly: ignore[missing-attribute] else: ckpt = self._checkpoint_manager.restore( step, @@ -230,7 +230,7 @@ def fix_sharding(state): ), ) # Update the model state with params from the restored checkpoint. - nnx.update(model, ckpt.model_params) + nnx.update(model, ckpt.model_params) # pyrefly: ignore[missing-attribute] logging.info( 'Restored params from step: %d in %.3f seconds', step, diff --git a/tunix/sft/checkpoint_options.py b/tunix/sft/checkpoint_options.py index 579b0d6d8..9fe1ff5dc 100644 --- a/tunix/sft/checkpoint_options.py +++ b/tunix/sft/checkpoint_options.py @@ -99,10 +99,10 @@ class TunixCheckpointingOptions: # - Use async checkpointing. # - Timeout for async operations is 1200 seconds. DEFAULT_CHECKPOINTING_OPTIONS = TunixCheckpointingOptions( - save_decision_policy=ocp.training.save_decision_policies.ContinuousCheckpointingPolicy( + save_decision_policy=ocp.training.save_decision_policies.ContinuousCheckpointingPolicy( # pyrefly: ignore[bad-argument-type] minimum_interval_secs=180, ), - preservation_policy=ocp.training.preservation_policies.LatestN(n=3), + preservation_policy=ocp.training.preservation_policies.LatestN(n=3), # pyrefly: ignore[bad-argument-type] step_name_format=ocp.path.step.standard_name_format(), enable_async_checkpointing=True, async_options=ocp.options.AsyncOptions(timeout_secs=1200), @@ -206,9 +206,9 @@ def resolve_checkpointing_defaults( async_options = DEFAULT_CHECKPOINTING_OPTIONS.async_options return create_checkpointing_options( - save_decision_policy=save_policy, - preservation_policy=preserve_policy, - step_name_format=step_name_format, + save_decision_policy=save_policy, # pyrefly: ignore[bad-argument-type] + preservation_policy=preserve_policy, # pyrefly: ignore[bad-argument-type] + step_name_format=step_name_format, # pyrefly: ignore[bad-argument-type] enable_async_checkpointing=enable_async, async_options=async_options, ) diff --git a/tunix/sft/dpo/dpo_trainer.py b/tunix/sft/dpo/dpo_trainer.py index 88f625144..ffe09ca2e 100644 --- a/tunix/sft/dpo/dpo_trainer.py +++ b/tunix/sft/dpo/dpo_trainer.py @@ -166,7 +166,7 @@ def compute_logps( completion_logps = (completion_logps * completion_mask).sum(axis=-1) # Extract log probs for prompt + completion (excluding first token) - full_sequence_mask = full_mask[:, 1:] + full_sequence_mask = full_mask[:, 1:] # pyrefly: ignore[unsupported-operation] full_logps = (token_logps * full_sequence_mask).sum(axis=-1) batch_size = token_logps.shape[0] @@ -257,7 +257,7 @@ def __init__( if self.algorithm == "orpo": self.with_gen_model_input_fn( - lambda x: { + lambda x: { # pyrefly: ignore[bad-argument-type] "train_example": x, "algorithm": "orpo", "lambda_orpo": self.dpo_config.lambda_orpo, @@ -278,7 +278,7 @@ def __init__( } else: self.with_gen_model_input_fn( - lambda x: { + lambda x: { # pyrefly: ignore[bad-argument-type] "train_example": x, "algorithm": "dpo", "beta": self.dpo_config.beta, @@ -342,7 +342,7 @@ def _prepare_inputs( ) training_input = process_dpo_record( - record={ + record={ # pyrefly: ignore[bad-argument-type] "prompts": training_input.prompts, "images": training_input.images, "chosen_responses": training_input.chosen_responses, @@ -499,7 +499,7 @@ def dpo_loss_fn( # SFT log probs (can include prompt) if enable_prompt_loss_orpo: if average_log_prob_orpo: - chosen_full_mask = train_example.full_mask[:batch_size, 1:] + chosen_full_mask = train_example.full_mask[:batch_size, 1:] # pyrefly: ignore[unsupported-operation] chosen_sft_lengths = jnp.maximum(chosen_full_mask.sum(axis=-1), 1.0) sft_loss = -prompt_chosen_logps / chosen_sft_lengths else: @@ -627,7 +627,7 @@ def _preprocess_dict( for field in tokenized_input_fields if field != "images" ): - return TrainingInput(**{ + return TrainingInput(**{ # pyrefly: ignore[bad-argument-type] field: training_input.get(field, None) for field in tokenized_input_fields }) @@ -636,7 +636,7 @@ def _preprocess_dict( for field in data_input_fields if field != "images" ): - return DataInput(**{ + return DataInput(**{ # pyrefly: ignore[bad-argument-type] field: training_input.get(field, None) for field in data_input_fields }) else: @@ -697,16 +697,16 @@ def process_dpo_record( # Only prompt is left padded, others are right padded. prompt_ids, prompt_mask = _generate_ids_and_masks( - prompts, + prompts, # pyrefly: ignore[bad-argument-type] tokenizer, max_prompt_length, left_pad=True, ) chosen_ids, chosen_mask = _generate_ids_and_masks( - chosen_responses, tokenizer, max_response_length, left_pad=False + chosen_responses, tokenizer, max_response_length, left_pad=False # pyrefly: ignore[bad-argument-type] ) rejected_ids, rejected_mask = _generate_ids_and_masks( - rejected_responses, tokenizer, max_response_length, left_pad=False + rejected_responses, tokenizer, max_response_length, left_pad=False # pyrefly: ignore[bad-argument-type] ) if images is not None: if image_processor is None: @@ -725,7 +725,7 @@ def process_dpo_record( chosen_mask = jnp.squeeze(chosen_mask, axis=0) rejected_mask = jnp.squeeze(rejected_mask, axis=0) if images is not None: - images = jnp.squeeze(images, axis=0) + images = jnp.squeeze(images, axis=0) # pyrefly: ignore[bad-argument-type] return TrainingInput( prompt_ids=prompt_ids, @@ -734,7 +734,7 @@ def process_dpo_record( chosen_mask=chosen_mask, rejected_ids=rejected_ids, rejected_mask=rejected_mask, - images=images, + images=images, # pyrefly: ignore[bad-argument-type] ) diff --git a/tunix/sft/hooks.py b/tunix/sft/hooks.py index 1c9ae811b..7088e9db3 100644 --- a/tunix/sft/hooks.py +++ b/tunix/sft/hooks.py @@ -26,24 +26,24 @@ class TrainingHooks(ABC): """Hooks to be used for training.""" @abstractmethod - def on_train_start(self, train_ctx: "PeftTrainer.PeftTrainer"): + def on_train_start(self, train_ctx: "PeftTrainer.PeftTrainer"): # pyrefly: ignore[missing-attribute] """Called at the beginning of training.""" pass @abstractmethod - def on_train_end(self, train_ctx: "PeftTrainer.PeftTrainer"): + def on_train_end(self, train_ctx: "PeftTrainer.PeftTrainer"): # pyrefly: ignore[missing-attribute] """Called at the end of training.""" pass @abstractmethod - def on_train_step_start(self, train_ctx: "PeftTrainer.PeftTrainer"): + def on_train_step_start(self, train_ctx: "PeftTrainer.PeftTrainer"): # pyrefly: ignore[missing-attribute] """Called at the beginning of a training step.""" pass @abstractmethod def on_train_step_end( self, - train_ctx: "PeftTrainer.PeftTrainer", + train_ctx: "PeftTrainer.PeftTrainer", # pyrefly: ignore[missing-attribute] train_step: int, train_loss: float, ): @@ -51,13 +51,13 @@ def on_train_step_end( pass @abstractmethod - def on_eval_step_start(self, train_ctx: "PeftTrainer.PeftTrainer"): + def on_eval_step_start(self, train_ctx: "PeftTrainer.PeftTrainer"): # pyrefly: ignore[missing-attribute] """Called at the beginning of an evaluation step.""" pass @abstractmethod def on_eval_step_end( - self, train_ctx: "PeftTrainer.PeftTrainer", eval_loss: float + self, train_ctx: "PeftTrainer.PeftTrainer", eval_loss: float # pyrefly: ignore[missing-attribute] ): """Called at the end of an evaluation step.""" pass @@ -67,11 +67,11 @@ class DataHooks(ABC): """Hooks to wire in external data loader and processing logic.""" @abstractmethod - def load_next_train_batch(self, train_ctx: "PeftTrainer.PeftTrainer") -> Any: + def load_next_train_batch(self, train_ctx: "PeftTrainer.PeftTrainer") -> Any: # pyrefly: ignore[missing-attribute] """Loads the next batch of data for training.""" raise NotImplementedError() @abstractmethod - def load_next_eval_batch(self, train_ctx: "PeftTrainer.PeftTrainer") -> Any: + def load_next_eval_batch(self, train_ctx: "PeftTrainer.PeftTrainer") -> Any: # pyrefly: ignore[missing-attribute] """Loads the next batch of data for evaluation.""" raise NotImplementedError() diff --git a/tunix/sft/metrics_logger.py b/tunix/sft/metrics_logger.py index 10b83c40e..abb87f1f2 100644 --- a/tunix/sft/metrics_logger.py +++ b/tunix/sft/metrics_logger.py @@ -60,7 +60,7 @@ def create_backends(self) -> list[LoggingBackend]: "custom_backend" in self.backend_kwargs and self.backend_kwargs["custom_backend"] ): - return [factory() for factory in self.backend_kwargs["custom_backend"]] + return [factory() for factory in self.backend_kwargs["custom_backend"]] # pyrefly: ignore[not-callable] # Case 2: Defaults. active_backends = [] @@ -79,7 +79,7 @@ def create_backends(self) -> list[LoggingBackend]: TensorboardBackend( log_dir=self.log_dir, flush_every_n_steps=self.flush_every_n_steps, - **tb_kwargs, + **tb_kwargs, # pyrefly: ignore[bad-unpacking] ) ) try: @@ -88,7 +88,7 @@ def create_backends(self) -> list[LoggingBackend]: WandbBackend( project=self.project_name, name=self.run_name, - **wandb_kwargs, + **wandb_kwargs, # pyrefly: ignore[bad-unpacking] ) ) except ImportError: @@ -146,7 +146,7 @@ def log( mode_metrics[metric_name].append(scalar_value) jax.monitoring.record_scalar( - f"{metrics_prefix}/{mode}/{metric_name}", scalar_value, step=step + f"{metrics_prefix}/{mode}/{metric_name}", scalar_value, step=step # pyrefly: ignore[bad-argument-type] ) def metric_exists( diff --git a/tunix/sft/peft_trainer.py b/tunix/sft/peft_trainer.py index 501e5739c..336bcc3bc 100644 --- a/tunix/sft/peft_trainer.py +++ b/tunix/sft/peft_trainer.py @@ -210,7 +210,7 @@ def __init__( self.config = training_config self._lora_enabled = utils.is_lora_enabled(self.model) if training_config.gradient_accumulation_steps is not None: - optimizer = optax.MultiSteps( + optimizer = optax.MultiSteps( # pyrefly: ignore[bad-assignment] optimizer, training_config.gradient_accumulation_steps ) if self._lora_enabled: @@ -304,8 +304,8 @@ def with_loss_fn( has_aux: bool = False, ): self.clear_jit_cache() - self.loss_fn = loss_fn - self.eval_loss_fn = loss_fn + self.loss_fn = loss_fn # pyrefly: ignore[bad-assignment] + self.eval_loss_fn = loss_fn # pyrefly: ignore[bad-assignment] self._has_aux = has_aux return self @@ -325,7 +325,7 @@ def with_gen_model_input_fn( PeftTrainer. """ self.clear_jit_cache() - self.gen_model_input_fn = gen_model_input_fn + self.gen_model_input_fn = gen_model_input_fn # pyrefly: ignore[bad-assignment] return self def _train_step( @@ -364,7 +364,7 @@ def _eval_step( inputs = self.gen_model_input_fn(inputs) out = self.eval_loss_fn(model, **inputs) if self._has_aux: - loss, aux = out + loss, aux = out # pyrefly: ignore[not-iterable] return loss, aux else: return out, None @@ -377,7 +377,7 @@ def create_train_step_fn( def create_eval_step_fn(self) -> Callable[..., ArrayLike]: """Creates the eval step function.""" - return self._eval_step + return self._eval_step # pyrefly: ignore[bad-return] def _shard_optimizer(self, mesh: shd.Mesh) -> None: """Optimizer states should be sharded before calling the jit function. @@ -470,13 +470,13 @@ def _log_metrics( ): """Logs the metrics to the metrics logger and console.""" perplexity = np.exp(jax.device_get(loss)) - self.metrics_logger.log(self.metrics_prefix, "loss", loss, self._mode, step) - self.metrics_logger.log( + self.metrics_logger.log(self.metrics_prefix, "loss", loss, self._mode, step) # pyrefly: ignore[missing-attribute] + self.metrics_logger.log( # pyrefly: ignore[missing-attribute] self.metrics_prefix, "perplexity", perplexity, self._mode, step ) learning_rate = self._try_get_learning_rate() if learning_rate is not None: - self.metrics_logger.log( + self.metrics_logger.log( # pyrefly: ignore[missing-attribute] self.metrics_prefix, "learning_rate", jax.device_get(learning_rate), @@ -492,7 +492,7 @@ def _log_metrics( perplexity, ) for k, v in (additional_metrics or {}).items(): - self.metrics_logger.log(self.metrics_prefix, k, v, self._mode, step) + self.metrics_logger.log(self.metrics_prefix, k, v, self._mode, step) # pyrefly: ignore[missing-attribute] def _buffer_metrics( self, @@ -618,7 +618,7 @@ def train( if self.config.max_steps is not None and self._pbar is None: self._pbar = progress_bar.ProgressBar( metrics_prefix=self.metrics_prefix, - metrics_logger=self.metrics_logger, + metrics_logger=self.metrics_logger, # pyrefly: ignore[bad-argument-type] initial_steps=self._train_steps, max_steps=self.config.max_steps, description=self.config.pbar_description, @@ -788,7 +788,7 @@ def close(self): self._write_train_metrics() self._save_last_checkpoint() self.checkpoint_manager.close() - self.metrics_logger.close() + self.metrics_logger.close() # pyrefly: ignore[missing-attribute] if self._pbar is not None: self._pbar.close() self._pbar = None @@ -836,12 +836,12 @@ def _run_eval( ) return - self._write_metrics(self._buffered_eval_metrics) + self._write_metrics(self._buffered_eval_metrics) # pyrefly: ignore[bad-argument-type] logging.info( "Train step %d eval loss: %f - eval perplexity: %f", self._train_steps, - self.metrics_logger.get_metric(self.metrics_prefix, "loss", "eval"), - self.metrics_logger.get_metric( + self.metrics_logger.get_metric(self.metrics_prefix, "loss", "eval"), # pyrefly: ignore[missing-attribute] + self.metrics_logger.get_metric( # pyrefly: ignore[missing-attribute] self.metrics_prefix, "perplexity", "eval" ), ) diff --git a/tunix/sft/profiler.py b/tunix/sft/profiler.py index 598845dca..616548dfd 100644 --- a/tunix/sft/profiler.py +++ b/tunix/sft/profiler.py @@ -107,7 +107,7 @@ def _start_trace( jax.profiler.start_trace( # pytype: disable=wrong-keyword-args log_dir=log_dir, profiler_options=profiler_options, - max_num_hosts=self._profiler_options.max_num_hosts, + max_num_hosts=self._profiler_options.max_num_hosts, # pyrefly: ignore[unexpected-keyword] ) else: jax.profiler.start_trace( diff --git a/tunix/utils/math_utils.py b/tunix/utils/math_utils.py index 6c3302e9b..69c3bcb61 100644 --- a/tunix/utils/math_utils.py +++ b/tunix/utils/math_utils.py @@ -231,15 +231,15 @@ def _is_frac(expr: str) -> bool: def _str_is_int(x: str) -> bool: try: x = _strip_properly_formatted_commas(x) - x = float(x) - return abs(x - int(round(x))) <= 1e-7 + x = float(x) # pyrefly: ignore[bad-assignment] + return abs(x - int(round(x))) <= 1e-7 # pyrefly: ignore[bad-argument-type, unsupported-operation] except Exception: return False def _str_to_int(x: str) -> int: x = x.replace(",", "") - x = float(x) + x = float(x) # pyrefly: ignore[bad-assignment] return int(x) @@ -267,7 +267,7 @@ def _strip_properly_formatted_commas(expr: str): def _normalize(expr: str) -> str: """Normalize answer expressions.""" if expr is None: - return None + return None # pyrefly: ignore[bad-return] # Remove enclosing `\text{}`. m = re.search(r"^\\text\{(?P.+?)\}", expr) @@ -434,7 +434,7 @@ def remove_boxed(s): def extract_boxed_answer(solution: str): """Extract the answer from inside a LaTeX \\boxed{} command""" solution = last_boxed_only_string(solution) - solution = remove_boxed(solution) if solution is not None else solution + solution = remove_boxed(solution) if solution is not None else solution # pyrefly: ignore[bad-assignment] logging.vlog(4, f"{solution=} in extracted_boxed_answer") return solution diff --git a/tunix/utils/mesh.py b/tunix/utils/mesh.py index 9f07ebebc..a1bb5687b 100644 --- a/tunix/utils/mesh.py +++ b/tunix/utils/mesh.py @@ -299,7 +299,7 @@ def infer_core_on_chip_count(devices: Sequence[Any]) -> int | None: if min_core is None: return None - return max_core - min_core + 1 + return max_core - min_core + 1 # pyrefly: ignore[unsupported-operation] def summarize_devices_for_logging( diff --git a/tunix/utils/script_utils.py b/tunix/utils/script_utils.py index 72d8d9030..958187bc6 100644 --- a/tunix/utils/script_utils.py +++ b/tunix/utils/script_utils.py @@ -43,7 +43,7 @@ def get_dataset( ) -> grain.MapDataset: """Loads the dataset, from CNS in g3 or downloading in OSS.""" if ENV == 'g3': - with gfile.Open(path, 'rb') as f: + with gfile.Open(path, 'rb') as f: # pyrefly: ignore[missing-attribute] data = json.loads(f.read()) else: # oss if path.startswith('gs://'): diff --git a/tunix/utils/trajectory_logger.py b/tunix/utils/trajectory_logger.py index 0dfe4e04d..71a9504db 100644 --- a/tunix/utils/trajectory_logger.py +++ b/tunix/utils/trajectory_logger.py @@ -101,10 +101,10 @@ def log_item( else: raise ValueError(f'Item {item} is not a dataclass, dictionary or list.') - log_path = epath.Path(log_path) - log_path.mkdir(parents=True, exist_ok=True) + log_path = epath.Path(log_path) # pyrefly: ignore[bad-assignment] + log_path.mkdir(parents=True, exist_ok=True) # pyrefly: ignore[missing-attribute] - assert log_path.is_dir(), f'log_path `{log_path}` must be a directory.' + assert log_path.is_dir(), f'log_path `{log_path}` must be a directory.' # pyrefly: ignore[missing-attribute] if isinstance(item, list): item_name = _get_item_name(item[0]) @@ -113,7 +113,7 @@ def log_item( file_stem = item_name if item_name else 'trajectory_log' filename = f'{file_stem}_{suffix}.csv' if suffix else f'{file_stem}.csv' - file_path = log_path / filename + file_path = log_path / filename # pyrefly: ignore[unsupported-operation] logging.log_first_n(logging.INFO, f'Logging item to {file_path}', 1) write_header = not file_path.exists() @@ -197,9 +197,9 @@ def _worker(): # Register signal handlers for robust termination if threading.current_thread() is threading.main_thread(): try: - signal.signal(signal.SIGINT, self._handle_signal) - signal.signal(signal.SIGTERM, self._handle_signal) - signal.signal(signal.SIGHUP, self._handle_signal) + signal.signal(signal.SIGINT, self._handle_signal) # pyrefly: ignore[bad-argument-type] + signal.signal(signal.SIGTERM, self._handle_signal) # pyrefly: ignore[bad-argument-type] + signal.signal(signal.SIGHUP, self._handle_signal) # pyrefly: ignore[bad-argument-type] except ValueError: logging.warning('Failed to register signal handlers.')