Nanobit commited on
Commit
3cc67d2
1 Parent(s): 1bc1186

Feat: Add dataset loading from S3, GCS (#765)

Browse files

* Feat: Add dataset loading from S3, GCS

* chore: update docs

* chore: add more info on cloud loading

Files changed (3) hide show
  1. README.md +7 -1
  2. requirements.txt +6 -1
  3. src/axolotl/utils/data.py +97 -19
README.md CHANGED
@@ -426,6 +426,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
426
  - path: knowrohit07/know_sql
427
  type: context_qa.load_v2
428
  train_on_split: validation
 
 
 
 
 
 
429
  ```
430
 
431
  - loading
@@ -520,7 +526,7 @@ float16: true
520
 
521
  # A list of one or more datasets to finetune the model with
522
  datasets:
523
- # HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
524
  - path: vicgalle/alpaca-gpt4
525
  # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
526
  type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
 
426
  - path: knowrohit07/know_sql
427
  type: context_qa.load_v2
428
  train_on_split: validation
429
+
430
+ # loading from s3 or gcs
431
+ # s3 creds will be loaded from the system default and gcs only supports public access
432
+ dataset:
433
+ - path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
434
+ ...
435
  ```
436
 
437
  - loading
 
526
 
527
  # A list of one or more datasets to finetune the model with
528
  datasets:
529
+ # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
530
  - path: vicgalle/alpaca-gpt4
531
  # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
532
  type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
requirements.txt CHANGED
@@ -11,7 +11,7 @@ deepspeed
11
  addict
12
  fire
13
  PyYAML>=6.0
14
- datasets
15
  flash-attn>=2.3.0
16
  sentencepiece
17
  wandb
@@ -33,3 +33,8 @@ art
33
  fschat==0.2.29
34
  gradio
35
  tensorboard
 
 
 
 
 
 
11
  addict
12
  fire
13
  PyYAML>=6.0
14
+ datasets>=2.14.0
15
  flash-attn>=2.3.0
16
  sentencepiece
17
  wandb
 
33
  fschat==0.2.29
34
  gradio
35
  tensorboard
36
+
37
+ # remote filesystems
38
+ s3fs
39
+ gcsfs
40
+ # adlfs
src/axolotl/utils/data.py CHANGED
@@ -170,30 +170,74 @@ def load_tokenized_prepared_datasets(
170
  except (FileNotFoundError, ConnectionError):
171
  pass
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  # prefer local dataset, even if hub exists
174
  local_path = Path(config_dataset.path)
175
  if local_path.exists():
176
  if local_path.is_dir():
177
- # TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
178
- ds = load_dataset(
179
- config_dataset.path,
180
- name=config_dataset.name,
181
- data_files=config_dataset.data_files,
182
- streaming=False,
183
- split=None,
184
- )
185
  elif local_path.is_file():
186
- ds_type = "json"
187
- if config_dataset.ds_type:
188
- ds_type = config_dataset.ds_type
189
- elif ".parquet" in config_dataset.path:
190
- ds_type = "parquet"
191
- elif ".arrow" in config_dataset.path:
192
- ds_type = "arrow"
193
- elif ".csv" in config_dataset.path:
194
- ds_type = "csv"
195
- elif ".txt" in config_dataset.path:
196
- ds_type = "text"
197
  ds = load_dataset(
198
  ds_type,
199
  name=config_dataset.name,
@@ -213,6 +257,22 @@ def load_tokenized_prepared_datasets(
213
  data_files=config_dataset.data_files,
214
  token=use_auth_token,
215
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  else:
217
  if isinstance(config_dataset.data_files, str):
218
  fp = hf_hub_download(
@@ -304,6 +364,24 @@ def load_tokenized_prepared_datasets(
304
  return dataset, prompters
305
 
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  def load_prepare_datasets(
308
  tokenizer: PreTrainedTokenizerBase,
309
  cfg,
 
170
  except (FileNotFoundError, ConnectionError):
171
  pass
172
 
173
+ ds_from_cloud = False
174
+ storage_options = {}
175
+ remote_file_system = None
176
+ if config_dataset.path.startswith("s3://"):
177
+ try:
178
+ import aiobotocore.session # type: ignore
179
+ import s3fs # type: ignore
180
+ except ImportError as exc:
181
+ raise ImportError(
182
+ "s3:// paths require aiobotocore and s3fs to be installed"
183
+ ) from exc
184
+
185
+ # Takes credentials from ~/.aws/credentials for default profile
186
+ s3_session = aiobotocore.session.AioSession(profile="default")
187
+ storage_options = {"session": s3_session}
188
+ remote_file_system = s3fs.S3FileSystem(**storage_options)
189
+ elif config_dataset.path.startswith(
190
+ "gs://"
191
+ ) or config_dataset.path.startswith("gcs://"):
192
+ try:
193
+ import gcsfs # type: ignore
194
+ except ImportError as exc:
195
+ raise ImportError(
196
+ "gs:// or gcs:// paths require gcsfs to be installed"
197
+ ) from exc
198
+
199
+ # gcsfs will use default credentials from the environment else anon
200
+ # https://gcsfs.readthedocs.io/en/latest/#credentials
201
+ storage_options = {"token": None}
202
+ remote_file_system = gcsfs.GCSFileSystem(**storage_options)
203
+ # TODO: Figure out how to get auth creds passed
204
+ # elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
205
+ # try:
206
+ # import adlfs
207
+ # except ImportError as exc:
208
+ # raise ImportError(
209
+ # "adl:// or abfs:// paths require adlfs to be installed"
210
+ # ) from exc
211
+
212
+ # # Gen 1
213
+ # storage_options = {
214
+ # "tenant_id": TENANT_ID,
215
+ # "client_id": CLIENT_ID,
216
+ # "client_secret": CLIENT_SECRET,
217
+ # }
218
+ # # Gen 2
219
+ # storage_options = {
220
+ # "account_name": ACCOUNT_NAME,
221
+ # "account_key": ACCOUNT_KEY,
222
+ # }
223
+
224
+ # remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
225
+ try:
226
+ if remote_file_system and remote_file_system.exists(
227
+ config_dataset.path
228
+ ):
229
+ ds_from_cloud = True
230
+ except (FileNotFoundError, ConnectionError):
231
+ pass
232
+
233
  # prefer local dataset, even if hub exists
234
  local_path = Path(config_dataset.path)
235
  if local_path.exists():
236
  if local_path.is_dir():
237
+ ds = load_from_disk(config_dataset.path)
 
 
 
 
 
 
 
238
  elif local_path.is_file():
239
+ ds_type = get_ds_type(config_dataset)
240
+
 
 
 
 
 
 
 
 
 
241
  ds = load_dataset(
242
  ds_type,
243
  name=config_dataset.name,
 
257
  data_files=config_dataset.data_files,
258
  token=use_auth_token,
259
  )
260
+ elif ds_from_cloud and remote_file_system:
261
+ if remote_file_system.isdir(config_dataset.path):
262
+ ds = load_from_disk(
263
+ config_dataset.path,
264
+ storage_options=storage_options,
265
+ )
266
+ elif remote_file_system.isfile(config_dataset.path):
267
+ ds_type = get_ds_type(config_dataset)
268
+ ds = load_dataset(
269
+ ds_type,
270
+ name=config_dataset.name,
271
+ data_files=config_dataset.path,
272
+ streaming=False,
273
+ split=None,
274
+ storage_options=storage_options,
275
+ )
276
  else:
277
  if isinstance(config_dataset.data_files, str):
278
  fp = hf_hub_download(
 
364
  return dataset, prompters
365
 
366
 
367
+ def get_ds_type(config_dataset: DictDefault):
368
+ """
369
+ Get the dataset type from the path if it's not specified
370
+ """
371
+ ds_type = "json"
372
+ if config_dataset.ds_type:
373
+ ds_type = config_dataset.ds_type
374
+ elif ".parquet" in config_dataset.path:
375
+ ds_type = "parquet"
376
+ elif ".arrow" in config_dataset.path:
377
+ ds_type = "arrow"
378
+ elif ".csv" in config_dataset.path:
379
+ ds_type = "csv"
380
+ elif ".txt" in config_dataset.path:
381
+ ds_type = "text"
382
+ return ds_type
383
+
384
+
385
  def load_prepare_datasets(
386
  tokenizer: PreTrainedTokenizerBase,
387
  cfg,