File size: 1,819 Bytes
c98c191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse
import json
import pathlib

# Prompt from stanford alpaca's training script
PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}


def main(args):
    data_path = pathlib.Path(args.data_path)
    with data_path.open() as f:
        data = json.load(f)

    prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
    sources = [
        prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
        for example in data
    ]
    targets = [example['output'] for example in data]

    new_data = []
    cnt = 1
    for s, t in zip(sources, targets):
        new_data.append({
            'id': str(cnt),
            'conversations': [
                {
                    'from': 'human',
                    'value': s,
                },
                {
                    'from': 'gpt',
                    'value': t,
                }
            ]
        })
        cnt += 1

    json.dump(new_data, open(args.output_path, 'w'), indent=2)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str, default='alpaca-data.json')
    parser.add_argument('--output_path', type=str, default='alpaca-data-conversation.json')
    args = parser.parse_args()
    main(args)