Nanobit commited on
Commit
813aab3
1 Parent(s): a27d594

Fix Trainer() got multiple values for keyword argument 'callbacks'

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/trainer.py +5 -5
src/axolotl/utils/trainer.py CHANGED
@@ -175,12 +175,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
175
  )
176
  trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
177
 
 
178
  # TODO on_save callback to sync checkpoints to GCP/AWS in background
179
  if cfg.early_stopping_patience:
180
  early_stop_cb = EarlyStoppingCallback(
181
  cfg.early_stopping_patience,
182
  )
183
- trainer_kwargs["callbacks"] = [early_stop_cb]
 
 
 
184
 
185
  data_collator_kwargs = {
186
  "padding": True,
@@ -190,10 +194,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
190
  else:
191
  data_collator_kwargs["pad_to_multiple_of"] = 8
192
 
193
- callbacks = []
194
- if cfg.adapter == 'lora':
195
- callbacks.append(SavePeftModelCallback)
196
-
197
  trainer = transformers.Trainer(
198
  model=model,
199
  train_dataset=train_dataset,
 
175
  )
176
  trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
177
 
178
+ callbacks = []
179
  # TODO on_save callback to sync checkpoints to GCP/AWS in background
180
  if cfg.early_stopping_patience:
181
  early_stop_cb = EarlyStoppingCallback(
182
  cfg.early_stopping_patience,
183
  )
184
+ callbacks.append(early_stop_cb)
185
+
186
+ if cfg.local_rank == 0 and cfg.adapter == 'lora': # only save in rank 0
187
+ callbacks.append(SavePeftModelCallback)
188
 
189
  data_collator_kwargs = {
190
  "padding": True,
 
194
  else:
195
  data_collator_kwargs["pad_to_multiple_of"] = 8
196
 
 
 
 
 
197
  trainer = transformers.Trainer(
198
  model=model,
199
  train_dataset=train_dataset,