|
import os |
|
import argparse |
|
import collections |
|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
def process_files(args): |
|
all_predictions = collections.OrderedDict() |
|
all_labels = collections.OrderedDict() |
|
all_uid = collections.OrderedDict() |
|
for path in args.paths: |
|
path = os.path.join(path, args.prediction_name) |
|
try: |
|
data = torch.load(path) |
|
for dataset in data: |
|
name, d = dataset |
|
predictions, labels, uid = d |
|
if name not in all_predictions: |
|
all_predictions[name] = np.array(predictions) |
|
if args.labels is None: |
|
args.labels = [i for i in range(all_predictions[name].shape[1])] |
|
if args.eval: |
|
all_labels[name] = np.array(labels) |
|
all_uid[name] = np.array(uid) |
|
else: |
|
all_predictions[name] += np.array(predictions) |
|
assert np.allclose(all_uid[name], np.array(uid)) |
|
except Exception as e: |
|
print(e) |
|
continue |
|
return all_predictions, all_labels, all_uid |
|
|
|
|
|
def get_threshold(all_predictions, all_labels, one_threshold=False): |
|
if one_threshold: |
|
all_predictons = {'combined': np.concatenate(list(all_predictions.values()))} |
|
all_labels = {'combined': np.concatenate(list(all_predictions.labels()))} |
|
out_thresh = [] |
|
for dataset in all_predictions: |
|
preds = all_predictions[dataset] |
|
labels = all_labels[dataset] |
|
out_thresh.append(calc_threshold(preds, labels)) |
|
return out_thresh |
|
|
|
|
|
def calc_threshold(p, l): |
|
trials = [(i) * (1. / 100.) for i in range(100)] |
|
best_acc = float('-inf') |
|
best_thresh = 0 |
|
for t in trials: |
|
acc = ((apply_threshold(p, t).argmax(-1) == l).astype(float)).mean() |
|
if acc > best_acc: |
|
best_acc = acc |
|
best_thresh = t |
|
return best_thresh |
|
|
|
|
|
def apply_threshold(preds, t): |
|
assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0]))) |
|
prob = preds[:, -1] |
|
thresholded = (prob >= t).astype(int) |
|
preds = np.zeros_like(preds) |
|
preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1 |
|
return preds |
|
|
|
|
|
def threshold_predictions(all_predictions, threshold): |
|
if len(threshold) != len(all_predictions): |
|
threshold = [threshold[-1]] * (len(all_predictions) - len(threshold)) |
|
for i, dataset in enumerate(all_predictions): |
|
thresh = threshold[i] |
|
preds = all_predictions[dataset] |
|
all_predictions[dataset] = apply_threshold(preds, thresh) |
|
return all_predictions |
|
|
|
|
|
def postprocess_predictions(all_predictions, all_labels, args): |
|
for d in all_predictions: |
|
all_predictions[d] = all_predictions[d] / len(args.paths) |
|
|
|
if args.calc_threshold: |
|
args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold) |
|
print('threshold', args.threshold) |
|
|
|
if args.threshold is not None: |
|
all_predictions = threshold_predictions(all_predictions, args.threshold) |
|
|
|
return all_predictions, all_labels |
|
|
|
|
|
def write_predictions(all_predictions, all_labels, all_uid, args): |
|
all_correct = 0 |
|
count = 0 |
|
for dataset in all_predictions: |
|
preds = all_predictions[dataset] |
|
preds = np.argmax(preds, -1) |
|
if args.eval: |
|
correct = (preds == all_labels[dataset]).sum() |
|
num = len(all_labels[dataset]) |
|
accuracy = correct / num |
|
count += num |
|
all_correct += correct |
|
accuracy = (preds == all_labels[dataset]).mean() |
|
print(accuracy) |
|
if not os.path.exists(os.path.join(args.outdir, dataset)): |
|
os.makedirs(os.path.join(args.outdir, dataset)) |
|
outpath = os.path.join( |
|
args.outdir, dataset, os.path.splitext( |
|
args.prediction_name)[0] + '.tsv') |
|
with open(outpath, 'w') as f: |
|
f.write('id\tlabel\n') |
|
f.write('\n'.join(str(uid) + '\t' + str(args.labels[p]) |
|
for uid, p in zip(all_uid[dataset], preds.tolist()))) |
|
if args.eval: |
|
print(all_correct / count) |
|
|
|
|
|
def ensemble_predictions(args): |
|
all_predictions, all_labels, all_uid = process_files(args) |
|
all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args) |
|
write_predictions(all_predictions, all_labels, all_uid, args) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--paths', required=True, nargs='+', |
|
help='paths to checkpoint directories used in ensemble') |
|
parser.add_argument('--eval', action='store_true', |
|
help='compute accuracy metrics against labels (dev set)') |
|
parser.add_argument('--outdir', |
|
help='directory to place ensembled predictions in') |
|
parser.add_argument('--prediction-name', default='test_predictions.pt', |
|
help='name of predictions in checkpoint directories') |
|
parser.add_argument('--calc-threshold', action='store_true', |
|
help='calculate threshold classification') |
|
parser.add_argument('--one-threshold', action='store_true', |
|
help='use on threshold for all subdatasets') |
|
parser.add_argument('--threshold', nargs='+', default=None, type=float, |
|
help='user supplied threshold for classification') |
|
parser.add_argument('--labels', nargs='+', default=None, |
|
help='whitespace separated list of label names') |
|
args = parser.parse_args() |
|
ensemble_predictions(args) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|