File size: 4,322 Bytes
bcc611b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import json
import torch
import numpy as np
from PIL import Image

from avsbench_eval import calc_color_miou_fscore, scores_gzsl


def get_v2_pallete(label_to_idx_path, num_cls=71):
    def _getpallete(num_cls = 71):
        """build the unified color pallete for AVSBench-object (V1) and AVSBench-semantic (V2),
        71 is the total category number of V2 dataset, you should not change that"""
        n = num_cls
        pallete = [0] * (n * 3)
        for j in range(0, n):
            lab = j
            pallete[j * 3 + 0] = 0
            pallete[j * 3 + 1] = 0
            pallete[j * 3 + 2] = 0
            i = 0
            while (lab > 0):
                pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
                pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
                pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
                i = i + 1
                lab >>= 3
        return pallete # list, lenth is n_classes*3

    with open(label_to_idx_path, 'r') as fr:
        label_to_pallete_idx = json.load(fr)
    v2_pallete = _getpallete(num_cls) # list
    v2_pallete = np.array(v2_pallete).reshape(-1, 3)
    assert len(v2_pallete) == len(label_to_pallete_idx)
    return v2_pallete


def color_mask_to_label(mask, v_pallete):
    mask_array = np.array(mask).astype('int32')
    semantic_map = []
    for colour in v_pallete:
        equality = np.equal(mask_array, colour)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
    label = np.argmax(semantic_map, axis=-1)
    return label


def load_color_mask_in_PIL_to_Tensor(path, v_pallete, mode='RGB'):
    color_mask_PIL = Image.open(path).convert(mode)
    color_label = color_mask_to_label(color_mask_PIL, v_pallete)
    color_label = torch.from_numpy(color_label) # [H, W]
    color_label = color_label.unsqueeze(0)
    return color_label   # both [1, H, W]


def save_and_compute(predictions_all_labels, predictions_all_masks, save_base_path, vid_frames_name, gt_masks):
    v_pallete = get_v2_pallete("./datasets/AVSBench-semantic/label2idx.json")
    if not os.path.exists(save_base_path):
        os.makedirs(save_base_path, exist_ok=True)

    h = predictions_all_masks[0][0].shape[0]
    w = predictions_all_masks[0][0].shape[1]
    pred_masks = torch.zeros((len(predictions_all_masks), h, w))

    predictions_all_masks_list = []
    for iii in predictions_all_masks:
        predictions_all_masks_list_1 = []
        for jjj in iii:
            predictions_all_masks_list_1.append(jjj)
        predictions_all_masks_list.append(predictions_all_masks_list_1)

    for i in range(len(predictions_all_masks_list)):
        if predictions_all_masks_list[i] == []:
            pred_masks[i] = torch.zeros((h, w))
        else:
            for j in range(len(predictions_all_masks_list[i])):
                predictions_all_masks_list[i][j] = predictions_all_masks_list[i][j].int() * (predictions_all_labels[j] + 1)
            if len(predictions_all_masks_list[0]) < 2:
                pred_masks[i] = predictions_all_masks_list[i][0]
            else:
                result = predictions_all_masks_list[i][0]
                for mask in predictions_all_masks_list[i][1:]:
                    mask_merge = mask * (result == 0)
                    result += mask_merge
                pred_masks[i] = result

    pred_rgb_masks = np.zeros((pred_masks.shape + (3,)), np.uint8)  # [T, H, W, 3]
    for cls_idx in range(71):
        rgb = v_pallete[cls_idx]
        pred_rgb_masks[pred_masks == cls_idx] = rgb

    for idx in range(len(vid_frames_name)):
        frame_name = vid_frames_name[idx]
        frame_mask = pred_rgb_masks[idx]  # [5, 224, 224, 3]

        output_name = "%s.png" % (frame_name[0].split(".")[0])
        im = Image.fromarray(frame_mask)  # .convert('RGB')
        im.save(os.path.join(save_base_path, output_name), format='PNG')

    up = torch.nn.Upsample(size=(224, 224), mode="nearest")
    pred_masks_up = up(pred_masks.unsqueeze(dim=1))
    pred_masks_up = pred_masks_up.squeeze()
    gt_masks = torch.stack(gt_masks, dim=0).squeeze()
    _miou_pc, _fscore_pc, _cls_pc = calc_color_miou_fscore(pred_masks_up, gt_masks)

    return _miou_pc, _fscore_pc, _cls_pc, pred_masks_up