Younes Belkada winglian commited on
Commit
db9094d
1 Parent(s): 6ef46f8

FEAT: add tagging support to axolotl (#1004)

Browse files

* add tagging support to axolotl

* chore: lint

* fix method w self

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

Files changed (1) hide show
  1. src/axolotl/core/trainer_builder.py +35 -1
src/axolotl/core/trainer_builder.py CHANGED
@@ -9,7 +9,7 @@ import math
9
  import sys
10
  from abc import abstractmethod
11
  from dataclasses import dataclass, field
12
- from functools import partial
13
  from pathlib import Path
14
  from typing import Optional
15
 
@@ -120,6 +120,7 @@ class AxolotlTrainer(Trainer):
120
  """
121
 
122
  args = None # type: AxolotlTrainingArguments
 
123
 
124
  def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
125
  self.num_epochs = num_epochs
@@ -290,12 +291,41 @@ class AxolotlTrainer(Trainer):
290
  # return (loss, outputs) if return_outputs else loss
291
  return super().compute_loss(model, inputs, return_outputs=return_outputs)
292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
  class AxolotlMambaTrainer(AxolotlTrainer):
295
  """
296
  Mamba specific trainer to handle loss calculation
297
  """
298
 
 
 
299
  def compute_loss(
300
  self,
301
  model,
@@ -322,6 +352,8 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
322
  Trainer subclass that uses the OneCycleLR scheduler
323
  """
324
 
 
 
325
  def __init__(self, *args, **kwargs):
326
  super().__init__(*args, **kwargs)
327
  self.lr_scheduler = None
@@ -351,6 +383,8 @@ class ReLoRATrainer(AxolotlTrainer):
351
  Trainer subclass that uses the OneCycleLR scheduler
352
  """
353
 
 
 
354
  def __init__(self, *args, **kwargs):
355
  super().__init__(*args, **kwargs)
356
  self.lr_scheduler = None
 
9
  import sys
10
  from abc import abstractmethod
11
  from dataclasses import dataclass, field
12
+ from functools import partial, wraps
13
  from pathlib import Path
14
  from typing import Optional
15
 
 
120
  """
121
 
122
  args = None # type: AxolotlTrainingArguments
123
+ tag_names = ["axolotl"]
124
 
125
  def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
126
  self.num_epochs = num_epochs
 
291
  # return (loss, outputs) if return_outputs else loss
292
  return super().compute_loss(model, inputs, return_outputs=return_outputs)
293
 
294
+ def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
295
+ if isinstance(tag_names, str):
296
+ tag_names = [tag_names]
297
+
298
+ if kwargs is not None:
299
+ if "tags" not in kwargs:
300
+ kwargs["tags"] = tag_names
301
+ elif "tags" in kwargs and isinstance(kwargs["tags"], list):
302
+ kwargs["tags"].extend(tag_names)
303
+ elif "tags" in kwargs and isinstance(kwargs["tags"], str):
304
+ tag_names.append(kwargs["tags"])
305
+ kwargs["tags"] = tag_names
306
+
307
+ return kwargs
308
+
309
+ @wraps(Trainer.push_to_hub)
310
+ def push_to_hub(self, *args, **kwargs) -> str:
311
+ """
312
+ Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
313
+ model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
314
+ """
315
+ kwargs = self._sanitize_kwargs_for_tagging(
316
+ tag_names=self.tag_names, kwargs=kwargs
317
+ )
318
+
319
+ return super().push_to_hub(*args, **kwargs)
320
+
321
 
322
  class AxolotlMambaTrainer(AxolotlTrainer):
323
  """
324
  Mamba specific trainer to handle loss calculation
325
  """
326
 
327
+ tag_names = ["axolotl", "mamba"]
328
+
329
  def compute_loss(
330
  self,
331
  model,
 
352
  Trainer subclass that uses the OneCycleLR scheduler
353
  """
354
 
355
+ tag_names = ["axolotl", "onecycle"]
356
+
357
  def __init__(self, *args, **kwargs):
358
  super().__init__(*args, **kwargs)
359
  self.lr_scheduler = None
 
383
  Trainer subclass that uses the OneCycleLR scheduler
384
  """
385
 
386
+ tag_names = ["axolotl", "relora"]
387
+
388
  def __init__(self, *args, **kwargs):
389
  super().__init__(*args, **kwargs)
390
  self.lr_scheduler = None