@@ -479,27 +479,30 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi
479479 if image_latents is None :
480480 block_state .image_latents = None
481481 block_state .image_latent_ids = None
482- else :
483- device = components ._execution_device
484- batch_size = block_state .batch_size * block_state .num_images_per_prompt
482+ self .set_block_state (state , block_state )
483+
484+ return components , state
485+
486+ device = components ._execution_device
487+ batch_size = block_state .batch_size * block_state .num_images_per_prompt
485488
486- image_latent_ids = self ._prepare_image_ids (image_latents )
489+ image_latent_ids = self ._prepare_image_ids (image_latents )
487490
488- packed_latents = []
489- for latent in image_latents :
490- packed = self ._pack_latents (latent )
491- packed = packed .squeeze (0 )
492- packed_latents .append (packed )
491+ packed_latents = []
492+ for latent in image_latents :
493+ packed = self ._pack_latents (latent )
494+ packed = packed .squeeze (0 )
495+ packed_latents .append (packed )
493496
494- image_latents = torch .cat (packed_latents , dim = 0 )
495- image_latents = image_latents .unsqueeze (0 )
497+ image_latents = torch .cat (packed_latents , dim = 0 )
498+ image_latents = image_latents .unsqueeze (0 )
496499
497- image_latents = image_latents .repeat (batch_size , 1 , 1 )
498- image_latent_ids = image_latent_ids .repeat (batch_size , 1 , 1 )
499- image_latent_ids = image_latent_ids .to (device )
500+ image_latents = image_latents .repeat (batch_size , 1 , 1 )
501+ image_latent_ids = image_latent_ids .repeat (batch_size , 1 , 1 )
502+ image_latent_ids = image_latent_ids .to (device )
500503
501- block_state .image_latents = image_latents
502- block_state .image_latent_ids = image_latent_ids
504+ block_state .image_latents = image_latents
505+ block_state .image_latent_ids = image_latent_ids
503506
504507 self .set_block_state (state , block_state )
505508 return components , state
0 commit comments