File size: 4,742 Bytes
1076673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cab0f57
1076673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b84d9de
1076673
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import torch

from transformers import glue_convert_examples_to_features as convert_examples_to_features
from transformers import InputExample

class MyClassifier():
    def __init__(self,model,tokenizer,label_list,output_mode,exit_type,exit_value,model_type='albert',max_length=128):
        self.model = model
        self.model.eval()
        self.model_type = model_type
        self.tokenizer = tokenizer
        self.label_list = label_list
        self.output_mode = output_mode
        self.max_length = max_length
        self.exit_type = exit_type
        self.exit_value = exit_value
        self.count = 0
        self.reset_status(mode='all',stats=True)
        if exit_type == 'patience':
            self.set_patience(patience=exit_value)
        elif exit_type == 'confi':
            self.set_threshold(confidence_threshold=exit_value)
        
    def tokenize(self,input_,idx):
        examples = []
        guid = f"dev_{idx}"
        if input_[1] == "<none>":
            text_a = input_[0]
            text_b = None
        else:
            text_a = input_[0]
            text_b = input_[1]
        # print(f'len: {len(input_)}\t text_a: {text_a}\t text_b:{text_b}')
        label = None
        examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        # print(examples)
        features = convert_examples_to_features(
            examples,
            self.tokenizer,
            label_list=self.label_list,
            max_length=self.max_length,
            output_mode=self.output_mode,
            )
        # print(features)
        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
        all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
        return all_input_ids,all_attention_mask,all_token_type_ids
    
    def set_threshold(self,confidence_threshold):
        if self.model_type == 'albert':
            self.model.albert.set_confi_threshold(confidence_threshold)
        elif self.model_type == 'bert':
            self.model.bert.set_confi_threshold(confidence_threshold)

    def set_patience(self,patience):
        if self.model_type == 'albert':
            self.model.albert.set_patience(patience)
        elif self.model_type == 'bert':
            self.model.bert.set_patience(patience)
        
    def set_exit_position(self,exit_pos):
        if self.model_type == 'albert':
            self.model.albert.set_exit_pos(exit_pos)

    def reset_status(self,mode,stats=False):
        if self.model_type == 'albert':
            self.model.albert.set_mode(mode)
            if stats:
                self.model.albert.reset_stats()
        elif self.model_type == 'bert':
            self.model.bert.set_mode(mode)
            if stats:
                self.model.bert.reset_stats()

    def get_exit_number(self):
        if self.model_type == 'albert':
            return self.model.albert.config.num_hidden_layers
        elif self.model_type == 'bert':
            return self.model.bert.config.num_hidden_layers

    def get_current_exit(self):
        if self.model_type == 'albert':
            return self.model.albert.current_exit_layer
        elif self.model_type == 'bert':
            return self.model.bert.current_exit_layer
    
    # TODO: 改一下预测算法得到预测结果
    def get_pred(self,input_):
        # print(self.get_prob(input_).argmax(axis=2).shape)
        return self.get_prob(input_).argmax(axis=2)

    def get_prob(self,input_):
        self.reset_status(mode=self.exit_type,stats=False) # set patience
        ret = []
        for sent in input_:
            self.count+=1
            batch = self.tokenize(sent,idx=self.count)
            inputs = {"input_ids": batch[0], "attention_mask": batch[1],"token_type_ids":batch[2]}
            outputs = self.model(**inputs)[0] # get all logits
            output_ = torch.softmax(outputs,dim=1)[0].detach().cpu().numpy()
            ret.append(output_)
        return np.array(ret)
        
    def get_prob_time(self,input_,exit_position):
        self.reset_status(mode='exact',stats=False) # set patience
        self.set_exit_position(exit_position)
        ret = []
        for sent in input_:
            self.count+=1
            batch = self.tokenize(sent,idx=self.count)
            inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids":batch[2]}
            outputs = self.model(**inputs)[0] # get all logits
            output_ = torch.softmax(outputs,dim=1)[0].detach().cpu().numpy()
            ret.append(output_)
        return np.array(ret)