qwerrwe / src /axolotl /datasets.py
winglian's picture
update table for rwkv4 support, fix process count for dataset (#822)
cdc71f7 unverified
raw
history blame
No virus
7.13 kB
"""Module containing Dataset functionality"""
import logging
import os
from typing import List, Optional
import torch
from datasets import Dataset, IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = logging.getLogger("axolotl")
class TokenizedPromptDataset(Dataset):
"""
Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
dataset (dataset.Dataset): Dataset with text files.
"""
def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
process_count: Optional[int] = None,
**kwargs,
):
self.prompt_tokenizer = prompt_tokenizer
self.process_count = process_count
super().__init__(self.process(dataset).data, **kwargs)
def process(self, dataset):
features = dataset.features.keys()
num_proc = (
min(64, self.process_count)
if self.process_count
else min(64, os.cpu_count())
)
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
**map_kwargs,
)
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for processing the data.
dataset (dataset.Dataset): Dataset with text files.
seq_length (int): Length of token sequences to return.
"""
def __init__( # pylint: disable=super-init-not-called
self,
tokenizer,
datasets,
seq_length=2048,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: List[IterableDataset] = datasets
self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab())
if vocab_size <= torch.iinfo(torch.int16).max:
self.tokens_dtype = torch.int16
elif vocab_size <= torch.iinfo(torch.int32).max:
self.tokens_dtype = torch.int32
else:
self.tokens_dtype = torch.int64
def __iter__(self):
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
add_concat_token = False
if example:
example_len = len(example["input_ids"])
add_concat_token = example["input_ids"][-1] != self.concat_token_id
else:
example_len = 0
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
: self.seq_length
]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)