File size: 5,007 Bytes
cb9e677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python3
import argparse
import json
import os
import random
import string


def reformat_jsonl(input_file):  # noqa: C901
    output_file = os.path.splitext(input_file)[0] + "_reformatted.jsonl"
    skipped_samples = []

    with open(input_file, "r") as infile, open(output_file, "w") as outfile:
        for i, line in enumerate(infile):
            reformat_data = True
            data = json.loads(line)

            # Extract function description
            try:
                function_desc = json.loads(data["function_description"])
            except json.decoder.JSONDecodeError:
                function_desc = (
                    data["function_description"].replace("\n", "").replace("}{", "},{").replace("\\t", "")
                )
                function_desc = "[{" + function_desc[1:-1] + "}]"
                function_desc = json.loads(function_desc)

            function_desc = function_desc if isinstance(function_desc, list) else [function_desc]

            # Reformat tools section
            if len(function_desc) == 1 and function_desc[0] == {}:
                tools = None
            else:
                tools = []
                for f in function_desc:
                    if f["parameters"] is None:
                        f["parameters"] = {}
                    tools.append({"type": "function", "function": f})

            messages = []

            # Process conversations
            for idx, msg in enumerate(data["conversations"]):
                role = msg["from"]
                content = msg["value"]

                if role == "system":
                    messages.append(
                        {"role": "system", "content": content.split(" -")[0]}
                    )
                elif role == "human":
                    messages.append({"role": "user", "content": content})
                elif role == "function-call":
                    try:
                        function_call = json.loads(content)
                    except json.decoder.JSONDecodeError:
                        content = content.replace("'", "").replace("\\", "'")
                        try:
                            function_call = json.loads(content)
                        except:  # noqa: E722
                            skipped_samples.append(str(i))
                            reformat_data = False
                            break

                    if not isinstance(function_call, list):
                        function_calls = [function_call]
                    else:
                        function_calls = function_call

                    tool_calls = []
                    for function_call in function_calls:
                        assert not isinstance(function_call, list)
                        tool_call_id = "".join(
                            random.choices(string.ascii_letters + string.digits, k=9)
                        )

                        if "arguments" in function_call and not isinstance(function_call["arguments"], str):
                            function_call["arguments"] = str(function_call["arguments"])
                        elif "arguments" not in function_call:
                            function_call["arguments"] = ""

                        tool_calls.append({"id": tool_call_id, "type": "function", "function": function_call})

                    messages.append(
                        {
                            "role": "assistant",
                            "tool_calls": tool_calls
                        }
                    )
                elif role == "function-response":
                    if "tool_calls" not in messages[-1]:
                        skipped_samples.append(str(i))
                        reformat_data = False
                        break

                    assert len(messages[-1]["tool_calls"]) == 1
                    tool_call_id = messages[-1]["tool_calls"][0]["id"]
                    messages.append(
                        {
                            "role": "tool",
                            "content": content,
                            "tool_call_id": tool_call_id,
                        }
                    )
                elif role == "gpt":
                    messages.append({"role": "assistant", "content": content})

            output_data = {"messages": messages}

            if tools is not None:
                output_data["tools"] = tools

            if reformat_data:
                outfile.write(json.dumps(output_data) + "\n")

    os.rename(output_file, input_file)
    print(
        f"Skipped {len(skipped_samples)} samples ({len(skipped_samples) / i:.2%}). The following samples are incorrectly formated: \n\n {', '.join(skipped_samples)}"
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Reformat a JSONL file.")
    parser.add_argument("file", type=str, help="The input JSONL file")

    args = parser.parse_args()
    reformat_jsonl(args.file)