forked from deepspeedai/DeepSpeedExamples
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathautotp_example.py
More file actions
137 lines (104 loc) · 4.55 KB
/
autotp_example.py
File metadata and controls
137 lines (104 loc) · 4.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import argparse
from dataclasses import dataclass
import torch
import torch.distributed as dist
import deepspeed
from transformers import AutoModelForCausalLM
@dataclass
class ModelParallelUnit:
"""Minimal MPU for DeepSpeed TP+DP."""
tp_group: dist.ProcessGroup
dp_group: dist.ProcessGroup
tp_size: int
dp_size: int
tp_rank: int
dp_rank: int
def get_data_parallel_group(self):
return self.dp_group
def get_model_parallel_group(self):
return self.tp_group
def get_data_parallel_world_size(self):
return self.dp_size
def get_model_parallel_world_size(self):
return self.tp_size
def get_data_parallel_rank(self):
return self.dp_rank
def get_model_parallel_rank(self):
return self.tp_rank
def parse_args():
parser = argparse.ArgumentParser(description="AutoTP training example (distilled from verify_autotp).")
parser.add_argument("--local_rank", type=int, default=-1, help="Passed by deepspeed/torchrun.")
parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.1-8B")
parser.add_argument("--tp_size", type=int, default=4)
parser.add_argument("--dp_size", type=int, default=2)
parser.add_argument("--zero_stage", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--seq_length", type=int, default=1024)
parser.add_argument("--num_steps", type=int, default=10)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--precision", type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
return parser.parse_args()
def build_tp_dp_groups(rank, world_size, tp_size, dp_size):
if tp_size * dp_size != world_size:
raise ValueError(f"tp_size ({tp_size}) * dp_size ({dp_size}) must equal world_size ({world_size})")
tp_rank = rank % tp_size
dp_rank = rank // tp_size
tp_group = None
dp_group = None
for dp_idx in range(dp_size):
tp_ranks = list(range(dp_idx * tp_size, (dp_idx + 1) * tp_size))
group = dist.new_group(tp_ranks)
if rank in tp_ranks:
tp_group = group
for tp_idx in range(tp_size):
dp_ranks = [tp_idx + dp_idx * tp_size for dp_idx in range(dp_size)]
group = dist.new_group(dp_ranks)
if rank in dp_ranks:
dp_group = group
return tp_group, dp_group, tp_rank, dp_rank
def broadcast_inputs(input_ids, labels, tp_group, tp_src_rank):
dist.broadcast(input_ids, src=tp_src_rank, group=tp_group)
dist.broadcast(labels, src=tp_src_rank, group=tp_group)
def main():
args = parse_args()
deepspeed.init_distributed()
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
tp_group, dp_group, tp_rank, dp_rank = build_tp_dp_groups(
rank, world_size, args.tp_size, args.dp_size
)
model = AutoModelForCausalLM.from_pretrained(args.model_name)
model = model.to(device)
# AutoTP is enabled via the DeepSpeed config.
ds_config = {
"train_batch_size": args.batch_size * args.dp_size,
"train_micro_batch_size_per_gpu": args.batch_size,
"zero_optimization": {"stage": args.zero_stage},
"tensor_parallel": {"autotp_size": args.tp_size},
"data_parallel_size": args.dp_size,
}
if args.precision == "bf16":
ds_config["bf16"] = {"enabled": True}
elif args.precision == "fp16":
ds_config["fp16"] = {"enabled": True}
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
mpu = ModelParallelUnit(tp_group, dp_group, args.tp_size, args.dp_size, tp_rank, dp_rank)
engine, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config, mpu=mpu)
vocab_size = model.config.vocab_size
for _ in range(args.num_steps):
if tp_rank == 0:
input_ids = torch.randint(0, vocab_size, (args.batch_size, args.seq_length), device=device)
labels = input_ids.clone()
else:
input_ids = torch.empty((args.batch_size, args.seq_length), dtype=torch.long, device=device)
labels = torch.empty((args.batch_size, args.seq_length), dtype=torch.long, device=device)
tp_src_rank = dp_rank * args.tp_size
broadcast_inputs(input_ids, labels, tp_group, tp_src_rank)
outputs = engine(input_ids=input_ids, labels=labels)
engine.backward(outputs.loss)
engine.step()
if rank == 0:
print("AutoTP example completed.")
if __name__ == "__main__":
main()