File size: 4,322 Bytes
bcc611b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import os
import json
import torch
import numpy as np
from PIL import Image
from avsbench_eval import calc_color_miou_fscore, scores_gzsl
def get_v2_pallete(label_to_idx_path, num_cls=71):
def _getpallete(num_cls = 71):
"""build the unified color pallete for AVSBench-object (V1) and AVSBench-semantic (V2),
71 is the total category number of V2 dataset, you should not change that"""
n = num_cls
pallete = [0] * (n * 3)
for j in range(0, n):
lab = j
pallete[j * 3 + 0] = 0
pallete[j * 3 + 1] = 0
pallete[j * 3 + 2] = 0
i = 0
while (lab > 0):
pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
i = i + 1
lab >>= 3
return pallete # list, lenth is n_classes*3
with open(label_to_idx_path, 'r') as fr:
label_to_pallete_idx = json.load(fr)
v2_pallete = _getpallete(num_cls) # list
v2_pallete = np.array(v2_pallete).reshape(-1, 3)
assert len(v2_pallete) == len(label_to_pallete_idx)
return v2_pallete
def color_mask_to_label(mask, v_pallete):
mask_array = np.array(mask).astype('int32')
semantic_map = []
for colour in v_pallete:
equality = np.equal(mask_array, colour)
class_map = np.all(equality, axis=-1)
semantic_map.append(class_map)
semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
label = np.argmax(semantic_map, axis=-1)
return label
def load_color_mask_in_PIL_to_Tensor(path, v_pallete, mode='RGB'):
color_mask_PIL = Image.open(path).convert(mode)
color_label = color_mask_to_label(color_mask_PIL, v_pallete)
color_label = torch.from_numpy(color_label) # [H, W]
color_label = color_label.unsqueeze(0)
return color_label # both [1, H, W]
def save_and_compute(predictions_all_labels, predictions_all_masks, save_base_path, vid_frames_name, gt_masks):
v_pallete = get_v2_pallete("./datasets/AVSBench-semantic/label2idx.json")
if not os.path.exists(save_base_path):
os.makedirs(save_base_path, exist_ok=True)
h = predictions_all_masks[0][0].shape[0]
w = predictions_all_masks[0][0].shape[1]
pred_masks = torch.zeros((len(predictions_all_masks), h, w))
predictions_all_masks_list = []
for iii in predictions_all_masks:
predictions_all_masks_list_1 = []
for jjj in iii:
predictions_all_masks_list_1.append(jjj)
predictions_all_masks_list.append(predictions_all_masks_list_1)
for i in range(len(predictions_all_masks_list)):
if predictions_all_masks_list[i] == []:
pred_masks[i] = torch.zeros((h, w))
else:
for j in range(len(predictions_all_masks_list[i])):
predictions_all_masks_list[i][j] = predictions_all_masks_list[i][j].int() * (predictions_all_labels[j] + 1)
if len(predictions_all_masks_list[0]) < 2:
pred_masks[i] = predictions_all_masks_list[i][0]
else:
result = predictions_all_masks_list[i][0]
for mask in predictions_all_masks_list[i][1:]:
mask_merge = mask * (result == 0)
result += mask_merge
pred_masks[i] = result
pred_rgb_masks = np.zeros((pred_masks.shape + (3,)), np.uint8) # [T, H, W, 3]
for cls_idx in range(71):
rgb = v_pallete[cls_idx]
pred_rgb_masks[pred_masks == cls_idx] = rgb
for idx in range(len(vid_frames_name)):
frame_name = vid_frames_name[idx]
frame_mask = pred_rgb_masks[idx] # [5, 224, 224, 3]
output_name = "%s.png" % (frame_name[0].split(".")[0])
im = Image.fromarray(frame_mask) # .convert('RGB')
im.save(os.path.join(save_base_path, output_name), format='PNG')
up = torch.nn.Upsample(size=(224, 224), mode="nearest")
pred_masks_up = up(pred_masks.unsqueeze(dim=1))
pred_masks_up = pred_masks_up.squeeze()
gt_masks = torch.stack(gt_masks, dim=0).squeeze()
_miou_pc, _fscore_pc, _cls_pc = calc_color_miou_fscore(pred_masks_up, gt_masks)
return _miou_pc, _fscore_pc, _cls_pc, pred_masks_up
|