winglian commited on
Commit
131afdb
1 Parent(s): 00dce35

add bf16 check (#587)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/config.py +9 -0
src/axolotl/utils/config.py CHANGED
@@ -4,6 +4,7 @@ import logging
4
  import os
5
 
6
  import torch
 
7
 
8
  from axolotl.utils.bench import log_gpu_memory_usage
9
  from axolotl.utils.models import load_model_config
@@ -89,6 +90,14 @@ def normalize_config(cfg):
89
 
90
 
91
  def validate_config(cfg):
 
 
 
 
 
 
 
 
92
  if cfg.max_packed_sequence_len and cfg.sample_packing:
93
  raise ValueError(
94
  "please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
 
4
  import os
5
 
6
  import torch
7
+ from transformers.utils import is_torch_bf16_gpu_available
8
 
9
  from axolotl.utils.bench import log_gpu_memory_usage
10
  from axolotl.utils.models import load_model_config
 
90
 
91
 
92
  def validate_config(cfg):
93
+ if is_torch_bf16_gpu_available():
94
+ if not cfg.bf16 and not cfg.bfloat16:
95
+ LOG.info("bf16 support detected, but not enabled for this configuration.")
96
+ else:
97
+ if cfg.bf16 or cfg.bfloat16:
98
+ raise ValueError(
99
+ "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
100
+ )
101
  if cfg.max_packed_sequence_len and cfg.sample_packing:
102
  raise ValueError(
103
  "please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"