anon5's picture
Update app.py
90123a4 verified
raw
history blame contribute delete
No virus
1.1 kB
from flask import Flask, request, jsonify
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests
import logging
print("Loading models...")
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch32-384')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch32-384')
print("Starting webapp...")
app = Flask(__name__)
log = logging.getLogger('werkzeug')
log.disabled = True
app.logger.disabled = True
print("Ready")
@app.route("/")
def hello_world():
global feature_extractor, model
url = request.args.get('url')
if url is None:
return jsonify(error="Url is required", url=None, label=None)
image = Image.open(requests.get(url, stream=True).raw)
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
return jsonify(url=url, label=model.config.id2label[predicted_class_idx])