Hi authors, I found that there may be a problem in your code when calculating lm_loss for distillm. As shown in the following code, lm_loss is calculated with logits and labels in no_model_batch. However, sometimes model_batch and no_model_batch are from SGO, and calculating lm_loss on this data will push the student model to learn the outputs generated by itself.
|
elif "adaptive" in args.type and r < adaptive_threshold: |
|
model_batch, no_model_batch = replay_buffer.sample() |
|
model_batch, no_model_batch = replay_buffer.move_to_device(model_batch, no_model_batch, device) |
|
|
|
model.train() |
|
|
|
outputs = model(**model_batch, use_cache=False) |
|
|
|
logits = outputs.logits |
|
if args.model_parallel: |
|
raise NotImplementedError |
|
else: |
|
lm_loss = loss_func(logits.float().view(-1, logits.shape[-1]), no_model_batch["label"].view(-1)) |
|
|
|
if teacher_model is not None: |
|
distil_loss = get_distil_loss(args, tokenizer, model, teacher_model, model_batch, no_model_batch, logits) |
|
loss = (1 - args.kd_ratio) * lm_loss + args.kd_ratio * distil_loss |
|
else: |
|
loss = lm_loss |
In our experiments, we found this will lead to model collapse when the threshold increases (after several epochs).
Hi authors, I found that there may be a problem in your code when calculating lm_loss for distillm. As shown in the following code, lm_loss is calculated with logits and labels in
no_model_batch. However, sometimesmodel_batchandno_model_batchare from SGO, and calculating lm_loss on this data will push the student model to learn the outputs generated by itself.distillm/finetune.py
Lines 321 to 339 in d47e77f
In our experiments, we found this will lead to model collapse when the threshold increases (after several epochs).