""" helper util to calculate dataset lengths """ import numpy as np def get_dataset_lengths(dataset): if "length" in dataset.data.column_names: lengths = np.array(dataset.data.column("length")) else: lengths = ( dataset.data.column("position_ids") .to_pandas() .apply(lambda x: x[-1] + 1) .values ) return lengths