themeetjani commited on
Commit
dae16ed
1 Parent(s): 8b8017b

Upload 3 files

Browse files
Files changed (3) hide show
  1. application.py +91 -0
  2. requirements.txt +6 -0
  3. tweet_model_v1.bin +3 -0
application.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import transformers
4
+ import json
5
+ from flask import Flask, jsonify, request
6
+ import torch.nn.functional as F
7
+ import boto3
8
+ import pandas as pd
9
+ bucket = 'data-ai-dev2'
10
+ from transformers import BertTokenizer, BertModel
11
+ from torch import cuda
12
+ device = 'cuda' if cuda.is_available() else 'cpu'
13
+
14
+ class RobertaClass(torch.nn.Module):
15
+ def __init__(self):
16
+ super(RobertaClass, self).__init__()
17
+ self.l1 = BertModel.from_pretrained("bert-base-multilingual-cased")
18
+ self.pre_classifier = torch.nn.Linear(768, 768)
19
+ self.dropout = torch.nn.Dropout(0.3)
20
+ self.classifier = torch.nn.Linear(768, 8)
21
+
22
+ def forward(self, input_ids, attention_mask, token_type_ids):
23
+ output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
24
+ hidden_state = output_1[0]
25
+ pooler = hidden_state[:, 0]
26
+ pooler = self.pre_classifier(pooler)
27
+ pooler = torch.nn.ReLU()(pooler)
28
+ pooler = self.dropout(pooler)
29
+ output = self.classifier(pooler)
30
+ return output
31
+
32
+ model = RobertaClass()
33
+ model.to(device)
34
+
35
+ s3 = boto3.client('s3', aws_access_key_id='AKIAW5BGUY6ZRCSQBSIJ', aws_secret_access_key= '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')
36
+ s3.download_file(Bucket=bucket, Key='model_hf/tweet_model/tweet_model_v1.bin', Filename = './tweet_model_v1.bin')
37
+
38
+ model = torch.load('tweet_model_v1.bin', map_location=torch.device('cpu'))
39
+
40
+ tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased', truncation=True, do_lower_case=True)
41
+
42
+ def id2class_fun(lst, map_cl):
43
+ s = pd.Series(lst)
44
+ return s.map(map_cl).tolist()
45
+
46
+ application = Flask(__name__)
47
+ @application.route('/')
48
+ def home():
49
+ return "Working!"
50
+
51
+ @application.route('/process/', methods=['POST'])
52
+ def process():
53
+ content_type = request.headers.get('Content-Type')
54
+ if (content_type == 'application/json'):
55
+ json_file = request.json
56
+ loaded = json.dumps(json_file)
57
+ new_loaded = json.loads(loaded)
58
+ text = new_loaded['text']
59
+ id2class = {0: 'InappropriateUndesirable', 1 : 'GreenContent', 2 : 'IllegalActivities',
60
+ 3 : 'DiscriminatoryHate', 4 :'ViolentGraphic', 5:'PotentialAddiction',
61
+ 6 : 'ExtremismTerrorism', 7 : 'SexualExplicit'}
62
+ try:
63
+
64
+ inputs = (
65
+ tokenizer.encode_plus(
66
+ text, None, add_special_tokens=True, max_length = 512,
67
+ return_token_type_ids=True, padding=True,
68
+ truncation=True, return_tensors='pt'))
69
+ ids = inputs['input_ids']
70
+ mask = inputs['attention_mask']
71
+ token_type_ids = inputs["token_type_ids"]
72
+ outputs = model(ids, mask, token_type_ids)
73
+ top_values, top_indices = torch.topk(outputs.data, k=2, dim=1)
74
+ probs_values = F.softmax(top_values, dim=0)
75
+ prd_cls = top_indices.cpu().detach().numpy().tolist()
76
+ prd_cls = [item for sublist in prd_cls for item in sublist]
77
+ prd_cls_1 = id2class_fun(prd_cls, id2class)
78
+ prd_score = top_values.cpu().detach().numpy().tolist()
79
+ prd_score = [item for sublist in prd_score for item in sublist]
80
+ otp = dict(zip(prd_cls_1, prd_score))
81
+ # .replace(map_class, inplace=True)
82
+ return jsonify({'output':otp})
83
+ except:
84
+ return jsonify({'output':'something went wrong'})
85
+
86
+
87
+
88
+ if __name__ == "__main__":
89
+ application.debug = True
90
+ application.run()
91
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers==4.31.0
2
+ numpy==1.25.2
3
+ Flask==2.3.2
4
+ boto3==1.26.157
5
+ torch==2.0.0
6
+ pandas==1.5.3
tweet_model_v1.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eaa1adad810a4ec32ba1e5e7226eafc7f083953355d902d5d67cfebab2a72359
3
+ size 713927888