anon5 commited on
Commit
32580f5
1 Parent(s): 0887a0d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -0
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
3
+ from PIL import Image
4
+ import requests
5
+
6
+ print("Loading models...")
7
+
8
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch32-384')
9
+ model = ViTForImageClassification.from_pretrained('google/vit-base-patch32-384')
10
+
11
+ print("Starting webapp...")
12
+
13
+ app = Flask(__name__)
14
+
15
+ print("Ready")
16
+
17
+ @app.route("/")
18
+ def hello_world():
19
+ global feature_extractor, model
20
+
21
+ url = request.args.get('url')
22
+
23
+ if url is None:
24
+ return jsonify(error="Url is required", url=None, classes=[])
25
+
26
+ image = Image.open(requests.get(url, stream=True).raw)
27
+
28
+ inputs = feature_extractor(images=image, return_tensors="pt")
29
+ outputs = model(**inputs)
30
+ logits = outputs.logits
31
+
32
+ print(logits)
33
+
34
+ # model predicts one of the 1000 ImageNet classes
35
+ predicted_class_idx = logits.argmax(-1).item()
36
+
37
+ return jsonify(url=url, classes=model.config.id2label[predicted_class_idx])