ovavss / demo_video /avsbench_dataloader.py
ruohguo's picture
Upload 91 files
bcc611b verified
raw
history blame contribute delete
No virus
9.54 kB
import os
import csv
import glob
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from avsbench_utils import get_v2_pallete
# categories
categories = [{'supercategory': None, 'id': 1, 'name': 'accordion', 'frequency': 'n'},
{'supercategory': None, 'id': 2, 'name': 'airplane', 'frequency': 'f'},
{'supercategory': None, 'id': 3, 'name': 'axe', 'frequency': 'r'},
{'supercategory': None, 'id': 4, 'name': 'baby', 'frequency': 'f'},
{'supercategory': None, 'id': 5, 'name': 'bassoon', 'frequency': 'n'},
{'supercategory': None, 'id': 6, 'name': 'bell', 'frequency': 'f'},
{'supercategory': None, 'id': 7, 'name': 'bird', 'frequency': 'f'},
{'supercategory': None, 'id': 8, 'name': 'boat', 'frequency': 'f'},
{'supercategory': None, 'id': 9, 'name': 'boy', 'frequency': 'f'},
{'supercategory': None, 'id': 10, 'name': 'bus', 'frequency': 'f'},
{'supercategory': None, 'id': 11, 'name': 'car', 'frequency': 'f'},
{'supercategory': None, 'id': 12, 'name': 'cat', 'frequency': 'f'},
{'supercategory': None, 'id': 13, 'name': 'cello', 'frequency': 'n'},
{'supercategory': None, 'id': 14, 'name': 'clarinet', 'frequency': 'r'},
{'supercategory': None, 'id': 15, 'name': 'clipper', 'frequency': 'r'},
{'supercategory': None, 'id': 16, 'name': 'clock', 'frequency': 'f'},
{'supercategory': None, 'id': 17, 'name': 'dog', 'frequency': 'f'},
{'supercategory': None, 'id': 18, 'name': 'donkey', 'frequency': 'c'},
{'supercategory': None, 'id': 19, 'name': 'drum', 'frequency': 'c'},
{'supercategory': None, 'id': 20, 'name': 'duck', 'frequency': 'f'},
{'supercategory': None, 'id': 21, 'name': 'elephant', 'frequency': 'f'},
{'supercategory': None, 'id': 22, 'name': 'emergency-car', 'frequency': 'n'},
{'supercategory': None, 'id': 23, 'name': 'erhu', 'frequency': 'n'},
{'supercategory': None, 'id': 24, 'name': 'flute', 'frequency': 'n'},
{'supercategory': None, 'id': 25, 'name': 'frying-food', 'frequency': 'n'},
{'supercategory': None, 'id': 26, 'name': 'girl', 'frequency': 'f'},
{'supercategory': None, 'id': 27, 'name': 'goose', 'frequency': 'c'},
{'supercategory': None, 'id': 28, 'name': 'guitar', 'frequency': 'f'},
{'supercategory': None, 'id': 29, 'name': 'gun', 'frequency': 'c'},
{'supercategory': None, 'id': 30, 'name': 'guzheng', 'frequency': 'n'},
{'supercategory': None, 'id': 31, 'name': 'hair-dryer', 'frequency': 'f'},
{'supercategory': None, 'id': 32, 'name': 'handpan', 'frequency': 'n'},
{'supercategory': None, 'id': 33, 'name': 'harmonica', 'frequency': 'n'},
{'supercategory': None, 'id': 34, 'name': 'harp', 'frequency': 'n'},
{'supercategory': None, 'id': 35, 'name': 'helicopter', 'frequency': 'c'},
{'supercategory': None, 'id': 36, 'name': 'hen', 'frequency': 'c'},
{'supercategory': None, 'id': 37, 'name': 'horse', 'frequency': 'f'},
{'supercategory': None, 'id': 38, 'name': 'keyboard', 'frequency': 'f'},
{'supercategory': None, 'id': 39, 'name': 'leopard', 'frequency': 'n'},
{'supercategory': None, 'id': 40, 'name': 'lion', 'frequency': 'c'},
{'supercategory': None, 'id': 41, 'name': 'man', 'frequency': 'f'},
{'supercategory': None, 'id': 42, 'name': 'marimba', 'frequency': 'n'},
{'supercategory': None, 'id': 43, 'name': 'missile-rocket', 'frequency': 'c'},
{'supercategory': None, 'id': 44, 'name': 'motorcycle', 'frequency': 'f'},
{'supercategory': None, 'id': 45, 'name': 'mower', 'frequency': 'r'},
{'supercategory': None, 'id': 46, 'name': 'parrot', 'frequency': 'c'},
{'supercategory': None, 'id': 47, 'name': 'piano', 'frequency': 'f'},
{'supercategory': None, 'id': 48, 'name': 'pig', 'frequency': 'c'},
{'supercategory': None, 'id': 49, 'name': 'pipa', 'frequency': 'n'},
{'supercategory': None, 'id': 50, 'name': 'saw', 'frequency': 'r'},
{'supercategory': None, 'id': 51, 'name': 'saxophone', 'frequency': 'r'},
{'supercategory': None, 'id': 52, 'name': 'sheep', 'frequency': 'f'},
{'supercategory': None, 'id': 53, 'name': 'sitar', 'frequency': 'n'},
{'supercategory': None, 'id': 54, 'name': 'sorna', 'frequency': 'n'},
{'supercategory': None, 'id': 55, 'name': 'squirrel', 'frequency': 'c'},
{'supercategory': None, 'id': 56, 'name': 'tabla', 'frequency': 'n'},
{'supercategory': None, 'id': 57, 'name': 'tank', 'frequency': 'n'},
{'supercategory': None, 'id': 58, 'name': 'tiger', 'frequency': 'c'},
{'supercategory': None, 'id': 59, 'name': 'tractor', 'frequency': 'c'},
{'supercategory': None, 'id': 60, 'name': 'train', 'frequency': 'f'},
{'supercategory': None, 'id': 61, 'name': 'trombone', 'frequency': 'n'},
{'supercategory': None, 'id': 62, 'name': 'truck', 'frequency': 'f'},
{'supercategory': None, 'id': 63, 'name': 'trumpet', 'frequency': 'c'},
{'supercategory': None, 'id': 64, 'name': 'tuba', 'frequency': 'r'},
{'supercategory': None, 'id': 65, 'name': 'ukulele', 'frequency': 'n'},
{'supercategory': None, 'id': 66, 'name': 'utv', 'frequency': 'n'},
{'supercategory': None, 'id': 67, 'name': 'vacuum-cleaner', 'frequency': 'c'},
{'supercategory': None, 'id': 68, 'name': 'violin', 'frequency': 'r'},
{'supercategory': None, 'id': 69, 'name': 'wolf', 'frequency': 'r'},
{'supercategory': None, 'id': 70, 'name': 'woman', 'frequency': 'f'}]
def color_mask_to_label(mask, v_pallete):
mask_array = np.array(mask).astype('int32')
semantic_map = []
for colour in v_pallete:
equality = np.equal(mask_array, colour)
class_map = np.all(equality, axis=-1)
semantic_map.append(class_map)
semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
label = np.argmax(semantic_map, axis=-1)
return label
def load_color_mask_in_PIL_to_Tensor(path, v_pallete, size, mode='RGB'):
color_mask_PIL = Image.open(path).convert(mode)
color_mask_PIL = color_mask_PIL.resize(size, Image.NEAREST)
# obtain semantic label
color_label = color_mask_to_label(color_mask_PIL, v_pallete)
color_label = torch.from_numpy(color_label) # [H, W]
color_label = color_label.unsqueeze(0)
return color_label # both [1, H, W]
class AVSBenchTest(Dataset):
def __init__(self, test_img_dir="./datasets/AVSBench-openvoc/test/JPEGImages/",
test_mask_dir="./datasets/AVSBench-semantic/",
pallete_dir="./datasets/AVSBench-semantic/label2idx.json",
avsbench_csv="./datasets/AVSBench-semantic/metadata.csv"):
super(AVSBenchTest, self).__init__()
self.test_img_dir = test_img_dir
self.test_mask_dir = test_mask_dir
self.pallete_dir = pallete_dir
self.avsbench_csv = avsbench_csv
self.frame_num = 5
self.videos = [v for v in os.listdir(test_img_dir)]
print("{} videos are used for test.".format(len(self.videos)))
self.v2_pallete = get_v2_pallete(pallete_dir, num_cls=71)
def __getitem__(self, index):
video = self.videos[index]
with open(self.avsbench_csv, "r") as csvfile:
csvreader = csv.reader(csvfile)
for info in csvreader:
if info[1] == video:
video_class = info[4]
video_label = info[6]
video_class_id = [0]
for d in video_class.split("_"):
for cate in categories:
if d == cate["name"]:
video_class_id.append(cate["id"])
# img:
imgs_pth = []
imgs_names = []
img_pth = os.path.join(self.test_img_dir, video)
for img_name in sorted(os.listdir(img_pth)):
imgs_names.append(img_name)
imgs_pth.append(os.path.join(img_pth, img_name))
# gt:
annos = []
anno_pth = os.path.join(self.test_mask_dir, video_label, video, "labels_rgb")
for mask_image in sorted(glob.glob(anno_pth + "/*.png")):
color_label = np.array(load_color_mask_in_PIL_to_Tensor(mask_image, self.v2_pallete, (224, 224)))
# filter error annotations
color_label_mask = np.isin(color_label, video_class_id)
color_label[~color_label_mask] = 0
annos.append(color_label)
# audios:
audios = os.path.join(self.test_mask_dir, video_label, video, "audio.wav")
return imgs_pth, annos, video, imgs_names, audios
def __len__(self):
return len(self.videos)
if __name__=="__main__":
test_dataset = AVSBenchTest()
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
for n_iter, batch_data in enumerate(test_dataloader):
imgs_pth, annos, video_name, imgs_names, audios = batch_data
print(imgs_pth)