# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MNLI dataset.""" from megatron import print_rank_0 from tasks.data_utils import clean_text from .data import GLUEAbstractDataset import json from tasks.label_dict import get_label_dict LABELS = get_label_dict("WSC") class WSCDataset(GLUEAbstractDataset): def __init__(self, name, datapaths, tokenizer, max_seq_length, test_label="false"): self.test_label = test_label super().__init__('WSC', name, datapaths, tokenizer, max_seq_length) def process_samples_from_single_path(self, filename): """"Implement abstract method.""" print_rank_0(' > Processing {} ...'.format(filename)) samples = [] total = 0 first = True is_test = False with open(filename, 'r') as f: reader = f.readlines() lines = [] for line in reader: lines.append(json.loads(line.strip())) drop_cnt = 0 for index, row in enumerate(lines): if "id" not in row: row["id"] = index text_a = row['text'] text_a_list = list(text_a) target = row['target'] query = target['span1_text'] query_idx = target['span1_index'] pronoun = target['span2_text'] pronoun_idx = target['span2_index'] assert text_a[pronoun_idx: (pronoun_idx + len(pronoun))] == pronoun, "pronoun: {}".format(pronoun) assert text_a[query_idx: (query_idx + len(query))] == query, "query: {}".format(query) if pronoun_idx > query_idx: text_a_list.insert(query_idx, "_") text_a_list.insert(query_idx + len(query) + 1, "_") text_a_list.insert(pronoun_idx + 2, "[") text_a_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]") else: text_a_list.insert(pronoun_idx, "[") text_a_list.insert(pronoun_idx + len(pronoun) + 1, "]") text_a_list.insert(query_idx + 2, "_") text_a_list.insert(query_idx + len(query) + 2 + 1, "_") text_a = "".join(text_a_list) # text_b = "在这句话中,{}指代的是{}".format(pronoun, query) text_b = None if first: first = False if "label" not in row: is_test = True print_rank_0( ' reading {}, {} and {} columns and setting ' 'labels to {}'.format( row["id"], text_a, text_b, self.test_label)) else: is_test = False print_rank_0(' reading {} , {}, {}, and {} columns ' '...'.format( row["id"], text_a, text_b, row["label"].strip())) text_a = text_a text_b = text_b # text_b = None unique_id = int(row["id"]) if is_test: label = self.test_label else: label = row["label"].strip() assert len(text_a) > 0 # assert len(text_b) > 0 assert label in LABELS, "found label {} {} {}".format(label, row, type(label)) assert unique_id >= 0 sample = {'text_a': text_a, 'text_b': text_b, 'label': LABELS[label], 'uid': unique_id} total += 1 samples.append(sample) if total % 5000 == 0: print_rank_0(' > processed {} so far ...'.format(total)) print_rank_0(' >> processed {} samples.'.format(len(samples))) print_rank_0(' >> drop {} samples.'.format(drop_cnt)) return samples