-
Notifications
You must be signed in to change notification settings - Fork 689
Expand file tree
/
Copy pathgeneration_batch.py
More file actions
315 lines (293 loc) · 15.4 KB
/
generation_batch.py
File metadata and controls
315 lines (293 loc) · 15.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
PanGu predict run
"""
import os
import time
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
import moxing as mox
import numpy as np
from mindspore import context, Tensor
from mindspore import export
from mindspore.context import ParallelMode
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.parallel.nn.transformer import TransformerOpParallelConfig
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.code_tokenizer import CodeTokenizer
from src.pangu_alpha import EvalNet, PanguAlphaModel
from src.pangu_alpha_config import set_parse, PanguAlphaConfig
from src.utils import get_args
def load_model(args_opt):
r"""
The main function for load model
"""
# Set execution mode
context.set_context(save_graphs=False,
mode=context.GRAPH_MODE,
device_target=args_opt.device_target)
context.set_context(variable_memory_max_size="30GB")
# Set parallel context
if args_opt.distribute == "true":
D.init()
device_num = D.get_group_size()
rank = D.get_rank()
print("rank_id is {}, device_num is {}".format(rank, device_num))
context.reset_auto_parallel_context()
context.set_auto_parallel_context(
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
gradients_mean=False,
full_batch=True,
loss_repeated_mean=True,
enable_parallel_optimizer=False,
pipeline_stages=args_opt.stage_num)
set_algo_parameters(elementwise_op_strategy_follow=True)
_set_multi_subgraphs()
else:
rank = 0
device_num = 1
context.reset_auto_parallel_context()
context.set_auto_parallel_context(
strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path)
context.set_context(
save_graphs=False,
save_graphs_path="/cache/graphs_of_device_id_" + str(rank),
)
use_past = (args_opt.use_past == "true")
print('local_rank:{}, start to run...'.format(rank), flush=True)
if args_opt.export:
use_past = True
# Set model property
model_parallel_num = args_opt.op_level_model_parallel_num
data_parallel_num = int(device_num / model_parallel_num)
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
model_parallel=model_parallel_num,
pipeline_stage=args_opt.stage_num,
micro_batch_num=args_opt.micro_size,
optimizer_shard=False,
vocab_emb_dp=bool(args_opt.word_emb_dp),
recompute=True)
per_batch_size = args_opt.per_batch_size
batch_size = per_batch_size * data_parallel_num
config = PanguAlphaConfig(
batch_size=batch_size,
seq_length=args_opt.seq_length,
vocab_size=args_opt.vocab_size,
hidden_size=args_opt.embedding_size,
num_layers=args_opt.num_layers,
num_heads=args_opt.num_heads,
post_layernorm_residual=False,
dropout_rate=0.0,
ffn_hidden_size=args_opt.embedding_size * 4,
use_past=use_past,
eod_token=args_opt.eod_id,
eod_reset=False,
parallel_config=parallel_config,
load_ckpt_path=args_opt.load_ckpt_path,
param_init_type=mstype.float32
if args_opt.param_init_type == 'fp32'
else mstype.float16,
)
print("===config is: ", config, flush=True)
print("=====args_opt is: ", args_opt, flush=True)
ckpt_name = args_opt.load_ckpt_name
# Define network
pangu_alpha = PanguAlphaModel(config)
eval_net = EvalNet(pangu_alpha, pad_token=50256)
eval_net.set_train(False)
model_predict = Model(eval_net)
# Compile network and obtain tensor layout for loading ckpt
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
current_index = Tensor(np.array([0 for _ in range(batch_size)]), mstype.int32)
if args_opt.distribute == "false":
predict_layout = None
elif config.use_past:
batch_valid_length = Tensor(np.array([0 for _ in range(batch_size)]), mstype.int32)
init_true = Tensor([True], mstype.bool_)
print("Input shape:", inputs_np.shape, flush=True)
inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32)
model_predict.predict_network.add_flags_recursive(is_first_iteration=True)
print("is_first_iteration=True", flush=True)
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length)
model_predict.predict_network.add_flags_recursive(is_first_iteration=False)
print("is_first_iteration=False", flush=True)
init_false = Tensor([False], mstype.bool_)
_ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_false, batch_valid_length)
else:
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index)
if context.get_context("save_graphs"):
print("==============save_graph", flush=True)
jobid = os.environ["BATCH_JOB_ID"]
rank_id = rank
mox.file.make_dirs("s3://wudao-1/yyf/graphs_" + jobid)
mox.file.copy_parallel(src_url="/cache/graphs_of_device_id_" + str(rank_id),
dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id))
print("======start load_distributed checkpoint", flush=True)
if args_opt.load_ckpt_epoch > 0:
time.sleep(rank * 0.1)
os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}"))
ckpt_name = f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt"
if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)):
print(f"Checkpoint from rank {rank} doesn't exist!")
mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name),
os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name))
param_dict = load_checkpoint(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name))
if param_dict.get("epoch_num") and param_dict.get("step_num"):
args_opt.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy())
args_opt.has_trained_steps = int(param_dict["step_num"].data.asnumpy())
os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1/rank_{rank}')
while True:
num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1'))
if num == device_num:
break
if rank % 8 == 0:
print("Loaded ckpt in step 1: ", num)
time.sleep(1)
net_not_load = load_param_into_net(pangu_alpha, param_dict)
print("====== load_distributed checkpoint done, net_not_load: ", net_not_load, flush=True)
return model_predict, config, rank
def export_mindir(model_predict, config):
"""Export mindir model"""
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
current_index = Tensor(np.array([0]), mstype.int32)
batch_valid_length = Tensor(np.array([0]), mstype.int32)
init_true = Tensor([True], mstype.bool_)
inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32)
model_predict.predict_network.add_flags_recursive(is_first_iteration=True)
export(model_predict.predict_network, inputs_np, current_index,
init_true, batch_valid_length, file_name='pangu_alpha_1024', file_format='MINDIR')
model_predict.predict_network.add_flags_recursive(is_first_iteration=False)
export(model_predict.predict_network, inputs_np_1, current_index,
init_true, batch_valid_length, file_name='pangu_alpha_1', file_format='MINDIR')
print("Export finished and now exit.")
def run_predict(model_predict, config, args_opt, rank):
"""run predict"""
from src.generate_finetune import generate_increment
# Define tokenizer
tokenizer = CodeTokenizer(mode='6b')
# Tokenize input sentence to ids
samples = [
"Hello there!",
"# language: Python\ndef add(a, b):\n '''\n Find the sum of a and b.\n '''\n",
"def add(a, b):\n '''\n Find the sum of a and b.\n '''\n",
"# language: Python\ndef optimization():\n '''\n Find the maximum of P=E**2*R/(R + r)**2 if E and r are fixed but R varies. Import sympy. Use sympy. Find where the derivative is equal to zero. Substitute the value of R into P.\n '''\n",
"from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n",
"// language: JavaScript\nfunction prime(n) {\n // Find whether n is a prime number.\n",
"string morse_encoder(string text) {\n // Translate text into Morse code\n",
"def morse_encoder(text):\n # Translate text into Morse code separated by spaces\n",
f"% language: MATLAB\nfunction x = solve(A, b)\n % Solve Ax = b\n",
f"% language: MATLAB\nfunction [L, U] = lu(A)\n % Return LU factorization of A\n",
"def TCPState(state):\n # given a state in TCP protocol, return a list of next possible states\n",
"def coordinates(p1, p2, precision=0)\n # p1 is (x1, y1), p2 is (x2, y2), return the distance between p1 and p2 on a cartesian plane, rounded to precision\n",
"double travel(double total_time, double run_time, double rest_time, double speed) {\n // the horse runs for run_time with speed speed and rests for rest_time, return the distance it travels after total_time\n",
"def travel(total_time, run_time, rest_time, speed):\n # the horse runs for run_time with speed speed and rests for rest_time, return the distance it travels after total_time\n",
"// language: C++\nint add(int a, int b) {\n /* Find the sum of a and b. */\n",
"int add(int a, int b) {\n /* Find the sum of a and b. */\n",
"// language: C++\nvoid sort(int *array, int len) {\n // Sort the array with length len\n",
"bool prime(int n) {\n // Find whether n is a prime number\n",
"def prime(n):\n # Find whether n is a prime number\n",
f"% language: MATLAB\nfunction H = hilbert(n)\n % Return Hilbert matrix of size n * n\n",
f"% language: MATLAB\nfunction L = cholesky(A)\n % Return Cholesky factorization of symmetric positive definete matrix A\n",
"// language: JavaScript\nfunction add(a, b) {\n // Find the sum of a and b.\n",
"# language: R\nadd<-function(a, b) {\n # Find the sum of a and b.\n",
]
samples = [tokenizer.encode_code(l) for l in samples]
generations = []
batch_size = config.batch_size
verbose = (rank % 8 == 0)
# Use configurable output path
output_dir = getattr(args_opt, 'output_path', './output')
save_dir = os.path.join(output_dir, 'generation_batch')
save_path = os.path.join(save_dir, f'temp_{args_opt.temperature}.txt')
if rank == 0:
os.makedirs(save_dir, exist_ok=True)
if not os.path.exists(save_path):
with open(save_path, 'w') as f:
pass # Create empty file
if os.name != 'nt': # Only on Unix-like systems
os.system(f'chmod 777 -R {save_dir}')
batch = []
input_length = []
sample_ids = []
for i, sample in enumerate(samples):
tokenized_token = sample
input_ids = np.array(tokenized_token).reshape(1, -1)
batch.append(input_ids)
input_length.append(input_ids.shape[1])
sample_ids.append(i)
if (i + 1) % batch_size == 0:
valid_length = max(input_length)
for j in range(len(batch)):
batch[j] = np.pad(batch[j], ((0, 0), (0, valid_length - input_length[j])),
'constant', constant_values=(args_opt.end_token, args_opt.end_token))
input_ids = np.concatenate(batch, axis=0)
t0 = time.perf_counter()
output_ids = generate_increment(model_predict, input_ids, input_length, args_opt, tokenizer, verbose)
t1 = time.perf_counter()
batch, input_length = [], []
if rank % 8 == 0:
print(f"=== Batch time: {t1 - t0}s")
for k, out in enumerate(output_ids):
if not out.endswith('\n'):
out = out + '\n'
print(f"=================== generation {sample_ids[k]} ====================")
print(out, flush=True)
generations.append(out)
if rank == 0:
f = open(save_path, 'a')
f.write(generations[-1])
f.close()
sample_ids = []
if len(batch) > 0:
for j in range(batch_size - len(sample_ids)):
batch.append(np.zeros((1, 1)))
input_length.append(-1)
valid_length = max(input_length)
for j in range(len(batch)):
batch[j] = np.pad(batch[j], ((0, 0), (0, valid_length - batch[j].shape[1])),
'constant', constant_values=(args_opt.end_token, args_opt.end_token))
input_ids = np.concatenate(batch, axis=0)
t0 = time.perf_counter()
output_ids = generate_increment(model_predict, input_ids, input_length, args_opt, tokenizer, verbose)
t1 = time.perf_counter()
if rank % 8 == 0:
print(f"=== Batch time: {t1 - t0}s")
for k, out in enumerate(output_ids):
if input_length[k] == -1:
break
if not out.endswith('\n'):
out = out + '\n'
print(f"=================== generation {sample_ids[k]} ====================")
print(out, flush=True)
generations.append(out)
if rank == 0:
f = open(save_path, 'a')
f.write(generations[-1])
f.close()
def main():
"""Main process for predict or export model"""
print("===Enter main!")
opt = get_args(True)
set_parse(opt)
model_predict, config, rank = load_model(opt)
if opt.export:
export_mindir(model_predict, config)
else:
run_predict(model_predict, config, opt, rank)
if __name__ == "__main__":
main()