ovavss / demo_video /avsbench_utils.py
ruohguo's picture
Upload 91 files
bcc611b verified
raw
history blame contribute delete
No virus
4.32 kB
import os
import json
import torch
import numpy as np
from PIL import Image
from avsbench_eval import calc_color_miou_fscore, scores_gzsl
def get_v2_pallete(label_to_idx_path, num_cls=71):
def _getpallete(num_cls = 71):
"""build the unified color pallete for AVSBench-object (V1) and AVSBench-semantic (V2),
71 is the total category number of V2 dataset, you should not change that"""
n = num_cls
pallete = [0] * (n * 3)
for j in range(0, n):
lab = j
pallete[j * 3 + 0] = 0
pallete[j * 3 + 1] = 0
pallete[j * 3 + 2] = 0
i = 0
while (lab > 0):
pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
i = i + 1
lab >>= 3
return pallete # list, lenth is n_classes*3
with open(label_to_idx_path, 'r') as fr:
label_to_pallete_idx = json.load(fr)
v2_pallete = _getpallete(num_cls) # list
v2_pallete = np.array(v2_pallete).reshape(-1, 3)
assert len(v2_pallete) == len(label_to_pallete_idx)
return v2_pallete
def color_mask_to_label(mask, v_pallete):
mask_array = np.array(mask).astype('int32')
semantic_map = []
for colour in v_pallete:
equality = np.equal(mask_array, colour)
class_map = np.all(equality, axis=-1)
semantic_map.append(class_map)
semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
label = np.argmax(semantic_map, axis=-1)
return label
def load_color_mask_in_PIL_to_Tensor(path, v_pallete, mode='RGB'):
color_mask_PIL = Image.open(path).convert(mode)
color_label = color_mask_to_label(color_mask_PIL, v_pallete)
color_label = torch.from_numpy(color_label) # [H, W]
color_label = color_label.unsqueeze(0)
return color_label # both [1, H, W]
def save_and_compute(predictions_all_labels, predictions_all_masks, save_base_path, vid_frames_name, gt_masks):
v_pallete = get_v2_pallete("./datasets/AVSBench-semantic/label2idx.json")
if not os.path.exists(save_base_path):
os.makedirs(save_base_path, exist_ok=True)
h = predictions_all_masks[0][0].shape[0]
w = predictions_all_masks[0][0].shape[1]
pred_masks = torch.zeros((len(predictions_all_masks), h, w))
predictions_all_masks_list = []
for iii in predictions_all_masks:
predictions_all_masks_list_1 = []
for jjj in iii:
predictions_all_masks_list_1.append(jjj)
predictions_all_masks_list.append(predictions_all_masks_list_1)
for i in range(len(predictions_all_masks_list)):
if predictions_all_masks_list[i] == []:
pred_masks[i] = torch.zeros((h, w))
else:
for j in range(len(predictions_all_masks_list[i])):
predictions_all_masks_list[i][j] = predictions_all_masks_list[i][j].int() * (predictions_all_labels[j] + 1)
if len(predictions_all_masks_list[0]) < 2:
pred_masks[i] = predictions_all_masks_list[i][0]
else:
result = predictions_all_masks_list[i][0]
for mask in predictions_all_masks_list[i][1:]:
mask_merge = mask * (result == 0)
result += mask_merge
pred_masks[i] = result
pred_rgb_masks = np.zeros((pred_masks.shape + (3,)), np.uint8) # [T, H, W, 3]
for cls_idx in range(71):
rgb = v_pallete[cls_idx]
pred_rgb_masks[pred_masks == cls_idx] = rgb
for idx in range(len(vid_frames_name)):
frame_name = vid_frames_name[idx]
frame_mask = pred_rgb_masks[idx] # [5, 224, 224, 3]
output_name = "%s.png" % (frame_name[0].split(".")[0])
im = Image.fromarray(frame_mask) # .convert('RGB')
im.save(os.path.join(save_base_path, output_name), format='PNG')
up = torch.nn.Upsample(size=(224, 224), mode="nearest")
pred_masks_up = up(pred_masks.unsqueeze(dim=1))
pred_masks_up = pred_masks_up.squeeze()
gt_masks = torch.stack(gt_masks, dim=0).squeeze()
_miou_pc, _fscore_pc, _cls_pc = calc_color_miou_fscore(pred_masks_up, gt_masks)
return _miou_pc, _fscore_pc, _cls_pc, pred_masks_up