tonyzhao6 commited on
Commit
02b3a3b
1 Parent(s): 81aa6ea

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +54 -18
README.md CHANGED
@@ -50,30 +50,66 @@ Then, copy the snippet from the section that is relevant for your usecase.
50
 
51
  ```python
52
  # pip install accelerate
53
- from transformers import AutoTokenizer, AutoModelForCausalLM
54
  import torch
55
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
 
 
 
 
 
 
 
 
 
 
56
  model = AutoModelForCausalLM.from_pretrained(
57
- "google/gemma-2-2b-it",
58
- device_map="auto",
59
- torch_dtype=torch.bfloat16,
60
  )
61
- input_text = "Write me a poem about Machine Learning."
62
- input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
63
- outputs = model.generate(**input_ids, max_new_tokens=32)
64
- print(tokenizer.decode(outputs[0]))
65
- ```
66
 
67
- You can ensure the correct chat template is applied by using `tokenizer.apply_chat_template` as follows:
68
- ```python
69
- messages = [
70
- {"role": "user", "content": "Write me a poem about Machine Learning."},
71
- ]
72
- input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True).to("cuda")
73
- outputs = model.generate(**input_ids, max_new_tokens=256)
74
- print(tokenizer.decode(outputs[0]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  ```
76
 
 
 
77
  ### Evaluation
78
 
79
  XXX
 
50
 
51
  ```python
52
  # pip install accelerate
 
53
  import torch
54
+
55
+ from .gemma2_inference_hf import get_completions
56
+ from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
57
+
58
+ DTYPE = torch.bfloat16
59
+ MODEL_ID = "idinsight/gemma-2-2b-it-ud"
60
+
61
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, add_eos_token=False)
62
+ tokenizer.pad_token = tokenizer.eos_token
63
+ tokenizer.padding_side = "right"
64
+
65
  model = AutoModelForCausalLM.from_pretrained(
66
+ MODEL_ID, device_map="auto", return_dict=True, torch_dtype=DTYPE
 
 
67
  )
 
 
 
 
 
68
 
69
+ text_generation_params = {
70
+ "do_sample": True,
71
+ "eos_token_id": tokenizer.eos_token_id,
72
+ "max_new_tokens": 1024,
73
+ "num_return_sequences": 1,
74
+ "repetition_penalty": 1.1,
75
+ "temperature": 1e-6,
76
+ "top_p": 0.9,
77
+ }
78
+
79
+ response = get_completions(
80
+ model=model,
81
+ rules_list=[
82
+ "NOT URGENT",
83
+ "Bleeding from the vagina",
84
+ "Bad tummy pain",
85
+ "Bad headache that won’t go away",
86
+ "Bad headache that won’t go away",
87
+ "Changes to vision",
88
+ "Trouble breathing",
89
+ "Hot or very cold, and very weak",
90
+ "Fits or uncontrolled shaking",
91
+ "Baby moves less",
92
+ "Fluid from the vagina",
93
+ "Feeding problems",
94
+ "Fits or uncontrolled shaking",
95
+ "Fast, slow or difficult breathing",
96
+ "Too hot or cold",
97
+ "Baby’s colour changes",
98
+ "Vomiting and watery poo",
99
+ "Infected belly button",
100
+ "Swollen or infected eyes",
101
+ "Bulging or sunken soft spot",
102
+ ],
103
+ skip_special_tokens_during_decode=False,
104
+ text_generation_params=text_generation_params,
105
+ tokenizer=tokenizer,
106
+ user_message="If my newborn can't able to breathe what can i do",
107
+ )
108
+ print(f"{response = }")
109
  ```
110
 
111
+ The `gemma2_inferece_hf.py` module is provided for downloaded with the model files.
112
+
113
  ### Evaluation
114
 
115
  XXX