1313from unit .common import DistributedTest , preferred_dtype
1414from deepspeed .accelerator import get_accelerator
1515from deepspeed .utils import groups
16- from deepspeed .module_inject .layers import (LinearAllreduce , LinearLayer , SubParamLinearLayer )
16+ from deepspeed .module_inject .layers import (LinearAllreduce , LinearLayer , SubParamLinearLayer , fused_LinearLayer )
1717from deepspeed .module_inject .autotp_config import AutoTPConfig
1818from deepspeed .module_inject .auto_tp import AutoTP
1919
@@ -35,6 +35,49 @@ def forward(self, x):
3535 return x
3636
3737
38+ class CustomLinearModule (torch .nn .Module ):
39+
40+ def __init__ (self , hidden_dim ):
41+ super (CustomLinearModule , self ).__init__ ()
42+ self .weight = torch .nn .Parameter (torch .empty (hidden_dim , hidden_dim ))
43+ self .bias = torch .nn .Parameter (torch .empty (hidden_dim ))
44+ torch .nn .init .uniform_ (self .weight , - 0.02 , 0.02 )
45+ torch .nn .init .uniform_ (self .bias , - 0.02 , 0.02 )
46+
47+ def forward (self , x ):
48+ return torch .matmul (x , self .weight .transpose (- 1 , - 2 )) + self .bias
49+
50+
51+ class CustomLinearModel (torch .nn .Module ):
52+
53+ def __init__ (self , hidden_dim ):
54+ super (CustomLinearModel , self ).__init__ ()
55+ self .custom = CustomLinearModule (hidden_dim )
56+
57+ def forward (self , x ):
58+ return self .custom (x )
59+
60+
61+ class QKVLinearModule (torch .nn .Module ):
62+
63+ def __init__ (self , hidden_dim ):
64+ super (QKVLinearModule , self ).__init__ ()
65+ self .qkv_proj = torch .nn .Linear (hidden_dim , hidden_dim * 3 )
66+
67+ def forward (self , x ):
68+ return self .qkv_proj (x )
69+
70+
71+ class QKVLinearModel (torch .nn .Module ):
72+
73+ def __init__ (self , hidden_dim ):
74+ super (QKVLinearModel , self ).__init__ ()
75+ self .self_attn = QKVLinearModule (hidden_dim )
76+
77+ def forward (self , x ):
78+ return self .self_attn (x )
79+
80+
3881def init_tp_engine (tp_size , partition_config = None ):
3982 config_dict = {
4083 "train_micro_batch_size_per_gpu" : 1 ,
@@ -100,6 +143,15 @@ def gather_subparam_output(output, subparam_sizes, mp_group):
100143 return torch .cat (gathered_chunks , dim = - 1 )
101144
102145
146+ def assert_close_for_preferred_dtype (actual , expected ):
147+ atol = 1e-3
148+ rtol = 2e-2
149+ if preferred_dtype () is torch .float32 :
150+ atol = 1e-5
151+ rtol = 1e-5
152+ torch .testing .assert_close (actual , expected , atol = atol , rtol = rtol )
153+
154+
103155class TestAutoTPCustomPatterns (DistributedTest ):
104156 world_size = 2
105157 reuse_dist_env = False
@@ -178,6 +230,151 @@ def test_custom_patterns_applied_via_config(self):
178230 assert isinstance (engine .module .linears [1 ], LinearLayer )
179231 assert isinstance (engine .module .linears [2 ], nn .Linear )
180232
233+ def test_use_default_specs_false_skips_unmatched_layers (self ):
234+ skip_on_device ()
235+ # Verify unmatched layers remain unsharded when defaults are disabled.
236+ partition_config = {
237+ "use_default_specs" :
238+ False ,
239+ "layer_specs" : [
240+ {
241+ "patterns" : [".*linears\\ .0\\ .weight$" ],
242+ "partition_type" : "row" ,
243+ },
244+ {
245+ "patterns" : [".*linears\\ .1\\ .weight$" ],
246+ "partition_type" : "column" ,
247+ },
248+ ],
249+ }
250+ config_dict = {
251+ "train_micro_batch_size_per_gpu" : 1 ,
252+ "optimizer" : {
253+ "type" : "Adam" ,
254+ "params" : {
255+ "lr" : 1e-6
256+ }
257+ },
258+ "tensor_parallel" : {
259+ "autotp_size" : 2 ,
260+ "partition_config" : partition_config ,
261+ },
262+ "zero_optimization" : {
263+ "stage" : 0 ,
264+ }
265+ }
266+ if preferred_dtype () is torch .float16 :
267+ config_dict ["fp16" ] = {"enabled" : True }
268+ elif preferred_dtype () is torch .bfloat16 :
269+ config_dict ["bf16" ] = {"enabled" : True }
270+
271+ model = SequentialLinearModel (hidden_dim = 16 , nlayers = 3 )
272+ engine , _ , _ , _ = deepspeed .initialize (model = model , model_parameters = model .parameters (), config = config_dict )
273+ assert isinstance (engine .module .linears [0 ], LinearAllreduce )
274+ assert isinstance (engine .module .linears [1 ], LinearLayer )
275+ assert isinstance (engine .module .linears [2 ], nn .Linear )
276+
277+ def test_custom_module_replacement_with_patterns (self ):
278+ skip_on_device ()
279+ # Verify custom linear-like modules are partitioned via patterns.
280+ partition_config = {
281+ "use_default_specs" : False ,
282+ "layer_specs" : [
283+ {
284+ "patterns" : [".*custom\\ .weight$" ],
285+ "partition_type" : "column" ,
286+ },
287+ ],
288+ }
289+ config_dict = {
290+ "train_micro_batch_size_per_gpu" : 1 ,
291+ "optimizer" : {
292+ "type" : "Adam" ,
293+ "params" : {
294+ "lr" : 1e-6
295+ }
296+ },
297+ "tensor_parallel" : {
298+ "autotp_size" : 2 ,
299+ "partition_config" : partition_config ,
300+ },
301+ "zero_optimization" : {
302+ "stage" : 0 ,
303+ }
304+ }
305+ if preferred_dtype () is torch .float16 :
306+ config_dict ["fp16" ] = {"enabled" : True }
307+ elif preferred_dtype () is torch .bfloat16 :
308+ config_dict ["bf16" ] = {"enabled" : True }
309+
310+ model = CustomLinearModel (hidden_dim = 16 )
311+ engine , _ , _ , _ = deepspeed .initialize (model = model , model_parameters = model .parameters (), config = config_dict )
312+ assert isinstance (engine .module .custom , LinearLayer )
313+
314+ def test_custom_pattern_disables_fused_qkv_heuristic (self ):
315+ skip_on_device ()
316+ # Use a qkv_proj name that would trigger the fused-QKV heuristic, then
317+ # verify custom patterns override that path and preserve correctness.
318+ torch .manual_seed (1234 )
319+ hidden_dim = 16
320+ qkv_sizes = (hidden_dim , hidden_dim , hidden_dim )
321+ partition_config = {
322+ "use_default_specs" :
323+ False ,
324+ "layer_specs" : [
325+ {
326+ "patterns" : [".*self_attn\\ .qkv_proj\\ .weight$" ],
327+ "partition_type" : "column" ,
328+ "shape" : [list (qkv_sizes ), - 1 ],
329+ "partition_dim" : 0 ,
330+ },
331+ ],
332+ }
333+ config_dict = {
334+ "train_micro_batch_size_per_gpu" : 1 ,
335+ "optimizer" : {
336+ "type" : "Adam" ,
337+ "params" : {
338+ "lr" : 1e-6
339+ }
340+ },
341+ "tensor_parallel" : {
342+ "autotp_size" : 2 ,
343+ "partition_config" : partition_config ,
344+ },
345+ "zero_optimization" : {
346+ "stage" : 0 ,
347+ }
348+ }
349+ if preferred_dtype () is torch .float16 :
350+ config_dict ["fp16" ] = {"enabled" : True }
351+ elif preferred_dtype () is torch .bfloat16 :
352+ config_dict ["bf16" ] = {"enabled" : True }
353+
354+ model = QKVLinearModel (hidden_dim = hidden_dim )
355+ baseline = deepcopy (model ).to (get_accelerator ().current_device (), dtype = preferred_dtype ())
356+ engine , _ , _ , _ = deepspeed .initialize (model = model , model_parameters = model .parameters (), config = config_dict )
357+ qkv_layer = engine .module .self_attn .qkv_proj
358+ # Custom pattern should force SubParamLinearLayer (shape-based path),
359+ # and avoid the legacy fused-QKV heuristic despite the qkv_proj name.
360+ assert isinstance (qkv_layer , SubParamLinearLayer )
361+ assert not isinstance (qkv_layer , fused_LinearLayer )
362+
363+ assert qkv_layer .partition_dim == 0
364+ assert qkv_layer ._subparam_sizes == qkv_sizes
365+ assert qkv_layer ._orig_weight_shape == (hidden_dim * 3 , hidden_dim )
366+
367+ qkv_layer .gather_params ([qkv_layer .weight , qkv_layer .bias ])
368+ torch .testing .assert_close (qkv_layer .weight , baseline .self_attn .qkv_proj .weight )
369+ if qkv_layer .bias is not None :
370+ torch .testing .assert_close (qkv_layer .bias , baseline .self_attn .qkv_proj .bias )
371+
372+ torch .manual_seed (4321 )
373+ inputs = torch .randn (2 , hidden_dim , dtype = preferred_dtype (), device = get_accelerator ().current_device ())
374+ full_output = baseline (inputs )
375+ tp_output = engine .module (inputs )
376+ assert_close_for_preferred_dtype (tp_output , full_output )
377+
181378 def test_first_match_precedence (self ):
182379 skip_on_device ()
183380 partition_config = {
@@ -294,9 +491,4 @@ def test_gqa_uneven_qkv_fused_forward(self):
294491
295492 gathered_output = gather_subparam_output (tp_output , (q_size , k_size , v_size ),
296493 groups .get_tensor_model_parallel_group ())
297- atol = 1e-3
298- rtol = 2e-2
299- if preferred_dtype () is torch .float32 :
300- atol = 1e-5
301- rtol = 1e-5
302- torch .testing .assert_close (gathered_output , full_output , atol = atol , rtol = rtol )
494+ assert_close_for_preferred_dtype (gathered_output , full_output )
0 commit comments