heegyu commited on
Commit
250f909
β€’
1 Parent(s): d9a8850

simple demo

Browse files
Files changed (2) hide show
  1. app.py +63 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_chat import message
3
+
4
+ @st.cache_resource()
5
+ def get_pipe():
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ model_name = "heegyu/ajoublue-gpt2-medium-dialog"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ return model, tokenizer
11
+
12
+ def get_response(tokenizer, model, history, max_context: int = 7, bot_id: str = '1'):
13
+ context = []
14
+ for i, text in enumerate(history):
15
+ context.append(f"{i % 2} : {text}</s>")
16
+
17
+ if len(context) > max_context:
18
+ context = context[-max_context:]
19
+ context = "".join(context) + f"{bot_id} : "
20
+ inputs = tokenizer(context, return_tensors="pt")
21
+
22
+ generation_args = dict(
23
+ max_new_tokens=64,
24
+ min_length=inputs["input_ids"].shape[1] + 5,
25
+ eos_token_id=2,
26
+ do_sample=True,
27
+ top_p=0.6,
28
+ temperature=0.8,
29
+ repetition_penalty=1.5,
30
+ early_stopping=True
31
+ )
32
+
33
+ outputs = model.generate(**inputs, **generation_args)
34
+ response = tokenizer.decode(outputs[0])
35
+ print(context)
36
+ print(response)
37
+ response = response[len(context):].replace("</s>", "")
38
+
39
+ return response
40
+
41
+ st.title("ν•œκ΅­μ–΄ λŒ€ν™” λͺ¨λΈ demo")
42
+
43
+ with st.spinner("loading model..."):
44
+ model, tokenizer = get_pipe()
45
+
46
+ if 'message_history' not in st.session_state:
47
+ st.session_state.message_history = []
48
+ history = st.session_state.message_history
49
+
50
+ # print(st.session_state.message_history)
51
+ for i, message_ in enumerate(st.session_state.message_history):
52
+ message(message_,is_user=i % 2 == 0, key=i) # display all the previous message
53
+
54
+ # placeholder = st.empty() # placeholder for latest message
55
+ input_ = st.text_input("아무 λ§μ΄λ‚˜ ν•΄λ³΄μ„Έμš”", value="")
56
+
57
+ if input_ is not None and len(input_) > 0:
58
+ if len(history) <= 1 or history[-2] != input_:
59
+ with st.spinner("λŒ€λ‹΅μ„ μƒμ„±μ€‘μž…λ‹ˆλ‹€..."):
60
+ st.session_state.message_history.append(input_)
61
+ response = get_response(tokenizer, model, history)
62
+ st.session_state.message_history.append(response)
63
+ st.experimental_rerun()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ streamlit_chat
3
+ torch