diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index e463c957f..c839c83ab 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -189,6 +189,7 @@ def __init__( attn_mask_enabled=False, long_skip_connection=False, checkpoint_activations=False, + gc_checkpoint_interval: int = 1, # when checkpoint_activations=True, only ckpt every Nth block ): super().__init__() @@ -231,6 +232,11 @@ def __init__( self.proj_out = nn.Linear(dim, mel_dim) self.checkpoint_activations = checkpoint_activations + # gc_checkpoint_interval > 1 → selective checkpointing every Nth block (saves ~50% of + # the activation memory of full GC at ~half the compute overhead). Set to 2 to trade + # ~10% throughput for headroom to ~2x batch size. interval=1 (default) preserves the + # prior behavior of checkpoint every block when checkpoint_activations=True. + self.gc_checkpoint_interval = max(1, int(gc_checkpoint_interval)) self.initialize_weights() @@ -354,8 +360,9 @@ def forward( if self.long_skip_connection is not None: residual = x - for block in self.transformer_blocks: - if self.checkpoint_activations: + for i, block in enumerate(self.transformer_blocks): + do_ckpt = self.checkpoint_activations and (i % self.gc_checkpoint_interval == 0) + if do_ckpt: # https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False) else: diff --git a/src/f5_tts/model/backbones/mmdit.py b/src/f5_tts/model/backbones/mmdit.py index 262ad0d5f..292fa137b 100644 --- a/src/f5_tts/model/backbones/mmdit.py +++ b/src/f5_tts/model/backbones/mmdit.py @@ -99,6 +99,7 @@ def __init__( text_mask_padding=True, qk_norm=None, checkpoint_activations=False, + gc_checkpoint_interval: int = 1, attn_backend="torch", attn_mask_enabled=False, ): @@ -133,6 +134,8 @@ def __init__( self.proj_out = nn.Linear(dim, mel_dim) self.checkpoint_activations = checkpoint_activations + # Selective GC: when checkpoint_activations=True, checkpoint every Nth block (default 1 = every). + self.gc_checkpoint_interval = max(1, int(gc_checkpoint_interval)) self.initialize_weights() @@ -248,8 +251,9 @@ def forward( rope_audio = self.rotary_embed.forward_from_seq_len(seq_len) rope_text = self.rotary_embed.forward_from_seq_len(text_len) - for block in self.transformer_blocks: - if self.checkpoint_activations: + for i, block in enumerate(self.transformer_blocks): + do_ckpt = self.checkpoint_activations and (i % self.gc_checkpoint_interval == 0) + if do_ckpt: c, x = torch.utils.checkpoint.checkpoint( self.ckpt_wrapper(block), x, c, t, mask, rope_audio, rope_text, c_mask, use_reentrant=False ) diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index e001c2ec6..ffc0f7fca 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -99,6 +99,8 @@ def sample( duplicate_test=False, t_inter=0.1, edit_mask=None, + cfg_zero_init_steps: int = 0, + cfg_zero_star_velocity: bool = False, ): self.eval() # raw wave @@ -159,36 +161,58 @@ def sample( # neural ode - def fn(t, x): - # at each step, conditioning is fixed - # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) - - # predict flow (cond) - if cfg_strength < 1e-5: - pred = self.transformer( + # CFG-Zero* (https://arxiv.org/abs/2503.18886): zero-init the first N solver steps + # (no DiT forward; integrator stays at x0) and optionally replace fixed cfg_strength + # with the optimized projection scalar α* per step. Drop-in inference-time only; + # no retrain. Composes with EPSS — zero-init skips the cheapest early-noise steps. + # `zero_init_threshold` is computed against the t schedule below. + + def make_fn(zero_init_threshold: float): + def fn(t_scalar, x): + if zero_init_threshold > -1.0 and float(t_scalar) < zero_init_threshold: + return torch.zeros_like(x) + + # at each step, conditioning is fixed + # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) + + # predict flow (cond) + if cfg_strength < 1e-5: + pred = self.transformer( + x=x, + cond=step_cond, + text=text, + time=t_scalar, + mask=mask, + drop_audio_cond=False, + drop_text=False, + cache=True, + ) + return pred + + # predict flow (cond and uncond), for classifier-free guidance + pred_cfg = self.transformer( x=x, cond=step_cond, text=text, - time=t, + time=t_scalar, mask=mask, - drop_audio_cond=False, - drop_text=False, + cfg_infer=True, cache=True, ) - return pred - - # predict flow (cond and uncond), for classifier-free guidance - pred_cfg = self.transformer( - x=x, - cond=step_cond, - text=text, - time=t, - mask=mask, - cfg_infer=True, - cache=True, - ) - pred, null_pred = torch.chunk(pred_cfg, 2, dim=0) - return pred + (pred - null_pred) * cfg_strength + pred, null_pred = torch.chunk(pred_cfg, 2, dim=0) + + if cfg_zero_star_velocity: + # Optimized scale α* = / ||null_pred||² per-batch. + pos = pred.reshape(pred.shape[0], -1) + neg = null_pred.reshape(null_pred.shape[0], -1) + dot = (pos * neg).sum(dim=1, keepdim=True) + sq_norm = (neg * neg).sum(dim=1, keepdim=True).clamp_min(1e-8) + alpha = (dot / sq_norm).view(pred.shape[0], 1, 1) + return null_pred * alpha + cfg_strength * (pred - null_pred * alpha) + + return pred + (pred - null_pred) * cfg_strength + + return fn # noise input # to make sure batch inference result is same with different batch size, and for sure single inference @@ -215,6 +239,12 @@ def fn(t, x): if sway_sampling_coef is not None: t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + if cfg_zero_init_steps > 0 and cfg_zero_init_steps < t.numel(): + zero_init_threshold = float(t[cfg_zero_init_steps].item()) + else: + zero_init_threshold = -1.0 + fn = make_fn(zero_init_threshold) + trajectory = odeint(fn, y0, t, **self.odeint_kwargs) self.transformer.clear_cache() diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 782923222..3165fa82d 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -53,6 +53,7 @@ def __init__( is_local_vocoder: bool = False, # use local path vocoder local_vocoder_path: str = "", # local vocoder path model_cfg_dict: dict = dict(), # training config + torch_compile_mode: str | None = None, # None | "default" | "reduce-overhead" | "max-autotune" ): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) @@ -103,6 +104,18 @@ def __init__( self.model = model + # torch.compile (PyTorch 2.x). Applied BEFORE accelerator.prepare so DDP wraps the + # compiled module. Variable mel-frame counts trigger recompiles — keep cache size + # generous and prefer "default" mode unless seq-lens are bucketed externally. + if torch_compile_mode is not None: + try: + import torch._dynamo as _dynamo # noqa: F401 + + torch._dynamo.config.cache_size_limit = max(64, torch._dynamo.config.cache_size_limit) + except Exception: + pass + self.model = torch.compile(self.model, mode=torch_compile_mode) + if self.is_main: self.ema_model = EMA(model, include_online_model=False, **ema_kwargs) self.ema_model.to(self.accelerator.device) diff --git a/src/f5_tts/train/finetune_cli.py b/src/f5_tts/train/finetune_cli.py index cdf42a9ac..fc8ec21a7 100644 --- a/src/f5_tts/train/finetune_cli.py +++ b/src/f5_tts/train/finetune_cli.py @@ -71,6 +71,23 @@ def parse_args(): action="store_true", help="Use 8-bit Adam optimizer from bitsandbytes", ) + parser.add_argument( + "--torch_compile", + type=str, + default=None, + choices=[None, "default", "reduce-overhead", "max-autotune"], + help="torch.compile() mode for the training model. None disables (default).", + ) + parser.add_argument( + "--gc_checkpoint_interval", + type=int, + default=0, + help=( + "Gradient-checkpointing interval. 0 disables (default). 1 = checkpoint every transformer " + "block (max memory savings). 2 = every 2nd block (~50%% activation mem at ~10%% throughput " + "cost — often net positive once batch size is doubled). DiT/MMDiT/UNetT only." + ), + ) return parser.parse_args() @@ -174,6 +191,19 @@ def main(): mel_spec_type=mel_spec_type, ) + # Forward optional gradient-checkpointing flags to the transformer __init__. + # UNetT doesn't expose checkpoint_activations yet — guard via signature inspection. + if args.gc_checkpoint_interval > 0: + import inspect + + params = inspect.signature(model_cls.__init__).parameters + if "checkpoint_activations" in params: + model_cfg["checkpoint_activations"] = True + if "gc_checkpoint_interval" in params: + model_cfg["gc_checkpoint_interval"] = args.gc_checkpoint_interval + if "checkpoint_activations" not in params: + print(f"warning: {model_cls.__name__} does not support --gc_checkpoint_interval; flag ignored") + model = CFM( transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=mel_spec_kwargs, @@ -200,6 +230,7 @@ def main(): log_samples=args.log_samples, last_per_updates=args.last_per_updates, bnb_optimizer=args.bnb_optimizer, + torch_compile_mode=args.torch_compile, ) train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)