tmm1 commited on
Commit
9643121
2 Parent(s): f5c11f8 9c31410

Merge pull request #354 from tmm1/gpu-util

Browse files
requirements.txt CHANGED
@@ -19,3 +19,4 @@ evaluate==0.4.0
19
  rouge-score==0.1.2
20
  scipy
21
  scikit-learn==1.2.2
 
 
19
  rouge-score==0.1.2
20
  scipy
21
  scikit-learn==1.2.2
22
+ pynvml
scripts/finetune.py CHANGED
@@ -18,6 +18,7 @@ from optimum.bettertransformer import BetterTransformer
18
  from transformers import GenerationConfig, TextStreamer
19
 
20
  from axolotl.logging_config import configure_logging
 
21
  from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
22
  from axolotl.utils.dict import DictDefault
23
  from axolotl.utils.models import load_model, load_tokenizer
@@ -250,6 +251,8 @@ def train(
250
  LOG.info("Finished preparing dataset. Exiting...")
251
  return
252
 
 
 
253
  # Load the model and tokenizer
254
  LOG.info("loading model and peft_config...")
255
  model, peft_config = load_model(
 
18
  from transformers import GenerationConfig, TextStreamer
19
 
20
  from axolotl.logging_config import configure_logging
21
+ from axolotl.utils.bench import log_gpu_memory_usage
22
  from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
23
  from axolotl.utils.dict import DictDefault
24
  from axolotl.utils.models import load_model, load_tokenizer
 
251
  LOG.info("Finished preparing dataset. Exiting...")
252
  return
253
 
254
+ log_gpu_memory_usage(LOG, "baseline", cfg.device)
255
+
256
  # Load the model and tokenizer
257
  LOG.info("loading model and peft_config...")
258
  model, peft_config = load_model(
src/axolotl/utils/bench.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Benchmarking and measurement utilities"""
2
+
3
+ import pynvml
4
+ import torch
5
+
6
+
7
+ def gpu_memory_usage(device):
8
+ if isinstance(device, torch.device):
9
+ device = device.index
10
+ if isinstance(device, str) and device.startswith("cuda:"):
11
+ device = int(device[5:])
12
+
13
+ # NB torch.cuda.memory_usage returns zero so we use lower level api
14
+ pynvml.nvmlInit()
15
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device)
16
+ info = pynvml.nvmlDeviceGetMemoryInfo(handle)
17
+ return info.used / 1024.0**3
18
+
19
+
20
+ def log_gpu_memory_usage(log, msg, device):
21
+ log.info(
22
+ f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
23
+ )
src/axolotl/utils/callbacks.py CHANGED
@@ -1,5 +1,6 @@
1
  """Callbacks for Trainer class"""
2
 
 
3
  import os
4
 
5
  from optimum.bettertransformer import BetterTransformer
@@ -11,6 +12,10 @@ from transformers import (
11
  )
12
  from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
13
 
 
 
 
 
14
 
15
  class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
16
  """Callback to save the PEFT adapter"""
@@ -67,3 +72,25 @@ class SaveBetterTransformerModelCallback(
67
  # the trainer will raise an exception since it can't save a BetterTransformer wrapped model
68
  control.should_save = False
69
  return control
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Callbacks for Trainer class"""
2
 
3
+ import logging
4
  import os
5
 
6
  from optimum.bettertransformer import BetterTransformer
 
12
  )
13
  from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
14
 
15
+ from axolotl.utils.bench import log_gpu_memory_usage
16
+
17
+ LOG = logging.getLogger("axolotl.callbacks")
18
+
19
 
20
  class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
21
  """Callback to save the PEFT adapter"""
 
72
  # the trainer will raise an exception since it can't save a BetterTransformer wrapped model
73
  control.should_save = False
74
  return control
75
+
76
+
77
+ class PrintGPUStatsCallback(
78
+ TrainerCallback
79
+ ): # pylint: disable=too-few-public-methods disable=unused-argument
80
+ """Callback to print GPU utilization"""
81
+
82
+ def __init__(self, cfg):
83
+ self.cfg = cfg
84
+ self.logged = False
85
+
86
+ def on_step_end(
87
+ self,
88
+ args: TrainingArguments,
89
+ state: TrainerState,
90
+ control: TrainerControl,
91
+ **kwargs,
92
+ ):
93
+ if not self.logged:
94
+ log_gpu_memory_usage(LOG, "while training", self.cfg.device)
95
+ self.logged = True
96
+ return control
src/axolotl/utils/models.py CHANGED
@@ -22,6 +22,7 @@ from transformers import ( # noqa: F401
22
  )
23
 
24
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
 
25
 
26
  LOG = logging.getLogger("axolotl")
27
 
@@ -324,6 +325,9 @@ def load_model(
324
  )
325
  model.config.max_position_embeddings = cfg.sequence_len
326
 
 
 
 
327
  if not cfg.gptq and (
328
  (cfg.adapter == "lora" and load_in_8bit)
329
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -360,6 +364,9 @@ def load_model(
360
  module.scales = module.scales.half()
361
  module.bias = module.bias.half()
362
 
 
 
 
363
  if (
364
  torch.cuda.device_count() > 1
365
  and int(os.getenv("WORLD_SIZE", "1")) > 1
 
22
  )
23
 
24
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
25
+ from axolotl.utils.bench import log_gpu_memory_usage
26
 
27
  LOG = logging.getLogger("axolotl")
28
 
 
325
  )
326
  model.config.max_position_embeddings = cfg.sequence_len
327
 
328
+ if model.device.type == "cuda":
329
+ log_gpu_memory_usage(LOG, "after model load", model.device)
330
+
331
  if not cfg.gptq and (
332
  (cfg.adapter == "lora" and load_in_8bit)
333
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
 
364
  module.scales = module.scales.half()
365
  module.bias = module.bias.half()
366
 
367
+ if model.device.type == "cuda":
368
+ log_gpu_memory_usage(LOG, "after adapters", model.device)
369
+
370
  if (
371
  torch.cuda.device_count() > 1
372
  and int(os.getenv("WORLD_SIZE", "1")) > 1
src/axolotl/utils/trainer.py CHANGED
@@ -18,6 +18,7 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
18
  from transformers.trainer_pt_utils import get_parameter_names
19
 
20
  from axolotl.utils.callbacks import (
 
21
  SaveBetterTransformerModelCallback,
22
  SavePeftModelCallback,
23
  )
@@ -292,6 +293,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
292
  trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
293
 
294
  callbacks = []
 
295
  # TODO on_save callback to sync checkpoints to GCP/AWS in background
296
  if cfg.early_stopping_patience:
297
  early_stop_cb = EarlyStoppingCallback(
 
18
  from transformers.trainer_pt_utils import get_parameter_names
19
 
20
  from axolotl.utils.callbacks import (
21
+ PrintGPUStatsCallback,
22
  SaveBetterTransformerModelCallback,
23
  SavePeftModelCallback,
24
  )
 
293
  trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
294
 
295
  callbacks = []
296
+ callbacks.append(PrintGPUStatsCallback(cfg))
297
  # TODO on_save callback to sync checkpoints to GCP/AWS in background
298
  if cfg.early_stopping_patience:
299
  early_stop_cb = EarlyStoppingCallback(