File size: 1,209 Bytes
1581f35
53d7ebc
1581f35
d8cad3b
 
 
 
1581f35
 
d8cad3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53d7ebc
d8cad3b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
os.environ["HF_HOME"] = "/.cache"

import re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_dir = 'edithram23/Redaction'
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)

def mask_generation(text):
    import re
    inputs = ["Mask Generation: " + text]
    inputs = tokenizer(inputs, max_length=500, truncation=True, return_tensors="pt")
    output = model.generate(**inputs, num_beams=8, do_sample=True, max_length=len(text)+10)
    decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    predicted_title = decoded_output.strip()
    pattern = r'\[.*?\]'
    # Replace all occurrences of the pattern with [redacted]
    redacted_text = re.sub(pattern, '[redacted]', predicted_title)
    return redacted_text

from fastapi import FastAPI
import uvicorn

app = FastAPI()

@app.get("/")
async def hello():
    return {"msg" : "Live"}

@app.post("/mask")
async def mask_input(query):
    output = mask_generation(query)
    return {"data" : output}

if __name__ == '__main__':
    os.environ["HF_HOME"] = "/.cache"
    uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True, workers=1)