Skip to content

Commit 46ead1a

Browse files
tohtananathon-lee
authored andcommitted
Support custom partitioning patterns for AutoTP (deepspeedai#7806)
This PR introduces a flexible, configuration-driven API for AutoTP (Automatic Tensor Parallelism) that allows users to define custom layer partitioning patterns for training. @inkcherry @delock ## Motivation Previously, AutoTP relied on hardcoded layer detection logic that was difficult to customize for new model architectures. This PR enables: 1. **Custom models**: Users can define exact regex patterns to match their model's parameter names 2. **Fused layers**: Support for fused QKV, gate_up_proj, and other packed weight matrices with unequal sub-parameter sizes (e.g., GQA with different Q/K/V dimensions) 3. **Extensibility**: Easy to add new model presets or customize existing ones Here is an example of a config including custom partitioning patterns: ```json { "tensor_parallel": { "autotp_size": 4, "partition_config": { "use_default_specs": false, "layer_specs": [ { "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"], "partition_type": "row" }, { "patterns": [".*\\.[qkv]_proj\\.weight$"], "partition_type": "column" }, { "patterns": [".*\\.gate_up_proj\\.weight$"], "partition_type": "column", "shape": [2, -1], "partition_dim": 0 } ] } } } ``` Refer to the [document](https://github.com/tohtana/DeepSpeed/blob/tohtana/autotp_custom_patterns/docs/code-docs/source/training.rst) for more details (including preset models and how to define partitioning for fused models). We also opened a new [PR](deepspeedai/DeepSpeedExamples#998) to show the usage. ## Simplified initialization step AutoTP previously required calling ``set_autotp_mode(training=True)`` and ``deepspeed.tp_model_init`` before ``deepspeed.initialize``. Now we can include all the necessary configurations in the DeepSpeed config. We still support the traditional initialization path for backward compatibility. When you use both (i.e. calling ``set_autotp_mode(training=True)`` and ``deepspeed.tp_model_init`` and passing the config to ``deepspeed.initialize``), we will merge the settings at initialization. When we have conflicting settings, we will error out. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Signed-off-by: nathon-lee <leejianwoo@gmail.com>
1 parent 6d11126 commit 46ead1a

19 files changed

Lines changed: 2412 additions & 152 deletions

File tree

blogs/huggingface-tp/README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,11 @@ Furthermore, if users are not using transformers library, you can replace the ``
230230

231231
# Ongoing Work
232232
- **Optimization**: Communication/Activation optimization.
233-
- **Usability**: Support [Transformers TP plan](https://github.com/huggingface/transformers/blob/336dc69d63d56f232a183a3e7f52790429b871ef/src/transformers/models/llama/configuration_llama.py#L145), decouple AutoTP parser and more model testing,
234-
233+
- **Usability**: Support the [Transformers TP plan](https://github.com/huggingface/transformers/blob/336dc69d63d56f232a183a3e7f52790429b871ef/src/transformers/models/llama/configuration_llama.py#L145), decouple the AutoTP parser, and expand model testing.
234+
- [UPDATE] We now support [custom partitioning](https://deepspeed.readthedocs.io/en/latest/training.html#custom-layer-specs) in the same spirit as HF's partitioning plan, and will build Transformers TP plan support on top of that ([PR](http://github.com/deepspeedai/DeepSpeed/pull/7806)).
235235

236236
Theoretically, features supported by ZeRO should also be supported, though extensive testing is pending.
237-
238237
Welcome bug reports, enhancement, and additional model training examples.
239238

240-
241239
# Contributors
242240
This work was made possible through a deep collaboration between Intel and Microsoft. The contributors include Mingzhi Liu, Guokai Ma, Kiefer Kuah, Yejing Lai, Kurt Chen, Yejun Guo, Guangxin Xu, Xiaofei Feng, and Yang Wang from Intel; Guanhua Wang and Olatunji Ruwase from Microsoft.

deepspeed/__init__.py

Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
from .pipe import PipelineModule
4949

5050
from .git_version_info import version, git_hash, git_branch
51+
from .runtime.tensor_parallel.init_utils import (load_ds_config, merge_tp_model_init_into_config,
52+
record_tp_model_init_args)
5153

5254

5355
def _parse_version(version_str):
@@ -159,17 +161,6 @@ def initialize(args=None,
159161
if config is None and config_params is not None:
160162
config = config_params
161163

162-
mesh_device = None
163-
if mesh_param:
164-
logger.info(f"mesh_param to Initialize mesh device: {mesh_param}")
165-
mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel"))
166-
#if config file has sequence parallelize and data parallelize, then use them to initialize mesh device
167-
elif config is not None:
168-
if "sequence_parallel_size" in config and "data_parallel_size" in config:
169-
logger.info(f"config to Initialize mesh device: {config}")
170-
mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \
171-
("data_parallel", "sequence_parallel"))
172-
173164
# Check for deepscale_config for backwards compat
174165
if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
175166
logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************")
@@ -184,6 +175,26 @@ def initialize(args=None,
184175
assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call"
185176
config = args.deepspeed_config
186177
assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file"
178+
179+
if not isinstance(config, dict):
180+
config = load_ds_config(config)
181+
182+
mesh_device = None
183+
if mesh_param:
184+
logger.info(f"mesh_param to Initialize mesh device: {mesh_param}")
185+
mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel"))
186+
#if config file has sequence parallelize and data parallelize, then use them to initialize mesh device
187+
else:
188+
if "sequence_parallel_size" in config and "data_parallel_size" in config:
189+
logger.info(f"config to Initialize mesh device: {config}")
190+
mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \
191+
("data_parallel", "sequence_parallel"))
192+
193+
merge_tp_model_init_into_config(config, mpu, mesh_param, dist)
194+
195+
autotp_size = config.get("tensor_parallel", {}).get("autotp_size", 0)
196+
if autotp_size and autotp_size > 0:
197+
set_autotp_mode(training=True)
187198
if not isinstance(model, PipelineModule):
188199
config_class = DeepSpeedConfig(config, mpu, mesh_device=mesh_device)
189200
set_optimizer_flags(config_class, model)
@@ -379,31 +390,65 @@ def init_inference(model, config=None, **kwargs):
379390

380391
def tp_model_init(model, tp_size, dtype, config=None, **kwargs):
381392
"""
382-
Initialize the model for tensor parallelism.
393+
Record tensor-parallel initialization arguments for training.
394+
395+
Note (compatibility and initialization behavior):
396+
AutoTP sharding is applied during ``deepspeed.initialize(...)``. This
397+
function exists for backward compatibility and only records TP arguments so
398+
they can be validated and merged with the DeepSpeed config at initialization.
399+
When you use both (i.e., calling ``set_autotp_mode(training=True)`` and
400+
``deepspeed.tp_model_init`` while also passing the config to
401+
``deepspeed.initialize``), DeepSpeed merges the settings at initialization.
402+
Conflicting settings raise an error. The table below summarizes the behavior
403+
across input combinations.
404+
405+
Inputs:
406+
- TPI: tp_model_init was called? (Y/N)
407+
- TPG: tp_model_init provided tp_group? (Y/N)
408+
- CFG: tensor_parallel in DeepSpeed config? (Y/N)
409+
- MPU: mpu passed to deepspeed.initialize()? (Y/N)
410+
411+
| TPI | TPG | CFG | MPU | Outcome | Notes |
412+
|-----|-----|-----|-----|----------------------------------------|-------|
413+
| N | N | N | N | Error | No TP intent; nothing to initialize |
414+
| N | N | N | Y | No AutoTP | mpu may be used for other MP, but TP not enabled |
415+
| N | N | Y | N | Init AutoTP from config | Use config; need TP group via config-driven init |
416+
| N | N | Y | Y | Init AutoTP from config | mpu used to build TP group |
417+
| Y | N | N | N | Error | No TP group source |
418+
| Y | N | N | Y | Init AutoTP from tp_model_init | Use recorded args + mpu for TP group |
419+
| Y | N | Y | N | Init AutoTP from config | Fill missing from TPI; error on mismatches; need TP group source |
420+
| Y | N | Y | Y | Init AutoTP from config | Fill missing from TPI; error on mismatches |
421+
| Y | Y | N | N | Init AutoTP from tp_model_init | Use recorded tp_group; config absent |
422+
| Y | Y | N | Y | Error | tp_group + mpu conflict |
423+
| Y | Y | Y | N | Init AutoTP from config | Error on mismatches; use tp_group from TPI; reject mpu |
424+
| Y | Y | Y | Y | Error | tp_group + mpu conflict |
425+
426+
Field-level merge rules when both tp_model_init and config exist:
427+
- Canonical source: config
428+
- Allowed: fill missing config fields from tp_model_init
429+
- Error on mismatch: autotp_size, dtype, tp_group size or identity
430+
431+
Extra checks:
432+
- If tp_group is provided, reject mpu.
433+
- If tp_group is not provided, require mpu (or another TP group source).
434+
- If tensor_parallel is absent and only tp_model_init was called, require
435+
a TP group source (direct tp_group or mpu).
383436
384437
Args:
385438
model (torch.nn.Module): The model to be initialized.
386439
tp_size (int): The tensor parallelism size.
387440
dtype (torch.dtype): The data type to be used for the model.
388441
389442
Returns:
390-
torch.nn.Module: The initialized model with tensor parallelism.
443+
torch.nn.Module: The original model (no sharding applied here).
391444
"""
392-
# avoid re-entry
393445
if hasattr(model, 'ds_autotp_parsed'):
394-
logger.warning("ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed.")
395-
return
396-
397-
set_autotp_mode(training=True)
446+
logger.warning("ds_autotp_parsed' attribute already exists in the model; tp_model_init is now record-only.")
398447

399-
from deepspeed.runtime.tensor_parallel import TpTrainingManager
400-
# The expected usage here is for it to be invoked by transformers package.
448+
tp_group = kwargs.get("tp_group")
449+
record_tp_model_init_args(tp_size=tp_size, dtype=dtype, tp_group=tp_group, dist_module=dist)
401450

402-
#TODO: We should provide a custom TP mapping solution without using autoTP
403-
#as modifying the autoTP logic may be more difficult for users compared to configuring it
404-
405-
model = TpTrainingManager(model=model, tp_size=tp_size, dtype=dtype).module
406-
407-
setattr(model, 'ds_autotp_parsed', True)
451+
# Keep AutoTP training mode active for backward compatibility.
452+
set_autotp_mode(training=True)
408453

409454
return model

deepspeed/module_inject/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection
77
from .module_quantize import quantize_transformer_layer
88
from .replace_policy import HFBertLayerPolicy
9-
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode
9+
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode, SubParamLinearLayer, SubParamLinearAllreduce
1010
from .policy import DSPolicy
11+
from .autotp_config import TPLayerSpec, AutoTPConfig, PartitionType, AutoTPPresets, merge_autotp_configs

deepspeed/module_inject/auto_tp.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
1818
from deepspeed.utils import groups
1919
from deepspeed.module_inject.layers import is_autotp_training_mode
20+
from .autotp_config import TPLayerSpec, AutoTPConfig, PartitionType
2021

2122

2223
def move(tensor, device, copy=True):
@@ -199,7 +200,8 @@ def __init__(self,
199200
state_dict,
200201
linear_layer_setting,
201202
orig_layer_impl,
202-
keep_module_on_host=False):
203+
keep_module_on_host=False,
204+
partition_config: Optional[AutoTPConfig] = None):
203205
self.module = module
204206
self.all_reduce_linears = all_reduce_linears
205207
self.prefix = prefix
@@ -211,6 +213,7 @@ def __init__(self,
211213
self.orig_layer_impl = orig_layer_impl
212214
self.linear_policies = None
213215
self.conv_linear_layer = False
216+
self.partition_config = partition_config
214217
TensorParallel_Layer.set_keep_module_on_host(keep_module_on_host)
215218

216219
def in_module_list(module, module_list):
@@ -353,6 +356,11 @@ def _replace(self, child, name, conv_linear_layer):
353356

354357
weight_shape = child.weight.shape
355358
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
359+
360+
# If partition_config is provided, use the new configurable API
361+
if self.partition_config is not None:
362+
return self._replace_with_config(child, name)
363+
356364
# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
357365
if "mlp.gate" == name or "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or (
358366
('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))):
@@ -396,6 +404,87 @@ def _replace(self, child, name, conv_linear_layer):
396404

397405
return LinearLayer(child, self.mp_group, name=name)
398406

407+
def _replace_with_config(self, child, name):
408+
"""
409+
Replace layer using the new configurable AutoTP API.
410+
411+
This method uses TPLayerSpec to determine how to partition the layer.
412+
"""
413+
if getattr(child, "replaced", False) == True:
414+
return child
415+
416+
# Build the full parameter name for pattern matching
417+
param_name = name + ".weight" if not name.endswith(".weight") else name
418+
419+
# Find matching spec
420+
model_type = self._get_model_type()
421+
spec = self.partition_config.find_matching_spec(param_name, model_type)
422+
423+
if spec is None:
424+
# No matching spec found
425+
if self.partition_config.strict_mode:
426+
raise ValueError(f"No matching spec for {param_name}")
427+
# Default: column parallel for Linear layers
428+
spec = TPLayerSpec(patterns=[], partition_type=PartitionType.COLUMN)
429+
430+
setattr(child, "replaced", True)
431+
432+
if spec.partition_type == PartitionType.SKIP:
433+
return child
434+
435+
if spec.partition_type == PartitionType.ROW:
436+
return self._create_row_parallel_layer(child, spec, name)
437+
else:
438+
return self._create_column_parallel_layer(child, spec, name)
439+
440+
def _create_row_parallel_layer(self, module, spec: TPLayerSpec, name: str):
441+
"""Create row-parallel layer (AllReduce after forward)."""
442+
# Check for lm_head / embed_out
443+
if name == "lm_head" or name == 'embed_out':
444+
return LmHeadLinearAllreduce(module, self.mp_group)
445+
446+
if spec.shape is not None:
447+
return SubParamLinearAllreduce(
448+
module,
449+
self.mp_group,
450+
shape=spec.shape,
451+
partition_dim=spec.get_partition_dim(),
452+
name=name,
453+
)
454+
return LinearAllreduce(module, self.mp_group, name=name)
455+
456+
def _create_column_parallel_layer(self, module, spec: TPLayerSpec, name: str):
457+
"""Create column-parallel layer (AllReduce in backward)."""
458+
if spec.shape is not None:
459+
return SubParamLinearLayer(
460+
module,
461+
self.mp_group,
462+
shape=spec.shape,
463+
partition_dim=spec.get_partition_dim(),
464+
name=name,
465+
)
466+
return LinearLayer(module, self.mp_group, name=name)
467+
468+
def _get_model_type(self) -> Optional[str]:
469+
"""Extract model type from module config or class name."""
470+
config = getattr(self.module, "config", None)
471+
if config is not None:
472+
model_type = getattr(config, "model_type", None)
473+
if model_type:
474+
return str(model_type).lower()
475+
module_str = str(type(self.module))
476+
# Try to extract model type from class name (e.g., "LlamaDecoderLayer" -> "llama")
477+
patterns = [
478+
r"(\w+)DecoderLayer",
479+
r"(\w+)Block",
480+
r"(\w+)Layer",
481+
]
482+
for pattern in patterns:
483+
match = re.search(pattern, module_str)
484+
if match:
485+
return match.group(1).lower()
486+
return None
487+
399488
def _slice_embedding(self, child, name, conv_linear_layer):
400489
if getattr(child, "replaced", False) == True:
401490
return

0 commit comments

Comments
 (0)