-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathtraining.py
1822 lines (1630 loc) · 67.2 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2024, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file has been modified from its original version
#
"""Pretrain utilities."""
from datetime import datetime
from functools import partial
from collections import defaultdict
import math
import sys
from contextlib import nullcontext
import torch
import torch.nn.functional as F
import deepspeed
from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler
import numpy as np
from megatron.utils import (
Timers,
init_wandb,
get_ltor_masks_and_position_ids,
reduce_losses,
)
from megatron import print_rank_0, mpu
from megatron.model import (
GPT2ModelPipe,
SoftEmbedding,
get_params_for_weight_decay_optimization,
mark_norms_for_sequence_parallel_grad_sync,
)
from megatron.mpu.mappings import gather_from_model_parallel_region
from megatron.checkpointing import load_checkpoint, save_checkpoint
from megatron.data.data_utils import (
build_train_valid_test_data_loaders,
shift_and_wrap_data_loaders,
)
from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR
from megatron.logging import tb_wandb_log, training_log
from megatron.utils import (
OverflowMonitor,
get_noise_scale_logger,
get_total_params,
CharCounter,
)
from megatron.model.weight_server import start_server
from megatron.model.gpt2_model import cross_entropy
from megatron.mpu import vocab_parallel_cross_entropy
from pickle import dump
import os
def mup_weights_reinit(neox_args, model):
def has_method(o, name):
return callable(getattr(o, name, None))
for layer in model.modules():
# This normally would happen in set_base_shapes if we actually were able to use the MuReadout class
if hasattr(layer, "mup_rescale_parameters") and layer.mup_rescale_parameters:
layer._rescale_parameters()
if has_method(layer, "mup_reinitialize_weights"):
layer.mup_reinitialize_weights(neox_args)
def save_base_shapes(neox_args, base_shapes, use_cache):
# Instantiation of the base model fails in the init function (init_functions.py) because we haven't called set_base_shapes on it at this point, so disable it temporarily here
neox_args.use_mup = False
base_model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
if not neox_args.is_pipe_parallel:
base_model = base_model.to_sequential()
try:
import mup
except ModuleNotFoundError:
print("Please install mup https://github.com/microsoft/mup")
raise Exception
base_shapes = mup.get_shapes(base_model)
del base_model
old_hidden_size = neox_args.hidden_size
neox_args.hidden_size = neox_args.hidden_size * neox_args.mup_width_scale
delta_model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
if not neox_args.is_pipe_parallel:
delta_model = delta_model.to_sequential()
delta_shapes = mup.get_shapes(delta_model)
# change back
neox_args.use_mup = True
neox_args.hidden_size = old_hidden_size
save_shapes = f"{neox_args.base_shapes_file}.{torch.distributed.get_rank()}"
print(f"saving base shapes at {save_shapes}")
mup.make_base_shapes(base_shapes, delta_shapes, savefile=save_shapes)
print(f"base shapes saved...exiting")
sys.exit(1)
def mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator):
from megatron.mup_substitute import get_coord_data
from mup.coord_check import plot_coord_data
def lazy_model(hidden_size):
def gen():
old_hidden_size = neox_args.hidden_size
neox_args.hidden_size = hidden_size
model, optimizer, _, _ = setup_model_and_optimizer(
neox_args=neox_args, use_cache=False
)
neox_args.hidden_size = old_hidden_size
return model
return gen
models = {}
# Hidden size needs to be divisible by num attention heads
for hidden_size in (neox_args.num_attention_heads * (2**p) for p in range(2, 9)):
models[hidden_size] = lazy_model(hidden_size)
neox_args.use_mup = True
df_up = get_coord_data(
neox_args, timers, lr_scheduler, models, train_data_iterator, mup=True
)
neox_args.use_mup = False
df_sp = get_coord_data(
neox_args, timers, lr_scheduler, models, train_data_iterator, mup=False
)
plot_coord_data(df_up, save_to=f"coord_check_up.{torch.distributed.get_rank()}.jpg")
plot_coord_data(df_sp, save_to=f"coord_check_sp.{torch.distributed.get_rank()}.jpg")
print_rank_0("Saved coord check plots... exiting")
sys.exit(1)
def update_iterations(neox_args, data_loaders):
"""
Compute the number of train iterations if not specified and num_epochs, updates the neox_args object.
Note that if len(train_dataloader) % gradient_accumulation_steps != 0, this will configure neox
to do as many iterations as possible while ensuring that each example is seen *at most* train_epochs
times.
"""
if (not neox_args.do_train) or (neox_args.train_iters is not None):
pass
elif neox_args.train_iters is None and neox_args.train_epochs is None:
print_rank_0(
"ERROR:Failed to specify either train_epochs or train_iters in config file"
)
else:
global_rank = torch.distributed.get_rank()
if global_rank == 0:
train_dataloader = data_loaders["train"]
train_epochs = neox_args.train_epochs
gradient_accumulation_steps = neox_args.gradient_accumulation_steps
train_dataloader_len = len(train_dataloader)
train_iterations = (
train_dataloader_len * train_epochs
) // gradient_accumulation_steps
train_iters_tensor = torch.cuda.LongTensor([train_iterations])
else:
train_iters_tensor = torch.cuda.LongTensor([0])
torch.distributed.broadcast(train_iters_tensor, src=0)
neox_args.train_iters = train_iters_tensor[0].item()
print_rank_0(
f"Training for a total of {neox_args.train_iters} iterations, corresponding to {neox_args.train_epochs} epochs."
)
def pretrain(neox_args):
"""Main training program.
This function will run the following in the order provided:
1) initialize Megatron.
2) get train/val/test datasets.
3) setup model, optimizer and lr schedule.
4) configure data loading
5) train the model.
Arguments:
neox_args: an instance of NeoXArgs containing the configuration for pretrain
"""
# setup logging and timers
init_wandb(neox_args=neox_args)
timers = Timers(
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
# Initialize and get arguments, timers, and Tensorboard writer.
initialize_megatron(neox_args=neox_args)
# Create data loaders
timers("train/valid/test data loaders").start()
data_loaders = build_train_valid_test_data_loaders(neox_args=neox_args)
update_iterations(neox_args=neox_args, data_loaders=data_loaders)
timers("train/valid/test data loaders").stop()
# Model, optimizer, and learning rate.
timers("model and optimizer").start()
model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
neox_args=neox_args, use_cache=False, iteration=neox_args.iteration
)
timers("model and optimizer").stop()
if neox_args.serve_model_weights:
start_server(model)
# sync...
torch.distributed.barrier()
# Start data stuff:
# Make and configure iterators
timers("train/valid/test data iterators").start()
(
train_data_iterator,
valid_data_iterator,
test_data_iterator,
) = shift_and_wrap_data_loaders(neox_args=neox_args, data_loaders=data_loaders)
timers("train/valid/test data iterators").stop()
if neox_args.use_mup and neox_args.coord_check:
mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator)
# Print setup timing.
print_rank_0("done with setups ...")
timers.log(
[
"train/valid/test data loaders",
"model and optimizer",
"train/valid/test data iterators",
]
)
print_rank_0("training ...")
iteration = neox_args.iteration
# edge case: save step 0 checkpoint if requested and we're starting from step 0
if (
neox_args.save
and neox_args.extra_save_iters
and 0 in neox_args.extra_save_iters
and iteration == 0
):
save_checkpoint(
neox_args=neox_args,
iteration=iteration,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
if neox_args.do_train and neox_args.train_iters > 0:
iteration = train(
neox_args=neox_args,
timers=timers,
model=model,
reference_model=reference_model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
train_data_iterator=train_data_iterator,
valid_data_iterator=valid_data_iterator,
)
if neox_args.do_valid:
prefix = "the end of training for val data"
evaluate_and_print_results(
neox_args=neox_args,
prefix=prefix,
forward_step_func=forward_step,
data_iterator=valid_data_iterator,
model=model,
iteration=iteration,
verbose=False,
timers=timers,
reference_model=reference_model,
)
if neox_args.save and iteration != 0:
save_checkpoint(
neox_args=neox_args,
iteration=iteration,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
if neox_args.do_test:
# Run on test data.
prefix = "the end of training for test data"
evaluate_and_print_results(
neox_args=neox_args,
prefix=prefix,
forward_step_func=forward_step,
data_iterator=test_data_iterator,
model=model,
iteration=iteration,
verbose=True,
timers=timers,
chart_name="test",
reference_model=reference_model,
)
def _get_batch(neox_args, tokenizer, keys, data, datatype, label_mask_zero=False):
"""Support function for get_batch / get_batch pipe (to avoid code repetition)"""
data_b = mpu.broadcast_data(keys, data, datatype)
token_key = keys[0]
label_key = keys[1] if len(keys) > 1 else None
# Unpack.
tokens_ = data_b[token_key].long()
if label_key in data_b:
label_mask = (data_b[label_key].long() >= 0)[:, 1:].contiguous()
labels = torch.where(
data_b[label_key].long() >= 0,
data_b[label_key].long(),
torch.zeros_like(data_b[label_key].long()),
)[:, 1:].contiguous()
else:
label_mask = (tokens_.long() >= 0)[:, 1:].contiguous()
labels = tokens_[:, 1:].contiguous()
if label_mask_zero:
labels = labels * label_mask
tokens = tokens_[:, :-1].contiguous()
# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
eod_token=neox_args.tokenizer.eod,
eod_mask_loss=neox_args.eod_mask_loss,
sliding_window_width=neox_args.sliding_window_width,
)
# combine loss masks from get_ltor_masks_and_position_ids with loss masks from data
loss_mask = label_mask.to(loss_mask.dtype) * loss_mask
return tokens, labels, loss_mask, attention_mask, position_ids
def get_batch(neox_args, data_iterator):
"""Generate a batch"""
# Items and their type.
if neox_args.train_impl in ["normal", "kto", "reinforce"]:
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]
elif neox_args.train_impl in ["dpo", "rm"]:
keys = (
[["pos", "pos_label"], ["neg", "neg_label"]]
if neox_args.pos_train_label_data_paths
else [["pos"], ["neg"]]
)
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
if neox_args.train_impl == "normal":
return _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
keys=keys,
data=data,
datatype=datatype,
)
elif neox_args.train_impl == "kto":
assert (
neox_args.train_micro_batch_size_per_gpu > 1
), "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1."
tup = _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
keys=keys,
data=data,
datatype=datatype,
)
# Remove the last token from the reward since we predict the next token, so
# Reward of <current prediction> will be based on the label of <next token>
rw_data = mpu.broadcast_data(["reward"], data, torch.float)["reward"][
:, :-1
].contiguous()
ref_data = (
mpu.broadcast_data(["ref"], data, torch.float)["ref"][:, :-1].contiguous()
if neox_args.precompute_model_name
else None
)
return tup + (rw_data, ref_data)
elif neox_args.train_impl == "reinforce":
tup = _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
keys=keys,
data=data,
datatype=datatype,
)
rw_data = mpu.broadcast_data(["reward"], data, torch.float)["reward"]
raw_rw_data = mpu.broadcast_data(["raw_reward"], data, torch.float)[
"raw_reward"
]
return tup + (rw_data, raw_rw_data)
elif neox_args.train_impl in ["dpo", "rm"]:
pos_tup = _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
keys=keys[0],
data=data,
datatype=datatype,
label_mask_zero=True,
)
neg_tup = _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
keys=keys[1],
data=data,
datatype=datatype,
label_mask_zero=True,
)
if neox_args.precompute_model_name:
ref_data = mpu.broadcast_data(["pos_ref", "neg_ref"], data, torch.float)
else:
ref_data = {"pos_ref": None}
return [
torch.cat((pos_item, neg_item), dim=0)
for pos_item, neg_item in zip(pos_tup, neg_tup)
] + [
torch.cat((ref_data["pos_ref"], ref_data["neg_ref"]), dim=0)[
:, :-1
].contiguous()
if ref_data["pos_ref"] is not None
else None
]
def get_batch_pipe(data, neox_args, curr_scheduler=None):
"""A modification of get_batch() to work with the latest batch instead of an iterator."""
assert neox_args.train_impl not in [
"kto",
"dpo",
"rm",
], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0"
# Items and their type.
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]
datatype = torch.int64
tokens, labels, loss_mask, attention_mask, position_ids = _get_batch(
neox_args, neox_args.tokenizer, keys, data, datatype
)
if curr_scheduler is not None:
# iteration + 1 to align with how/when DeepSpeed updates the buffers
curriculum_seqlen = curr_scheduler.update_difficulty(neox_args.iteration + 1)
if curriculum_seqlen < tokens.size()[1]:
# seqlen-based curriculum learning
# input_ids, position_ids, labels have size [batch size, seqlen]
# input_ids = input_ids[:, :curriculum_seqlen].contiguous()
tokens = tokens[:, :curriculum_seqlen].contiguous()
position_ids = position_ids[:, :curriculum_seqlen].contiguous()
if labels is not None:
labels = labels[:, :curriculum_seqlen].contiguous()
if loss_mask is not None:
loss_mask = loss_mask[:, :curriculum_seqlen].contiguous()
# attention_mask has size [1, 1, seqlen, seqlen]
attention_mask = attention_mask[
:, :, :curriculum_seqlen, :curriculum_seqlen
].contiguous()
# unpack data
return (tokens, position_ids, attention_mask), (labels, loss_mask)
def get_batch_sequential(forward_input, neox_args):
"""A modification of get_batch() to work with the latest batch instead of an iterator."""
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
data=forward_input[0],
eod_token=neox_args.tokenizer.eod,
eod_mask_loss=neox_args.eod_mask_loss,
)
return (forward_input[0], forward_input[1], attention_mask)
def average_losses_across_data_parallel_group(losses):
"""Reduce a tensor of losses across all GPUs."""
averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(averaged_losses, group=mpu.get_data_parallel_group())
averaged_losses = averaged_losses / torch.distributed.get_world_size(
group=mpu.get_data_parallel_group()
)
return averaged_losses
def mb_moe_loss_func(args, loss_mask, output_tensor=None):
from megatron.model import megablocks_utils
from megatron.model.megablocks_utils import moe
# NOTE: For pipeline parallelism this function will be run on the
# non-final stages to calculate load balancing loss contribution
# for the MoE layers within the stage. For these cases, output_tensor
# will be None.
loss, loss_dict = (None, {})
if False:
assert output_tensor is not None
loss, loss_dict = loss_func(loss_mask, output_tensor)
assert loss.numel() == 1
# NOTE: If recompute is enabled we will collect duplicate load
# balancing loss contributions. Prune these before calculating
# the load balancing loss.
if args.checkpoint_activations:
# Ignore load balancing loss contributions compute during
# the forward pass if recompute is turned on.
load_balancing_loss_data = moe.get_load_balancing_loss()
if args.num_layers * 2 == len(load_balancing_loss_data):
load_balancing_loss_data = load_balancing_loss_data[args.num_layers :]
moe.clear_load_balancing_loss()
for x in load_balancing_loss_data:
moe.save_load_balancing_loss(x)
# Compute the load balancing loss for all MoE layers.
megablocks_args = args = megablocks_utils.as_megablocks_args(args)
lbl = moe.batched_load_balancing_loss(megablocks_args)
moe.clear_load_balancing_loss()
# Average the load balancing loss across data parallel
# replicas and save for logging.
averaged_lbl = average_losses_across_data_parallel_group([lbl])
loss_dict["load balancing loss"] = averaged_lbl[0]
return averaged_lbl, loss_dict
def get_logp(logits, labels, force_fp32=False):
# Rather than reimplementing logp, cross entropy loss is actually logp, just inverted.
if force_fp32:
logits = logits.float()
return -vocab_parallel_cross_entropy(logits, labels)
def get_pos_neg_logp(logits, labels, force_fp32=False):
# Rather than reimplementing logp, cross entropy loss is actually logp, just inverted.
if force_fp32:
logits = logits.float()
return torch.chunk(-vocab_parallel_cross_entropy(logits, labels), 2, 0)
def forward_step(
data_iterator,
model,
neox_args,
timers,
return_logits=False,
is_train=False,
reference_model=None,
):
"""Forward step."""
if neox_args.is_pipe_parallel:
return model.eval_batch(data_iterator, return_logits=return_logits)
# Get the batch.
if neox_args.memory_profiling and neox_args.iteration:
torch.cuda.nvtx.range_push(f"Get batch")
if timers is not None:
timers("batch generator").start()
if neox_args.train_impl == "normal":
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
neox_args=neox_args, data_iterator=data_iterator
)
elif neox_args.train_impl == "kto":
(
tokens,
labels,
loss_mask,
attention_mask,
position_ids,
rewards,
ref_logp,
) = get_batch(neox_args=neox_args, data_iterator=data_iterator)
elif neox_args.train_impl == "reinforce":
(
tokens,
labels,
loss_mask,
attention_mask,
position_ids,
rewards,
raw_rewards,
) = get_batch(neox_args=neox_args, data_iterator=data_iterator)
if neox_args.train_impl in ["dpo", "rm"]:
tokens, labels, loss_mask, attention_mask, position_ids, ref_logp = get_batch(
neox_args=neox_args, data_iterator=data_iterator
)
if timers is not None:
timers("batch generator").stop()
if neox_args.memory_profiling:
torch.cuda.nvtx.range_pop()
if neox_args.memory_profiling:
torch.cuda.nvtx.range_push(f"Forward pass")
metrics = {}
if neox_args.train_impl == "normal":
# Sequential returns moe_losses, but this is not yet supported by pipe parallel
maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args)
if type(maybe_tuple) is tuple:
outputs, moe_losses = maybe_tuple
else:
outputs = maybe_tuple
moe_losses = []
if (
is_train
and neox_args.curriculum_learning
and neox_args.curriculum_seqlen < neox_args.seq_length
):
loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous()
labels = labels[:, : neox_args.curriculum_seqlen].contiguous()
main_loss = cross_entropy(
outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy
)
if neox_args.moe_num_experts > 1:
if neox_args.moe_type == "deepspeed":
moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses)
elif neox_args.moe_type == "megablocks":
moe_loss = mb_moe_loss_func(neox_args, loss_mask, outputs)[0]
else:
raise ValueError(f"Unsupported moe_type: {neox_args.moe_type}")
else:
moe_loss = 0.0
loss = main_loss + moe_loss
elif neox_args.train_impl == "rm":
maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args)
if type(maybe_tuple) is tuple:
outputs, _ = maybe_tuple
else:
outputs = maybe_tuple
pos, neg = torch.chunk(outputs, 2, 0)
pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0)
# We assume that each pos, neg pair occur in the same order
# e.g. second nonzero pos is the corresponding second nonzero neg
# and that there are also an equal number of pos and neg in each sequence.
pos_indx = pos_loss_mask.nonzero()
neg_indx = neg_loss_mask.nonzero()
# indx[:, 0] is the batch index, indx[:, 1] is the token index, we only care about the token index.
pos_indx = pos_indx[:, 1].unsqueeze(1)
neg_indx = neg_indx[:, 1].unsqueeze(1)
pos = torch.gather(pos.squeeze(), dim=1, index=pos_indx)
neg = torch.gather(neg.squeeze(), dim=1, index=neg_indx)
with torch.no_grad():
metrics["pos_values"] = pos.clone().detach().mean()
metrics["neg_values"] = neg.clone().detach().mean()
metrics["margin"] = (pos - neg).clone().detach().mean()
metrics["accuracy"] = ((pos - neg) > 0).clone().detach().float().mean()
loss = (-F.logsigmoid(pos - neg).mean()) + (
(neox_args.z_loss * (pos**2 + neg**2)).mean()
)
elif neox_args.train_impl == "dpo":
# Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
with torch.inference_mode():
# So we can gather token logps...
token_logp_labels = labels.clone()
pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0)
if neox_args.dpo_reference_free:
ref_pos = 0
ref_neg = 0
elif ref_logp is None:
ref_maybe_tuple = reference_model(
(tokens, position_ids, attention_mask), neox_args=neox_args
)
if type(ref_maybe_tuple) is tuple:
# We should ignore MoE losses yeah?
ref_outputs, _ = ref_maybe_tuple
else:
ref_outputs = ref_maybe_tuple
ref_pos, ref_neg = get_pos_neg_logp(
ref_outputs, token_logp_labels, neox_args.dpo_fp32
)
else:
ref_pos, ref_neg = torch.chunk(ref_logp, 2, 0)
ref_pos = (ref_pos * pos_loss_mask).sum(-1)
ref_neg = (ref_neg * neg_loss_mask).sum(-1)
chosen_maybe_tuple = model(
(tokens, position_ids, attention_mask), neox_args=neox_args
)
if type(chosen_maybe_tuple) is tuple:
# We should ignore MoE losses yeah?
chosen_outputs, _ = chosen_maybe_tuple
else:
chosen_outputs = chosen_maybe_tuple
chosen_pos, chosen_neg = get_pos_neg_logp(
chosen_outputs, token_logp_labels, neox_args.dpo_fp32
)
chosen_pos = (chosen_pos * pos_loss_mask).sum(-1)
chosen_neg = (chosen_neg * neg_loss_mask).sum(-1)
with torch.no_grad():
# Collect metrics...
if not neox_args.dpo_reference_free:
metrics["ref_neg"] = ref_neg.clone().detach().mean()
metrics["ref_pos"] = ref_pos.clone().detach().mean()
metrics["chosen_neg"] = chosen_neg.clone().detach().mean()
metrics["chosen_pos"] = chosen_pos.clone().detach().mean()
if not neox_args.dpo_reference_free:
chosen_rewards = neox_args.dpo_beta * (
chosen_pos.clone().detach() - ref_pos.clone().detach()
)
rejected_rewards = neox_args.dpo_beta * (
chosen_neg.clone().detach() - ref_neg.clone().detach()
)
metrics["chosen_rewards"] = chosen_rewards.mean()
metrics["rejected_rewards"] = rejected_rewards.mean()
reward_acc = (chosen_rewards > rejected_rewards).float()
metrics["reward_acc"] = reward_acc.mean()
metrics["margins"] = (chosen_rewards - rejected_rewards).mean()
pi_logrations = chosen_pos - chosen_neg
ref_logrations = ref_pos - ref_neg
logits = pi_logrations - ref_logrations
loss = -F.logsigmoid(neox_args.dpo_beta * logits).mean()
elif neox_args.train_impl == "kto":
# Based on https://github.com/huggingface/trl/blob/main/trl/trainer/kto_trainer.py
# Except we don't have an extra input for KL logp, we just split the batch in half
with torch.no_grad():
# So we can gather token logps...
token_logp_labels = labels.clone()
token_logp_labels[token_logp_labels == -100] = 0
if ref_logp is None:
# Did not precompute logits....
ref_maybe_tuple = reference_model(
(tokens, position_ids, attention_mask), neox_args=neox_args
)
if type(ref_maybe_tuple) is tuple:
# We should ignore MoE losses yeah?
ref_outputs, _ = ref_maybe_tuple
else:
ref_outputs = ref_maybe_tuple
# gather across tensor parallel group
ref_outputs = gather_from_model_parallel_region(ref_outputs)
ref_logp = get_logp(ref_outputs, token_logp_labels, neox_args.kto_fp32)
else:
print(f"REF LOGP: {ref_logp.clone().detach().mean()}")
ref_logp = ref_logp * loss_mask
scaling = (rewards.sum(-1) > 0.001).float() * neox_args.kto_desirable_weight
scaling += (
rewards.sum(-1) < -0.001
).float() * neox_args.kto_undesirable_weight
pos_mask = (rewards > 0.001).float()
neg_mask = (rewards < -0.001).float()
chosen_maybe_tuple = model(
(tokens, position_ids, attention_mask), neox_args=neox_args
)
if type(chosen_maybe_tuple) is tuple:
# We should ignore MoE losses yeah?
chosen_outputs, _ = chosen_maybe_tuple
else:
chosen_outputs = chosen_maybe_tuple
chosen_outputs = gather_from_model_parallel_region(chosen_outputs)
chosen_logp = get_logp(chosen_outputs, token_logp_labels, neox_args.kto_fp32)
chosen_logp = chosen_logp * loss_mask
with torch.no_grad():
# Collect metrics...
metrics["ref_logp"] = ref_logp.clone().detach().sum(-1).mean()
metrics["policy_logp"] = chosen_logp.clone().detach().sum(-1).mean()
metrics["pos_ref_logp"] = (
(ref_logp * pos_mask).clone().detach().sum(-1).mean()
)
metrics["neg_ref_logp"] = (
(ref_logp * neg_mask).clone().detach().sum(-1).mean()
)
metrics["pos_policy_logp"] = (
(chosen_logp * pos_mask).clone().detach().sum(-1).mean()
)
metrics["neg_policy_logp"] = (
(chosen_logp * neg_mask).clone().detach().sum(-1).mean()
)
metrics["kl"] = (
chosen_logp.clone().detach() - ref_logp.clone().detach()
).sum() / loss_mask.sum()
policy_rewards = (
neox_args.kto_beta
* rewards
* (chosen_logp.clone().detach() - ref_logp.clone().detach())
)
reward_acc = (policy_rewards.sum(-1) > 0.0).float()
metrics["reward_acc"] = reward_acc.mean()
metrics["policy_rewards"] = policy_rewards.sum()
print(metrics)
pol_logp1, pol_logp2 = torch.chunk(chosen_logp, 2, 0)
ref_logp1, ref_logp2 = torch.chunk(ref_logp, 2, 0)
reward1, reward2 = torch.chunk(rewards, 2, 0)
scaling1, scaling2 = torch.chunk(scaling, 2, 0)
kl1 = torch.clamp((pol_logp1 - ref_logp1).sum(-1), min=0).mean()
kl2 = torch.clamp((pol_logp2 - ref_logp2).sum(-1), min=0).mean()
log_ratio1 = pol_logp1 - ref_logp1
log_ratio2 = pol_logp2 - ref_logp2
# TODO: Add pack_until_overflow sequence support
loss = (
0.5
* scaling1.mean(-1)
* (
1
- F.sigmoid(
(
neox_args.kto_beta
* reward1.mean(-1)
* (log_ratio1.sum(-1) - kl2.clone().detach())
)
)
)
) + (
0.5
* scaling2.mean(-1)
* (
1
- F.sigmoid(
(
neox_args.kto_beta
* reward2.mean(-1)
* (log_ratio2.sum(-1) - kl1.clone().detach())
)
)
)
)
# print(loss.shape)
loss = loss.mean()
# print(loss.shape)
elif neox_args.train_impl == "reinforce":
if reference_model is not None:
with torch.no_grad():
ref_outputs = reference_model(
(tokens, position_ids, attention_mask), neox_args=neox_args
)
if type(ref_outputs) is tuple:
ref_outputs, _ = ref_outputs
ref_outputs = ref_outputs
if neox_args.kl_impl == "full":
# Have to do the loss over all tokens...
ref_outputs = gather_from_model_parallel_region(ref_outputs)
if neox_args.fp32_reinforce:
ref_outputs = ref_outputs.float()
ref_logp = ref_outputs.log_softmax(dim=-1).detach()
ref_per_token_logp = torch.gather(
ref_logp.clone(), dim=2, index=labels.unsqueeze(2)
).squeeze(2)
else:
ref_per_token_logp = get_logp(
ref_outputs, labels, neox_args.fp32_reinforce
)
metrics["ref_logp"] = ref_per_token_logp.clone().detach().mean()
outputs = model((tokens, position_ids, attention_mask), neox_args=neox_args)
if type(outputs) is tuple:
outputs, _ = outputs
if neox_args.kl_impl == "full":
# Have to do the loss over all tokens...
outputs = gather_from_model_parallel_region(outputs)
if neox_args.fp32_reinforce:
outputs = outputs.float()
logp = outputs.log_softmax(dim=-1)
per_token_logp = torch.gather(
logp.clone(), dim=2, index=labels.unsqueeze(2)
).squeeze(2)
else:
per_token_logp = get_logp(outputs, labels, neox_args.fp32_reinforce)
with torch.no_grad():
metrics["logp"] = per_token_logp.clone().detach().mean()
metrics["reward"] = raw_rewards.clone().detach().mean()
metrics["reward_std"] = raw_rewards.clone().detach().std()
loss_mask_sum = loss_mask.sum()
if reference_model is not None:
if neox_args.kl_impl == "full":
# Following along with
# https://github.com/huggingface/trl/blob/104a02d207b63a4a062882aaff68f2d275493399/trl/trainer/ppo_trainer.py#L1109
kl = F.kl_div(ref_logp, logp, log_target=True, reduction="none").sum(-1)
else:
kl = per_token_logp - ref_per_token_logp
if neox_args.kl_impl == "abs":
kl = kl.abs()
elif neox_args.kl_impl == "mse":
kl = 0.5 * (kl).square()
elif neox_args.kl_impl == "kl":
pass
with torch.no_grad():
metrics["kl"] = kl.clone().detach().mean()
loss = (-per_token_logp * rewards) + (neox_args.kl_div_beta * kl)
loss = (loss * loss_mask).sum(-1) / loss_mask_sum
loss = loss.mean()
else:
loss = -(rewards * per_token_logp)
loss = (loss * loss_mask).sum(-1) / loss_mask_sum
loss = loss.mean()
if neox_args.memory_profiling:
torch.cuda.nvtx.range_pop()
if return_logits:
return loss, outputs, metrics
return loss, metrics
def get_model(neox_args, use_cache=False):
"""Build the model."""
# Build model on cpu.
print_rank_0("building GPT2 model ...")
# Temporarily disable mup so that the base model does not use the mup init functions before set_base_shapes is called below.
# If mup isn't being used anyways, this has no effect.
old_use_mup = neox_args.use_mup
neox_args.use_mup = False
if neox_args.zero_stage in [2, 3]:
if neox_args.pipe_parallel_size == 1:
print_rank_0(
"ZeRO stage 2/3 and the PipelineModule are incompatible, please set 'pipe_parallel_size' to 0 instead"
)
exit()
if neox_args.pipe_parallel_size > 1:
print_rank_0(
"ZeRO stage 2/3 and pipeline paralleism are not supported simultaneously"
)
exit()
if neox_args.model_parallel_size > 1:
print_rank_0(
"ZeRO stage 2/3 and model paralleism are not currently supported simultaneously"
)
exit()
with deepspeed.zero.Init(
config_dict_or_path=neox_args.deepspeed_config
) if neox_args.zero_stage == 3 else nullcontext() as gs:
model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
### soft prompt tuning stuff ###
if neox_args.soft_prompt_tuning is not None and neox_args.soft_prompt_tuning.get(
"enabled", False
):
soft_prompt = SoftEmbedding(
neox_args,
wte=getattr(model, "0").word_embeddings,
n_tokens=neox_args.soft_prompt_tuning.get("n_tokens", 10),
init_string=neox_args.soft_prompt_tuning.get("init_string", ""),
init_range=neox_args.soft_prompt_tuning.get("init_range", 0.5),
)
model.insert_layers(
layers=soft_prompt, idx=1
) # insert the soft prompt layer directly after the word embeddings