Update README.md
Browse files
README.md
CHANGED
@@ -50,30 +50,66 @@ Then, copy the snippet from the section that is relevant for your usecase.
|
|
50 |
|
51 |
```python
|
52 |
# pip install accelerate
|
53 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
54 |
import torch
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
model = AutoModelForCausalLM.from_pretrained(
|
57 |
-
"
|
58 |
-
device_map="auto",
|
59 |
-
torch_dtype=torch.bfloat16,
|
60 |
)
|
61 |
-
input_text = "Write me a poem about Machine Learning."
|
62 |
-
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
63 |
-
outputs = model.generate(**input_ids, max_new_tokens=32)
|
64 |
-
print(tokenizer.decode(outputs[0]))
|
65 |
-
```
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
```
|
76 |
|
|
|
|
|
77 |
### Evaluation
|
78 |
|
79 |
XXX
|
|
|
50 |
|
51 |
```python
|
52 |
# pip install accelerate
|
|
|
53 |
import torch
|
54 |
+
|
55 |
+
from .gemma2_inference_hf import get_completions
|
56 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
|
57 |
+
|
58 |
+
DTYPE = torch.bfloat16
|
59 |
+
MODEL_ID = "idinsight/gemma-2-2b-it-ud"
|
60 |
+
|
61 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, add_eos_token=False)
|
62 |
+
tokenizer.pad_token = tokenizer.eos_token
|
63 |
+
tokenizer.padding_side = "right"
|
64 |
+
|
65 |
model = AutoModelForCausalLM.from_pretrained(
|
66 |
+
MODEL_ID, device_map="auto", return_dict=True, torch_dtype=DTYPE
|
|
|
|
|
67 |
)
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
+
text_generation_params = {
|
70 |
+
"do_sample": True,
|
71 |
+
"eos_token_id": tokenizer.eos_token_id,
|
72 |
+
"max_new_tokens": 1024,
|
73 |
+
"num_return_sequences": 1,
|
74 |
+
"repetition_penalty": 1.1,
|
75 |
+
"temperature": 1e-6,
|
76 |
+
"top_p": 0.9,
|
77 |
+
}
|
78 |
+
|
79 |
+
response = get_completions(
|
80 |
+
model=model,
|
81 |
+
rules_list=[
|
82 |
+
"NOT URGENT",
|
83 |
+
"Bleeding from the vagina",
|
84 |
+
"Bad tummy pain",
|
85 |
+
"Bad headache that won’t go away",
|
86 |
+
"Bad headache that won’t go away",
|
87 |
+
"Changes to vision",
|
88 |
+
"Trouble breathing",
|
89 |
+
"Hot or very cold, and very weak",
|
90 |
+
"Fits or uncontrolled shaking",
|
91 |
+
"Baby moves less",
|
92 |
+
"Fluid from the vagina",
|
93 |
+
"Feeding problems",
|
94 |
+
"Fits or uncontrolled shaking",
|
95 |
+
"Fast, slow or difficult breathing",
|
96 |
+
"Too hot or cold",
|
97 |
+
"Baby’s colour changes",
|
98 |
+
"Vomiting and watery poo",
|
99 |
+
"Infected belly button",
|
100 |
+
"Swollen or infected eyes",
|
101 |
+
"Bulging or sunken soft spot",
|
102 |
+
],
|
103 |
+
skip_special_tokens_during_decode=False,
|
104 |
+
text_generation_params=text_generation_params,
|
105 |
+
tokenizer=tokenizer,
|
106 |
+
user_message="If my newborn can't able to breathe what can i do",
|
107 |
+
)
|
108 |
+
print(f"{response = }")
|
109 |
```
|
110 |
|
111 |
+
The `gemma2_inferece_hf.py` module is provided for downloaded with the model files.
|
112 |
+
|
113 |
### Evaluation
|
114 |
|
115 |
XXX
|