SmilingWolf's picture
Update index to danbooru dataset v3
03d2c4c
raw
history blame contribute delete
No virus
4.66 kB
import argparse
import json
import faiss
import gradio as gr
import numpy as np
import requests
from imgutils.tagging import wd14
TITLE = "## Danbooru Explorer"
DESCRIPTION = """
Image similarity-based retrieval tool using:
- [SmilingWolf/wd-swinv2-tagger-v3](https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3) as feature extractor
- [dghs-imgutils](https://github.com/deepghs/imgutils) for feature extraction
- [Faiss](https://github.com/facebookresearch/faiss) and [autofaiss](https://github.com/criteo/autofaiss) for indexing
Also, check out [SmilingWolf/danbooru2022_embeddings_playground](https://huggingface.co/spaces/SmilingWolf/danbooru2022_embeddings_playground) for a similar space with experimental support for text input combined with image input.
"""
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true")
return parser.parse_args()
def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
headers = {"User-Agent": "image_similarity_tool"}
ratings_to_letters = {
"General": "g",
"Sensitive": "s",
"Questionable": "q",
"Explicit": "e",
}
acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings]
image_url = f"https://danbooru.donmai.us/posts/{image_id}.json"
if api_username != "" and api_key != "":
image_url = f"{image_url}?api_key={api_key}&login={api_username}"
r = requests.get(image_url, headers=headers)
if r.status_code != 200:
return None
content = json.loads(r.text)
image_url = content["large_file_url"] if "large_file_url" in content else None
image_url = image_url if content["rating"] in acceptable_ratings else None
return image_url
class SimilaritySearcher:
def __init__(self):
self.images_ids = np.load("index/cosine_ids.npy")
self.knn_index = faiss.read_index("index/cosine_knn.index")
config = json.loads(open("index/cosine_infos.json").read())["index_param"]
faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
def predict(
self,
img_input,
selected_ratings,
n_neighbours,
api_username,
api_key,
):
embeddings = wd14.get_wd14_tags(
img_input,
model_name="SwinV2_v3",
fmt=("embedding"),
)
embeddings = np.expand_dims(embeddings, 0)
faiss.normalize_L2(embeddings)
dists, indexes = self.knn_index.search(embeddings, k=n_neighbours)
neighbours_ids = self.images_ids[indexes][0]
neighbours_ids = [int(x) for x in neighbours_ids]
captions = []
image_urls = []
for image_id, dist in zip(neighbours_ids, dists[0]):
current_url = danbooru_id_to_url(
image_id,
selected_ratings,
api_username,
api_key,
)
if current_url is not None:
image_urls.append(current_url)
captions.append(f"{image_id}/{dist:.2f}")
return list(zip(image_urls, captions))
def main():
args = parse_args()
searcher = SimilaritySearcher()
with gr.Blocks() as demo:
gr.Markdown(TITLE)
gr.Markdown(DESCRIPTION)
with gr.Row():
img_input = gr.Image(type="pil", label="Input")
with gr.Column():
with gr.Row():
api_username = gr.Textbox(label="Danbooru API Username")
api_key = gr.Textbox(label="Danbooru API Key")
selected_ratings = gr.CheckboxGroup(
choices=["General", "Sensitive", "Questionable", "Explicit"],
value=["General", "Sensitive"],
label="Ratings",
)
with gr.Row():
n_neighbours = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="# of images",
)
find_btn = gr.Button("Find similar images")
similar_images = gr.Gallery(label="Similar images", columns=[5])
find_btn.click(
fn=searcher.predict,
inputs=[
img_input,
selected_ratings,
n_neighbours,
api_username,
api_key,
],
outputs=[similar_images],
)
demo.queue()
demo.launch(share=args.share)
if __name__ == "__main__":
main()