tmm1 commited on
Commit
7b55fe6
1 Parent(s): e029ab3

improve GPU logging to break out pytorch cache and system mem

Browse files
scripts/finetune.py CHANGED
@@ -18,7 +18,6 @@ from optimum.bettertransformer import BetterTransformer
18
  from transformers import GenerationConfig, TextStreamer
19
 
20
  from axolotl.logging_config import configure_logging
21
- from axolotl.utils.bench import log_gpu_memory_usage
22
  from axolotl.utils.config import normalize_config, validate_config
23
  from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
24
  from axolotl.utils.dict import DictDefault
@@ -226,8 +225,6 @@ def train(
226
  LOG.info("Finished preparing dataset. Exiting...")
227
  return
228
 
229
- log_gpu_memory_usage(LOG, "baseline", cfg.device)
230
-
231
  # Load the model and tokenizer
232
  LOG.info("loading model and (optionally) peft_config...")
233
  model, peft_config = load_model(cfg, tokenizer)
 
18
  from transformers import GenerationConfig, TextStreamer
19
 
20
  from axolotl.logging_config import configure_logging
 
21
  from axolotl.utils.config import normalize_config, validate_config
22
  from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
23
  from axolotl.utils.dict import DictDefault
 
225
  LOG.info("Finished preparing dataset. Exiting...")
226
  return
227
 
 
 
228
  # Load the model and tokenizer
229
  LOG.info("loading model and (optionally) peft_config...")
230
  model, peft_config = load_model(cfg, tokenizer)
src/axolotl/utils/bench.py CHANGED
@@ -4,13 +4,23 @@ import pynvml
4
  import torch
5
 
6
 
7
- def gpu_memory_usage(device):
 
 
 
 
 
 
 
 
 
 
 
8
  if isinstance(device, torch.device):
9
  device = device.index
10
  if isinstance(device, str) and device.startswith("cuda:"):
11
  device = int(device[5:])
12
 
13
- # NB torch.cuda.memory_usage returns zero so we use lower level api
14
  pynvml.nvmlInit()
15
  handle = pynvml.nvmlDeviceGetHandleByIndex(device)
16
  info = pynvml.nvmlDeviceGetMemoryInfo(handle)
@@ -18,6 +28,13 @@ def gpu_memory_usage(device):
18
 
19
 
20
  def log_gpu_memory_usage(log, msg, device):
 
 
 
 
 
 
21
  log.info(
22
- f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
23
  )
 
 
4
  import torch
5
 
6
 
7
+ def gpu_memory_usage(device=0):
8
+ return torch.cuda.memory_allocated(device) / 1024.0**3
9
+
10
+
11
+ def gpu_memory_usage_all(device=0):
12
+ usage = torch.cuda.memory_allocated(device) / 1024.0**3
13
+ reserved = torch.cuda.memory_reserved(device) / 1024.0**3
14
+ smi = gpu_memory_usage_smi(device)
15
+ return usage, reserved - usage, max(0, smi - reserved)
16
+
17
+
18
+ def gpu_memory_usage_smi(device=0):
19
  if isinstance(device, torch.device):
20
  device = device.index
21
  if isinstance(device, str) and device.startswith("cuda:"):
22
  device = int(device[5:])
23
 
 
24
  pynvml.nvmlInit()
25
  handle = pynvml.nvmlDeviceGetHandleByIndex(device)
26
  info = pynvml.nvmlDeviceGetMemoryInfo(handle)
 
28
 
29
 
30
  def log_gpu_memory_usage(log, msg, device):
31
+ usage, cache, misc = gpu_memory_usage_all(device)
32
+ extras = []
33
+ if cache > 0:
34
+ extras.append(f"+{cache:.03f}GB cache")
35
+ if misc > 0:
36
+ extras.append(f"+{misc:.03f}GB misc")
37
  log.info(
38
+ f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
39
  )
40
+ return usage, cache, misc
src/axolotl/utils/callbacks.py CHANGED
@@ -74,10 +74,10 @@ class SaveBetterTransformerModelCallback(
74
  return control
75
 
76
 
77
- class PrintGPUStatsCallback(
78
  TrainerCallback
79
  ): # pylint: disable=too-few-public-methods disable=unused-argument
80
- """Callback to print GPU utilization"""
81
 
82
  def __init__(self, cfg):
83
  self.cfg = cfg
@@ -90,7 +90,7 @@ class PrintGPUStatsCallback(
90
  control: TrainerControl,
91
  **kwargs,
92
  ):
93
- if not self.logged:
94
  log_gpu_memory_usage(LOG, "while training", self.cfg.device)
95
  self.logged = True
96
  return control
 
74
  return control
75
 
76
 
77
+ class GPUStatsCallback(
78
  TrainerCallback
79
  ): # pylint: disable=too-few-public-methods disable=unused-argument
80
+ """Callback to track GPU utilization"""
81
 
82
  def __init__(self, cfg):
83
  self.cfg = cfg
 
90
  control: TrainerControl,
91
  **kwargs,
92
  ):
93
+ if not self.logged and state.global_step > 1:
94
  log_gpu_memory_usage(LOG, "while training", self.cfg.device)
95
  self.logged = True
96
  return control
src/axolotl/utils/config.py CHANGED
@@ -5,6 +5,8 @@ import os
5
 
6
  import torch
7
 
 
 
8
  LOG = logging.getLogger("axolotl")
9
 
10
 
@@ -54,6 +56,8 @@ def normalize_config(cfg):
54
  else:
55
  torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
56
 
 
 
57
 
58
  def validate_config(cfg):
59
  if cfg.max_packed_sequence_len and cfg.sample_packing:
 
5
 
6
  import torch
7
 
8
+ from axolotl.utils.bench import log_gpu_memory_usage
9
+
10
  LOG = logging.getLogger("axolotl")
11
 
12
 
 
56
  else:
57
  torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
58
 
59
+ log_gpu_memory_usage(LOG, "baseline", cfg.device)
60
+
61
 
62
  def validate_config(cfg):
63
  if cfg.max_packed_sequence_len and cfg.sample_packing:
src/axolotl/utils/models.py CHANGED
@@ -381,9 +381,6 @@ def load_model(
381
  module.scales = module.scales.half()
382
  module.bias = module.bias.half()
383
 
384
- if model.device.type == "cuda":
385
- log_gpu_memory_usage(LOG, "after adapters", model.device)
386
-
387
  if (
388
  torch.cuda.device_count() > 1
389
  and int(os.getenv("WORLD_SIZE", "1")) > 1
@@ -406,6 +403,9 @@ def load_model(
406
  if cfg.flash_optimum:
407
  model = BetterTransformer.transform(model)
408
 
 
 
 
409
  # TODO resume_from_checkpoint handling
410
  return model, lora_config
411
 
 
381
  module.scales = module.scales.half()
382
  module.bias = module.bias.half()
383
 
 
 
 
384
  if (
385
  torch.cuda.device_count() > 1
386
  and int(os.getenv("WORLD_SIZE", "1")) > 1
 
403
  if cfg.flash_optimum:
404
  model = BetterTransformer.transform(model)
405
 
406
+ if cfg.adapter is not None:
407
+ log_gpu_memory_usage(LOG, "after adapters", model.device)
408
+
409
  # TODO resume_from_checkpoint handling
410
  return model, lora_config
411
 
src/axolotl/utils/trainer.py CHANGED
@@ -22,7 +22,7 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
22
  from transformers.trainer_pt_utils import get_parameter_names
23
 
24
  from axolotl.utils.callbacks import (
25
- PrintGPUStatsCallback,
26
  SaveBetterTransformerModelCallback,
27
  SavePeftModelCallback,
28
  )
@@ -555,7 +555,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
555
  trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
556
 
557
  callbacks = []
558
- callbacks.append(PrintGPUStatsCallback(cfg))
559
  # TODO on_save callback to sync checkpoints to GCP/AWS in background
560
  if cfg.early_stopping_patience:
561
  early_stop_cb = EarlyStoppingCallback(
 
22
  from transformers.trainer_pt_utils import get_parameter_names
23
 
24
  from axolotl.utils.callbacks import (
25
+ GPUStatsCallback,
26
  SaveBetterTransformerModelCallback,
27
  SavePeftModelCallback,
28
  )
 
555
  trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
556
 
557
  callbacks = []
558
+ callbacks.append(GPUStatsCallback(cfg))
559
  # TODO on_save callback to sync checkpoints to GCP/AWS in background
560
  if cfg.early_stopping_patience:
561
  early_stop_cb = EarlyStoppingCallback(