sarahai's picture
Update app.py
4bc4ee9 verified
raw
history blame contribute delete
No virus
4.67 kB
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, T5ForConditionalGeneration, NllbTokenizer, T5Tokenizer
import easyocr
from PIL import Image
import numpy as np
# Load models and tokenizers
translation_model_name = 'sarahai/nllb-uzbek-cyrillic-to-russian'
translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name)
translation_tokenizer = NllbTokenizer.from_pretrained(translation_model_name)
summarization_model_name = 'sarahai/ruT5-base-summarizer'
summarization_model = T5ForConditionalGeneration.from_pretrained(summarization_model_name)
summarization_tokenizer = T5Tokenizer.from_pretrained(summarization_model_name)
def extract_text(image_path, lang='tjk'):
reader = easyocr.Reader([lang])
results = reader.readtext(np.array(image_path))
all_text = ''
confidences = []
for (bbox, text, prob) in results:
all_text += ' ' + text
confidences.append(prob)
final_confidence = sum(confidences) / len(confidences) if confidences else 0
return all_text.strip(), final_confidence
def split_into_chunks(text, tokenizer, max_length=150):
tokens = tokenizer.tokenize(text)
chunks = []
current_chunk = []
current_length = 0
for token in tokens:
current_chunk.append(token)
current_length += 1
if current_length >= max_length:
chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
current_chunk = []
current_length = 0
if current_chunk:
chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
return chunks
def translate(text, model, tokenizer, src_lang='uzb_Cyrl', tgt_lang='rus_Cyrl'):
tokenizer.src_lang = src_lang
tokenizer.tgt_lang = tgt_lang
chunks = split_into_chunks(text, tokenizer)
translated_chunks = []
for chunk in chunks:
inputs = tokenizer(chunk, return_tensors='pt', padding=True, truncation=True, max_length=128)
outputs = model.generate(inputs['input_ids'], forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang])
translated_chunks.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
return ' '.join(translated_chunks)
def summarize(text, model, tokenizer, max_length=250):
input_ids = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=2048, truncation=True)
summary_ids = model.generate(input_ids, max_length=max_length, length_penalty=2.0, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
# Custom CSS styles
st.markdown("""
<style>
.big-font {
font-size:30px !important;
font-weight: bold;
}
.small-font {
font-size:18px !important;
}
</style>
""", unsafe_allow_html=True)
# Sidebar
st.sidebar.markdown('## Навигация')
uploaded_file = st.sidebar.file_uploader("Загрузите изображение с узбекским текстом...", type=["jpg", "jpeg", "png"])
process_btn = False # Define button state here
if uploaded_file:
image = Image.open(uploaded_file)
st.sidebar.image(image, caption='Загруженное изображение', use_column_width=True)
process_btn = st.sidebar.button("Перевести и суммаризировать")
# Title and Description
st.markdown('<h1 class="big-font">Текстовая обработка изображений</h1>', unsafe_allow_html=True)
st.markdown('<div class="big-font">Перевод с узбекского на русский и суммаризация</div>', unsafe_allow_html=True)
if process_btn and uploaded_file:
st.write("Процесс извлечения текста...")
extracted_text, confidence = extract_text(image, 'tjk') # Adjust the language code if necessary
st.write("Извлеченный текст:")
st.text_area("Результат", extracted_text, height=150)
st.write(f"Точность распознавания: {confidence*100:.2f}%")
if extracted_text:
with st.spinner('Переводим...'):
translated_text = translate(extracted_text, translation_model, translation_tokenizer)
st.text_area("Переведенный текст (на русском):", value=translated_text, height=200)
with st.spinner('Суммаризируем...'):
summary_text = summarize(translated_text, summarization_model, summarization_tokenizer, max_length=250)
st.text_area("Суммаризация (на русском):", value=summary_text, height=100)
else:
st.error("Текст для перевода не найден.")