|
import os |
|
import csv |
|
import glob |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
|
|
from avsbench_utils import get_v2_pallete |
|
|
|
|
|
|
|
categories = [{'supercategory': None, 'id': 1, 'name': 'accordion', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 2, 'name': 'airplane', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 3, 'name': 'axe', 'frequency': 'r'}, |
|
{'supercategory': None, 'id': 4, 'name': 'baby', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 5, 'name': 'bassoon', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 6, 'name': 'bell', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 7, 'name': 'bird', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 8, 'name': 'boat', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 9, 'name': 'boy', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 10, 'name': 'bus', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 11, 'name': 'car', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 12, 'name': 'cat', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 13, 'name': 'cello', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 14, 'name': 'clarinet', 'frequency': 'r'}, |
|
{'supercategory': None, 'id': 15, 'name': 'clipper', 'frequency': 'r'}, |
|
{'supercategory': None, 'id': 16, 'name': 'clock', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 17, 'name': 'dog', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 18, 'name': 'donkey', 'frequency': 'c'}, |
|
{'supercategory': None, 'id': 19, 'name': 'drum', 'frequency': 'c'}, |
|
{'supercategory': None, 'id': 20, 'name': 'duck', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 21, 'name': 'elephant', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 22, 'name': 'emergency-car', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 23, 'name': 'erhu', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 24, 'name': 'flute', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 25, 'name': 'frying-food', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 26, 'name': 'girl', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 27, 'name': 'goose', 'frequency': 'c'}, |
|
{'supercategory': None, 'id': 28, 'name': 'guitar', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 29, 'name': 'gun', 'frequency': 'c'}, |
|
{'supercategory': None, 'id': 30, 'name': 'guzheng', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 31, 'name': 'hair-dryer', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 32, 'name': 'handpan', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 33, 'name': 'harmonica', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 34, 'name': 'harp', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 35, 'name': 'helicopter', 'frequency': 'c'}, |
|
{'supercategory': None, 'id': 36, 'name': 'hen', 'frequency': 'c'}, |
|
{'supercategory': None, 'id': 37, 'name': 'horse', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 38, 'name': 'keyboard', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 39, 'name': 'leopard', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 40, 'name': 'lion', 'frequency': 'c'}, |
|
{'supercategory': None, 'id': 41, 'name': 'man', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 42, 'name': 'marimba', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 43, 'name': 'missile-rocket', 'frequency': 'c'}, |
|
{'supercategory': None, 'id': 44, 'name': 'motorcycle', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 45, 'name': 'mower', 'frequency': 'r'}, |
|
{'supercategory': None, 'id': 46, 'name': 'parrot', 'frequency': 'c'}, |
|
{'supercategory': None, 'id': 47, 'name': 'piano', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 48, 'name': 'pig', 'frequency': 'c'}, |
|
{'supercategory': None, 'id': 49, 'name': 'pipa', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 50, 'name': 'saw', 'frequency': 'r'}, |
|
{'supercategory': None, 'id': 51, 'name': 'saxophone', 'frequency': 'r'}, |
|
{'supercategory': None, 'id': 52, 'name': 'sheep', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 53, 'name': 'sitar', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 54, 'name': 'sorna', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 55, 'name': 'squirrel', 'frequency': 'c'}, |
|
{'supercategory': None, 'id': 56, 'name': 'tabla', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 57, 'name': 'tank', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 58, 'name': 'tiger', 'frequency': 'c'}, |
|
{'supercategory': None, 'id': 59, 'name': 'tractor', 'frequency': 'c'}, |
|
{'supercategory': None, 'id': 60, 'name': 'train', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 61, 'name': 'trombone', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 62, 'name': 'truck', 'frequency': 'f'}, |
|
{'supercategory': None, 'id': 63, 'name': 'trumpet', 'frequency': 'c'}, |
|
{'supercategory': None, 'id': 64, 'name': 'tuba', 'frequency': 'r'}, |
|
{'supercategory': None, 'id': 65, 'name': 'ukulele', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 66, 'name': 'utv', 'frequency': 'n'}, |
|
{'supercategory': None, 'id': 67, 'name': 'vacuum-cleaner', 'frequency': 'c'}, |
|
{'supercategory': None, 'id': 68, 'name': 'violin', 'frequency': 'r'}, |
|
{'supercategory': None, 'id': 69, 'name': 'wolf', 'frequency': 'r'}, |
|
{'supercategory': None, 'id': 70, 'name': 'woman', 'frequency': 'f'}] |
|
|
|
|
|
def color_mask_to_label(mask, v_pallete): |
|
mask_array = np.array(mask).astype('int32') |
|
semantic_map = [] |
|
for colour in v_pallete: |
|
equality = np.equal(mask_array, colour) |
|
class_map = np.all(equality, axis=-1) |
|
semantic_map.append(class_map) |
|
semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32) |
|
label = np.argmax(semantic_map, axis=-1) |
|
return label |
|
|
|
def load_color_mask_in_PIL_to_Tensor(path, v_pallete, size, mode='RGB'): |
|
color_mask_PIL = Image.open(path).convert(mode) |
|
color_mask_PIL = color_mask_PIL.resize(size, Image.NEAREST) |
|
|
|
color_label = color_mask_to_label(color_mask_PIL, v_pallete) |
|
color_label = torch.from_numpy(color_label) |
|
color_label = color_label.unsqueeze(0) |
|
return color_label |
|
|
|
|
|
class AVSBenchTest(Dataset): |
|
def __init__(self, test_img_dir="./datasets/AVSBench-openvoc/test/JPEGImages/", |
|
test_mask_dir="./datasets/AVSBench-semantic/", |
|
pallete_dir="./datasets/AVSBench-semantic/label2idx.json", |
|
avsbench_csv="./datasets/AVSBench-semantic/metadata.csv"): |
|
super(AVSBenchTest, self).__init__() |
|
|
|
self.test_img_dir = test_img_dir |
|
self.test_mask_dir = test_mask_dir |
|
self.pallete_dir = pallete_dir |
|
self.avsbench_csv = avsbench_csv |
|
|
|
self.frame_num = 5 |
|
self.videos = [v for v in os.listdir(test_img_dir)] |
|
print("{} videos are used for test.".format(len(self.videos))) |
|
self.v2_pallete = get_v2_pallete(pallete_dir, num_cls=71) |
|
|
|
def __getitem__(self, index): |
|
video = self.videos[index] |
|
|
|
with open(self.avsbench_csv, "r") as csvfile: |
|
csvreader = csv.reader(csvfile) |
|
for info in csvreader: |
|
if info[1] == video: |
|
video_class = info[4] |
|
video_label = info[6] |
|
|
|
video_class_id = [0] |
|
for d in video_class.split("_"): |
|
for cate in categories: |
|
if d == cate["name"]: |
|
video_class_id.append(cate["id"]) |
|
|
|
|
|
imgs_pth = [] |
|
imgs_names = [] |
|
img_pth = os.path.join(self.test_img_dir, video) |
|
for img_name in sorted(os.listdir(img_pth)): |
|
imgs_names.append(img_name) |
|
imgs_pth.append(os.path.join(img_pth, img_name)) |
|
|
|
|
|
annos = [] |
|
anno_pth = os.path.join(self.test_mask_dir, video_label, video, "labels_rgb") |
|
for mask_image in sorted(glob.glob(anno_pth + "/*.png")): |
|
color_label = np.array(load_color_mask_in_PIL_to_Tensor(mask_image, self.v2_pallete, (224, 224))) |
|
|
|
color_label_mask = np.isin(color_label, video_class_id) |
|
color_label[~color_label_mask] = 0 |
|
annos.append(color_label) |
|
|
|
|
|
audios = os.path.join(self.test_mask_dir, video_label, video, "audio.wav") |
|
|
|
return imgs_pth, annos, video, imgs_names, audios |
|
|
|
def __len__(self): |
|
return len(self.videos) |
|
|
|
|
|
if __name__=="__main__": |
|
|
|
test_dataset = AVSBenchTest() |
|
|
|
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) |
|
|
|
for n_iter, batch_data in enumerate(test_dataloader): |
|
|
|
imgs_pth, annos, video_name, imgs_names, audios = batch_data |
|
|
|
print(imgs_pth) |
|
|