winglian commited on
Commit
196ff11
1 Parent(s): 92512c3

skip the gpu memory checks if the device is set to 'auto' (#609)

Browse files

* skip the gpu memory checks if the device is set to 'auto'

* skip gpu mem logging if cpu too

* don't worry about log_gpu_memory_usage since it calls another annotated fn

* rename decorator internal

Files changed (1) hide show
  1. src/axolotl/utils/bench.py +27 -3
src/axolotl/utils/bench.py CHANGED
@@ -1,14 +1,40 @@
1
  """Benchmarking and measurement utilities"""
 
2
 
3
  import pynvml
4
  import torch
5
  from pynvml.nvml import NVMLError
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def gpu_memory_usage(device=0):
9
  return torch.cuda.memory_allocated(device) / 1024.0**3
10
 
11
 
 
12
  def gpu_memory_usage_all(device=0):
13
  usage = torch.cuda.memory_allocated(device) / 1024.0**3
14
  reserved = torch.cuda.memory_reserved(device) / 1024.0**3
@@ -16,6 +42,7 @@ def gpu_memory_usage_all(device=0):
16
  return usage, reserved - usage, max(0, smi - reserved)
17
 
18
 
 
19
  def gpu_memory_usage_smi(device=0):
20
  if isinstance(device, torch.device):
21
  device = device.index
@@ -31,9 +58,6 @@ def gpu_memory_usage_smi(device=0):
31
 
32
 
33
  def log_gpu_memory_usage(log, msg, device):
34
- if not torch.cuda.is_available() or device == "auto":
35
- return (0, 0, 0)
36
-
37
  usage, cache, misc = gpu_memory_usage_all(device)
38
  extras = []
39
  if cache > 0:
 
1
  """Benchmarking and measurement utilities"""
2
+ import functools
3
 
4
  import pynvml
5
  import torch
6
  from pynvml.nvml import NVMLError
7
 
8
 
9
+ def check_cuda_device(default_value):
10
+ """
11
+ wraps a function and returns the default value instead of running the
12
+ wrapped function if cuda isn't available or the device is auto
13
+ :param default_value:
14
+ :return:
15
+ """
16
+
17
+ def deco(func):
18
+ @functools.wraps(func)
19
+ def wrapper(*args, **kwargs):
20
+ device = kwargs.get("device", args[0] if args else None)
21
+
22
+ if not torch.cuda.is_available() or device == "auto" or device == "cpu":
23
+ return default_value
24
+
25
+ return func(*args, **kwargs)
26
+
27
+ return wrapper
28
+
29
+ return deco
30
+
31
+
32
+ @check_cuda_device(0.0)
33
  def gpu_memory_usage(device=0):
34
  return torch.cuda.memory_allocated(device) / 1024.0**3
35
 
36
 
37
+ @check_cuda_device((0.0, 0.0, 0.0))
38
  def gpu_memory_usage_all(device=0):
39
  usage = torch.cuda.memory_allocated(device) / 1024.0**3
40
  reserved = torch.cuda.memory_reserved(device) / 1024.0**3
 
42
  return usage, reserved - usage, max(0, smi - reserved)
43
 
44
 
45
+ @check_cuda_device(0.0)
46
  def gpu_memory_usage_smi(device=0):
47
  if isinstance(device, torch.device):
48
  device = device.index
 
58
 
59
 
60
  def log_gpu_memory_usage(log, msg, device):
 
 
 
61
  usage, cache, misc = gpu_memory_usage_all(device)
62
  extras = []
63
  if cache > 0: