ricdomolm's picture
Efficiently get the length of the tokenized docs (#1063)
81d3845 unverified
raw
history blame
No virus
400 Bytes
"""
helper util to calculate dataset lengths
"""
import numpy as np
def get_dataset_lengths(dataset):
if "length" in dataset.data.column_names:
lengths = np.array(dataset.data.column("length"))
else:
lengths = (
dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
)
return lengths