Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/f5_tts/model/backbones/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def __init__(
attn_mask_enabled=False,
long_skip_connection=False,
checkpoint_activations=False,
gc_checkpoint_interval: int = 1, # when checkpoint_activations=True, only ckpt every Nth block
):
super().__init__()

Expand Down Expand Up @@ -231,6 +232,11 @@ def __init__(
self.proj_out = nn.Linear(dim, mel_dim)

self.checkpoint_activations = checkpoint_activations
# gc_checkpoint_interval > 1 → selective checkpointing every Nth block (saves ~50% of
# the activation memory of full GC at ~half the compute overhead). Set to 2 to trade
# ~10% throughput for headroom to ~2x batch size. interval=1 (default) preserves the
# prior behavior of checkpoint every block when checkpoint_activations=True.
self.gc_checkpoint_interval = max(1, int(gc_checkpoint_interval))

self.initialize_weights()

Expand Down Expand Up @@ -354,8 +360,9 @@ def forward(
if self.long_skip_connection is not None:
residual = x

for block in self.transformer_blocks:
if self.checkpoint_activations:
for i, block in enumerate(self.transformer_blocks):
do_ckpt = self.checkpoint_activations and (i % self.gc_checkpoint_interval == 0)
if do_ckpt:
# https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
else:
Expand Down
8 changes: 6 additions & 2 deletions src/f5_tts/model/backbones/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
text_mask_padding=True,
qk_norm=None,
checkpoint_activations=False,
gc_checkpoint_interval: int = 1,
attn_backend="torch",
attn_mask_enabled=False,
):
Expand Down Expand Up @@ -133,6 +134,8 @@ def __init__(
self.proj_out = nn.Linear(dim, mel_dim)

self.checkpoint_activations = checkpoint_activations
# Selective GC: when checkpoint_activations=True, checkpoint every Nth block (default 1 = every).
self.gc_checkpoint_interval = max(1, int(gc_checkpoint_interval))

self.initialize_weights()

Expand Down Expand Up @@ -248,8 +251,9 @@ def forward(
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
rope_text = self.rotary_embed.forward_from_seq_len(text_len)

for block in self.transformer_blocks:
if self.checkpoint_activations:
for i, block in enumerate(self.transformer_blocks):
do_ckpt = self.checkpoint_activations and (i % self.gc_checkpoint_interval == 0)
if do_ckpt:
c, x = torch.utils.checkpoint.checkpoint(
self.ckpt_wrapper(block), x, c, t, mask, rope_audio, rope_text, c_mask, use_reentrant=False
)
Expand Down
78 changes: 54 additions & 24 deletions src/f5_tts/model/cfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def sample(
duplicate_test=False,
t_inter=0.1,
edit_mask=None,
cfg_zero_init_steps: int = 0,
cfg_zero_star_velocity: bool = False,
):
self.eval()
# raw wave
Expand Down Expand Up @@ -159,36 +161,58 @@ def sample(

# neural ode

def fn(t, x):
# at each step, conditioning is fixed
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))

# predict flow (cond)
if cfg_strength < 1e-5:
pred = self.transformer(
# CFG-Zero* (https://arxiv.org/abs/2503.18886): zero-init the first N solver steps
# (no DiT forward; integrator stays at x0) and optionally replace fixed cfg_strength
# with the optimized projection scalar α* per step. Drop-in inference-time only;
# no retrain. Composes with EPSS — zero-init skips the cheapest early-noise steps.
# `zero_init_threshold` is computed against the t schedule below.

def make_fn(zero_init_threshold: float):
def fn(t_scalar, x):
if zero_init_threshold > -1.0 and float(t_scalar) < zero_init_threshold:
return torch.zeros_like(x)

# at each step, conditioning is fixed
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))

# predict flow (cond)
if cfg_strength < 1e-5:
pred = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t_scalar,
mask=mask,
drop_audio_cond=False,
drop_text=False,
cache=True,
)
return pred

# predict flow (cond and uncond), for classifier-free guidance
pred_cfg = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
time=t_scalar,
mask=mask,
drop_audio_cond=False,
drop_text=False,
cfg_infer=True,
cache=True,
)
return pred

# predict flow (cond and uncond), for classifier-free guidance
pred_cfg = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
cfg_infer=True,
cache=True,
)
pred, null_pred = torch.chunk(pred_cfg, 2, dim=0)
return pred + (pred - null_pred) * cfg_strength
pred, null_pred = torch.chunk(pred_cfg, 2, dim=0)

if cfg_zero_star_velocity:
# Optimized scale α* = <pred, null_pred> / ||null_pred||² per-batch.
pos = pred.reshape(pred.shape[0], -1)
neg = null_pred.reshape(null_pred.shape[0], -1)
dot = (pos * neg).sum(dim=1, keepdim=True)
sq_norm = (neg * neg).sum(dim=1, keepdim=True).clamp_min(1e-8)
alpha = (dot / sq_norm).view(pred.shape[0], 1, 1)
return null_pred * alpha + cfg_strength * (pred - null_pred * alpha)

return pred + (pred - null_pred) * cfg_strength

return fn

# noise input
# to make sure batch inference result is same with different batch size, and for sure single inference
Expand All @@ -215,6 +239,12 @@ def fn(t, x):
if sway_sampling_coef is not None:
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)

if cfg_zero_init_steps > 0 and cfg_zero_init_steps < t.numel():
zero_init_threshold = float(t[cfg_zero_init_steps].item())
else:
zero_init_threshold = -1.0
fn = make_fn(zero_init_threshold)

trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
self.transformer.clear_cache()

Expand Down
13 changes: 13 additions & 0 deletions src/f5_tts/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
is_local_vocoder: bool = False, # use local path vocoder
local_vocoder_path: str = "", # local vocoder path
model_cfg_dict: dict = dict(), # training config
torch_compile_mode: str | None = None, # None | "default" | "reduce-overhead" | "max-autotune"
):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)

Expand Down Expand Up @@ -103,6 +104,18 @@ def __init__(

self.model = model

# torch.compile (PyTorch 2.x). Applied BEFORE accelerator.prepare so DDP wraps the
# compiled module. Variable mel-frame counts trigger recompiles — keep cache size
# generous and prefer "default" mode unless seq-lens are bucketed externally.
if torch_compile_mode is not None:
try:
import torch._dynamo as _dynamo # noqa: F401

torch._dynamo.config.cache_size_limit = max(64, torch._dynamo.config.cache_size_limit)
except Exception:
pass
self.model = torch.compile(self.model, mode=torch_compile_mode)

if self.is_main:
self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
self.ema_model.to(self.accelerator.device)
Expand Down
31 changes: 31 additions & 0 deletions src/f5_tts/train/finetune_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ def parse_args():
action="store_true",
help="Use 8-bit Adam optimizer from bitsandbytes",
)
parser.add_argument(
"--torch_compile",
type=str,
default=None,
choices=[None, "default", "reduce-overhead", "max-autotune"],
help="torch.compile() mode for the training model. None disables (default).",
)
parser.add_argument(
"--gc_checkpoint_interval",
type=int,
default=0,
help=(
"Gradient-checkpointing interval. 0 disables (default). 1 = checkpoint every transformer "
"block (max memory savings). 2 = every 2nd block (~50%% activation mem at ~10%% throughput "
"cost — often net positive once batch size is doubled). DiT/MMDiT/UNetT only."
),
)

return parser.parse_args()

Expand Down Expand Up @@ -174,6 +191,19 @@ def main():
mel_spec_type=mel_spec_type,
)

# Forward optional gradient-checkpointing flags to the transformer __init__.
# UNetT doesn't expose checkpoint_activations yet — guard via signature inspection.
if args.gc_checkpoint_interval > 0:
import inspect

params = inspect.signature(model_cls.__init__).parameters
if "checkpoint_activations" in params:
model_cfg["checkpoint_activations"] = True
if "gc_checkpoint_interval" in params:
model_cfg["gc_checkpoint_interval"] = args.gc_checkpoint_interval
if "checkpoint_activations" not in params:
print(f"warning: {model_cls.__name__} does not support --gc_checkpoint_interval; flag ignored")

model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=mel_spec_kwargs,
Expand All @@ -200,6 +230,7 @@ def main():
log_samples=args.log_samples,
last_per_updates=args.last_per_updates,
bnb_optimizer=args.bnb_optimizer,
torch_compile_mode=args.torch_compile,
)

train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
Expand Down