IDfy-Avatarify / tests /test_predict.py
yashvii's picture
Upload folder using huggingface_hub
b2cbfed verified
raw
history blame
No virus
4.29 kB
import base64
from io import BytesIO
import os
import time
import numpy as np
from PIL import Image, ImageChops
import pytest
import requests
def local_run(model_endpoint: str, model_input: dict):
# Maximum wait time in seconds
max_wait_time = 1000
# Interval between status checks in seconds
retry_interval = 100
total_wait_time = 0
while total_wait_time < max_wait_time:
response = requests.post(model_endpoint, json={"input": model_input})
data = response.json()
if "output" in data:
try:
datauri = data["output"][0]
base64_encoded_data = datauri.split(",")[1]
decoded_data = base64.b64decode(base64_encoded_data)
return Image.open(BytesIO(decoded_data))
except Exception as e:
print("Error while processing output:")
print("input:", model_input)
print(data)
raise e
elif "detail" in data and data["detail"] == "Already running a prediction":
print(f"Prediction in progress, waited {total_wait_time}s, waiting more...")
time.sleep(retry_interval)
total_wait_time += retry_interval
else:
print("Unexpected response data:", data)
break
else:
raise Exception("Max wait time exceeded, unable to get valid response")
def image_equal_fuzzy(img_expected, img_actual, test_name="default", tol=20):
"""
Assert that average pixel values differ by less than tol across image
Tol determined empirically - holding everything else equal but varying seed
generates images that vary by at least 50
"""
img1 = np.array(img_expected, dtype=np.int32)
img2 = np.array(img_actual, dtype=np.int32)
mean_delta = np.mean(np.abs(img1 - img2))
imgs_equal = mean_delta < tol
if not imgs_equal:
# save failures for quick inspection
save_dir = f"/tmp/{test_name}"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
img_expected.save(os.path.join(save_dir, "expected.png"))
img_actual.save(os.path.join(save_dir, "actual.png"))
difference = ImageChops.difference(img_expected, img_actual)
difference.save(os.path.join(save_dir, "delta.png"))
return imgs_equal
@pytest.fixture
def expected_image():
return Image.open("tests/assets/out.png")
def test_seeded_prediction(expected_image):
data = {
"image": "https://replicate.delivery/pbxt/KIIutO7jIleskKaWebhvurgBUlHR6M6KN7KHaMMWSt4OnVrF/musk_resize.jpeg",
"prompt": "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality",
"scheduler": "EulerDiscreteScheduler",
"enable_lcm": False,
"pose_image": "https://replicate.delivery/pbxt/KJmFdQRQVDXGDVdVXftLvFrrvgOPXXRXbzIVEyExPYYOFPyF/80048a6e6586759dbcb529e74a9042ca.jpeg",
"sdxl_weights": "protovision-xl-high-fidel",
"pose_strength": 0.4,
"canny_strength": 0.3,
"depth_strength": 0.5,
"guidance_scale": 5,
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured",
"ip_adapter_scale": 0.8,
"lcm_guidance_scale": 1.5,
"num_inference_steps": 30,
"enable_pose_controlnet": True,
"enhance_nonface_region": True,
"enable_canny_controlnet": False,
"enable_depth_controlnet": False,
"lcm_num_inference_steps": 5,
"controlnet_conditioning_scale": 0.8,
"seed": 1337,
}
actual_image = local_run("http://localhost:5000/predictions", data)
expected_image = Image.open("tests/assets/out.png")
test_result = image_equal_fuzzy(
actual_image, expected_image, test_name="test_seeded_prediction"
)
if test_result:
print("Test passed successfully.")
else:
print("Test failed.")
assert test_result