hanzla commited on
Commit
145ecb9
1 Parent(s): 1b3204d

model added

Browse files
Files changed (1) hide show
  1. app.py +31 -12
app.py CHANGED
@@ -6,28 +6,47 @@ import transformers
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
- model_name = "meta-llama/Meta-Llama-3-8B"
10
-
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16,device_map="auto")
13
 
 
 
 
 
 
 
14
 
15
  @spaces.GPU
16
- def yes_man(message, history):
17
- input_ids = tokenizer(message, return_tensors="pt").input_ids.to(model.device)
18
- output = model.generate(input_ids, max_length=512, num_return_sequences=1)
19
- detailed_prompt = tokenizer.decode(output[0], skip_special_tokens=True)
20
- return detailed_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  gr.ChatInterface(
23
- yes_man,
24
  chatbot=gr.Chatbot(height=300),
25
  textbox=gr.Textbox(placeholder="Enter message here", container=False, scale=7),
26
  title="LLAMA 3 8B Chat",
27
  description="Ask Yes Man any question",
28
  theme="soft",
29
- examples=["Hello", "Am I cool?", "Are tomatoes vegetables?"],
30
- cache_examples=True,
31
  retry_btn=None,
32
  undo_btn="Delete Previous",
33
  clear_btn="Clear",
 
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
+ model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
 
 
 
10
 
11
+ pipeline = transformers.pipeline(
12
+ "text-generation",
13
+ model=model_name,
14
+ model_kwargs={"torch_dtype": torch.bfloat16},
15
+ device="cuda",
16
+ )
17
 
18
  @spaces.GPU
19
+ def chat_function(message, history):
20
+ messages = [
21
+ {"role": "system", "content": "You are a helpful assistant!"},
22
+ {"role": "user", "content": message},
23
+ ]
24
+ prompt = pipeline.tokenizer.apply_chat_template(
25
+ messages,
26
+ tokenize=False,
27
+ add_generation_prompt=True
28
+ )
29
+ terminators = [
30
+ pipeline.tokenizer.eos_token_id,
31
+ pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
32
+ ]
33
+ outputs = pipeline(
34
+ prompt,
35
+ max_new_tokens=256,
36
+ eos_token_id=terminators,
37
+ do_sample=True,
38
+ temperature=0.6,
39
+ top_p=0.9,
40
+ )
41
+ return outputs[0]["generated_text"][len(prompt):]
42
 
43
  gr.ChatInterface(
44
+ chat_function,
45
  chatbot=gr.Chatbot(height=300),
46
  textbox=gr.Textbox(placeholder="Enter message here", container=False, scale=7),
47
  title="LLAMA 3 8B Chat",
48
  description="Ask Yes Man any question",
49
  theme="soft",
 
 
50
  retry_btn=None,
51
  undo_btn="Delete Previous",
52
  clear_btn="Clear",