winglian commited on
Commit
7b5e762
1 Parent(s): 3f6017d

fix merge conflict failure, black format

Browse files
src/axolotl/utils/data.py CHANGED
@@ -112,14 +112,10 @@ def load_tokenized_prepared_datasets(
112
  raise Exception("unhandled dataset load")
113
  # support for using a subset of the data
114
  if d.shards:
115
- <<<<<<< Updated upstream
116
- ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
117
- =======
118
  if "train" in ds:
119
  ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
120
  else:
121
  ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0)
122
- >>>>>>> Stashed changes
123
  d_type = d.type
124
  d_type_split = d_type.split(":")
125
  d_base_type = d_type_split[0]
 
112
  raise Exception("unhandled dataset load")
113
  # support for using a subset of the data
114
  if d.shards:
 
 
 
115
  if "train" in ds:
116
  ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
117
  else:
118
  ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0)
 
119
  d_type = d.type
120
  d_type_split = d_type.split(":")
121
  d_base_type = d_type_split[0]
src/axolotl/utils/models.py CHANGED
@@ -247,8 +247,10 @@ def load_model(
247
  model.resize_token_embeddings(embeddings_len)
248
 
249
  if (
250
- (cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
251
- ) and not cfg.load_4bit and (load_in_8bit or cfg.load_in_4bit):
 
 
252
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
253
  model = prepare_model_for_int8_training(model)
254
 
@@ -297,7 +299,7 @@ def load_adapter(model, cfg, adapter):
297
 
298
  if adapter is None:
299
  return model, None
300
- if adapter in ["lora" , "qlora"]:
301
  return load_lora(model, cfg)
302
  if adapter == "llama-adapter":
303
  return load_llama_adapter(model, cfg)
 
247
  model.resize_token_embeddings(embeddings_len)
248
 
249
  if (
250
+ ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
251
+ and not cfg.load_4bit
252
+ and (load_in_8bit or cfg.load_in_4bit)
253
+ ):
254
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
255
  model = prepare_model_for_int8_training(model)
256
 
 
299
 
300
  if adapter is None:
301
  return model, None
302
+ if adapter in ["lora", "qlora"]:
303
  return load_lora(model, cfg)
304
  if adapter == "llama-adapter":
305
  return load_llama_adapter(model, cfg)