File size: 5,440 Bytes
3542be4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# -*- coding: utf-8 -*-
# https://github.com/kohya-ss/sd-scripts/blob/main/finetune/tag_images_by_wd14_tagger.py

import csv
import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

from PIL import Image
import cv2
import numpy as np
from pathlib import Path
import onnx
import onnxruntime as ort

# from wd14 tagger
IMAGE_SIZE = 448

model = None  # Initialize model variable


def convert_array_to_bgr(array):
    """
    Convert a NumPy array image to BGR format regardless of its original format.
    
    Parameters:
    - array: NumPy array of the image.
    
    Returns:
    - A NumPy array representing the image in BGR format.
    """
    # グレースケール画像(2次元配列)
    if array.ndim == 2:
        # グレースケールをBGRに変換(3チャンネルに拡張)
        bgr_array = np.stack((array,) * 3, axis=-1)
    # RGBAまたはRGB画像(3次元配列)
    elif array.ndim == 3:
        # RGBA画像の場合、アルファチャンネルを削除
        if array.shape[2] == 4:
            array = array[:, :, :3]
        # RGBをBGRに変換
        bgr_array = array[:, :, ::-1]
    else:
        raise ValueError("Unsupported array shape.")

    return bgr_array


def preprocess_image(image):
    image = np.array(image)
    image = convert_array_to_bgr(image)

    size = max(image.shape[0:2])
    pad_x = size - image.shape[1]
    pad_y = size - image.shape[0]
    pad_l = pad_x // 2
    pad_t = pad_y // 2
    image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)

    interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
    image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)

    image = image.astype(np.float32)
    return image

def modelLoad(model_dir):
    onnx_path = os.path.join(model_dir, "model.onnx")
    # 実行プロバイダーをCPUのみに指定
    providers = ['CPUExecutionProvider']
    # InferenceSessionの作成時にプロバイダーのリストを指定
    ort_session = ort.InferenceSession(onnx_path, providers=providers)
    input_name = ort_session.get_inputs()[0].name
    
    # 実際に使用されているプロバイダーを取得して表示
    actual_provider = ort_session.get_providers()[0]  # 使用されているプロバイダー
    print(f"Using provider: {actual_provider}")
    
    return [ort_session, input_name]

def analysis(image_path, model_dir, model):
    ort_session = model[0]
    input_name = model[1]

    with open(os.path.join(model_dir, "selected_tags.csv"), "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        l = [row for row in reader]
        header = l[0]  # tag_id,name,category,count
        rows = l[1:]
    assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"

    general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
    character_tags = [row[1] for row in rows[1:] if row[2] == "4"]

    tag_freq = {}
    undesired_tags = ["transparent background"]

    # 画像をロードして前処理する
    if image_path:
        # 画像を開き、RGBA形式に変換して透過情報を保持
        img = Image.open(image_path)
        img = img.convert("RGBA")

        # 透過部分を白色で塗りつぶすキャンバスを作成
        canvas_image = Image.new('RGBA', img.size, (255, 255, 255, 255))
        # 画像をキャンバスにペーストし、透過部分が白色になるように設定
        canvas_image.paste(img, (0, 0), img)

        # RGBAからRGBに変換し、透過部分を白色にする
        image_pil = canvas_image.convert("RGB")
    image_preprocessed = preprocess_image(image_pil)
    image_preprocessed = np.expand_dims(image_preprocessed, axis=0)

    # 推論を実行
    prob = ort_session.run(None, {input_name: image_preprocessed})[0][0]
    # タグを生成
    combined_tags = []
    general_tag_text = ""
    character_tag_text = ""
    remove_underscore = True
    caption_separator = ", "
    general_threshold = 0.35
    character_threshold = 0.35

    for i, p in enumerate(prob[4:]):
        if i < len(general_tags) and p >= general_threshold:
            tag_name = general_tags[i]
            if remove_underscore and len(tag_name) > 3:  # ignore emoji tags like >_< and ^_^
                tag_name = tag_name.replace("_", " ")

            if tag_name not in undesired_tags:
                tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
                general_tag_text += caption_separator + tag_name
                combined_tags.append(tag_name)
        elif i >= len(general_tags) and p >= character_threshold:
            tag_name = character_tags[i - len(general_tags)]
            if remove_underscore and len(tag_name) > 3:
                tag_name = tag_name.replace("_", " ")

            if tag_name not in undesired_tags:
                tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
                character_tag_text += caption_separator + tag_name
                combined_tags.append(tag_name)

    # 先頭のカンマを取る
    if len(general_tag_text) > 0:
        general_tag_text = general_tag_text[len(caption_separator) :]
    if len(character_tag_text) > 0:
        character_tag_text = character_tag_text[len(caption_separator) :]
    tag_text = caption_separator.join(combined_tags)
    return tag_text