Skip to content

Commit 2e8c97b

Browse files
committed
update
1 parent 3806a9a commit 2e8c97b

1 file changed

Lines changed: 19 additions & 16 deletions

File tree

src/diffusers/modular_pipelines/flux2/before_denoise.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -479,27 +479,30 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi
479479
if image_latents is None:
480480
block_state.image_latents = None
481481
block_state.image_latent_ids = None
482-
else:
483-
device = components._execution_device
484-
batch_size = block_state.batch_size * block_state.num_images_per_prompt
482+
self.set_block_state(state, block_state)
483+
484+
return components, state
485+
486+
device = components._execution_device
487+
batch_size = block_state.batch_size * block_state.num_images_per_prompt
485488

486-
image_latent_ids = self._prepare_image_ids(image_latents)
489+
image_latent_ids = self._prepare_image_ids(image_latents)
487490

488-
packed_latents = []
489-
for latent in image_latents:
490-
packed = self._pack_latents(latent)
491-
packed = packed.squeeze(0)
492-
packed_latents.append(packed)
491+
packed_latents = []
492+
for latent in image_latents:
493+
packed = self._pack_latents(latent)
494+
packed = packed.squeeze(0)
495+
packed_latents.append(packed)
493496

494-
image_latents = torch.cat(packed_latents, dim=0)
495-
image_latents = image_latents.unsqueeze(0)
497+
image_latents = torch.cat(packed_latents, dim=0)
498+
image_latents = image_latents.unsqueeze(0)
496499

497-
image_latents = image_latents.repeat(batch_size, 1, 1)
498-
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
499-
image_latent_ids = image_latent_ids.to(device)
500+
image_latents = image_latents.repeat(batch_size, 1, 1)
501+
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
502+
image_latent_ids = image_latent_ids.to(device)
500503

501-
block_state.image_latents = image_latents
502-
block_state.image_latent_ids = image_latent_ids
504+
block_state.image_latents = image_latents
505+
block_state.image_latent_ids = image_latent_ids
503506

504507
self.set_block_state(state, block_state)
505508
return components, state

0 commit comments

Comments
 (0)