import streamlit as st from streamlit_chat import message @st.cache(allow_output_mutation=True) def get_pipe(): from transformers import AutoTokenizer, AutoModelForCausalLM model_name = "heegyu/ajoublue-gpt2-medium-dialog" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) return model, tokenizer def get_response(tokenizer, model, history, max_context: int = 7, bot_id: str = '1'): # print("history:", history) context = [] for i, text in enumerate(history): context.append(f"{i % 2}: {text}") if len(context) > max_context: context = context[-max_context:] context = "".join(context) + f"{bot_id}: " inputs = tokenizer(context, return_tensors="pt") generation_args = dict( max_new_tokens=128, min_length=inputs["input_ids"].shape[1] + 5, # no_repeat_ngram_size=4, eos_token_id=2, do_sample=True, top_p=0.95, temperature=1.35, # repetition_penalty=1.0, early_stopping=True ) outputs = model.generate(**inputs, **generation_args) response = tokenizer.decode(outputs[0], skip_special_tokens=False) print("Context:", tokenizer.decode(inputs["input_ids"][0])) print("Response:", response) response = response[len(context):].replace("", "").replace("\n", "") response = response.split("")[0] # print("Response:", response) return response st.title("ajoublue-gpt2-medium 한국어 대화 모델 demo") with st.spinner("loading model..."): model, tokenizer = get_pipe() if 'message_history' not in st.session_state: st.session_state.message_history = [] history = st.session_state.message_history # print(st.session_state.message_history) for i, message_ in enumerate(st.session_state.message_history): message(message_,is_user=i % 2 == 0, key=i) # display all the previous message # placeholder = st.empty() # placeholder for latest message input_ = st.text_input("아무 말이나 해보세요", value="") if input_ is not None and len(input_) > 0: if len(history) <= 1 or history[-2] != input_: with st.spinner("대답을 생성중입니다..."): st.session_state.message_history.append(input_) response = get_response(tokenizer, model, history) st.session_state.message_history.append(response) st.experimental_rerun()