You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Support custom partitioning patterns for AutoTP (deepspeedai#7806)
This PR introduces a flexible, configuration-driven API for AutoTP
(Automatic Tensor Parallelism) that allows users to define custom layer
partitioning patterns for training.
@inkcherry@delock
## Motivation
Previously, AutoTP relied on hardcoded layer detection logic that was
difficult to customize for new model architectures. This PR enables:
1. **Custom models**: Users can define exact regex patterns to match
their model's parameter names
2. **Fused layers**: Support for fused QKV, gate_up_proj, and other
packed weight matrices with unequal sub-parameter sizes (e.g., GQA with
different Q/K/V dimensions)
3. **Extensibility**: Easy to add new model presets or customize
existing ones
Here is an example of a config including custom partitioning patterns:
```json
{
"tensor_parallel": {
"autotp_size": 4,
"partition_config": {
"use_default_specs": false,
"layer_specs": [
{
"patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"],
"partition_type": "row"
},
{
"patterns": [".*\\.[qkv]_proj\\.weight$"],
"partition_type": "column"
},
{
"patterns": [".*\\.gate_up_proj\\.weight$"],
"partition_type": "column",
"shape": [2, -1],
"partition_dim": 0
}
]
}
}
}
```
Refer to the
[document](https://github.com/tohtana/DeepSpeed/blob/tohtana/autotp_custom_patterns/docs/code-docs/source/training.rst)
for more details (including preset models and how to define partitioning
for fused models).
We also opened a new
[PR](deepspeedai/DeepSpeedExamples#998) to show
the usage.
## Simplified initialization step
AutoTP previously required calling ``set_autotp_mode(training=True)``
and ``deepspeed.tp_model_init`` before ``deepspeed.initialize``. Now we
can include all the necessary configurations in the DeepSpeed config.
We still support the traditional initialization path for backward
compatibility.
When you use both (i.e. calling ``set_autotp_mode(training=True)`` and
``deepspeed.tp_model_init`` and passing the config to
``deepspeed.initialize``), we will merge the settings at initialization.
When we have conflicting settings, we will error out.
---------
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: nathon-lee <leejianwoo@gmail.com>
-**Usability**: Support [Transformers TP plan](https://github.com/huggingface/transformers/blob/336dc69d63d56f232a183a3e7f52790429b871ef/src/transformers/models/llama/configuration_llama.py#L145), decouple AutoTP parser and more model testing,
234
-
233
+
-**Usability**: Support the [Transformers TP plan](https://github.com/huggingface/transformers/blob/336dc69d63d56f232a183a3e7f52790429b871ef/src/transformers/models/llama/configuration_llama.py#L145), decouple the AutoTP parser, and expand model testing.
234
+
-[UPDATE] We now support [custom partitioning](https://deepspeed.readthedocs.io/en/latest/training.html#custom-layer-specs) in the same spirit as HF's partitioning plan, and will build Transformers TP plan support on top of that ([PR](http://github.com/deepspeedai/DeepSpeed/pull/7806)).
235
235
236
236
Theoretically, features supported by ZeRO should also be supported, though extensive testing is pending.
237
-
238
237
Welcome bug reports, enhancement, and additional model training examples.
239
238
240
-
241
239
# Contributors
242
240
This work was made possible through a deep collaboration between Intel and Microsoft. The contributors include Mingzhi Liu, Guokai Ma, Kiefer Kuah, Yejing Lai, Kurt Chen, Yejun Guo, Guangxin Xu, Xiaofei Feng, and Yang Wang from Intel; Guanhua Wang and Olatunji Ruwase from Microsoft.
0 commit comments