Chat template issue

#13
by mneilly - opened

I'm using the model with the default chat template defined in tokenizer_config.json. I've noticed a couple of things about the template. First, as per the model card the system role should be set with '<|im_start|>system' before the 'You are a function calling AI model' but the template doesn't appear to do so. Second, the output generated for the tool calls appears to be incorrect. Instead of generating the JSON format per the system message an incorrect format is used which subsequently causes the model to generate the wrong format as well.

Per the template instructions

For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{"name": <function-name>, "arguments": <args-dict>}
</tool_call>

the expected format of the get_current_time call is:

<tool_call>
{"name": "get_current_time", "arguments": {}}
</tool_call>

But after calling apply_chat_template() including the chat history the output is:

<tool_call>
{"name": "get_current_time"}, "arguments": {}
</tool_call>

This also leads to the model generating the incorrect format for newly generated calls.

This is the patch I have applied locally:

$ diff -p /tmp/original-chat-template.j2 /tmp/modified-chat-template.j2
*** /tmp/original-chat-template.j2      2024-08-31 17:36:53.776076955 -0700
--- /tmp/modified-chat-template.j2      2024-09-01 07:28:46.092541228 -0700
***************
*** 32,37 ****
--- 32,38 ----


  {{- bos_token }}
+ {{- '<|im_start|>system\n' }}
  {{- "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> " }}
  {%- for tool in tools %}
      {%- if tool.function is defined %}
***************
*** 96,107 ****
              {{- '{' }}
              {{- '"name": "' }}
              {{- tool_call.name }}
!             {{- '"}' }}
              {{- ', '}}
              {%- if tool_call.arguments is defined %}
                  {{- '"arguments": ' }}
                  {{- tool_call.arguments|tojson }}
              {%- endif %}
              {{- '\n</tool_call>' }}
      {%- endfor %}
          {{- '<|im_end|>\n' }}
--- 97,109 ----
              {{- '{' }}
              {{- '"name": "' }}
              {{- tool_call.name }}
!             {{- '"' }}
              {{- ', '}}
              {%- if tool_call.arguments is defined %}
                  {{- '"arguments": ' }}
                  {{- tool_call.arguments|tojson }}
              {%- endif %}
+             {{- '}' }}
              {{- '\n</tool_call>' }}
      {%- endfor %}
          {{- '<|im_end|>\n' }}

Here is a script that can reproduce what I'm seeing:

import argparse
import datetime
import json
import random
import re
import string
import subprocess
import urllib.parse

import torch
from transformers import LlamaForCausalLM, AutoTokenizer


def parse_args():
    parser = argparse.ArgumentParser(description="Chatbot with Llama and tool calls.")
    parser.add_argument("-t", "--template", type=str, help="The chat template to use")
    return parser.parse_args()


def extract_toolcall(response):
    """
    Extract a tool call contained between <tool_call></tool_call> tags from the LLM response.
    """
    pattern = r"(<\|im_start\|>)?assistant\n<tool_call>\n(.*?)\n</tool_call>"
    matches = re.findall(pattern, response)
    if matches:
        try:
            return json.loads(matches[-1][1])
        except json.JSONDecodeError as e:
            print(f"Error extracting tool call: {e}")
    return None


def get_current_time() -> str:
    """Returns the current time in HH:MM AM/PM format."""
    return datetime.datetime.now().strftime("%I:%M %p")


def get_weather(location: str) -> str:
    """
    Performs a google search asking about the weather in a specific location

    Args:
        location: the weather for this place will be retrieved
    """
    query = urllib.parse.quote(f"current weather in {location}")
    command = ["lynx", "-dump", f"https://www.google.com/search?q={query}"]
    result = subprocess.run(command, capture_output=True, text=True)
    return result.stdout[:200]


def load_chat_template(template_path):
    """Load the chat template from a file if provided."""
    if template_path:
        with open(template_path, "r") as f:
            return f.read()
    return None


def initialize_model(model_id):
    """Load the model and tokenizer."""
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = LlamaForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="flash_attention_2",
        quantization_config={
            'load_in_4bit': True,
            'bnb_4bit_compute_dtype': torch.bfloat16,
        }
    )
    return tokenizer, model


def get_response(conversation, chat_template=None):
    """Send input to the model and get a response."""
    inputs = tokenizer.apply_chat_template(
        conversation,
        tools=[get_current_time, get_weather],
        add_generation_prompt=True,
        return_dict=True,
        return_tensors="pt",
        chat_template=chat_template,
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, eos_token_id=tokenizer.eos_token_id)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


def handle_tool_call(tool_call):
    """Handle the execution of a tool call."""
    tool_name = tool_call.get("name")
    tool_id = ''.join(random.choices(string.ascii_letters + string.digits, k=9))
    tool_functions = {
        "get_current_time": get_current_time,
        "get_weather": lambda: get_weather(tool_call["arguments"]["location"])
    }

    if tool_name in tool_functions:
        result = tool_functions[tool_name]()
        conversation.append({"role": "tool", "name": tool_name, "content": result})
        return result, tool_id

    return None, tool_id


if __name__ == "__main__":
    args = parse_args()
    chat_template = load_chat_template(args.template)
    model_id = "NousResearch/Hermes-3-Llama-3.1-8B"
    tokenizer, model = initialize_model(model_id)

    prompts = ["What time is it?", "What is the weather in Sunnyvale, CA right now?", "Remind me what the time is."]
    conversation = []

    for prompt in prompts:
        conversation.append({"role": "user", "content": prompt})
        response = get_response(conversation, chat_template)
        print(response)

        tool_call = extract_toolcall(response)
        if tool_call:
            tool_result, tool_id = handle_tool_call(tool_call)
            conversation.append({
                "role": "assistant",
                "tool_calls": [{"type": "function", "id": tool_id, "function": tool_call}]
            })

            response = get_response(conversation, chat_template)
            print(response)
NousResearch org

the system role issue has been fixed but the tool arguments seems to work on my end with existing template:
https://huggingface.co/NousResearch/Hermes-3-Llama-3.1-8B/commit/5e3fd4817e213a9fd15a5bb3ec95788a6f04d4a5

<|im_start|>system
You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools>  </tools>Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}}
For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{"name": <function-name>, "arguments": <args-dict>}
</tool_call><|im_end|><|im_start|>user
Get the current weather forecast for Paris and get stock price of Hermes International<|im_end|>
<|im_start|>assistant
<tool_call>
{"name": "get_weather_forecast"}, "arguments": "{\"location\": \"Paris\"}"
</tool_call>
<tool_call>
{"name": "get_stock_price"}, "arguments": "{\"symbol\": \"Hermen.PA\"}"
</tool_call><|im_end|>
<|im_start|>tool
<tool_response>
{"name": "get_weather_forecast", "content": {"location": "Paris", "forecast": "Sunny", "temperature": "+84\u00b0F"}}
</tool_response>
<tool_response>
{"name": "get_stock_price", "content": "Unable to fetch stock price for Hermen.PA. Response: {'Global Quote': {}}"}
</tool_response><|im_end|>

you can test the new template here:
https://github.com/NousResearch/Hermes-Function-Calling/blob/main/template_tests/hermes_template_test.ipynb

interstellarninja changed discussion status to closed

Sign up or log in to comment