|
import os |
|
import json |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from avsbench_eval import calc_color_miou_fscore, scores_gzsl |
|
|
|
|
|
def get_v2_pallete(label_to_idx_path, num_cls=71): |
|
def _getpallete(num_cls = 71): |
|
"""build the unified color pallete for AVSBench-object (V1) and AVSBench-semantic (V2), |
|
71 is the total category number of V2 dataset, you should not change that""" |
|
n = num_cls |
|
pallete = [0] * (n * 3) |
|
for j in range(0, n): |
|
lab = j |
|
pallete[j * 3 + 0] = 0 |
|
pallete[j * 3 + 1] = 0 |
|
pallete[j * 3 + 2] = 0 |
|
i = 0 |
|
while (lab > 0): |
|
pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) |
|
pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) |
|
pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) |
|
i = i + 1 |
|
lab >>= 3 |
|
return pallete |
|
|
|
with open(label_to_idx_path, 'r') as fr: |
|
label_to_pallete_idx = json.load(fr) |
|
v2_pallete = _getpallete(num_cls) |
|
v2_pallete = np.array(v2_pallete).reshape(-1, 3) |
|
assert len(v2_pallete) == len(label_to_pallete_idx) |
|
return v2_pallete |
|
|
|
|
|
def color_mask_to_label(mask, v_pallete): |
|
mask_array = np.array(mask).astype('int32') |
|
semantic_map = [] |
|
for colour in v_pallete: |
|
equality = np.equal(mask_array, colour) |
|
class_map = np.all(equality, axis=-1) |
|
semantic_map.append(class_map) |
|
semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32) |
|
label = np.argmax(semantic_map, axis=-1) |
|
return label |
|
|
|
|
|
def load_color_mask_in_PIL_to_Tensor(path, v_pallete, mode='RGB'): |
|
color_mask_PIL = Image.open(path).convert(mode) |
|
color_label = color_mask_to_label(color_mask_PIL, v_pallete) |
|
color_label = torch.from_numpy(color_label) |
|
color_label = color_label.unsqueeze(0) |
|
return color_label |
|
|
|
|
|
def save_and_compute(predictions_all_labels, predictions_all_masks, save_base_path, vid_frames_name, gt_masks): |
|
v_pallete = get_v2_pallete("./datasets/AVSBench-semantic/label2idx.json") |
|
if not os.path.exists(save_base_path): |
|
os.makedirs(save_base_path, exist_ok=True) |
|
|
|
h = predictions_all_masks[0][0].shape[0] |
|
w = predictions_all_masks[0][0].shape[1] |
|
pred_masks = torch.zeros((len(predictions_all_masks), h, w)) |
|
|
|
predictions_all_masks_list = [] |
|
for iii in predictions_all_masks: |
|
predictions_all_masks_list_1 = [] |
|
for jjj in iii: |
|
predictions_all_masks_list_1.append(jjj) |
|
predictions_all_masks_list.append(predictions_all_masks_list_1) |
|
|
|
for i in range(len(predictions_all_masks_list)): |
|
if predictions_all_masks_list[i] == []: |
|
pred_masks[i] = torch.zeros((h, w)) |
|
else: |
|
for j in range(len(predictions_all_masks_list[i])): |
|
predictions_all_masks_list[i][j] = predictions_all_masks_list[i][j].int() * (predictions_all_labels[j] + 1) |
|
if len(predictions_all_masks_list[0]) < 2: |
|
pred_masks[i] = predictions_all_masks_list[i][0] |
|
else: |
|
result = predictions_all_masks_list[i][0] |
|
for mask in predictions_all_masks_list[i][1:]: |
|
mask_merge = mask * (result == 0) |
|
result += mask_merge |
|
pred_masks[i] = result |
|
|
|
pred_rgb_masks = np.zeros((pred_masks.shape + (3,)), np.uint8) |
|
for cls_idx in range(71): |
|
rgb = v_pallete[cls_idx] |
|
pred_rgb_masks[pred_masks == cls_idx] = rgb |
|
|
|
for idx in range(len(vid_frames_name)): |
|
frame_name = vid_frames_name[idx] |
|
frame_mask = pred_rgb_masks[idx] |
|
|
|
output_name = "%s.png" % (frame_name[0].split(".")[0]) |
|
im = Image.fromarray(frame_mask) |
|
im.save(os.path.join(save_base_path, output_name), format='PNG') |
|
|
|
up = torch.nn.Upsample(size=(224, 224), mode="nearest") |
|
pred_masks_up = up(pred_masks.unsqueeze(dim=1)) |
|
pred_masks_up = pred_masks_up.squeeze() |
|
gt_masks = torch.stack(gt_masks, dim=0).squeeze() |
|
_miou_pc, _fscore_pc, _cls_pc = calc_color_miou_fscore(pred_masks_up, gt_masks) |
|
|
|
return _miou_pc, _fscore_pc, _cls_pc, pred_masks_up |
|
|