Saad0KH's picture
Rename to
dc22aeb verified
history blame contribute delete
No virus
5.34 kB
from flask import Flask, request, jsonify ,send_file
from PIL import Image
import base64
import spaces
from loadimg import load_img
from io import BytesIO
import numpy as np
import insightface
import onnxruntime as ort
import huggingface_hub
from SegCloth import segment_clothing
from transparent_background import Remover
import threading
import logging
import uuid
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
app = Flask(__name__)
# Configure logging
# Load the model lazily
model = None
detector = None
def load_model():
global model, detector
path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx")
options = ort.SessionOptions()
options.intra_op_num_threads = 8
options.inter_op_num_threads = 8
session = ort.InferenceSession(
path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"]
model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session)
model.prepare(-1, nms_thresh=0.5, input_size=(640, 640))
detector = model"Model loaded successfully.")
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
transform_image = transforms.Compose(
transforms.Resize((1024, 1024)),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
return unique_name
# Function to decode a base64 image to PIL.Image.Image
def decode_image_from_base64(image_data):
image_data = base64.b64decode(image_data)
image ="RGB")
return image
# Function to encode a PIL image to base64
def encode_image_to_base64(image):
buffered = BytesIO(), format="PNG") # Use PNG for compatibility with RGBA
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def rm_background(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
image_size = im.size
origin = im.copy()
image = load_img(im)
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
return (image)
def remove_background(image):
remover = Remover()
if isinstance(image, Image.Image):
output = remover.process(image)
elif isinstance(image, np.ndarray):
image_pil = Image.fromarray(image)
output = remover.process(image_pil)
raise TypeError("Unsupported image type")
return output
def detect_and_segment_persons(image, clothes):
img = np.array(image)
img = img[:, :, ::-1] # RGB -> BGR
if detector is None:
load_model() # Ensure the model is loaded
bboxes, kpss = detector.detect(img)
if bboxes.shape[0] == 0:
return [save_image(rm_background(image))]
height, width, _ = img.shape
bboxes = np.round(bboxes[:, :4]).astype(int)
bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width)
bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height)
bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width)
bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height)
all_segmented_images = []
for i in range(bboxes.shape[0]):
bbox = bboxes[i]
x1, y1, x2, y2 = bbox
person_img = img[y1:y2, x1:x2]
pil_img = Image.fromarray(person_img[:, :, ::-1])
img_rm_background = rm_background(pil_img)
segmented_result = segment_clothing(img_rm_background, clothes)
image_paths = [save_image(img) for img in segmented_result]
return all_segmented_images
@app.route('/', methods=['GET'])
def welcome():
return "Welcome to Clothing Segmentation API"
@app.route('/api/detect', methods=['POST'])
def detect():
data = request.json
image_base64 = data['image']
image = decode_image_from_base64(image_base64)
clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"]
result = detect_and_segment_persons(image, clothes)
return jsonify({'images': result})
except Exception as e:
logging.error(f"Error occurred: {e}")
return jsonify({'error': str(e)}), 500
# Route pour récupérer l'image générée
@app.route('/api/get_image/<image_id>', methods=['GET'])
def get_image(image_id):
# Construire le chemin complet de l'image
image_path = image_id # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde
# Renvoyer l'image
return send_file(image_path, mimetype='image/png')
except FileNotFoundError:
return jsonify({'error': 'Image not found'}), 404
if __name__ == "__main__":, host="", port=7860)