Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions tests/generate/beam_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ def test_beam_search_step_without_stop(self) -> None:
new_scores1, tokens1 = jax.lax.top_k(
jax.nn.log_softmax(jnp.array([1, 2, 1.1])), 2
)
self.assertAlmostEqual(state.scores[0][0], new_scores0[0], places=6)
self.assertAlmostEqual(state.scores[0][1], new_scores0[1], places=6)
self.assertAlmostEqual(state.scores[1][0], new_scores1[0], places=6)
self.assertAlmostEqual(state.scores[1][1], new_scores1[1], places=6)
self.assertAlmostEqual(state.scores[0][0], new_scores0[0], places=6) # pyrefly: ignore[no-matching-overload]
self.assertAlmostEqual(state.scores[0][1], new_scores0[1], places=6) # pyrefly: ignore[no-matching-overload]
self.assertAlmostEqual(state.scores[1][0], new_scores1[0], places=6) # pyrefly: ignore[no-matching-overload]
self.assertAlmostEqual(state.scores[1][1], new_scores1[1], places=6) # pyrefly: ignore[no-matching-overload]

updated_token_buffer = updated_params['token_buffer']
expected = token_buffer
Expand Down Expand Up @@ -162,10 +162,10 @@ def _check(x, y) -> None:
).ravel(),
2,
)
self.assertAlmostEqual(state.scores[0][0], new_scores0[0], places=6)
self.assertAlmostEqual(state.scores[0][1], new_scores0[1], places=6)
self.assertAlmostEqual(state.scores[1][0], new_scores1[0], places=6)
self.assertAlmostEqual(state.scores[1][1], new_scores1[1], places=6)
self.assertAlmostEqual(state.scores[0][0], new_scores0[0], places=6) # pyrefly: ignore[no-matching-overload]
self.assertAlmostEqual(state.scores[0][1], new_scores0[1], places=6) # pyrefly: ignore[no-matching-overload]
self.assertAlmostEqual(state.scores[1][0], new_scores1[0], places=6) # pyrefly: ignore[no-matching-overload]
self.assertAlmostEqual(state.scores[1][1], new_scores1[1], places=6) # pyrefly: ignore[no-matching-overload]
# before the beam search, the token buffer[:][0] is [2, 1, 1, 2]
# token_buffer[0] should be [1, 1, -1, ...]
# token_buffer[1] should be [2, 1, -1, ...]
Expand Down
18 changes: 9 additions & 9 deletions tests/generate/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def test_samples_padding_output(self, max_prompt_length, echo, return_logits):
+ 1
)
np.testing.assert_allclose(
result_not_padded.logits[i],
result_padded.logits[i][:valid_length],
result_not_padded.logits[i], # pyrefly: ignore[unsupported-operation]
result_padded.logits[i][:valid_length], # pyrefly: ignore[unsupported-operation]
)
np.testing.assert_allclose(
result_not_padded.tokens[i],
Expand Down Expand Up @@ -197,7 +197,7 @@ def __call__(self, images):
return_logits=True,
max_prompt_length=8,
echo=True,
images=images,
images=images, # pyrefly: ignore[bad-argument-type]
)

self.assertIsNotNone(result)
Expand Down Expand Up @@ -259,9 +259,9 @@ def test_samples(self, max_prompt_length, echo):
self.assertIsNotNone(result)
self.assertLen(result.logits, 2)
if echo:
self.assertEqual(result.logits[0].shape, (13, vocab.GetPieceSize()))
self.assertEqual(result.logits[0].shape, (13, vocab.GetPieceSize())) # pyrefly: ignore[unsupported-operation]
else:
self.assertEqual(result.logits[0].shape, (10, vocab.GetPieceSize()))
self.assertEqual(result.logits[0].shape, (10, vocab.GetPieceSize())) # pyrefly: ignore[unsupported-operation]

# With 1 beam, the beam search result should be the
# same as the greedy output
Expand Down Expand Up @@ -438,7 +438,7 @@ def test_state_update(self):
input_strings, max_generation_steps=10, return_logits=True
).logits
with self.assertRaises(AssertionError):
for orig, new in zip(original_logits, new_logits):
for orig, new in zip(original_logits, new_logits): # pyrefly: ignore[bad-argument-type]
np.testing.assert_allclose(orig, new, atol=1e-1, rtol=1e-1)

def test_lora_state_update(self):
Expand Down Expand Up @@ -482,7 +482,7 @@ def test_lora_state_update(self):
input_strings, max_generation_steps=10, return_logits=True
).logits
with self.assertRaises(AssertionError):
for orig, new in zip(original_logits, new_logits):
for orig, new in zip(original_logits, new_logits): # pyrefly: ignore[bad-argument-type]
np.testing.assert_allclose(orig, new, atol=1e-1, rtol=1e-1)

def test_invalid_state_update(self):
Expand Down Expand Up @@ -721,8 +721,8 @@ def test_gemma4_decode_only_last_token_consistency(self):
self.assertEqual(len(res_opt.tokens), len(res_unopt.tokens))
for t_opt, t_unopt in zip(res_opt.tokens, res_unopt.tokens):
np.testing.assert_array_equal(t_opt, t_unopt)
self.assertEqual(len(res_opt.logits), len(res_unopt.logits))
for l_opt, l_unopt in zip(res_opt.logits, res_unopt.logits):
self.assertEqual(len(res_opt.logits), len(res_unopt.logits)) # pyrefly: ignore[bad-argument-type]
for l_opt, l_unopt in zip(res_opt.logits, res_unopt.logits): # pyrefly: ignore[bad-argument-type]
self.assertEqual(l_opt.shape, l_unopt.shape)
np.testing.assert_allclose(l_opt, l_unopt, atol=1e-5, rtol=1e-5)

Expand Down
4 changes: 2 additions & 2 deletions tests/generate/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,15 @@ def test_logprobs_basic_extraction(self):
]
expected = [-1.71, -0.37, 0.0]
self.assertEqual(
utils.get_logprobs_from_vllm_output(token_ids, logprobs),
utils.get_logprobs_from_vllm_output(token_ids, logprobs), # pyrefly: ignore[bad-argument-type]
expected,
)

def test_logprobs_extraction_with_missing_token(self):
token_ids = [100, 200]
logprobs = [{101: Logprob(-0.5)}, {200: Logprob(-1.2)}]
with self.assertRaises(ValueError):
utils.get_logprobs_from_vllm_output(token_ids, logprobs)
utils.get_logprobs_from_vllm_output(token_ids, logprobs) # pyrefly: ignore[bad-argument-type]

@parameterized.named_parameters(
("none_logprobs", [], None),
Expand Down
12 changes: 6 additions & 6 deletions tests/processors/audio_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_process_inputs_batch_list_ndarray(self):
]

processed_audios, new_tokens = audio_processor.process_gemma4_inputs(
audios=audios,
audios=audios, # pyrefly: ignore[bad-argument-type]
tokens=tokens,
audio_encoder=self.audio_encoder,
)
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_process_inputs_batch_nested_list(self):
]

proc_audios, new_tokens = audio_processor.process_gemma4_inputs(
audios=audios,
audios=audios, # pyrefly: ignore[bad-argument-type]
tokens=tokens,
audio_encoder=self.audio_encoder,
)
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_process_inputs_no_audio_clips(self):
]

processed_audios, new_tokens = audio_processor.process_gemma4_inputs(
audios=audios,
audios=audios, # pyrefly: ignore[bad-argument-type]
tokens=tokens,
audio_encoder=self.audio_encoder,
)
Expand Down Expand Up @@ -225,7 +225,7 @@ def test_batch_mismatch_error(self):
ValueError, "Batch size of tokens.*does not match"
):
audio_processor.process_gemma4_inputs(
audios=audios,
audios=audios, # pyrefly: ignore[bad-argument-type]
tokens=tokens,
audio_encoder=self.audio_encoder,
)
Expand All @@ -239,7 +239,7 @@ def test_placeholder_mismatch_error(self):
ValueError, "Placeholders provided for 2 clips, but only 1 provided"
):
audio_processor.process_gemma4_inputs(
audios=audios,
audios=audios, # pyrefly: ignore[bad-argument-type]
tokens=tokens,
audio_encoder=self.audio_encoder,
)
Expand All @@ -255,7 +255,7 @@ def test_max_audio_clips_exceeded_error(self):
ValueError, "A batch entry has more clips than the specified"
):
audio_processor.process_gemma4_inputs(
audios=audios,
audios=audios, # pyrefly: ignore[bad-argument-type]
tokens=tokens,
audio_encoder=self.audio_encoder,
max_audio_clips=1, # Limit to 1 clip
Expand Down
8 changes: 4 additions & 4 deletions tests/processors/image_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_multiple_images_dim_1(self):
img1 = np.zeros((100, 100, 3), dtype=np.uint8)
img2 = np.zeros((50, 50, 3), dtype=np.uint8)
images = [img1, img2]
processed_images = self.processor(images=images)
processed_images = self.processor(images=images) # pyrefly: ignore[bad-argument-type]
np.testing.assert_allclose(
processed_images[0][0], -1.0 * np.ones((self.height, self.width, 3))
)
Expand All @@ -129,7 +129,7 @@ def test_padding(self, input_type):
else:
images = [img1, [img1, img2]]

processed_images = self.processor(images=images)
processed_images = self.processor(images=images) # pyrefly: ignore[bad-argument-type]
np.testing.assert_allclose(
processed_images[0][0], -1.0 * np.ones((self.height, self.width, 3))
)
Expand All @@ -148,7 +148,7 @@ def test_mixed_inputs(self):
img1 = np.zeros((100, 100, 3), dtype=np.uint8)
img2 = self._create_dummy_image_file()
images = [img1, [img1, img2]]
processed_images = self.processor(images=images)
processed_images = self.processor(images=images) # pyrefly: ignore[bad-argument-type]
np.testing.assert_allclose(
processed_images[0][0], -1.0 * np.ones((self.height, self.width, 3))
)
Expand All @@ -164,7 +164,7 @@ def test_mixed_inputs(self):

def test_call_with_none_in_batch(self):
images = [None, [np.zeros((100, 100, 3), dtype=np.uint8)]]
processed_images = self.processor(images=images)
processed_images = self.processor(images=images) # pyrefly: ignore[bad-argument-type]
np.testing.assert_allclose(
processed_images[0][0], np.zeros((self.height, self.width, 3))
)
Expand Down
4 changes: 2 additions & 2 deletions tests/rl/grpo/dapo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def create_train_example(self):
def test_diff_loss(self):
dapo_config = dapo_lib.DAPOConfig()
grpo_config = grpo_lib.GRPOConfig()
dapo_config.temperature = 1.0
grpo_config.temperature = 1.0
dapo_config.temperature = 1.0 # pyrefly: ignore[missing-attribute]
grpo_config.temperature = 1.0 # pyrefly: ignore[missing-attribute]

dapo_loss_fn_impl = fr.default_registry.get(
"policy_loss_fn", dapo_config.policy_loss_fn
Expand Down
2 changes: 1 addition & 1 deletion tests/rl/grpo/drgrpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_drgrpo_advantage_estimator(self):

def test_drgrpo_loss_fn(self):
drgrpo_config = drgrpo_lib.DrGRPOConfig()
drgrpo_config.temperature = 1.0
drgrpo_config.temperature = 1.0 # pyrefly: ignore[missing-attribute]

drgrpo_loss_fn_impl = fr.default_registry.get(
"policy_loss_fn", drgrpo_config.policy_loss_fn
Expand Down
10 changes: 5 additions & 5 deletions tests/rl/grpo/grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(self, grpo_config):
self._last_iter_step = 0
self.algo_config = grpo_config
self._data_shuffle_seed = None
self.rl_cluster = types.SimpleNamespace(
self.rl_cluster = types.SimpleNamespace( # pyrefly: ignore[bad-assignment]
global_steps=0,
cluster_config=types.SimpleNamespace(
training_config=types.SimpleNamespace(
Expand All @@ -167,7 +167,7 @@ def _generate_and_compute_advantage(self, example, mode='train'):
if 'trajectory_ids' in example:
del example['trajectory_ids']
prompts = example['prompts']
num_samples = len(prompts)
num_samples = len(prompts) # pyrefly: ignore[bad-argument-type]
# Return a SimpleNamespace to mimic TrainExample attributes
return types.SimpleNamespace(
prompt_ids=np.array(prompts),
Expand Down Expand Up @@ -348,12 +348,12 @@ def wrapper(*args, **kwargs):
metric_logger = grpo_learner.rl_cluster.actor_trainer.metrics_logger
for metric_name in ['loss', 'kl']:
self.assertLen(
metric_logger.get_metric_history('actor', metric_name, 'train'),
metric_logger.get_metric_history('actor', metric_name, 'train'), # pyrefly: ignore[missing-attribute]
grpo_learner.rl_cluster.actor_trainer.train_steps,
msg=f'metric_name: {metric_name}',
)
self.assertLen(
metric_logger.get_metric_history('actor', metric_name, 'eval'),
metric_logger.get_metric_history('actor', metric_name, 'eval'), # pyrefly: ignore[missing-attribute]
grpo_learner.rl_cluster.actor_trainer.train_steps
/ kwargs['eval_every_n_steps'],
msg=f'metric_name: {metric_name}',
Expand Down Expand Up @@ -546,7 +546,7 @@ def wrapper(*args, **kwargs):
// (kwargs.get('gradient_accumulation_steps') or 1),
)
self.assertLen(
grpo_learner.rl_cluster.actor_trainer.metrics_logger.get_metric_history(
grpo_learner.rl_cluster.actor_trainer.metrics_logger.get_metric_history( # pyrefly: ignore[missing-attribute]
'actor', 'kl', 'train'
),
grpo_learner.rl_cluster.actor_trainer.train_steps,
Expand Down
8 changes: 4 additions & 4 deletions tests/sft/checkpoint_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def setUp(self):
def test_empty_root_directory(self):
cp_manager = checkpoint_manager.CheckpointManager(root_directory=None)
self.assertIsNone(cp_manager.latest_step())
self.assertFalse(cp_manager.save(1, None))
self.assertEqual(cp_manager.maybe_restore(None), (0, {}))
self.assertFalse(cp_manager.save(1, None)) # pyrefly: ignore[bad-argument-type]
self.assertEqual(cp_manager.maybe_restore(None), (0, {})) # pyrefly: ignore[bad-argument-type]

def test_checkpoint_manager_options_none_sets_default(self):
cp_path = f'{self.temp_path}/{self.id()}'
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_restore_with_lora(self):
)

# Save the model params.
self.assertTrue(cp_manager.save(1, model, save_only_lora_params=True))
self.assertTrue(cp_manager.save(1, model, save_only_lora_params=True)) # pyrefly: ignore[bad-argument-type]
cp_manager._checkpoint_manager.wait_until_finished() # pytype: disable=attribute-error

# Change the model state.
Expand All @@ -219,7 +219,7 @@ def test_restore_with_lora(self):

# Restore the model lora params.
self.assertEqual(
cp_manager.maybe_restore(model, restore_only_lora_params=True),
cp_manager.maybe_restore(model, restore_only_lora_params=True), # pyrefly: ignore[bad-argument-type]
(1, {}),
)
# Check the model lora params are restored correctly.
Expand Down
12 changes: 6 additions & 6 deletions tests/sft/checkpoint_options_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_resolve_checkpointing_defaults_with_deprecated_options(self):

with self.assertLogs(level='WARNING') as log:
opts = checkpoint_options.resolve_checkpointing_defaults(
legacy_opts
legacy_opts # pyrefly: ignore[bad-argument-type]
)

# Verify deprecation warnings were logged
Expand All @@ -64,7 +64,7 @@ def test_resolve_checkpointing_defaults_with_legacy_options_dataclass(self):
),
)
opts = checkpoint_options.resolve_checkpointing_defaults(
legacy_opts
legacy_opts # pyrefly: ignore[bad-argument-type]
)
self.assertIsInstance(
opts.save_decision_policy,
Expand Down Expand Up @@ -92,10 +92,10 @@ def test_resolve_checkpointing_defaults_with_async_timeout(self):

def test_resolve_checkpointing_defaults_with_modern_options(self):
modern_opts = checkpoint_options.TunixCheckpointingOptions(
save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(
save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy( # pyrefly: ignore[bad-argument-type]
50
),
preservation_policy=ocp.training.preservation_policies.LatestN(10),
preservation_policy=ocp.training.preservation_policies.LatestN(10), # pyrefly: ignore[bad-argument-type]
enable_async_checkpointing=False,
)
opts = checkpoint_options.resolve_checkpointing_defaults(
Expand All @@ -111,10 +111,10 @@ def test_resolve_checkpointing_defaults_with_modern_options(self):

def test_create_checkpointing_options(self):
opts = checkpoint_options.create_checkpointing_options(
save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(
save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy( # pyrefly: ignore[bad-argument-type]
50
),
preservation_policy=ocp.training.preservation_policies.LatestN(10),
preservation_policy=ocp.training.preservation_policies.LatestN(10), # pyrefly: ignore[bad-argument-type]
enable_async_checkpointing=False,
)
self.assertIsInstance(opts, checkpoint_options.TunixCheckpointingOptions)
Expand Down
6 changes: 3 additions & 3 deletions tests/sft/dpo/dpo_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,13 @@ def test_dpo_trainer(
"log_probs/rejected",
]:
self.assertLen(
dpo_trainer.metrics_logger.get_metric_history(
dpo_trainer.metrics_logger.get_metric_history( # pyrefly: ignore[missing-attribute]
"", metric_name, "train"
),
dpo_trainer._train_steps,
)
self.assertLen(
dpo_trainer.metrics_logger.get_metric_history(
dpo_trainer.metrics_logger.get_metric_history( # pyrefly: ignore[missing-attribute]
"", metric_name, "eval"
),
3,
Expand Down Expand Up @@ -246,7 +246,7 @@ def test_dpo_trainer_with_string_inputs(self, train_ds):
"rewards/accuracy",
]:
self.assertLen(
dpo_trainer.metrics_logger.get_metric_history(
dpo_trainer.metrics_logger.get_metric_history( # pyrefly: ignore[missing-attribute]
"", metric_name, "train"
),
dpo_trainer._train_steps,
Expand Down
6 changes: 3 additions & 3 deletions tests/sft/dpo/orpo_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,13 @@ def test_orpo_trainer(
"odds_ratio",
]:
self.assertLen(
orpo_trainer.metrics_logger.get_metric_history(
orpo_trainer.metrics_logger.get_metric_history( # pyrefly: ignore[missing-attribute]
"", metric_name, "train"
),
orpo_trainer._train_steps,
)
self.assertLen(
orpo_trainer.metrics_logger.get_metric_history(
orpo_trainer.metrics_logger.get_metric_history( # pyrefly: ignore[missing-attribute]
"", metric_name, "eval"
),
3,
Expand Down Expand Up @@ -221,7 +221,7 @@ def test_orpo_trainer_with_string_inputs(self, train_ds):
"rewards/accuracy",
]:
self.assertLen(
orpo_trainer.metrics_logger.get_metric_history(
orpo_trainer.metrics_logger.get_metric_history( # pyrefly: ignore[missing-attribute]
"", metric_name, "train"
),
orpo_trainer._train_steps,
Expand Down
Loading
Loading