Redaction / main.py
rahul-appu's picture
Update
53d7ebc
raw
history blame contribute delete
No virus
1.21 kB
import os
os.environ["HF_HOME"] = "/.cache"
import re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_dir = 'edithram23/Redaction'
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
def mask_generation(text):
import re
inputs = ["Mask Generation: " + text]
inputs = tokenizer(inputs, max_length=500, truncation=True, return_tensors="pt")
output = model.generate(**inputs, num_beams=8, do_sample=True, max_length=len(text)+10)
decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
predicted_title = decoded_output.strip()
pattern = r'\[.*?\]'
# Replace all occurrences of the pattern with [redacted]
redacted_text = re.sub(pattern, '[redacted]', predicted_title)
return redacted_text
from fastapi import FastAPI
import uvicorn
app = FastAPI()
@app.get("/")
async def hello():
return {"msg" : "Live"}
@app.post("/mask")
async def mask_input(query):
output = mask_generation(query)
return {"data" : output}
if __name__ == '__main__':
os.environ["HF_HOME"] = "/.cache"
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True, workers=1)