Skip to content

Commit 3610631

Browse files
tohtanaksugama
authored andcommitted
Fix AutoTP custom patterns: respect use_default_specs (deepspeedai#7827)
The current code has the following issues: - `use_default_specs: false` doesn't work - Injection by the traditional pattern runs even when custom patterns are set - `mpu` needs to be passed to `deepspeed.initialize` (HF integration doesn't pass mpu) This PR fixes AutoTP setup to respect `use_default_specs: false` and disable the traditional injection path when custom patterns are enabled. Also, when `mpu` is not passed, we create a TP group in the initialization process. With these changes, the [related tests](https://github.com/deepspeedai/DeepSpeed/tree/master/tests/unit/model_parallelism) pass and [all AutoTP examples](https://github.com/tohtana/DeepSpeedExamples/tree/tohtana/custom_auto_tp/training/tensor_parallel) in DeepSpeedExamples work now ([PR](deepspeedai/DeepSpeedExamples#998)). --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
1 parent 60d5cb9 commit 3610631

4 files changed

Lines changed: 254 additions & 15 deletions

File tree

deepspeed/module_inject/auto_tp.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,8 @@ def _replace_with_config(self, child, name):
424424
# No matching spec found
425425
if self.partition_config.strict_mode:
426426
raise ValueError(f"No matching spec for {param_name}")
427-
# Default: column parallel for Linear layers
428-
spec = TPLayerSpec(patterns=[], partition_type=PartitionType.COLUMN)
427+
# With partition_config, rely only on explicit specs and skip unmatched layers.
428+
return child
429429

430430
setattr(child, "replaced", True)
431431

@@ -439,6 +439,8 @@ def _replace_with_config(self, child, name):
439439

440440
def _create_row_parallel_layer(self, module, spec: TPLayerSpec, name: str):
441441
"""Create row-parallel layer (AllReduce after forward)."""
442+
if self.conv_linear_layer:
443+
return Conv_LinearALlreduce(module, self.mp_group, name=name)
442444
# Check for lm_head / embed_out
443445
if name == "lm_head" or name == 'embed_out':
444446
return LmHeadLinearAllreduce(module, self.mp_group)
@@ -455,6 +457,12 @@ def _create_row_parallel_layer(self, module, spec: TPLayerSpec, name: str):
455457

456458
def _create_column_parallel_layer(self, module, spec: TPLayerSpec, name: str):
457459
"""Create column-parallel layer (AllReduce in backward)."""
460+
if self.conv_linear_layer:
461+
return conv_LinearLayer(module, self.mp_group, name=name)
462+
# Only use fused-QKV heuristics when no partition_config is provided.
463+
elif self.partition_config is None and require_tp_fused_qkvw(name, self.mp_size):
464+
# Check and handle fused qkv for TP
465+
return fused_LinearLayer(module, self.mp_group, fused_module=self.module)
458466
if spec.shape is not None:
459467
return SubParamLinearLayer(
460468
module,
@@ -488,6 +496,7 @@ def _get_model_type(self) -> Optional[str]:
488496
def _slice_embedding(self, child, name, conv_linear_layer):
489497
if getattr(child, "replaced", False) == True:
490498
return
499+
491500
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
492501

493502
if hasattr(child.weight, 'ds_tensor'):
@@ -551,7 +560,30 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
551560
continue
552561
if len(child._buffers) != 0 and self.state_dict is not None:
553562
Loading.load_buffer(child, self.state_dict, checking_key)
554-
if child.__class__ in self.linear_policies:
563+
564+
# When using partition_config (custom patterns/presets), use pattern-based routing
565+
# instead of linear_policies. This keeps all pattern logic centralized here.
566+
if self.partition_config is not None:
567+
full_name = prev_name + '.' + name if prev_name else name
568+
if isinstance(child, nn.Embedding):
569+
# Check if embedding matches any pattern
570+
param_name = full_name + ".weight"
571+
model_type = self._get_model_type()
572+
spec = self.partition_config.find_matching_spec(param_name, model_type)
573+
if spec is not None and spec.partition_type != PartitionType.SKIP:
574+
new_child = self._slice_embedding(child, full_name, False)
575+
if new_child is not None:
576+
setattr(r_module, name, new_child)
577+
# If no pattern matched or skip, leave embedding unchanged
578+
elif hasattr(child, "weight") and getattr(child.weight, "dim", lambda: 0)() == 2:
579+
new_child = self._replace_with_config(child, full_name)
580+
if new_child is not None:
581+
setattr(r_module, name, new_child)
582+
else:
583+
self.update_mp_params(child)
584+
self._replace_module(child, name, class_name)
585+
# Traditional path: use linear_policies for type-based routing
586+
elif child.__class__ in self.linear_policies:
555587
setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
556588
self.conv_linear_layer))
557589
elif any(isinstance(child, lp) for lp in self.linear_policies):

deepspeed/runtime/tensor_parallel/init_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ def merge_tp_model_init_into_config(config_dict: dict, mpu, mesh_param, dist_mod
8787
if tp_group is not None and mpu is not None:
8888
raise ValueError("tp_model_init provided tp_group; deepspeed.initialize must not receive mpu.")
8989
if tp_group is None and mpu is None and mesh_param is None:
90-
raise ValueError("tp_model_init did not provide tp_group; deepspeed.initialize requires mpu or mesh_param.")
90+
# Auto-create TP groups for compatibility with HF Trainer (mpu is not passed).
91+
from deepspeed.utils import groups
92+
groups._init_tp_mesh_device(tensor_model_parallel_size=tp_size)
9193

9294
tp_section = config_dict.get("tensor_parallel")
9395
if tp_section is None:

tests/unit/model_parallelism/test_autotp_custom_patterns.py

Lines changed: 199 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from unit.common import DistributedTest, preferred_dtype
1414
from deepspeed.accelerator import get_accelerator
1515
from deepspeed.utils import groups
16-
from deepspeed.module_inject.layers import (LinearAllreduce, LinearLayer, SubParamLinearLayer)
16+
from deepspeed.module_inject.layers import (LinearAllreduce, LinearLayer, SubParamLinearLayer, fused_LinearLayer)
1717
from deepspeed.module_inject.autotp_config import AutoTPConfig
1818
from deepspeed.module_inject.auto_tp import AutoTP
1919

@@ -35,6 +35,49 @@ def forward(self, x):
3535
return x
3636

3737

38+
class CustomLinearModule(torch.nn.Module):
39+
40+
def __init__(self, hidden_dim):
41+
super(CustomLinearModule, self).__init__()
42+
self.weight = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
43+
self.bias = torch.nn.Parameter(torch.empty(hidden_dim))
44+
torch.nn.init.uniform_(self.weight, -0.02, 0.02)
45+
torch.nn.init.uniform_(self.bias, -0.02, 0.02)
46+
47+
def forward(self, x):
48+
return torch.matmul(x, self.weight.transpose(-1, -2)) + self.bias
49+
50+
51+
class CustomLinearModel(torch.nn.Module):
52+
53+
def __init__(self, hidden_dim):
54+
super(CustomLinearModel, self).__init__()
55+
self.custom = CustomLinearModule(hidden_dim)
56+
57+
def forward(self, x):
58+
return self.custom(x)
59+
60+
61+
class QKVLinearModule(torch.nn.Module):
62+
63+
def __init__(self, hidden_dim):
64+
super(QKVLinearModule, self).__init__()
65+
self.qkv_proj = torch.nn.Linear(hidden_dim, hidden_dim * 3)
66+
67+
def forward(self, x):
68+
return self.qkv_proj(x)
69+
70+
71+
class QKVLinearModel(torch.nn.Module):
72+
73+
def __init__(self, hidden_dim):
74+
super(QKVLinearModel, self).__init__()
75+
self.self_attn = QKVLinearModule(hidden_dim)
76+
77+
def forward(self, x):
78+
return self.self_attn(x)
79+
80+
3881
def init_tp_engine(tp_size, partition_config=None):
3982
config_dict = {
4083
"train_micro_batch_size_per_gpu": 1,
@@ -100,6 +143,15 @@ def gather_subparam_output(output, subparam_sizes, mp_group):
100143
return torch.cat(gathered_chunks, dim=-1)
101144

102145

146+
def assert_close_for_preferred_dtype(actual, expected):
147+
atol = 1e-3
148+
rtol = 2e-2
149+
if preferred_dtype() is torch.float32:
150+
atol = 1e-5
151+
rtol = 1e-5
152+
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
153+
154+
103155
class TestAutoTPCustomPatterns(DistributedTest):
104156
world_size = 2
105157
reuse_dist_env = False
@@ -178,6 +230,151 @@ def test_custom_patterns_applied_via_config(self):
178230
assert isinstance(engine.module.linears[1], LinearLayer)
179231
assert isinstance(engine.module.linears[2], nn.Linear)
180232

233+
def test_use_default_specs_false_skips_unmatched_layers(self):
234+
skip_on_device()
235+
# Verify unmatched layers remain unsharded when defaults are disabled.
236+
partition_config = {
237+
"use_default_specs":
238+
False,
239+
"layer_specs": [
240+
{
241+
"patterns": [".*linears\\.0\\.weight$"],
242+
"partition_type": "row",
243+
},
244+
{
245+
"patterns": [".*linears\\.1\\.weight$"],
246+
"partition_type": "column",
247+
},
248+
],
249+
}
250+
config_dict = {
251+
"train_micro_batch_size_per_gpu": 1,
252+
"optimizer": {
253+
"type": "Adam",
254+
"params": {
255+
"lr": 1e-6
256+
}
257+
},
258+
"tensor_parallel": {
259+
"autotp_size": 2,
260+
"partition_config": partition_config,
261+
},
262+
"zero_optimization": {
263+
"stage": 0,
264+
}
265+
}
266+
if preferred_dtype() is torch.float16:
267+
config_dict["fp16"] = {"enabled": True}
268+
elif preferred_dtype() is torch.bfloat16:
269+
config_dict["bf16"] = {"enabled": True}
270+
271+
model = SequentialLinearModel(hidden_dim=16, nlayers=3)
272+
engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
273+
assert isinstance(engine.module.linears[0], LinearAllreduce)
274+
assert isinstance(engine.module.linears[1], LinearLayer)
275+
assert isinstance(engine.module.linears[2], nn.Linear)
276+
277+
def test_custom_module_replacement_with_patterns(self):
278+
skip_on_device()
279+
# Verify custom linear-like modules are partitioned via patterns.
280+
partition_config = {
281+
"use_default_specs": False,
282+
"layer_specs": [
283+
{
284+
"patterns": [".*custom\\.weight$"],
285+
"partition_type": "column",
286+
},
287+
],
288+
}
289+
config_dict = {
290+
"train_micro_batch_size_per_gpu": 1,
291+
"optimizer": {
292+
"type": "Adam",
293+
"params": {
294+
"lr": 1e-6
295+
}
296+
},
297+
"tensor_parallel": {
298+
"autotp_size": 2,
299+
"partition_config": partition_config,
300+
},
301+
"zero_optimization": {
302+
"stage": 0,
303+
}
304+
}
305+
if preferred_dtype() is torch.float16:
306+
config_dict["fp16"] = {"enabled": True}
307+
elif preferred_dtype() is torch.bfloat16:
308+
config_dict["bf16"] = {"enabled": True}
309+
310+
model = CustomLinearModel(hidden_dim=16)
311+
engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
312+
assert isinstance(engine.module.custom, LinearLayer)
313+
314+
def test_custom_pattern_disables_fused_qkv_heuristic(self):
315+
skip_on_device()
316+
# Use a qkv_proj name that would trigger the fused-QKV heuristic, then
317+
# verify custom patterns override that path and preserve correctness.
318+
torch.manual_seed(1234)
319+
hidden_dim = 16
320+
qkv_sizes = (hidden_dim, hidden_dim, hidden_dim)
321+
partition_config = {
322+
"use_default_specs":
323+
False,
324+
"layer_specs": [
325+
{
326+
"patterns": [".*self_attn\\.qkv_proj\\.weight$"],
327+
"partition_type": "column",
328+
"shape": [list(qkv_sizes), -1],
329+
"partition_dim": 0,
330+
},
331+
],
332+
}
333+
config_dict = {
334+
"train_micro_batch_size_per_gpu": 1,
335+
"optimizer": {
336+
"type": "Adam",
337+
"params": {
338+
"lr": 1e-6
339+
}
340+
},
341+
"tensor_parallel": {
342+
"autotp_size": 2,
343+
"partition_config": partition_config,
344+
},
345+
"zero_optimization": {
346+
"stage": 0,
347+
}
348+
}
349+
if preferred_dtype() is torch.float16:
350+
config_dict["fp16"] = {"enabled": True}
351+
elif preferred_dtype() is torch.bfloat16:
352+
config_dict["bf16"] = {"enabled": True}
353+
354+
model = QKVLinearModel(hidden_dim=hidden_dim)
355+
baseline = deepcopy(model).to(get_accelerator().current_device(), dtype=preferred_dtype())
356+
engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
357+
qkv_layer = engine.module.self_attn.qkv_proj
358+
# Custom pattern should force SubParamLinearLayer (shape-based path),
359+
# and avoid the legacy fused-QKV heuristic despite the qkv_proj name.
360+
assert isinstance(qkv_layer, SubParamLinearLayer)
361+
assert not isinstance(qkv_layer, fused_LinearLayer)
362+
363+
assert qkv_layer.partition_dim == 0
364+
assert qkv_layer._subparam_sizes == qkv_sizes
365+
assert qkv_layer._orig_weight_shape == (hidden_dim * 3, hidden_dim)
366+
367+
qkv_layer.gather_params([qkv_layer.weight, qkv_layer.bias])
368+
torch.testing.assert_close(qkv_layer.weight, baseline.self_attn.qkv_proj.weight)
369+
if qkv_layer.bias is not None:
370+
torch.testing.assert_close(qkv_layer.bias, baseline.self_attn.qkv_proj.bias)
371+
372+
torch.manual_seed(4321)
373+
inputs = torch.randn(2, hidden_dim, dtype=preferred_dtype(), device=get_accelerator().current_device())
374+
full_output = baseline(inputs)
375+
tp_output = engine.module(inputs)
376+
assert_close_for_preferred_dtype(tp_output, full_output)
377+
181378
def test_first_match_precedence(self):
182379
skip_on_device()
183380
partition_config = {
@@ -294,9 +491,4 @@ def test_gqa_uneven_qkv_fused_forward(self):
294491

295492
gathered_output = gather_subparam_output(tp_output, (q_size, k_size, v_size),
296493
groups.get_tensor_model_parallel_group())
297-
atol = 1e-3
298-
rtol = 2e-2
299-
if preferred_dtype() is torch.float32:
300-
atol = 1e-5
301-
rtol = 1e-5
302-
torch.testing.assert_close(gathered_output, full_output, atol=atol, rtol=rtol)
494+
assert_close_for_preferred_dtype(gathered_output, full_output)

tests/unit/model_parallelism/test_autotp_training.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,19 +165,32 @@ def test_tp_model_init_config_autotp_size_mismatch(self):
165165
with pytest.raises(ValueError, match="tensor_parallel.autotp_size"):
166166
deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict, mpu=DummyMPU())
167167

168-
def test_tp_model_init_requires_mpu_or_mesh_param(self):
168+
def test_tp_model_init_autocreates_tp_group(self):
169169
skip_on_device()
170170
reset_tp_model_init_state()
171+
# Verify tp_model_init creates TP groups when no mpu is provided.
171172
model = SimpleModel(hidden_dim=8)
172-
deepspeed.tp_model_init(model, tp_size=1, dtype=preferred_dtype())
173+
tp_size = 2
174+
deepspeed.tp_model_init(model, tp_size=tp_size, dtype=preferred_dtype())
173175
config_dict = {
174176
"train_micro_batch_size_per_gpu": 1,
177+
"tensor_parallel": {
178+
"partition_config": {
179+
"use_default_specs": False,
180+
"layer_specs": [{
181+
"patterns": [".*\\.weight$"],
182+
"partition_type": "skip",
183+
}],
184+
}
185+
},
175186
"zero_optimization": {
176187
"stage": 0,
177188
}
178189
}
179-
with pytest.raises(ValueError, match="requires mpu or mesh_param"):
180-
deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
190+
engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
191+
assert engine.autotp_size() == tp_size
192+
assert groups.get_tensor_model_parallel_world_size() == tp_size
193+
assert groups.get_data_parallel_world_size() == dist.get_world_size() // tp_size
181194

182195
def test_tp_model_init_tp_group_rejects_mpu(self):
183196
skip_on_device()

0 commit comments

Comments
 (0)