diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..015f0df3405332e2023a4a157d80ade3dd6616f2 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,40 @@ +# The .dockerignore file excludes files from the container build process. +# +# https://docs.docker.com/engine/reference/builder/#dockerignore-file + +# Replicate +/safety-cache/ +/gradio_cached_examples/ +*.mp4 +*.pth +*.pt +*.bin +*.ckpt +*.onnx +*.tar +*.tar.gz +*.h5 +*.pb +*.caffemodel +*.weights +*.tar +*.jpg +*.jpeg +*.png +*.webp +.vscode + +# Exclude Git files +.git +.github +.gitignore + +# Exclude Python cache files +__pycache__ +.pytest_cache/ +.mypy_cache +.pytest_cache +.ruff_cache + +# Exclude Python virtual environment +/venv diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..1663a3da43569acf6d11e6a4fab02d79bfed37d3 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +generated_images/20240723_053704_668578_0.png filter=lfs diff=lfs merge=lfs -text +generated_images/20240723_053801_148984_0.png filter=lfs diff=lfs merge=lfs -text +generated_images/20240723_053853_022841_0.png filter=lfs diff=lfs merge=lfs -text +generated_images/20240723_053948_468290_0.png filter=lfs diff=lfs merge=lfs -text +generated_images/20240723_054025_692605_0.png filter=lfs diff=lfs merge=lfs -text +generated_images/20240723_054124_697176_0.png filter=lfs diff=lfs merge=lfs -text +images/aa.ll_gallery1.png filter=lfs diff=lfs merge=lfs -text +images/yashvi_gallery1.png filter=lfs diff=lfs merge=lfs -text +images/yashvi_gallery4.png filter=lfs diff=lfs merge=lfs -text +images/yashviwhy@instantid.com_gallery1.png filter=lfs diff=lfs merge=lfs -text +images/yashviwhy@instantid.com_gallery2.png filter=lfs diff=lfs merge=lfs -text +images/yashviwhy@instantid.com_gallery3.png filter=lfs diff=lfs merge=lfs -text +images/yashviwhy@instantid.com_gallery4.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4c33885a3527999932b31c675c7965f6ac9bb48d --- /dev/null +++ b/.gitignore @@ -0,0 +1,186 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ +huggingface/ + +# Cog +/.cog/ +/safety-cache/ +*.tar +.vscode +gradio_cached_examples +cog/test_batchsize.py +input.png +output_*.png +output.*.png +output_image_*.png +output_image.*.png +output_*.webp +output.*.webp +output_image_*.webp +output_image.*.webp +output_*.jpg +output.*.jpg +output_image_*.jpg +output_image.*.jpg +output_*.jpeg +output.*.jpeg +output_image_*.jpeg +output_image.*.jpeg \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md index bcd7d198a47e82e57236373e11b176d9faf30abb..830c30f332083b5174a7e964a1f0ddb9722d37f6 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,6 @@ --- -title: IDfy Avatarify -emoji: 👁 -colorFrom: indigo -colorTo: purple +title: IDfy-Avatarify +app_file: gradio_demo/app.py sdk: gradio -sdk_version: 4.42.0 -app_file: app.py -pinned: false +sdk_version: 4.38.1 --- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/Untitled.ipynb b/Untitled.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..0fd3d83634875a032a328cf71c9d6bb66bbe99b0 --- /dev/null +++ b/Untitled.ipynb @@ -0,0 +1,432 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "82c97d2c-16bf-4b2c-b16c-d9f7a8a5b12f", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting diffusers==0.25.1 (from -r gradio_demo/requirements.txt (line 1))\n", + " Downloading diffusers-0.25.1-py3-none-any.whl.metadata (19 kB)\n", + "Requirement already satisfied: torch==2.0.0 in /opt/conda/lib/python3.10/site-packages (from -r gradio_demo/requirements.txt (line 2)) (2.0.0+cu118)\n", + "Requirement already satisfied: torchvision==0.15.1 in /opt/conda/lib/python3.10/site-packages (from -r gradio_demo/requirements.txt (line 3)) (0.15.1+cu118)\n", + "Collecting transformers==4.37.1 (from -r gradio_demo/requirements.txt (line 4))\n", + " Downloading transformers-4.37.1-py3-none-any.whl.metadata (129 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m129.4/129.4 kB\u001b[0m \u001b[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", + "\u001b[?25hCollecting accelerate (from -r gradio_demo/requirements.txt (line 5))\n", + " Downloading accelerate-0.32.1-py3-none-any.whl.metadata (18 kB)\n", + "Collecting safetensors (from -r gradio_demo/requirements.txt (line 6))\n", + " Downloading safetensors-0.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)\n", + "Collecting einops (from -r gradio_demo/requirements.txt (line 7))\n", + " Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)\n", + "Collecting onnxruntime-gpu (from -r gradio_demo/requirements.txt (line 8))\n", + " Downloading onnxruntime_gpu-1.18.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.4 kB)\n", + "Collecting spaces==0.19.4 (from -r gradio_demo/requirements.txt (line 9))\n", + " Downloading spaces-0.19.4-py3-none-any.whl.metadata (972 bytes)\n", + "Collecting omegaconf (from -r gradio_demo/requirements.txt (line 10))\n", + " Downloading omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)\n", + "Collecting peft (from -r gradio_demo/requirements.txt (line 11))\n", + " Downloading peft-0.11.1-py3-none-any.whl.metadata (13 kB)\n", + "Collecting huggingface-hub==0.20.2 (from -r gradio_demo/requirements.txt (line 12))\n", + " Downloading huggingface_hub-0.20.2-py3-none-any.whl.metadata (12 kB)\n", + "Collecting opencv-python (from -r gradio_demo/requirements.txt (line 13))\n", + " Downloading opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)\n", + "Collecting insightface (from -r gradio_demo/requirements.txt (line 14))\n", + " Downloading insightface-0.7.3.tar.gz (439 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m439.5/439.5 kB\u001b[0m \u001b[31m10.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n", + "\u001b[?25h Installing build dependencies ... \u001b[?25ldone\n", + "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", + "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n", + "\u001b[?25hCollecting gradio (from -r gradio_demo/requirements.txt (line 15))\n", + " Downloading gradio-4.38.1-py3-none-any.whl.metadata (15 kB)\n", + "Collecting controlnet_aux (from -r gradio_demo/requirements.txt (line 16))\n", + " Downloading controlnet_aux-0.0.9-py3-none-any.whl.metadata (6.5 kB)\n", + "Collecting gdown (from -r gradio_demo/requirements.txt (line 17))\n", + " Downloading gdown-5.2.0-py3-none-any.whl.metadata (5.8 kB)\n", + "Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.10/site-packages (from diffusers==0.25.1->-r gradio_demo/requirements.txt (line 1)) (7.0.0)\n", + "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from diffusers==0.25.1->-r gradio_demo/requirements.txt (line 1)) (3.15.4)\n", + "Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from diffusers==0.25.1->-r gradio_demo/requirements.txt (line 1)) (1.25.2)\n", + "Collecting regex!=2019.12.17 (from diffusers==0.25.1->-r gradio_demo/requirements.txt (line 1))\n", + " Downloading regex-2024.5.15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.9/40.9 kB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from diffusers==0.25.1->-r gradio_demo/requirements.txt (line 1)) (2.32.3)\n", + "Requirement already satisfied: Pillow in /opt/conda/lib/python3.10/site-packages (from diffusers==0.25.1->-r gradio_demo/requirements.txt (line 1)) (10.4.0)\n", + "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.10/site-packages (from torch==2.0.0->-r gradio_demo/requirements.txt (line 2)) (4.12.2)\n", + "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch==2.0.0->-r gradio_demo/requirements.txt (line 2)) (1.13.0)\n", + "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch==2.0.0->-r gradio_demo/requirements.txt (line 2)) (3.3)\n", + "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch==2.0.0->-r gradio_demo/requirements.txt (line 2)) (3.1.4)\n", + "Requirement already satisfied: triton==2.0.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.0.0->-r gradio_demo/requirements.txt (line 2)) (2.0.0)\n", + "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from transformers==4.37.1->-r gradio_demo/requirements.txt (line 4)) (24.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from transformers==4.37.1->-r gradio_demo/requirements.txt (line 4)) (6.0.1)\n", + "Collecting tokenizers<0.19,>=0.14 (from transformers==4.37.1->-r gradio_demo/requirements.txt (line 4))\n", + " Downloading tokenizers-0.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n", + "Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.10/site-packages (from transformers==4.37.1->-r gradio_demo/requirements.txt (line 4)) (4.66.4)\n", + "Requirement already satisfied: httpx>=0.20 in /opt/conda/lib/python3.10/site-packages (from spaces==0.19.4->-r gradio_demo/requirements.txt (line 9)) (0.27.0)\n", + "Requirement already satisfied: psutil<6,>=2 in /opt/conda/lib/python3.10/site-packages (from spaces==0.19.4->-r gradio_demo/requirements.txt (line 9)) (5.9.3)\n", + "Requirement already satisfied: pydantic<3,>=1 in /opt/conda/lib/python3.10/site-packages (from spaces==0.19.4->-r gradio_demo/requirements.txt (line 9)) (1.10.17)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub==0.20.2->-r gradio_demo/requirements.txt (line 12)) (2024.6.1)\n", + "Requirement already satisfied: cmake in /opt/conda/lib/python3.10/site-packages (from triton==2.0.0->torch==2.0.0->-r gradio_demo/requirements.txt (line 2)) (3.30.0)\n", + "Requirement already satisfied: lit in /opt/conda/lib/python3.10/site-packages (from triton==2.0.0->torch==2.0.0->-r gradio_demo/requirements.txt (line 2)) (18.1.8)\n", + "Collecting coloredlogs (from onnxruntime-gpu->-r gradio_demo/requirements.txt (line 8))\n", + " Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)\n", + "Collecting flatbuffers (from onnxruntime-gpu->-r gradio_demo/requirements.txt (line 8))\n", + " Downloading flatbuffers-24.3.25-py2.py3-none-any.whl.metadata (850 bytes)\n", + "Requirement already satisfied: protobuf in /opt/conda/lib/python3.10/site-packages (from onnxruntime-gpu->-r gradio_demo/requirements.txt (line 8)) (3.20.3)\n", + "Collecting antlr4-python3-runtime==4.9.* (from omegaconf->-r gradio_demo/requirements.txt (line 10))\n", + " Downloading antlr4-python3-runtime-4.9.3.tar.gz (117 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m117.0/117.0 kB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25hCollecting onnx (from insightface->-r gradio_demo/requirements.txt (line 14))\n", + " Downloading onnx-1.16.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)\n", + "Requirement already satisfied: matplotlib in /opt/conda/lib/python3.10/site-packages (from insightface->-r gradio_demo/requirements.txt (line 14)) (3.7.3)\n", + "Requirement already satisfied: scipy in /opt/conda/lib/python3.10/site-packages (from insightface->-r gradio_demo/requirements.txt (line 14)) (1.11.4)\n", + "Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.10/site-packages (from insightface->-r gradio_demo/requirements.txt (line 14)) (1.5.1)\n", + "Requirement already satisfied: scikit-image in /opt/conda/lib/python3.10/site-packages (from insightface->-r gradio_demo/requirements.txt (line 14)) (0.24.0)\n", + "Collecting easydict (from insightface->-r gradio_demo/requirements.txt (line 14))\n", + " Downloading easydict-1.13-py3-none-any.whl.metadata (4.2 kB)\n", + "Requirement already satisfied: cython in /opt/conda/lib/python3.10/site-packages (from insightface->-r gradio_demo/requirements.txt (line 14)) (3.0.10)\n", + "Collecting albumentations (from insightface->-r gradio_demo/requirements.txt (line 14))\n", + " Downloading albumentations-1.4.11-py3-none-any.whl.metadata (39 kB)\n", + "Requirement already satisfied: prettytable in /opt/conda/lib/python3.10/site-packages (from insightface->-r gradio_demo/requirements.txt (line 14)) (3.10.0)\n", + "Requirement already satisfied: aiofiles<24.0,>=22.0 in /opt/conda/lib/python3.10/site-packages (from gradio->-r gradio_demo/requirements.txt (line 15)) (22.1.0)\n", + "Collecting altair<6.0,>=5.0 (from gradio->-r gradio_demo/requirements.txt (line 15))\n", + " Downloading altair-5.3.0-py3-none-any.whl.metadata (9.2 kB)\n", + "Requirement already satisfied: fastapi in /opt/conda/lib/python3.10/site-packages (from gradio->-r gradio_demo/requirements.txt (line 15)) (0.111.0)\n", + "Collecting ffmpy (from gradio->-r gradio_demo/requirements.txt (line 15))\n", + " Downloading ffmpy-0.3.2.tar.gz (5.5 kB)\n", + " Preparing metadata (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25hCollecting gradio-client==1.1.0 (from gradio->-r gradio_demo/requirements.txt (line 15))\n", + " Downloading gradio_client-1.1.0-py3-none-any.whl.metadata (7.1 kB)\n", + "Requirement already satisfied: importlib-resources<7.0,>=1.3 in /opt/conda/lib/python3.10/site-packages (from gradio->-r gradio_demo/requirements.txt (line 15)) (6.4.0)\n", + "Requirement already satisfied: markupsafe~=2.0 in /opt/conda/lib/python3.10/site-packages (from gradio->-r gradio_demo/requirements.txt (line 15)) (2.1.5)\n", + "Requirement already satisfied: orjson~=3.0 in /opt/conda/lib/python3.10/site-packages (from gradio->-r gradio_demo/requirements.txt (line 15)) (3.10.6)\n", + "Requirement already satisfied: pandas<3.0,>=1.0 in /opt/conda/lib/python3.10/site-packages (from gradio->-r gradio_demo/requirements.txt (line 15)) (2.0.3)\n", + "Collecting pydantic<3,>=1 (from spaces==0.19.4->-r gradio_demo/requirements.txt (line 9))\n", + " Downloading pydantic-2.8.2-py3-none-any.whl.metadata (125 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m125.2/125.2 kB\u001b[0m \u001b[31m16.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pydub (from gradio->-r gradio_demo/requirements.txt (line 15))\n", + " Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)\n", + "Requirement already satisfied: python-multipart>=0.0.9 in /opt/conda/lib/python3.10/site-packages (from gradio->-r gradio_demo/requirements.txt (line 15)) (0.0.9)\n", + "Collecting ruff>=0.2.2 (from gradio->-r gradio_demo/requirements.txt (line 15))\n", + " Downloading ruff-0.5.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (24 kB)\n", + "Collecting semantic-version~=2.0 (from gradio->-r gradio_demo/requirements.txt (line 15))\n", + " Downloading semantic_version-2.10.0-py2.py3-none-any.whl.metadata (9.7 kB)\n", + "Collecting tomlkit==0.12.0 (from gradio->-r gradio_demo/requirements.txt (line 15))\n", + " Downloading tomlkit-0.12.0-py3-none-any.whl.metadata (2.7 kB)\n", + "Requirement already satisfied: typer<1.0,>=0.12 in /opt/conda/lib/python3.10/site-packages (from gradio->-r gradio_demo/requirements.txt (line 15)) (0.12.3)\n", + "Collecting urllib3~=2.0 (from gradio->-r gradio_demo/requirements.txt (line 15))\n", + " Downloading urllib3-2.2.2-py3-none-any.whl.metadata (6.4 kB)\n", + "Requirement already satisfied: uvicorn>=0.14.0 in /opt/conda/lib/python3.10/site-packages (from gradio->-r gradio_demo/requirements.txt (line 15)) (0.30.1)\n", + "Collecting websockets<12.0,>=10.0 (from gradio-client==1.1.0->gradio->-r gradio_demo/requirements.txt (line 15))\n", + " Downloading websockets-11.0.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)\n", + "Collecting opencv-python-headless (from controlnet_aux->-r gradio_demo/requirements.txt (line 16))\n", + " Downloading opencv_python_headless-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)\n", + "Collecting timm<=0.6.7 (from controlnet_aux->-r gradio_demo/requirements.txt (line 16))\n", + " Downloading timm-0.6.7-py3-none-any.whl.metadata (33 kB)\n", + "Requirement already satisfied: beautifulsoup4 in /opt/conda/lib/python3.10/site-packages (from gdown->-r gradio_demo/requirements.txt (line 17)) (4.12.3)\n", + "Requirement already satisfied: jsonschema>=3.0 in /opt/conda/lib/python3.10/site-packages (from altair<6.0,>=5.0->gradio->-r gradio_demo/requirements.txt (line 15)) (4.22.0)\n", + "Requirement already satisfied: toolz in /opt/conda/lib/python3.10/site-packages (from altair<6.0,>=5.0->gradio->-r gradio_demo/requirements.txt (line 15)) (0.12.1)\n", + "Requirement already satisfied: anyio in /opt/conda/lib/python3.10/site-packages (from httpx>=0.20->spaces==0.19.4->-r gradio_demo/requirements.txt (line 9)) (4.4.0)\n", + "Requirement already satisfied: certifi in /opt/conda/lib/python3.10/site-packages (from httpx>=0.20->spaces==0.19.4->-r gradio_demo/requirements.txt (line 9)) (2024.7.4)\n", + "Requirement already satisfied: httpcore==1.* in /opt/conda/lib/python3.10/site-packages (from httpx>=0.20->spaces==0.19.4->-r gradio_demo/requirements.txt (line 9)) (1.0.5)\n", + "Requirement already satisfied: idna in /opt/conda/lib/python3.10/site-packages (from httpx>=0.20->spaces==0.19.4->-r gradio_demo/requirements.txt (line 9)) (3.7)\n", + "Requirement already satisfied: sniffio in /opt/conda/lib/python3.10/site-packages (from httpx>=0.20->spaces==0.19.4->-r gradio_demo/requirements.txt (line 9)) (1.3.1)\n", + "Requirement already satisfied: h11<0.15,>=0.13 in /opt/conda/lib/python3.10/site-packages (from httpcore==1.*->httpx>=0.20->spaces==0.19.4->-r gradio_demo/requirements.txt (line 9)) (0.14.0)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib->insightface->-r gradio_demo/requirements.txt (line 14)) (1.2.1)\n", + "Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.10/site-packages (from matplotlib->insightface->-r gradio_demo/requirements.txt (line 14)) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib->insightface->-r gradio_demo/requirements.txt (line 14)) (4.53.1)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib->insightface->-r gradio_demo/requirements.txt (line 14)) (1.4.5)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib->insightface->-r gradio_demo/requirements.txt (line 14)) (3.1.2)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/lib/python3.10/site-packages (from matplotlib->insightface->-r gradio_demo/requirements.txt (line 14)) (2.9.0)\n", + "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas<3.0,>=1.0->gradio->-r gradio_demo/requirements.txt (line 15)) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.1 in /opt/conda/lib/python3.10/site-packages (from pandas<3.0,>=1.0->gradio->-r gradio_demo/requirements.txt (line 15)) (2024.1)\n", + "Collecting annotated-types>=0.4.0 (from pydantic<3,>=1->spaces==0.19.4->-r gradio_demo/requirements.txt (line 9))\n", + " Downloading annotated_types-0.7.0-py3-none-any.whl.metadata (15 kB)\n", + "Collecting pydantic-core==2.20.1 (from pydantic<3,>=1->spaces==0.19.4->-r gradio_demo/requirements.txt (line 9))\n", + " Downloading pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->diffusers==0.25.1->-r gradio_demo/requirements.txt (line 1)) (3.3.2)\n", + "Requirement already satisfied: click>=8.0.0 in /opt/conda/lib/python3.10/site-packages (from typer<1.0,>=0.12->gradio->-r gradio_demo/requirements.txt (line 15)) (8.1.7)\n", + "Requirement already satisfied: shellingham>=1.3.0 in /opt/conda/lib/python3.10/site-packages (from typer<1.0,>=0.12->gradio->-r gradio_demo/requirements.txt (line 15)) (1.5.4)\n", + "Requirement already satisfied: rich>=10.11.0 in /opt/conda/lib/python3.10/site-packages (from typer<1.0,>=0.12->gradio->-r gradio_demo/requirements.txt (line 15)) (13.7.1)\n", + "Collecting albucore>=0.0.11 (from albumentations->insightface->-r gradio_demo/requirements.txt (line 14))\n", + " Downloading albucore-0.0.12-py3-none-any.whl.metadata (3.1 kB)\n", + "Collecting eval-type-backport (from albumentations->insightface->-r gradio_demo/requirements.txt (line 14))\n", + " Downloading eval_type_backport-0.2.0-py3-none-any.whl.metadata (2.2 kB)\n", + "Requirement already satisfied: imageio>=2.33 in /opt/conda/lib/python3.10/site-packages (from scikit-image->insightface->-r gradio_demo/requirements.txt (line 14)) (2.34.2)\n", + "Requirement already satisfied: tifffile>=2022.8.12 in /opt/conda/lib/python3.10/site-packages (from scikit-image->insightface->-r gradio_demo/requirements.txt (line 14)) (2024.7.2)\n", + "Requirement already satisfied: lazy-loader>=0.4 in /opt/conda/lib/python3.10/site-packages (from scikit-image->insightface->-r gradio_demo/requirements.txt (line 14)) (0.4)\n", + "Requirement already satisfied: joblib>=1.2.0 in /opt/conda/lib/python3.10/site-packages (from scikit-learn->insightface->-r gradio_demo/requirements.txt (line 14)) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /opt/conda/lib/python3.10/site-packages (from scikit-learn->insightface->-r gradio_demo/requirements.txt (line 14)) (3.5.0)\n", + "Requirement already satisfied: soupsieve>1.2 in /opt/conda/lib/python3.10/site-packages (from beautifulsoup4->gdown->-r gradio_demo/requirements.txt (line 17)) (2.5)\n", + "Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime-gpu->-r gradio_demo/requirements.txt (line 8))\n", + " Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)\n", + "Requirement already satisfied: starlette<0.38.0,>=0.37.2 in /opt/conda/lib/python3.10/site-packages (from fastapi->gradio->-r gradio_demo/requirements.txt (line 15)) (0.37.2)\n", + "Requirement already satisfied: fastapi-cli>=0.0.2 in /opt/conda/lib/python3.10/site-packages (from fastapi->gradio->-r gradio_demo/requirements.txt (line 15)) (0.0.4)\n", + "Requirement already satisfied: ujson!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,>=4.0.1 in /opt/conda/lib/python3.10/site-packages (from fastapi->gradio->-r gradio_demo/requirements.txt (line 15)) (5.10.0)\n", + "Requirement already satisfied: email_validator>=2.0.0 in /opt/conda/lib/python3.10/site-packages (from fastapi->gradio->-r gradio_demo/requirements.txt (line 15)) (2.2.0)\n", + "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.10/site-packages (from importlib-metadata->diffusers==0.25.1->-r gradio_demo/requirements.txt (line 1)) (3.19.2)\n", + "Requirement already satisfied: wcwidth in /opt/conda/lib/python3.10/site-packages (from prettytable->insightface->-r gradio_demo/requirements.txt (line 14)) (0.2.13)\n", + "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /opt/conda/lib/python3.10/site-packages (from requests[socks]->gdown->-r gradio_demo/requirements.txt (line 17)) (1.7.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.10/site-packages (from sympy->torch==2.0.0->-r gradio_demo/requirements.txt (line 2)) (1.3.0)\n", + "Requirement already satisfied: tomli>=2.0.1 in /opt/conda/lib/python3.10/site-packages (from albucore>=0.0.11->albumentations->insightface->-r gradio_demo/requirements.txt (line 14)) (2.0.1)\n", + "Requirement already satisfied: dnspython>=2.0.0 in /opt/conda/lib/python3.10/site-packages (from email_validator>=2.0.0->fastapi->gradio->-r gradio_demo/requirements.txt (line 15)) (2.6.1)\n", + "Requirement already satisfied: attrs>=22.2.0 in /opt/conda/lib/python3.10/site-packages (from jsonschema>=3.0->altair<6.0,>=5.0->gradio->-r gradio_demo/requirements.txt (line 15)) (23.2.0)\n", + "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /opt/conda/lib/python3.10/site-packages (from jsonschema>=3.0->altair<6.0,>=5.0->gradio->-r gradio_demo/requirements.txt (line 15)) (2023.12.1)\n", + "Requirement already satisfied: referencing>=0.28.4 in /opt/conda/lib/python3.10/site-packages (from jsonschema>=3.0->altair<6.0,>=5.0->gradio->-r gradio_demo/requirements.txt (line 15)) (0.35.1)\n", + "Requirement already satisfied: rpds-py>=0.7.1 in /opt/conda/lib/python3.10/site-packages (from jsonschema>=3.0->altair<6.0,>=5.0->gradio->-r gradio_demo/requirements.txt (line 15)) (0.19.0)\n", + "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib->insightface->-r gradio_demo/requirements.txt (line 14)) (1.16.0)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/conda/lib/python3.10/site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio->-r gradio_demo/requirements.txt (line 15)) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/conda/lib/python3.10/site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio->-r gradio_demo/requirements.txt (line 15)) (2.18.0)\n", + "Requirement already satisfied: exceptiongroup>=1.0.2 in /opt/conda/lib/python3.10/site-packages (from anyio->httpx>=0.20->spaces==0.19.4->-r gradio_demo/requirements.txt (line 9)) (1.2.0)\n", + "Requirement already satisfied: httptools>=0.5.0 in /opt/conda/lib/python3.10/site-packages (from uvicorn[standard]>=0.12.0->fastapi->gradio->-r gradio_demo/requirements.txt (line 15)) (0.6.1)\n", + "Requirement already satisfied: python-dotenv>=0.13 in /opt/conda/lib/python3.10/site-packages (from uvicorn[standard]>=0.12.0->fastapi->gradio->-r gradio_demo/requirements.txt (line 15)) (1.0.1)\n", + "Requirement already satisfied: uvloop!=0.15.0,!=0.15.1,>=0.14.0 in /opt/conda/lib/python3.10/site-packages (from uvicorn[standard]>=0.12.0->fastapi->gradio->-r gradio_demo/requirements.txt (line 15)) (0.19.0)\n", + "Requirement already satisfied: watchfiles>=0.13 in /opt/conda/lib/python3.10/site-packages (from uvicorn[standard]>=0.12.0->fastapi->gradio->-r gradio_demo/requirements.txt (line 15)) (0.22.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /opt/conda/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio->-r gradio_demo/requirements.txt (line 15)) (0.1.2)\n", + "Downloading diffusers-0.25.1-py3-none-any.whl (1.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m33.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", + "\u001b[?25hDownloading transformers-4.37.1-py3-none-any.whl (8.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.4/8.4 MB\u001b[0m \u001b[31m87.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hDownloading spaces-0.19.4-py3-none-any.whl (15 kB)\n", + "Downloading huggingface_hub-0.20.2-py3-none-any.whl (330 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m330.3/330.3 kB\u001b[0m \u001b[31m37.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading accelerate-0.32.1-py3-none-any.whl (314 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m314.1/314.1 kB\u001b[0m \u001b[31m36.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading safetensors-0.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m74.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading einops-0.8.0-py3-none-any.whl (43 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading onnxruntime_gpu-1.18.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (200.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m200.8/200.8 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hDownloading omegaconf-2.3.0-py3-none-any.whl (79 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.5/79.5 kB\u001b[0m \u001b[31m133.1 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", + "\u001b[?25hDownloading peft-0.11.1-py3-none-any.whl (251 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.6/251.6 kB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m\n", + "\u001b[?25hDownloading opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (62.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.5/62.5 MB\u001b[0m \u001b[31m21.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hDownloading gradio-4.38.1-py3-none-any.whl (12.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.4/12.4 MB\u001b[0m \u001b[31m84.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m0:01\u001b[0m\n", + "\u001b[?25hDownloading gradio_client-1.1.0-py3-none-any.whl (318 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m318.1/318.1 kB\u001b[0m \u001b[31m38.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading tomlkit-0.12.0-py3-none-any.whl (37 kB)\n", + "Downloading controlnet_aux-0.0.9-py3-none-any.whl (282 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m282.4/282.4 kB\u001b[0m \u001b[31m31.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading gdown-5.2.0-py3-none-any.whl (18 kB)\n", + "Downloading altair-5.3.0-py3-none-any.whl (857 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m857.8/857.8 kB\u001b[0m \u001b[31m64.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading pydantic-2.8.2-py3-none-any.whl (423 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m423.9/423.9 kB\u001b[0m \u001b[31m46.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m101.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading regex-2024.5.15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (775 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m775.1/775.1 kB\u001b[0m \u001b[31m58.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading ruff-0.5.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.1/10.1 MB\u001b[0m \u001b[31m82.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m0:01\u001b[0m\n", + "\u001b[?25hDownloading semantic_version-2.10.0-py2.py3-none-any.whl (15 kB)\n", + "Downloading timm-0.6.7-py3-none-any.whl (509 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m510.0/510.0 kB\u001b[0m \u001b[31m45.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading tokenizers-0.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.6/3.6 MB\u001b[0m \u001b[31m90.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m\n", + "\u001b[?25hDownloading urllib3-2.2.2-py3-none-any.whl (121 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.4/121.4 kB\u001b[0m \u001b[31m18.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading albumentations-1.4.11-py3-none-any.whl (165 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m165.3/165.3 kB\u001b[0m \u001b[31m23.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading opencv_python_headless-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (49.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.9/49.9 MB\u001b[0m \u001b[31m27.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading easydict-1.13-py3-none-any.whl (6.8 kB)\n", + "Downloading flatbuffers-24.3.25-py2.py3-none-any.whl (26 kB)\n", + "Downloading onnx-1.16.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m15.9/15.9 MB\u001b[0m \u001b[31m80.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hDownloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)\n", + "Downloading albucore-0.0.12-py3-none-any.whl (8.4 kB)\n", + "Downloading annotated_types-0.7.0-py3-none-any.whl (13 kB)\n", + "Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m13.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading websockets-11.0.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (129 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m129.9/129.9 kB\u001b[0m \u001b[31m18.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading eval_type_backport-0.2.0-py3-none-any.whl (5.9 kB)\n", + "Building wheels for collected packages: antlr4-python3-runtime, insightface, ffmpy\n", + " Building wheel for antlr4-python3-runtime (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25h Created wheel for antlr4-python3-runtime: filename=antlr4_python3_runtime-4.9.3-py3-none-any.whl size=144552 sha256=e9ae543340a864dee947980bab8fc7d8fc3b8a5a04b28963252f790444a5cd1f\n", + " Stored in directory: /home/jupyter/.cache/pip/wheels/12/93/dd/1f6a127edc45659556564c5730f6d4e300888f4bca2d4c5a88\n", + " Building wheel for insightface (pyproject.toml) ... \u001b[?25ldone\n", + "\u001b[?25h Created wheel for insightface: filename=insightface-0.7.3-cp310-cp310-linux_x86_64.whl size=874168 sha256=f381a87957a87ca37e1795c8ba2b854664428cf5117760cc1f7d58368918523b\n", + " Stored in directory: /home/jupyter/.cache/pip/wheels/e3/d0/80/e3773fb8b6d1cca87ea1d33d9b1f20a223a6493c896da249b5\n", + " Building wheel for ffmpy (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25h Created wheel for ffmpy: filename=ffmpy-0.3.2-py3-none-any.whl size=5581 sha256=d671b217ecfc883cea0aa0408a98e3d187bd0e888ba4e85318ea4b8bfa539786\n", + " Stored in directory: /home/jupyter/.cache/pip/wheels/bd/65/9a/671fc6dcde07d4418df0c592f8df512b26d7a0029c2a23dd81\n", + "Successfully built antlr4-python3-runtime insightface ffmpy\n", + "Installing collected packages: pydub, flatbuffers, ffmpy, easydict, antlr4-python3-runtime, websockets, urllib3, tomlkit, semantic-version, safetensors, ruff, regex, pydantic-core, opencv-python-headless, opencv-python, onnx, omegaconf, humanfriendly, eval-type-backport, einops, annotated-types, pydantic, coloredlogs, albucore, onnxruntime-gpu, huggingface-hub, albumentations, tokenizers, insightface, gradio-client, gdown, diffusers, altair, transformers, gradio, spaces, timm, accelerate, peft, controlnet_aux\n", + " Attempting uninstall: websockets\n", + " Found existing installation: websockets 12.0\n", + " Uninstalling websockets-12.0:\n", + " Successfully uninstalled websockets-12.0\n", + " Attempting uninstall: urllib3\n", + " Found existing installation: urllib3 1.26.19\n", + " Uninstalling urllib3-1.26.19:\n", + " Successfully uninstalled urllib3-1.26.19\n", + " Attempting uninstall: pydantic\n", + " Found existing installation: pydantic 1.10.17\n", + " Uninstalling pydantic-1.10.17:\n", + " Successfully uninstalled pydantic-1.10.17\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "dataproc-jupyter-plugin 0.1.79 requires pydantic~=1.10.0, but you have pydantic 2.8.2 which is incompatible.\n", + "kfp 2.5.0 requires urllib3<2.0.0, but you have urllib3 2.2.2 which is incompatible.\n", + "ydata-profiling 4.6.0 requires pydantic<2,>=1.8.1, but you have pydantic 2.8.2 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed accelerate-0.32.1 albucore-0.0.12 albumentations-1.4.11 altair-5.3.0 annotated-types-0.7.0 antlr4-python3-runtime-4.9.3 coloredlogs-15.0.1 controlnet_aux-0.0.9 diffusers-0.25.1 easydict-1.13 einops-0.8.0 eval-type-backport-0.2.0 ffmpy-0.3.2 flatbuffers-24.3.25 gdown-5.2.0 gradio-4.38.1 gradio-client-1.1.0 huggingface-hub-0.20.2 humanfriendly-10.0 insightface-0.7.3 omegaconf-2.3.0 onnx-1.16.1 onnxruntime-gpu-1.18.1 opencv-python-4.10.0.84 opencv-python-headless-4.10.0.84 peft-0.11.1 pydantic-2.8.2 pydantic-core-2.20.1 pydub-0.25.1 regex-2024.5.15 ruff-0.5.4 safetensors-0.4.3 semantic-version-2.10.0 spaces-0.19.4 timm-0.6.7 tokenizers-0.15.2 tomlkit-0.12.0 transformers-4.37.1 urllib3-2.2.2 websockets-11.0.3\n" + ] + } + ], + "source": [ + "!pip install -r gradio_demo/requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "dec146ca-0832-4c71-8b31-1586af435d67", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ControlNetModel/config.json: 100%|█████████| 1.38k/1.38k [00:00<00:00, 7.07MB/s]\n", + "diffusion_pytorch_model.safetensors: 100%|██| 2.50G/2.50G [00:05<00:00, 449MB/s]\n", + "ip-adapter.bin: 100%|███████████████████████| 1.69G/1.69G [00:09<00:00, 180MB/s]\n", + "pytorch_lora_weights.safetensors: 100%|███████| 394M/394M [00:02<00:00, 169MB/s]\n", + "Downloading...\n", + "From (original): https://drive.google.com/uc?id=18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8\n", + "From (redirected): https://drive.google.com/uc?id=18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8&confirm=t&uuid=abca1ed1-5c28-423b-a1c7-4fe1fa0d4dbc\n", + "To: /home/jupyter/InstantID/models/antelopev2.zip\n", + "100%|████████████████████████████████████████| 361M/361M [00:21<00:00, 16.9MB/s]\n", + "Archive: ./models/antelopev2.zip\n", + " creating: ./models/antelopev2/\n", + " inflating: ./models/antelopev2/genderage.onnx \n", + " inflating: ./models/antelopev2/2d106det.onnx \n", + " inflating: ./models/antelopev2/1k3d68.onnx \n", + " inflating: ./models/antelopev2/glintr100.onnx \n", + " inflating: ./models/antelopev2/scrfd_10g_bnkps.onnx \n" + ] + } + ], + "source": [ + "!python gradio_demo/download_models.py" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e4900619-4519-4ec9-bb32-a620128d1727", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting setuptools==69.5.1\n", + " Downloading setuptools-69.5.1-py3-none-any.whl.metadata (6.2 kB)\n", + "Downloading setuptools-69.5.1-py3-none-any.whl (894 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m894.6/894.6 kB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hInstalling collected packages: setuptools\n", + " Attempting uninstall: setuptools\n", + " Found existing installation: setuptools 70.1.1\n", + " Uninstalling setuptools-70.1.1:\n", + " Successfully uninstalled setuptools-70.1.1\n", + "Successfully installed setuptools-69.5.1\n" + ] + } + ], + "source": [ + "!pip install setuptools==69.5.1" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "9584f180-48a2-46bc-968d-9c99bc56f06c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting huggingface-hub==0.23.4\n", + " Downloading huggingface_hub-0.23.4-py3-none-any.whl.metadata (12 kB)\n", + "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from huggingface-hub==0.23.4) (3.15.4)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub==0.23.4) (2024.6.1)\n", + "Requirement already satisfied: packaging>=20.9 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub==0.23.4) (24.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub==0.23.4) (6.0.1)\n", + "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from huggingface-hub==0.23.4) (2.32.3)\n", + "Requirement already satisfied: tqdm>=4.42.1 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub==0.23.4) (4.66.4)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub==0.23.4) (4.12.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub==0.23.4) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub==0.23.4) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub==0.23.4) (2.2.2)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub==0.23.4) (2024.7.4)\n", + "Downloading huggingface_hub-0.23.4-py3-none-any.whl (402 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m402.6/402.6 kB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", + "\u001b[?25hInstalling collected packages: huggingface-hub\n", + " Attempting uninstall: huggingface-hub\n", + " Found existing installation: huggingface-hub 0.20.2\n", + " Uninstalling huggingface-hub-0.20.2:\n", + " Successfully uninstalled huggingface-hub-0.20.2\n", + "Successfully installed huggingface-hub-0.23.4\n" + ] + } + ], + "source": [ + "!pip install huggingface-hub==0.23.4" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56683081-5e9f-4378-84df-b957c84b23ad", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "environment": { + "kernel": "python3", + "name": ".m123", + "type": "gcloud", + "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/:m123" + }, + "kernelspec": { + "display_name": "Python 3 (Local)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/checkpoints/ControlNetModel/config.json b/checkpoints/ControlNetModel/config.json new file mode 100644 index 0000000000000000000000000000000000000000..7360f57d816de85660f607272eeae301a4eb0dcd --- /dev/null +++ b/checkpoints/ControlNetModel/config.json @@ -0,0 +1,57 @@ +{ + "_class_name": "ControlNetModel", + "_diffusers_version": "0.21.2", + "_name_or_path": "/mnt/nj-aigc/usr/guiwan/workspace/diffusion_output/face_xl_ipc_v4_2_XiezhenAnimeForeigner/checkpoint-150000/ControlNetModel", + "act_fn": "silu", + "addition_embed_type": "text_time", + "addition_embed_type_num_heads": 64, + "addition_time_embed_dim": 256, + "attention_head_dim": [ + 5, + 10, + 20 + ], + "block_out_channels": [ + 320, + 640, + 1280 + ], + "class_embed_type": null, + "conditioning_channels": 3, + "conditioning_embedding_out_channels": [ + 16, + 32, + 96, + 256 + ], + "controlnet_conditioning_channel_order": "rgb", + "cross_attention_dim": 2048, + "down_block_types": [ + "DownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D" + ], + "downsample_padding": 1, + "encoder_hid_dim": null, + "encoder_hid_dim_type": null, + "flip_sin_to_cos": true, + "freq_shift": 0, + "global_pool_conditions": false, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_attention_heads": null, + "num_class_embeds": null, + "only_cross_attention": false, + "projection_class_embeddings_input_dim": 2816, + "resnet_time_scale_shift": "default", + "transformer_layers_per_block": [ + 1, + 2, + 10 + ], + "upcast_attention": null, + "use_linear_projection": true +} diff --git a/checkpoints/ControlNetModel/diffusion_pytorch_model.safetensors b/checkpoints/ControlNetModel/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..610fcb2cc99d3414c88fc2de4d75cc594dd48b07 --- /dev/null +++ b/checkpoints/ControlNetModel/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8127be9f174101ebdafee9964d856b49b634435cf6daa396d3f593cf0bbbb05 +size 2502139136 diff --git a/checkpoints/ip-adapter.bin b/checkpoints/ip-adapter.bin new file mode 100644 index 0000000000000000000000000000000000000000..55c98e90c7047768538ad83e8f06f44c017fc329 --- /dev/null +++ b/checkpoints/ip-adapter.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:02b3618e36d803784166660520098089a81388e61a93ef8002aa79a5b1c546e1 +size 1691134141 diff --git a/checkpoints/pytorch_lora_weights.safetensors b/checkpoints/pytorch_lora_weights.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..e2dca1a02eb0d31dc2b9183c0eef147172cc3c77 --- /dev/null +++ b/checkpoints/pytorch_lora_weights.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a764e6859b6e04047cd761c08ff0cee96413a8e004c9f07707530cd776b19141 +size 393855224 diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1ffb702f1da37310b76dc50de4fe593a81b4e69 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,40 @@ +# Configuration for Cog ⚙️ +# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md + +build: + # set to true if your model requires a GPU + gpu: true + cuda: "12.1" + + # a list of ubuntu apt packages to install + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + + # python version in the form '3.11' or '3.11.4' + python_version: "3.11" + + # a list of packages in the format == + python_packages: + - "opencv-python==4.9.0.80" + - "transformers==4.37.0" + - "accelerate==0.26.1" + - "insightface==0.7.3" + - "diffusers==0.25.1" + - "onnxruntime==1.16.3" + - "omegaconf==2.3.0" + - "gradio==3.50.2" + - "peft==0.8.2" + - "transformers==4.37.0" + - "controlnet-aux==0.0.7" + + # fix for pydantic issues in cog + # https://github.com/replicate/cog/issues/1623 + - albumentations==1.4.3 + + # commands run after the environment is setup + run: + - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget + +# predict.py defines how predictions are run on your model +predict: "cog/predict.py:Predictor" diff --git a/cog/README.md b/cog/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d67a864b915c24b1c5a643d30bf5cdc7aa7cc0e8 --- /dev/null +++ b/cog/README.md @@ -0,0 +1,60 @@ +# InstantID Cog Model + +[![Replicate](https://replicate.com/zsxkib/instant-id/badge)](https://replicate.com/zsxkib/instant-id) + +## Overview +This repository contains the implementation of [InstantID](https://github.com/InstantID/InstantID) as a [Cog](https://github.com/replicate/cog) model. + +Using [Cog](https://github.com/replicate/cog) allows any users with a GPU to run the model locally easily, without the hassle of downloading weights, installing libraries, or managing CUDA versions. Everything just works. + +## Development +To push your own fork of InstantID to [Replicate](https://replicate.com), follow the [Model Pushing Guide](https://replicate.com/docs/guides/push-a-model). + +## Basic Usage +To make predictions using the model, execute the following command from the root of this project: + +```bash +cog predict \ +-i image=@examples/sam_resize.png \ +-i prompt="analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality" \ +-i negative_prompt="nsfw" \ +-i width=680 \ +-i height=680 \ +-i ip_adapter_scale=0.8 \ +-i controlnet_conditioning_scale=0.8 \ +-i num_inference_steps=30 \ +-i guidance_scale=5 +``` + + + + + + +
+

Input

+ Sample Input Image +
+

Output

+ Sample Output Image +
+ +## Input Parameters + +The following table provides details about each input parameter for the `predict` function: + +| Parameter | Description | Default Value | Range | +| ------------------------------- | ---------------------------------- | -------------------------------------------------------------------------------------------------------------- | ----------- | +| `image` | Input image | A path to the input image file | Path string | +| `prompt` | Input prompt | "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, ... " | String | +| `negative_prompt` | Input Negative Prompt | (empty string) | String | +| `width` | Width of output image | 640 | 512 - 2048 | +| `height` | Height of output image | 640 | 512 - 2048 | +| `ip_adapter_scale` | Scale for IP adapter | 0.8 | 0.0 - 1.0 | +| `controlnet_conditioning_scale` | Scale for ControlNet conditioning | 0.8 | 0.0 - 1.0 | +| `num_inference_steps` | Number of denoising steps | 30 | 1 - 500 | +| `guidance_scale` | Scale for classifier-free guidance | 5 | 1 - 50 | + +This table provides a quick reference to understand and modify the inputs for generating predictions using the model. + + diff --git a/cog/predict.py b/cog/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..9432b4a33a376a38d0b39f74e199443c0efe87b4 --- /dev/null +++ b/cog/predict.py @@ -0,0 +1,756 @@ +# Prediction interface for Cog ⚙️ +# https://github.com/replicate/cog/blob/main/docs/python.md + +import os +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(os.path.join(os.path.dirname(__file__), "../gradio_demo")) + +import cv2 +import time +import torch +import mimetypes +import subprocess +import numpy as np +from typing import List +from cog import BasePredictor, Input, Path + +import PIL +from PIL import Image + +import diffusers +from diffusers import LCMScheduler +from diffusers.utils import load_image +from diffusers.models import ControlNetModel +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel + +from model_util import get_torch_device +from insightface.app import FaceAnalysis +from transformers import CLIPImageProcessor +from controlnet_util import openpose, get_depth_map, get_canny_image + +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) +from pipeline_stable_diffusion_xl_instantid_full import ( + StableDiffusionXLInstantIDPipeline, + draw_kps, +) + +mimetypes.add_type("image/webp", ".webp") + +# GPU global variables +DEVICE = get_torch_device() +DTYPE = torch.float16 if str(DEVICE).__contains__("cuda") else torch.float32 + +# for `ip-adapter`, `ControlNetModel`, and `stable-diffusion-xl-base-1.0` +CHECKPOINTS_CACHE = "./checkpoints" +CHECKPOINTS_URL = "https://weights.replicate.delivery/default/InstantID/checkpoints.tar" + +# for `models/antelopev2` +MODELS_CACHE = "./models" +MODELS_URL = "https://weights.replicate.delivery/default/InstantID/models.tar" + +# for the safety checker +SAFETY_CACHE = "./safety-cache" +FEATURE_EXTRACTOR = "./feature-extractor" +SAFETY_URL = "https://weights.replicate.delivery/default/playgroundai/safety-cache.tar" + +SDXL_NAME_TO_PATHLIKE = { + # These are all huggingface models that we host via gcp + pget + "stable-diffusion-xl-base-1.0": { + "slug": "stabilityai/stable-diffusion-xl-base-1.0", + "url": "https://weights.replicate.delivery/default/InstantID/models--stabilityai--stable-diffusion-xl-base-1.0.tar", + "path": "checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0", + }, + "afrodite-xl-v2": { + "slug": "stablediffusionapi/afrodite-xl-v2", + "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--afrodite-xl-v2.tar", + "path": "checkpoints/models--stablediffusionapi--afrodite-xl-v2", + }, + "albedobase-xl-20": { + "slug": "stablediffusionapi/albedobase-xl-20", + "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--albedobase-xl-20.tar", + "path": "checkpoints/models--stablediffusionapi--albedobase-xl-20", + }, + "albedobase-xl-v13": { + "slug": "stablediffusionapi/albedobase-xl-v13", + "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--albedobase-xl-v13.tar", + "path": "checkpoints/models--stablediffusionapi--albedobase-xl-v13", + }, + "animagine-xl-30": { + "slug": "stablediffusionapi/animagine-xl-30", + "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--animagine-xl-30.tar", + "path": "checkpoints/models--stablediffusionapi--animagine-xl-30", + }, + "anime-art-diffusion-xl": { + "slug": "stablediffusionapi/anime-art-diffusion-xl", + "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--anime-art-diffusion-xl.tar", + "path": "checkpoints/models--stablediffusionapi--anime-art-diffusion-xl", + }, + "anime-illust-diffusion-xl": { + "slug": "stablediffusionapi/anime-illust-diffusion-xl", + "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--anime-illust-diffusion-xl.tar", + "path": "checkpoints/models--stablediffusionapi--anime-illust-diffusion-xl", + }, + "dreamshaper-xl": { + "slug": "stablediffusionapi/dreamshaper-xl", + "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--dreamshaper-xl.tar", + "path": "checkpoints/models--stablediffusionapi--dreamshaper-xl", + }, + "dynavision-xl-v0610": { + "slug": "stablediffusionapi/dynavision-xl-v0610", + "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--dynavision-xl-v0610.tar", + "path": "checkpoints/models--stablediffusionapi--dynavision-xl-v0610", + }, + "guofeng4-xl": { + "slug": "stablediffusionapi/guofeng4-xl", + "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--guofeng4-xl.tar", + "path": "checkpoints/models--stablediffusionapi--guofeng4-xl", + }, + "juggernaut-xl-v8": { + "slug": "stablediffusionapi/juggernaut-xl-v8", + "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--juggernaut-xl-v8.tar", + "path": "checkpoints/models--stablediffusionapi--juggernaut-xl-v8", + }, + "nightvision-xl-0791": { + "slug": "stablediffusionapi/nightvision-xl-0791", + "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--nightvision-xl-0791.tar", + "path": "checkpoints/models--stablediffusionapi--nightvision-xl-0791", + }, + "omnigen-xl": { + "slug": "stablediffusionapi/omnigen-xl", + "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--omnigen-xl.tar", + "path": "checkpoints/models--stablediffusionapi--omnigen-xl", + }, + "pony-diffusion-v6-xl": { + "slug": "stablediffusionapi/pony-diffusion-v6-xl", + "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--pony-diffusion-v6-xl.tar", + "path": "checkpoints/models--stablediffusionapi--pony-diffusion-v6-xl", + }, + "protovision-xl-high-fidel": { + "slug": "stablediffusionapi/protovision-xl-high-fidel", + "url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--protovision-xl-high-fidel.tar", + "path": "checkpoints/models--stablediffusionapi--protovision-xl-high-fidel", + }, + "RealVisXL_V3.0_Turbo": { + "slug": "SG161222/RealVisXL_V3.0_Turbo", + "url": "https://weights.replicate.delivery/default/InstantID/models--SG161222--RealVisXL_V3.0_Turbo.tar", + "path": "checkpoints/models--SG161222--RealVisXL_V3.0_Turbo", + }, + "RealVisXL_V4.0_Lightning": { + "slug": "SG161222/RealVisXL_V4.0_Lightning", + "url": "https://weights.replicate.delivery/default/InstantID/models--SG161222--RealVisXL_V4.0_Lightning.tar", + "path": "checkpoints/models--SG161222--RealVisXL_V4.0_Lightning", + }, +} + + +def convert_from_cv2_to_image(img: np.ndarray) -> Image: + return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + + +def convert_from_image_to_cv2(img: Image) -> np.ndarray: + return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + + +def resize_img( + input_image, + max_side=1280, + min_side=1024, + size=None, + pad_to_max_side=False, + mode=PIL.Image.BILINEAR, + base_pixel_number=64, +): + w, h = input_image.size + if size is not None: + w_resize_new, h_resize_new = size + else: + ratio = min_side / min(h, w) + w, h = round(ratio * w), round(ratio * h) + ratio = max_side / max(h, w) + input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode) + w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number + h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number + input_image = input_image.resize([w_resize_new, h_resize_new], mode) + + if pad_to_max_side: + res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 + offset_x = (max_side - w_resize_new) // 2 + offset_y = (max_side - h_resize_new) // 2 + res[offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new] = ( + np.array(input_image) + ) + input_image = Image.fromarray(res) + return input_image + + +def download_weights(url, dest): + start = time.time() + print("[!] Initiating download from URL: ", url) + print("[~] Destination path: ", dest) + command = ["pget", "-vf", url, dest] + if ".tar" in url: + command.append("-x") + try: + subprocess.check_call(command, close_fds=False) + except subprocess.CalledProcessError as e: + print( + f"[ERROR] Failed to download weights. Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}." + ) + raise + print("[+] Download completed in: ", time.time() - start, "seconds") + + +class Predictor(BasePredictor): + def setup(self) -> None: + """Load the model into memory to make running multiple predictions efficient""" + + if not os.path.exists(CHECKPOINTS_CACHE): + download_weights(CHECKPOINTS_URL, CHECKPOINTS_CACHE) + + if not os.path.exists(MODELS_CACHE): + download_weights(MODELS_URL, MODELS_CACHE) + + self.face_detection_input_width, self.face_detection_input_height = 640, 640 + self.app = FaceAnalysis( + name="antelopev2", + root="./", + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + self.app.prepare(ctx_id=0, det_size=(self.face_detection_input_width, self.face_detection_input_height)) + + # Path to InstantID models + self.face_adapter = f"./checkpoints/ip-adapter.bin" + controlnet_path = f"./checkpoints/ControlNetModel" + + # Load pipeline face ControlNetModel + self.controlnet_identitynet = ControlNetModel.from_pretrained( + controlnet_path, + torch_dtype=DTYPE, + cache_dir=CHECKPOINTS_CACHE, + local_files_only=True, + ) + self.setup_extra_controlnets() + + self.load_weights("stable-diffusion-xl-base-1.0") + self.setup_safety_checker() + + def setup_safety_checker(self): + print(f"[~] Seting up safety checker") + + if not os.path.exists(SAFETY_CACHE): + download_weights(SAFETY_URL, SAFETY_CACHE) + + self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( + SAFETY_CACHE, + torch_dtype=DTYPE, + local_files_only=True, + ) + self.safety_checker.to(DEVICE) + self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR) + + def run_safety_checker(self, image): + safety_checker_input = self.feature_extractor(image, return_tensors="pt").to( + DEVICE + ) + np_image = np.array(image) + image, has_nsfw_concept = self.safety_checker( + images=[np_image], + clip_input=safety_checker_input.pixel_values.to(DTYPE), + ) + return image, has_nsfw_concept + + def load_weights(self, sdxl_weights): + self.base_weights = sdxl_weights + weights_info = SDXL_NAME_TO_PATHLIKE[self.base_weights] + + download_url = weights_info["url"] + path_to_weights_dir = weights_info["path"] + if not os.path.exists(path_to_weights_dir): + download_weights(download_url, path_to_weights_dir) + + is_hugging_face_model = "slug" in weights_info.keys() + path_to_weights_file = os.path.join( + path_to_weights_dir, + weights_info.get("file", ""), + ) + + print(f"[~] Loading new SDXL weights: {path_to_weights_file}") + if is_hugging_face_model: + self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( + weights_info["slug"], + controlnet=[self.controlnet_identitynet], + torch_dtype=DTYPE, + cache_dir=CHECKPOINTS_CACHE, + local_files_only=True, + safety_checker=None, + feature_extractor=None, + ) + self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config( + self.pipe.scheduler.config + ) + else: # e.g. .safetensors, NOTE: This functionality is not being used right now + self.pipe.from_single_file( + path_to_weights_file, + controlnet=self.controlnet_identitynet, + torch_dtype=DTYPE, + cache_dir=CHECKPOINTS_CACHE, + ) + + self.pipe.load_ip_adapter_instantid(self.face_adapter) + self.setup_lcm_lora() + self.pipe.cuda() + + def setup_lcm_lora(self): + print(f"[~] Seting up LCM (just in case)") + + lcm_lora_key = "models--latent-consistency--lcm-lora-sdxl" + lcm_lora_path = f"checkpoints/{lcm_lora_key}" + if not os.path.exists(lcm_lora_path): + download_weights( + f"https://weights.replicate.delivery/default/InstantID/{lcm_lora_key}.tar", + lcm_lora_path, + ) + self.pipe.load_lora_weights( + "latent-consistency/lcm-lora-sdxl", + cache_dir=CHECKPOINTS_CACHE, + local_files_only=True, + weight_name="pytorch_lora_weights.safetensors", + ) + self.pipe.disable_lora() + + def setup_extra_controlnets(self): + print(f"[~] Seting up pose, canny, depth ControlNets") + + controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0" + controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0" + controlnet_depth_model = "diffusers/controlnet-depth-sdxl-1.0-small" + + for controlnet_key in [ + "models--diffusers--controlnet-canny-sdxl-1.0", + "models--diffusers--controlnet-depth-sdxl-1.0-small", + "models--thibaud--controlnet-openpose-sdxl-1.0", + ]: + controlnet_path = f"checkpoints/{controlnet_key}" + if not os.path.exists(controlnet_path): + download_weights( + f"https://weights.replicate.delivery/default/InstantID/{controlnet_key}.tar", + controlnet_path, + ) + + controlnet_pose = ControlNetModel.from_pretrained( + controlnet_pose_model, + torch_dtype=DTYPE, + cache_dir=CHECKPOINTS_CACHE, + local_files_only=True, + ).to(DEVICE) + controlnet_canny = ControlNetModel.from_pretrained( + controlnet_canny_model, + torch_dtype=DTYPE, + cache_dir=CHECKPOINTS_CACHE, + local_files_only=True, + ).to(DEVICE) + controlnet_depth = ControlNetModel.from_pretrained( + controlnet_depth_model, + torch_dtype=DTYPE, + cache_dir=CHECKPOINTS_CACHE, + local_files_only=True, + ).to(DEVICE) + + self.controlnet_map = { + "pose": controlnet_pose, + "canny": controlnet_canny, + "depth": controlnet_depth, + } + self.controlnet_map_fn = { + "pose": openpose, + "canny": get_canny_image, + "depth": get_depth_map, + } + + def generate_image( + self, + face_image_path, + pose_image_path, + prompt, + negative_prompt, + num_steps, + identitynet_strength_ratio, + adapter_strength_ratio, + pose_strength, + canny_strength, + depth_strength, + controlnet_selection, + guidance_scale, + seed, + scheduler, + enable_LCM, + enhance_face_region, + num_images_per_prompt, + ): + if enable_LCM: + self.pipe.enable_lora() + self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) + else: + self.pipe.disable_lora() + scheduler_class_name = scheduler.split("-")[0] + + add_kwargs = {} + if len(scheduler.split("-")) > 1: + add_kwargs["use_karras_sigmas"] = True + if len(scheduler.split("-")) > 2: + add_kwargs["algorithm_type"] = "sde-dpmsolver++" + scheduler = getattr(diffusers, scheduler_class_name) + self.pipe.scheduler = scheduler.from_config( + self.pipe.scheduler.config, + **add_kwargs, + ) + + if face_image_path is None: + raise Exception( + f"Cannot find any input face `image`! Please upload the face `image`" + ) + + face_image = load_image(face_image_path) + face_image = resize_img(face_image) + face_image_cv2 = convert_from_image_to_cv2(face_image) + height, width, _ = face_image_cv2.shape + + # Extract face features + face_info = self.app.get(face_image_cv2) + + if len(face_info) == 0: + raise Exception( + "Face detector could not find a face in the `image`. Please use a different `image` as input." + ) + + face_info = sorted( + face_info, + key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1], + )[ + -1 + ] # only use the maximum face + face_emb = face_info["embedding"] + face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"]) + + img_controlnet = face_image + if pose_image_path is not None: + pose_image = load_image(pose_image_path) + pose_image = resize_img(pose_image, max_side=1024) + img_controlnet = pose_image + pose_image_cv2 = convert_from_image_to_cv2(pose_image) + + face_info = self.app.get(pose_image_cv2) + + if len(face_info) == 0: + raise Exception( + "Face detector could not find a face in the `pose_image`. Please use a different `pose_image` as input." + ) + + face_info = face_info[-1] + face_kps = draw_kps(pose_image, face_info["kps"]) + + width, height = face_kps.size + + if enhance_face_region: + control_mask = np.zeros([height, width, 3]) + x1, y1, x2, y2 = face_info["bbox"] + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + control_mask[y1:y2, x1:x2] = 255 + control_mask = Image.fromarray(control_mask.astype(np.uint8)) + else: + control_mask = None + + if len(controlnet_selection) > 0: + controlnet_scales = { + "pose": pose_strength, + "canny": canny_strength, + "depth": depth_strength, + } + self.pipe.controlnet = MultiControlNetModel( + [self.controlnet_identitynet] + + [self.controlnet_map[s] for s in controlnet_selection] + ) + control_scales = [float(identitynet_strength_ratio)] + [ + controlnet_scales[s] for s in controlnet_selection + ] + control_images = [face_kps] + [ + self.controlnet_map_fn[s](img_controlnet).resize((width, height)) + for s in controlnet_selection + ] + else: + self.pipe.controlnet = self.controlnet_identitynet + control_scales = float(identitynet_strength_ratio) + control_images = face_kps + + generator = torch.Generator(device=DEVICE).manual_seed(seed) + + print("Start inference...") + print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}") + + self.pipe.set_ip_adapter_scale(adapter_strength_ratio) + images = self.pipe( + prompt=prompt, + negative_prompt=negative_prompt, + image_embeds=face_emb, + image=control_images, + control_mask=control_mask, + controlnet_conditioning_scale=control_scales, + num_inference_steps=num_steps, + guidance_scale=guidance_scale, + height=height, + width=width, + generator=generator, + num_images_per_prompt=num_images_per_prompt, + ).images + + return images + + def predict( + self, + image: Path = Input( + description="Input face image", + ), + pose_image: Path = Input( + description="(Optional) reference pose image", + default=None, + ), + prompt: str = Input( + description="Input prompt", + default="a person", + ), + negative_prompt: str = Input( + description="Input Negative Prompt", + default="", + ), + sdxl_weights: str = Input( + description="Pick which base weights you want to use", + default="stable-diffusion-xl-base-1.0", + choices=[ + "stable-diffusion-xl-base-1.0", + "juggernaut-xl-v8", + "afrodite-xl-v2", + "albedobase-xl-20", + "albedobase-xl-v13", + "animagine-xl-30", + "anime-art-diffusion-xl", + "anime-illust-diffusion-xl", + "dreamshaper-xl", + "dynavision-xl-v0610", + "guofeng4-xl", + "nightvision-xl-0791", + "omnigen-xl", + "pony-diffusion-v6-xl", + "protovision-xl-high-fidel", + "RealVisXL_V3.0_Turbo", + "RealVisXL_V4.0_Lightning", + ], + ), + face_detection_input_width: int = Input( + description="Width of the input image for face detection", + default=640, + ge=640, + le=4096, + ), + face_detection_input_height: int = Input( + description="Height of the input image for face detection", + default=640, + ge=640, + le=4096, + ), + scheduler: str = Input( + description="Scheduler", + choices=[ + "DEISMultistepScheduler", + "HeunDiscreteScheduler", + "EulerDiscreteScheduler", + "DPMSolverMultistepScheduler", + "DPMSolverMultistepScheduler-Karras", + "DPMSolverMultistepScheduler-Karras-SDE", + ], + default="EulerDiscreteScheduler", + ), + num_inference_steps: int = Input( + description="Number of denoising steps", + default=30, + ge=1, + le=500, + ), + guidance_scale: float = Input( + description="Scale for classifier-free guidance", + default=7.5, + ge=1, + le=50, + ), + ip_adapter_scale: float = Input( + description="Scale for image adapter strength (for detail)", # adapter_strength_ratio + default=0.8, + ge=0, + le=1.5, + ), + controlnet_conditioning_scale: float = Input( + description="Scale for IdentityNet strength (for fidelity)", # identitynet_strength_ratio + default=0.8, + ge=0, + le=1.5, + ), + enable_pose_controlnet: bool = Input( + description="Enable Openpose ControlNet, overrides strength if set to false", + default=True, + ), + pose_strength: float = Input( + description="Openpose ControlNet strength, effective only if `enable_pose_controlnet` is true", + default=0.4, + ge=0, + le=1, + ), + enable_canny_controlnet: bool = Input( + description="Enable Canny ControlNet, overrides strength if set to false", + default=False, + ), + canny_strength: float = Input( + description="Canny ControlNet strength, effective only if `enable_canny_controlnet` is true", + default=0.3, + ge=0, + le=1, + ), + enable_depth_controlnet: bool = Input( + description="Enable Depth ControlNet, overrides strength if set to false", + default=False, + ), + depth_strength: float = Input( + description="Depth ControlNet strength, effective only if `enable_depth_controlnet` is true", + default=0.5, + ge=0, + le=1, + ), + enable_lcm: bool = Input( + description="Enable Fast Inference with LCM (Latent Consistency Models) - speeds up inference steps, trade-off is the quality of the generated image. Performs better with close-up portrait face images", + default=False, + ), + lcm_num_inference_steps: int = Input( + description="Only used when `enable_lcm` is set to True, Number of denoising steps when using LCM", + default=5, + ge=1, + le=10, + ), + lcm_guidance_scale: float = Input( + description="Only used when `enable_lcm` is set to True, Scale for classifier-free guidance when using LCM", + default=1.5, + ge=1, + le=20, + ), + enhance_nonface_region: bool = Input( + description="Enhance non-face region", default=True + ), + output_format: str = Input( + description="Format of the output images", + choices=["webp", "jpg", "png"], + default="webp", + ), + output_quality: int = Input( + description="Quality of the output images, from 0 to 100. 100 is best quality, 0 is lowest quality.", + default=80, + ge=0, + le=100, + ), + seed: int = Input( + description="Random seed. Leave blank to randomize the seed", + default=None, + ), + num_outputs: int = Input( + description="Number of images to output", + default=1, + ge=1, + le=8, + ), + disable_safety_checker: bool = Input( + description="Disable safety checker for generated images", + default=False, + ), + ) -> List[Path]: + """Run a single prediction on the model""" + + # If no seed is provided, generate a random seed + if seed is None: + seed = int.from_bytes(os.urandom(2), "big") + print(f"Using seed: {seed}") + + # Load the weights if they are different from the base weights + if sdxl_weights != self.base_weights: + self.load_weights(sdxl_weights) + + # Resize the output if the provided dimensions are different from the current ones + if self.face_detection_input_width != face_detection_input_width or self.face_detection_input_height != face_detection_input_height: + print(f"[!] Resizing output to {face_detection_input_width}x{face_detection_input_height}") + self.face_detection_input_width = face_detection_input_width + self.face_detection_input_height = face_detection_input_height + self.app.prepare(ctx_id=0, det_size=(self.face_detection_input_width, self.face_detection_input_height)) + + # Set up ControlNet selection and their respective strength values (if any) + controlnet_selection = [] + if pose_strength > 0 and enable_pose_controlnet: + controlnet_selection.append("pose") + if canny_strength > 0 and enable_canny_controlnet: + controlnet_selection.append("canny") + if depth_strength > 0 and enable_depth_controlnet: + controlnet_selection.append("depth") + + # Switch to LCM inference steps and guidance scale if LCM is enabled + if enable_lcm: + num_inference_steps = lcm_num_inference_steps + guidance_scale = lcm_guidance_scale + + # Generate + images = self.generate_image( + face_image_path=str(image), + pose_image_path=str(pose_image) if pose_image else None, + prompt=prompt, + negative_prompt=negative_prompt, + num_steps=num_inference_steps, + identitynet_strength_ratio=controlnet_conditioning_scale, + adapter_strength_ratio=ip_adapter_scale, + pose_strength=pose_strength, + canny_strength=canny_strength, + depth_strength=depth_strength, + controlnet_selection=controlnet_selection, + scheduler=scheduler, + guidance_scale=guidance_scale, + seed=seed, + enable_LCM=enable_lcm, + enhance_face_region=enhance_nonface_region, + num_images_per_prompt=num_outputs, + ) + + # Save the generated images and check for NSFW content + output_paths = [] + for i, output_image in enumerate(images): + if not disable_safety_checker: + _, has_nsfw_content_list = self.run_safety_checker(output_image) + has_nsfw_content = any(has_nsfw_content_list) + print(f"NSFW content detected: {has_nsfw_content}") + if has_nsfw_content: + raise Exception( + "NSFW content detected. Try running it again, or try a different prompt." + ) + + extension = output_format.lower() + extension = "jpeg" if extension == "jpg" else extension + output_path = f"/tmp/out_{i}.{extension}" + + print(f"[~] Saving to {output_path}...") + print(f"[~] Output format: {extension.upper()}") + if output_format != "png": + print(f"[~] Output quality: {output_quality}") + + save_params = {"format": extension.upper()} + if output_format != "png": + save_params["quality"] = output_quality + save_params["optimize"] = True + + output_image.save(output_path, **save_params) + output_paths.append(Path(output_path)) + return output_paths diff --git a/feature-extractor/preprocessor_config.json b/feature-extractor/preprocessor_config.json new file mode 100644 index 0000000000000000000000000000000000000000..fb1692ff11ed6b3dee0634f577193bb23c0d37ec --- /dev/null +++ b/feature-extractor/preprocessor_config.json @@ -0,0 +1,27 @@ +{ + "crop_size": { + "height": 224, + "width": 224 + }, + "do_center_crop": true, + "do_convert_rgb": true, + "do_normalize": true, + "do_rescale": true, + "do_resize": true, + "image_mean": [ + 0.48145466, + 0.4578275, + 0.40821073 + ], + "image_processor_type": "CLIPImageProcessor", + "image_std": [ + 0.26862954, + 0.26130258, + 0.27577711 + ], + "resample": 3, + "rescale_factor": 0.00392156862745098, + "size": { + "shortest_edge": 224 + } +} diff --git a/generated_images/20240723_053704_668578_0.png b/generated_images/20240723_053704_668578_0.png new file mode 100644 index 0000000000000000000000000000000000000000..995cffeacf04de15419cb63211c5582c71f3cc06 --- /dev/null +++ b/generated_images/20240723_053704_668578_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:575f77734f6e2ebe3f9d4f322d9b3855e126692c94c1e7bf4e0390f5a07c510b +size 1746153 diff --git a/generated_images/20240723_053801_148984_0.png b/generated_images/20240723_053801_148984_0.png new file mode 100644 index 0000000000000000000000000000000000000000..ceda68208502d94568fa61a4b0a0c1d22f481f26 --- /dev/null +++ b/generated_images/20240723_053801_148984_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c005cc75b52ec5cfa74932ae0c510421e3913baae36edb2962eb04354834c118 +size 1663950 diff --git a/generated_images/20240723_053853_022841_0.png b/generated_images/20240723_053853_022841_0.png new file mode 100644 index 0000000000000000000000000000000000000000..0499cc9fc8f8caf5f2655ab97301045b7249fb3b --- /dev/null +++ b/generated_images/20240723_053853_022841_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:427f133f008a53a7d4e8edf2c1b0f59cb7096c429471f7dd940aad9a2bfa3c84 +size 2468302 diff --git a/generated_images/20240723_053948_468290_0.png b/generated_images/20240723_053948_468290_0.png new file mode 100644 index 0000000000000000000000000000000000000000..f28db120b7a7ee0556e6c4905462afe3d0c05c01 --- /dev/null +++ b/generated_images/20240723_053948_468290_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dbb8451142141a23c751f9c4a3efaa57c8f5bbda1048395db512cd68530efb24 +size 1517716 diff --git a/generated_images/20240723_054025_692605_0.png b/generated_images/20240723_054025_692605_0.png new file mode 100644 index 0000000000000000000000000000000000000000..22d24c42dade3954b9e02ff7c69d18a524700a5e --- /dev/null +++ b/generated_images/20240723_054025_692605_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51e93ef112561ae84c08b3eab6dc8c2ce2d1d20be2762d3d266d5ee6a4a231cc +size 1016506 diff --git a/generated_images/20240723_054124_697176_0.png b/generated_images/20240723_054124_697176_0.png new file mode 100644 index 0000000000000000000000000000000000000000..b03ab80bad8eeb80e0b5272960d3a05a8fb7f908 --- /dev/null +++ b/generated_images/20240723_054124_697176_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:774cd2a486c071fa2083ca4589d353c6374c659b1e3345d6fc7458bb052257da +size 1662523 diff --git a/generation_log.csv b/generation_log.csv new file mode 100644 index 0000000000000000000000000000000000000000..005e3ed7faf8082a6af507e8c917e50fde0dba69 --- /dev/null +++ b/generation_log.csv @@ -0,0 +1,7 @@ +image_name,new_file_name,identitynet_strength_ratio,adapter_strength_ratio,num_inference_steps,guidance_scale,seed,success,error_message,style_name,prompt,negative_prompt,time_taken,current_timestamp +musk_resize.jpeg,20240723_053704_668578_0.png,1.1491785966677859,0.8654292835406997,50,10.881974934041711,4170043132,True,,(No style),"human, sharp focus","(blurry, blur, text, abstract, glitch, lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",53.160874,2024-07-23 05:37:05 +sam_resize.png,20240723_053801_148984_0.png,1.0277924316289087,0.9683019180411349,53,11.111615489229361,1039000092,True,,(No style),"human, sharp focus","(blurry, blur, text, abstract, glitch, lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",56.287061,2024-07-23 05:38:01 +schmidhuber_resize.png,20240723_053853_022841_0.png,1.4917970061395218,0.7393876001187043,48,11.679426057392323,3752244045,True,,(No style),"human, sharp focus","(blurry, blur, text, abstract, glitch, lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",51.661633,2024-07-23 05:38:53 +kaifu_resize.png,20240723_053948_468290_0.png,1.4485948536834086,0.8122224472625851,52,9.984434112216853,2295950491,True,,(No style),"human, sharp focus","(blurry, blur, text, abstract, glitch, lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",55.366897,2024-07-23 05:39:49 +pp_0.jpg,20240723_054025_692605_0.png,1.1794069160183727,0.9857350785784462,51,8.76420747179281,2648835109,True,,(No style),"human, sharp focus","(blurry, blur, text, abstract, glitch, lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",36.771416,2024-07-23 05:40:26 +yann-lecun_resize.jpg,20240723_054124_697176_0.png,1.2770220875965888,0.8245108249424827,56,9.372671733967127,3933691473,True,,(No style),"human, sharp focus","(blurry, blur, text, abstract, glitch, lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",59.069521,2024-07-23 05:41:25 diff --git a/gradio_demo/aaa.py b/gradio_demo/aaa.py new file mode 100644 index 0000000000000000000000000000000000000000..9321b5723f8294b07b289248a575a41546e50785 --- /dev/null +++ b/gradio_demo/aaa.py @@ -0,0 +1,957 @@ +import sys +sys.path.append('./') + +from typing import Tuple + +import os +import cv2 +import math +import torch +import random +import numpy as np +import argparse + +import PIL +from PIL import Image + +import diffusers +from diffusers.utils import load_image +from diffusers.models import ControlNetModel +from diffusers import LCMScheduler + +from huggingface_hub import hf_hub_download + +import insightface +from insightface.app import FaceAnalysis + +from style_template import styles +from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline +from model_util import load_models_xl, get_torch_device, torch_gc + + +# global variable +MAX_SEED = np.iinfo(np.int32).max +device = get_torch_device() +dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32 +STYLE_NAMES = list(styles.keys()) +DEFAULT_STYLE_NAME = "Watercolor" + +# Load face encoder +app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) +app.prepare(ctx_id=0, det_size=(320, 320)) + +# Path to InstantID models +face_adapter = f'./checkpoints/ip-adapter.bin' +controlnet_path = f'./checkpoints/ControlNetModel' + +# Load pipeline +controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=dtype) + +logo = Image.open("./gradio_demo/logo.png") + +from cv2 import imencode +import base64 + +# def encode_pil_to_base64_new(pil_image): +# print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") +# image_arr = np.asarray(pil_image)[:,:,::-1] +# _, byte_data = imencode('.png', image_arr) +# base64_data = base64.b64encode(byte_data) +# base64_string_opencv = base64_data.decode("utf-8") +# return "data:image/png;base64," + base64_string_opencv + +import gradio as gr + +# gr.processing_utils.encode_pil_to_base64 = encode_pil_to_base64_new + +def main(pretrained_model_name_or_path="wangqixun/YamerMIX_v8", enable_lcm_arg=False): + + if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + scheduler_kwargs = hf_hub_download( + repo_id="wangqixun/YamerMIX_v8", + subfolder="scheduler", + filename="scheduler_config.json", + ) + + (tokenizers, text_encoders, unet, _, vae) = load_models_xl( + pretrained_model_name_or_path=pretrained_model_name_or_path, + scheduler_name=None, + weight_dtype=dtype, + ) + + scheduler = diffusers.EulerDiscreteScheduler.from_config(scheduler_kwargs) + pipe = StableDiffusionXLInstantIDPipeline( + vae=vae, + text_encoder=text_encoders[0], + text_encoder_2=text_encoders[1], + tokenizer=tokenizers[0], + tokenizer_2=tokenizers[1], + unet=unet, + scheduler=scheduler, + controlnet=controlnet, + ).to(device) + + else: + pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( + pretrained_model_name_or_path, + controlnet=controlnet, + torch_dtype=dtype, + safety_checker=None, + feature_extractor=None, + ).to(device) + + pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config) + + pipe.load_ip_adapter_instantid(face_adapter) + # load and disable LCM + pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl") + pipe.disable_lora() + + def remove_tips(): + print("GG") + return gr.update(visible=False) + + + # prompts = [ + # ["superman","Vibrant Color"], ["japanese anime character with white/neon hair","Watercolor"], + # # ["Suited professional","(No style)"], + # ["Scooba diver","Line art"], ["eskimo","Snow"] + # ] + + def convert_from_cv2_to_image(img: np.ndarray) -> Image: + return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + + def convert_from_image_to_cv2(img: Image) -> np.ndarray: + return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + + def run_for_prompts1(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[0], n) + # else: + # raise gr.Error("Email ID is compulsory") + def run_for_prompts2(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[1], n) + def run_for_prompts3(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[2], n) + def run_for_prompts4(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[3], n) + +# def validate_and_process(face_file, style, email): + +# # Your processing logic here +# gallery1, gallery2, gallery3, gallery4 = run_for_prompts1(face_file, style), run_for_prompts2(face_file, style), run_for_prompts3(face_file, style), run_for_prompts4(face_file, style) +# return gallery1, gallery2, gallery3, gallery4 + + def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]): + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + + def resize_img(input_image, max_side=640, min_side=640, size=None, + pad_to_max_side=True, mode=PIL.Image.BILINEAR, base_pixel_number=64): + + w, h = input_image.size + print(w) + print(h) + if size is not None: + w_resize_new, h_resize_new = size + else: + ratio = min_side / min(h, w) + w, h = round(ratio*w), round(ratio*h) + ratio = max_side / max(h, w) + input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) + w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number + h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number + input_image = input_image.resize([w_resize_new, h_resize_new], mode) + + if pad_to_max_side: + res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 + offset_x = (max_side - w_resize_new) // 2 + offset_y = (max_side - h_resize_new) // 2 + res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) + input_image = Image.fromarray(res) + return input_image + + # def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]: + # p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) + # return p.replace("{prompt}", positive), n + ' ' + negative + + def generate_image(face_image,prompt,negative_prompt): + pose_image_path = None + # prompt = "superman" + enable_LCM = False + identitynet_strength_ratio = 0.95 + adapter_strength_ratio = 0.60 + num_steps = 15 + guidance_scale = 8.5 + seed = random.randint(0, MAX_SEED) + # negative_prompt = "" + # negative_prompt += neg + enhance_face_region = True + if enable_LCM: + pipe.enable_lora() + pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) + else: + pipe.disable_lora() + pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config) + + if face_image is None: + raise gr.Error(f"Cannot find any input face image! Please upload the face image") + + # if prompt is None: + # prompt = "a person" + + # apply the style template + # prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt) + + # face_image = load_image(face_image_path) + face_image = resize_img(face_image) + face_image_cv2 = convert_from_image_to_cv2(face_image) + height, width, _ = face_image_cv2.shape + + # Extract face features + face_info = app.get(face_image_cv2) + + if len(face_info) == 0: + raise gr.Error(f"Cannot find any face in the image! Please upload another person image") + + face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face + face_emb = face_info['embedding'] + face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info['kps']) + + if pose_image_path is not None: + pose_image = load_image(pose_image_path) + pose_image = resize_img(pose_image) + pose_image_cv2 = convert_from_image_to_cv2(pose_image) + + face_info = app.get(pose_image_cv2) + + if len(face_info) == 0: + raise gr.Error(f"Cannot find any face in the reference image! Please upload another person image") + + face_info = face_info[-1] + face_kps = draw_kps(pose_image, face_info['kps']) + + width, height = face_kps.size + + if enhance_face_region: + control_mask = np.zeros([height, width, 3]) + x1, y1, x2, y2 = face_info["bbox"] + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + control_mask[y1:y2, x1:x2] = 255 + control_mask = Image.fromarray(control_mask.astype(np.uint8)) + else: + control_mask = None + + generator = torch.Generator(device=device).manual_seed(seed) + + print("Start inference...") + print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}") + + pipe.set_ip_adapter_scale(adapter_strength_ratio) + images = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + image_embeds=face_emb, + image=face_kps, + control_mask=control_mask, + controlnet_conditioning_scale=float(identitynet_strength_ratio), + num_inference_steps=num_steps, + guidance_scale=guidance_scale, + height=height, + width=width, + generator=generator, + # num_images_per_prompt = 4 + ).images + + return images[0] + + ### Description + title = r""" +

Choose your AVATAR

+ """ + + description = r""" +

Powered by IDfy

""" + + article = r"""""" + + tips = r"""""" + + # js = ''' ''' + + css = ''' + .gradio-container {width: 95% !important; background-color: #E6F3FF;} + .image-gallery {height: 100vh !important; overflow: auto;} + .gradio-row .gradio-element { margin: 0 !important; } + ''' +# with gr.Blocks(css=css, js=js) as demo: + +# # description +# gr.Markdown(title) +# with gr.Row(): +# gr.Image("./gradio_demo/logo.png",scale=0,min_width=50,show_label=False,show_download_button=False) +# gr.Markdown(description) +# with gr.Row(): +# with gr.Column(): +# style = gr.Dropdown(label="Choose your STYLE", choices=STYLE_NAMES) +# face_file = gr.Image(label="Upload a photo of your face", type="pil",sources="webcam") +# submit = gr.Button("Submit", variant="primary") +# with gr.Column(): +# with gr.Row(): +# gallery1 = gr.Image(label="Generated Images") +# gallery2 = gr.Image(label="Generated Images") +# with gr.Row(): +# gallery3 = gr.Image(label="Generated Images") +# gallery4 = gr.Image(label="Generated Images") +# email = gr.Textbox(label="Email", +# info="Enter your email address", +# value="") +# # submit1 = gr.Button("Store") + +# usage_tips = gr.Markdown(label="Usage tips of InstantID", value=tips ,visible=False) + +# face_file.upload( +# fn=remove_tips, +# outputs=usage_tips, +# queue=True, +# api_name=False, +# show_progress = "full" +# ).then( +# fn=run_for_prompts1, +# inputs=[face_file,style], +# outputs=[gallery1] +# ).then( +# fn=run_for_prompts2, +# inputs=[face_file,style], +# outputs=[gallery2] +# ).then( +# fn=run_for_prompts3, +# inputs=[face_file,style], +# outputs=[gallery3] +# ).then( +# fn=run_for_prompts4, +# inputs=[face_file,style], +# outputs=[gallery4] +# ) +# submit.click( +# fn=remove_tips, +# outputs=usage_tips, +# queue=True, +# api_name=False, +# show_progress = "full" +# ).then( +# fn=run_for_prompts1, +# inputs=[face_file,style], +# outputs=[gallery1] +# ).then( +# fn=run_for_prompts2, +# inputs=[face_file,style], +# outputs=[gallery2] +# ).then( +# fn=run_for_prompts3, +# inputs=[face_file,style], +# outputs=[gallery3] +# ).then( +# fn=run_for_prompts4, +# inputs=[face_file,style], +# outputs=[gallery4] +# ) + +# # submit1.click( +# # fn=store_images, +# # inputs=[email,gallery1,gallery2,gallery3,gallery4], +# # outputs=None) + + + +# gr.Markdown(article) + +# demo.launch(share=True) + + with gr.Blocks(css=css) as demo: + + # description + gr.Markdown(title) + with gr.Row(): + gr.Image("./gradio_demo/logo.png",scale=0,min_width=50,show_label=False,show_download_button=False) + gr.Markdown(description) + with gr.Row(): + with gr.Column(): + style = gr.Dropdown(label="Choose your STYLE", choices=STYLE_NAMES) + face_file = gr.Image(label="Upload a photo of your face", type="pil",sources="webcam") + submit = gr.Button("Submit", variant="primary") + with gr.Column(): + with gr.Row(): + gallery1 = gr.Image(label="Generated Images") + gallery2 = gr.Image(label="Generated Images") + with gr.Row(): + gallery3 = gr.Image(label="Generated Images") + gallery4 = gr.Image(label="Generated Images") + email = gr.Textbox(label="Email", + info="Enter your email address", + value="") + + usage_tips = gr.Markdown(label="Usage tips of InstantID", value=tips ,visible=False) + # identitynet_strength_ratio = gr.Slider( + # label="IdentityNet strength (for fidelity)", + # minimum=0, + # maximum=1.5, + # step=0.05, + # value=0.95, + # ) + # adapter_strength_ratio = gr.Slider( + # label="Image adapter strength (for detail)", + # minimum=0, + # maximum=1.5, + # step=0.05, + # value=0.60, + # ) + # negative_prompt = gr.Textbox( + # label="Negative Prompt", + # placeholder="low quality", + # value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", + # ) + # num_steps = gr.Slider( + # label="Number of sample steps", + # minimum=15, + # maximum=100, + # step=1, + # value=5 if enable_lcm_arg else 15, + # ) + # guidance_scale = gr.Slider( + # label="Guidance scale", + # minimum=0.1, + # maximum=10.0, + # step=0.1, + # value=0 if enable_lcm_arg else 8.5, + # ) + # if email is None: + # print("STOPPPP") + # raise gr.Error("Email ID is compulsory") + face_file.upload( + fn=remove_tips, + outputs=usage_tips, + queue=True, + api_name=False, + show_progress = "full" + ).then( + fn=run_for_prompts1, + inputs=[face_file,style], + outputs=[gallery1] + ).then( + fn=run_for_prompts2, + inputs=[face_file,style], + outputs=[gallery2] + ).then( + fn=run_for_prompts3, + inputs=[face_file,style], + outputs=[gallery3] + ).then( + fn=run_for_prompts4, + inputs=[face_file,style], + outputs=[gallery4] + ) + submit.click( + fn=remove_tips, + outputs=usage_tips, + queue=True, + api_name=False, + show_progress = "full" + ).then( + fn=run_for_prompts1, + inputs=[face_file,style], + outputs=[gallery1] + ).then( + fn=run_for_prompts2, + inputs=[face_file,style], + outputs=[gallery2] + ).then( + fn=run_for_prompts3, + inputs=[face_file,style], + outputs=[gallery3] + ).then( + fn=run_for_prompts4, + inputs=[face_file,style], + outputs=[gallery4] + ) + + + gr.Markdown(article) + + demo.launch(share=True) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained_model_name_or_path", type=str, default="wangqixun/YamerMIX_v8") + args = parser.parse_args() + + main(args.pretrained_model_name_or_path, False) + + +# import sys +# sys.path.append('./') + +# from typing import Tuple + +# import os +# import cv2 +# import math +# import torch +# import random +# import numpy as np +# import argparse + +# import PIL +# from PIL import Image + +# import diffusers +# from diffusers.utils import load_image +# from diffusers.models import ControlNetModel +# from diffusers import LCMScheduler + +# from huggingface_hub import hf_hub_download + +# import insightface +# from insightface.app import FaceAnalysis + +# from style_template import styles +# from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline +# from model_util import load_models_xl, get_torch_device, torch_gc + + +# # global variable +# MAX_SEED = np.iinfo(np.int32).max +# device = get_torch_device() +# dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32 +# STYLE_NAMES = list(styles.keys()) +# DEFAULT_STYLE_NAME = "Watercolor" + +# # Load face encoder +# app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) +# app.prepare(ctx_id=0, det_size=(320, 320)) + +# # Path to InstantID models +# face_adapter = f'./checkpoints/ip-adapter.bin' +# controlnet_path = f'./checkpoints/ControlNetModel' + +# # Load pipeline +# controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=dtype) + +# logo = Image.open("./gradio_demo/logo.png") + +# from cv2 import imencode +# import base64 + +# # def encode_pil_to_base64_new(pil_image): +# # print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") +# # image_arr = np.asarray(pil_image)[:,:,::-1] +# # _, byte_data = imencode('.png', image_arr) +# # base64_data = base64.b64encode(byte_data) +# # base64_string_opencv = base64_data.decode("utf-8") +# # return "data:image/png;base64," + base64_string_opencv + +# import gradio as gr + +# # gr.processing_utils.encode_pil_to_base64 = encode_pil_to_base64_new + +# def main(pretrained_model_name_or_path="wangqixun/YamerMIX_v8", enable_lcm_arg=False): + +# if pretrained_model_name_or_path.endswith( +# ".ckpt" +# ) or pretrained_model_name_or_path.endswith(".safetensors"): +# scheduler_kwargs = hf_hub_download( +# repo_id="wangqixun/YamerMIX_v8", +# subfolder="scheduler", +# filename="scheduler_config.json", +# ) + +# (tokenizers, text_encoders, unet, _, vae) = load_models_xl( +# pretrained_model_name_or_path=pretrained_model_name_or_path, +# scheduler_name=None, +# weight_dtype=dtype, +# ) + +# scheduler = diffusers.EulerDiscreteScheduler.from_config(scheduler_kwargs) +# pipe = StableDiffusionXLInstantIDPipeline( +# vae=vae, +# text_encoder=text_encoders[0], +# text_encoder_2=text_encoders[1], +# tokenizer=tokenizers[0], +# tokenizer_2=tokenizers[1], +# unet=unet, +# scheduler=scheduler, +# controlnet=controlnet, +# ).to(device) + +# else: +# pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( +# pretrained_model_name_or_path, +# controlnet=controlnet, +# torch_dtype=dtype, +# safety_checker=None, +# feature_extractor=None, +# ).to(device) + +# pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config) + +# pipe.load_ip_adapter_instantid(face_adapter) +# # load and disable LCM +# pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl") +# pipe.disable_lora() + +# def remove_tips(): +# return gr.update(visible=False) + + +# # prompts = [ +# # ["superman","Vibrant Color"], ["japanese anime character with white/neon hair","Watercolor"], +# # # ["Suited professional","(No style)"], +# # ["Scooba diver","Line art"], ["eskimo","Snow"] +# # ] + +# def convert_from_cv2_to_image(img: np.ndarray) -> Image: +# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + +# def convert_from_image_to_cv2(img: Image) -> np.ndarray: +# return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + +# def run_for_prompts1(face_file,style,progress=gr.Progress(track_tqdm=True)): +# # if email != "": +# p,n = styles.get(style, styles.get(STYLE_NAMES[1])) +# return generate_image(face_file, p[0], n) +# # else: +# # raise gr.Error("Email ID is compulsory") +# def run_for_prompts2(face_file,style,progress=gr.Progress(track_tqdm=True)): +# # if email != "": +# p,n = styles.get(style, styles.get(STYLE_NAMES[1])) +# return generate_image(face_file, p[1], n) +# def run_for_prompts3(face_file,style,progress=gr.Progress(track_tqdm=True)): +# # if email != "": +# p,n = styles.get(style, styles.get(STYLE_NAMES[1])) +# return generate_image(face_file, p[2], n) +# def run_for_prompts4(face_file,style,progress=gr.Progress(track_tqdm=True)): +# # if email != "": +# p,n = styles.get(style, styles.get(STYLE_NAMES[1])) +# return generate_image(face_file, p[3], n) + +# # def validate_and_process(face_file, style, email): + +# # # Your processing logic here +# # gallery1, gallery2, gallery3, gallery4 = run_for_prompts1(face_file, style), run_for_prompts2(face_file, style), run_for_prompts3(face_file, style), run_for_prompts4(face_file, style) +# # return gallery1, gallery2, gallery3, gallery4 + +# def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]): +# stickwidth = 4 +# limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) +# kps = np.array(kps) + +# w, h = image_pil.size +# out_img = np.zeros([h, w, 3]) + +# for i in range(len(limbSeq)): +# index = limbSeq[i] +# color = color_list[index[0]] + +# x = kps[index][:, 0] +# y = kps[index][:, 1] +# length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 +# angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) +# polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1) +# out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) +# out_img = (out_img * 0.6).astype(np.uint8) + +# for idx_kp, kp in enumerate(kps): +# color = color_list[idx_kp] +# x, y = kp +# out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + +# out_img_pil = Image.fromarray(out_img.astype(np.uint8)) +# return out_img_pil + +# def resize_img(input_image, max_side=640, min_side=640, size=None, +# pad_to_max_side=True, mode=PIL.Image.BILINEAR, base_pixel_number=64): + +# w, h = input_image.size +# print(w) +# print(h) +# if size is not None: +# w_resize_new, h_resize_new = size +# else: +# ratio = min_side / min(h, w) +# w, h = round(ratio*w), round(ratio*h) +# ratio = max_side / max(h, w) +# input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) +# w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number +# h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number +# input_image = input_image.resize([w_resize_new, h_resize_new], mode) + +# if pad_to_max_side: +# res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 +# offset_x = (max_side - w_resize_new) // 2 +# offset_y = (max_side - h_resize_new) // 2 +# res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) +# input_image = Image.fromarray(res) +# return input_image + +# # def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]: +# # p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) +# # return p.replace("{prompt}", positive), n + ' ' + negative + +# def generate_image(face_image,prompt,negative_prompt): +# pose_image_path = None +# # prompt = "superman" +# enable_LCM = False +# identitynet_strength_ratio = 0.95 +# adapter_strength_ratio = 0.60 +# num_steps = 15 +# guidance_scale = 8.5 +# seed = random.randint(0, MAX_SEED) +# # negative_prompt = "" +# # negative_prompt += neg +# enhance_face_region = True +# if enable_LCM: +# pipe.enable_lora() +# pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) +# else: +# pipe.disable_lora() +# pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config) + +# if face_image is None: +# raise gr.Error(f"Cannot find any input face image! Please upload the face image") + +# # if prompt is None: +# # prompt = "a person" + +# # apply the style template +# # prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt) + +# # face_image = load_image(face_image_path) +# face_image = resize_img(face_image) +# face_image_cv2 = convert_from_image_to_cv2(face_image) +# height, width, _ = face_image_cv2.shape + +# # Extract face features +# face_info = app.get(face_image_cv2) + +# if len(face_info) == 0: +# raise gr.Error(f"Cannot find any face in the image! Please upload another person image") + +# face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face +# face_emb = face_info['embedding'] +# face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info['kps']) + +# if pose_image_path is not None: +# pose_image = load_image(pose_image_path) +# pose_image = resize_img(pose_image) +# pose_image_cv2 = convert_from_image_to_cv2(pose_image) + +# face_info = app.get(pose_image_cv2) + +# if len(face_info) == 0: +# raise gr.Error(f"Cannot find any face in the reference image! Please upload another person image") + +# face_info = face_info[-1] +# face_kps = draw_kps(pose_image, face_info['kps']) + +# width, height = face_kps.size + +# if enhance_face_region: +# control_mask = np.zeros([height, width, 3]) +# x1, y1, x2, y2 = face_info["bbox"] +# x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) +# control_mask[y1:y2, x1:x2] = 255 +# control_mask = Image.fromarray(control_mask.astype(np.uint8)) +# else: +# control_mask = None + +# generator = torch.Generator(device=device).manual_seed(seed) + +# print("Start inference...") +# print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}") + +# pipe.set_ip_adapter_scale(adapter_strength_ratio) +# images = pipe( +# prompt=prompt, +# negative_prompt=negative_prompt, +# image_embeds=face_emb, +# image=face_kps, +# control_mask=control_mask, +# controlnet_conditioning_scale=float(identitynet_strength_ratio), +# num_inference_steps=num_steps, +# guidance_scale=guidance_scale, +# height=height, +# width=width, +# generator=generator, +# # num_images_per_prompt = 4 +# ).images + +# return images[0] + +# ### Description +# title = r""" +#

Choose your AVATAR

+# """ + +# description = r""" +#

Powered by IDfy

""" + +# article = r"""""" + +# tips = r"""""" + +# css = ''' +# .gradio-container {width: 95% !important; background-color: #E6F3FF;} +# .image-gallery {height: 100vh !important; overflow: auto;} +# .gradio-row .gradio-element { margin: 0 !important; } +# ''' +# with gr.Blocks(css=css) as demo: + +# # description +# gr.Markdown(title) +# with gr.Row(): +# gr.Image("./gradio_demo/logo.png",scale=0,min_width=50,show_label=False,show_download_button=False) +# gr.Markdown(description) +# with gr.Row(): +# with gr.Column(): +# style = gr.Dropdown(label="Choose your STYLE", choices=STYLE_NAMES) +# face_file = gr.Image(label="Upload a photo of your face", type="pil",sources="webcam") +# submit = gr.Button("Submit", variant="primary") +# with gr.Column(): +# with gr.Row(): +# gallery1 = gr.Image(label="Generated Images") +# gallery2 = gr.Image(label="Generated Images") +# with gr.Row(): +# gallery3 = gr.Image(label="Generated Images") +# gallery4 = gr.Image(label="Generated Images") +# email = gr.Textbox(label="Email", +# info="Enter your email address", +# value="") + +# usage_tips = gr.Markdown(label="Usage tips of InstantID", value=tips ,visible=False) +# # identitynet_strength_ratio = gr.Slider( +# # label="IdentityNet strength (for fidelity)", +# # minimum=0, +# # maximum=1.5, +# # step=0.05, +# # value=0.95, +# # ) +# # adapter_strength_ratio = gr.Slider( +# # label="Image adapter strength (for detail)", +# # minimum=0, +# # maximum=1.5, +# # step=0.05, +# # value=0.60, +# # ) +# # negative_prompt = gr.Textbox( +# # label="Negative Prompt", +# # placeholder="low quality", +# # value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", +# # ) +# # num_steps = gr.Slider( +# # label="Number of sample steps", +# # minimum=15, +# # maximum=100, +# # step=1, +# # value=5 if enable_lcm_arg else 15, +# # ) +# # guidance_scale = gr.Slider( +# # label="Guidance scale", +# # minimum=0.1, +# # maximum=10.0, +# # step=0.1, +# # value=0 if enable_lcm_arg else 8.5, +# # ) +# # if email is None: +# # print("STOPPPP") +# # raise gr.Error("Email ID is compulsory") +# face_file.upload( +# fn=remove_tips, +# outputs=usage_tips, +# queue=True, +# api_name=False, +# show_progress = "full" +# ).then( +# fn=run_for_prompts1, +# inputs=[face_file,style], +# outputs=[gallery1] +# ).then( +# fn=run_for_prompts2, +# inputs=[face_file,style], +# outputs=[gallery2] +# ).then( +# fn=run_for_prompts3, +# inputs=[face_file,style], +# outputs=[gallery3] +# ).then( +# fn=run_for_prompts4, +# inputs=[face_file,style], +# outputs=[gallery4] +# ) +# submit.click( +# fn=remove_tips, +# outputs=usage_tips, +# queue=True, +# api_name=False, +# show_progress = "full" +# ).then( +# fn=run_for_prompts1, +# inputs=[face_file,style], +# outputs=[gallery1] +# ).then( +# fn=run_for_prompts2, +# inputs=[face_file,style], +# outputs=[gallery2] +# ).then( +# fn=run_for_prompts3, +# inputs=[face_file,style], +# outputs=[gallery3] +# ).then( +# fn=run_for_prompts4, +# inputs=[face_file,style], +# outputs=[gallery4] +# ) + + +# gr.Markdown(article) + +# demo.launch(share=True) + +# if __name__ == "__main__": +# parser = argparse.ArgumentParser() +# parser.add_argument("--pretrained_model_name_or_path", type=str, default="wangqixun/YamerMIX_v8") +# args = parser.parse_args() + +# main(args.pretrained_model_name_or_path, False) \ No newline at end of file diff --git a/gradio_demo/app-multicontrolnet.py b/gradio_demo/app-multicontrolnet.py new file mode 100644 index 0000000000000000000000000000000000000000..681a24f2319d79247f172490800bc292ad1be49a --- /dev/null +++ b/gradio_demo/app-multicontrolnet.py @@ -0,0 +1,670 @@ +import sys +sys.path.append("./") + +from typing import Tuple + +import os +import cv2 +import math +import torch +import random +import numpy as np +import argparse + +import PIL +from PIL import Image + +import diffusers +from diffusers.utils import load_image +from diffusers.models import ControlNetModel +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel + +from huggingface_hub import hf_hub_download + +from insightface.app import FaceAnalysis + +from style_template import styles +from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline +from model_util import load_models_xl, get_torch_device, torch_gc +from controlnet_util import openpose, get_depth_map, get_canny_image + +import gradio as gr + + +# global variable +MAX_SEED = np.iinfo(np.int32).max +device = get_torch_device() +dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32 +STYLE_NAMES = list(styles.keys()) +DEFAULT_STYLE_NAME = "Watercolor" + +# Load face encoder +app = FaceAnalysis( + name="antelopev2", + root="./", + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], +) +app.prepare(ctx_id=0, det_size=(640, 640)) + +# Path to InstantID models +face_adapter = f"./checkpoints/ip-adapter.bin" +controlnet_path = f"./checkpoints/ControlNetModel" + +# Load pipeline face ControlNetModel +controlnet_identitynet = ControlNetModel.from_pretrained( + controlnet_path, torch_dtype=dtype +) + +# controlnet-pose +controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0" +controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0" +controlnet_depth_model = "diffusers/controlnet-depth-sdxl-1.0-small" + +controlnet_pose = ControlNetModel.from_pretrained( + controlnet_pose_model, torch_dtype=dtype +).to(device) +controlnet_canny = ControlNetModel.from_pretrained( + controlnet_canny_model, torch_dtype=dtype +).to(device) +controlnet_depth = ControlNetModel.from_pretrained( + controlnet_depth_model, torch_dtype=dtype +).to(device) + +controlnet_map = { + "pose": controlnet_pose, + "canny": controlnet_canny, + "depth": controlnet_depth, +} +controlnet_map_fn = { + "pose": openpose, + "canny": get_canny_image, + "depth": get_depth_map, +} + + +def main(pretrained_model_name_or_path="wangqixun/YamerMIX_v8", enable_lcm_arg=False): + if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + scheduler_kwargs = hf_hub_download( + repo_id="wangqixun/YamerMIX_v8", + subfolder="scheduler", + filename="scheduler_config.json", + ) + + (tokenizers, text_encoders, unet, _, vae) = load_models_xl( + pretrained_model_name_or_path=pretrained_model_name_or_path, + scheduler_name=None, + weight_dtype=dtype, + ) + + scheduler = diffusers.EulerDiscreteScheduler.from_config(scheduler_kwargs) + pipe = StableDiffusionXLInstantIDPipeline( + vae=vae, + text_encoder=text_encoders[0], + text_encoder_2=text_encoders[1], + tokenizer=tokenizers[0], + tokenizer_2=tokenizers[1], + unet=unet, + scheduler=scheduler, + controlnet=[controlnet_identitynet], + ).to(device) + + else: + pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( + pretrained_model_name_or_path, + controlnet=[controlnet_identitynet], + torch_dtype=dtype, + safety_checker=None, + feature_extractor=None, + ).to(device) + + pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config( + pipe.scheduler.config + ) + + pipe.load_ip_adapter_instantid(face_adapter) + # load and disable LCM + pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl") + pipe.disable_lora() + + def toggle_lcm_ui(value): + if value: + return ( + gr.update(minimum=0, maximum=100, step=1, value=5), + gr.update(minimum=0.1, maximum=20.0, step=0.1, value=1.5), + ) + else: + return ( + gr.update(minimum=5, maximum=100, step=1, value=30), + gr.update(minimum=0.1, maximum=20.0, step=0.1, value=5), + ) + + def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, MAX_SEED) + return seed + + def remove_tips(): + return gr.update(visible=False) + + def get_example(): + case = [ + [ + "./examples/yann-lecun_resize.jpg", + None, + "a man", + "Snow", + "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", + ], + [ + "./examples/musk_resize.jpeg", + "./examples/poses/pose2.jpg", + "a man flying in the sky in Mars", + "Mars", + "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", + ], + [ + "./examples/sam_resize.png", + "./examples/poses/pose4.jpg", + "a man doing a silly pose wearing a suite", + "Jungle", + "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree", + ], + [ + "./examples/schmidhuber_resize.png", + "./examples/poses/pose3.jpg", + "a man sit on a chair", + "Neon", + "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", + ], + [ + "./examples/kaifu_resize.png", + "./examples/poses/pose.jpg", + "a man", + "Vibrant Color", + "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", + ], + ] + return case + + def run_for_examples(face_file, pose_file, prompt, style, negative_prompt): + return generate_image( + face_file, + pose_file, + prompt, + negative_prompt, + style, + 20, # num_steps + 0.8, # identitynet_strength_ratio + 0.8, # adapter_strength_ratio + 0.4, # pose_strength + 0.3, # canny_strength + 0.5, # depth_strength + ["pose", "canny"], # controlnet_selection + 5.0, # guidance_scale + 42, # seed + "EulerDiscreteScheduler", # scheduler + False, # enable_LCM + True, # enable_Face_Region + ) + + def convert_from_cv2_to_image(img: np.ndarray) -> Image: + return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + + def convert_from_image_to_cv2(img: Image) -> np.ndarray: + return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + + def draw_kps( + image_pil, + kps, + color_list=[ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + ], + ): + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly( + (int(np.mean(x)), int(np.mean(y))), + (int(length / 2), stickwidth), + int(angle), + 0, + 360, + 1, + ) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + + def resize_img( + input_image, + max_side=1280, + min_side=1024, + size=None, + pad_to_max_side=False, + mode=PIL.Image.BILINEAR, + base_pixel_number=64, + ): + w, h = input_image.size + if size is not None: + w_resize_new, h_resize_new = size + else: + ratio = min_side / min(h, w) + w, h = round(ratio * w), round(ratio * h) + ratio = max_side / max(h, w) + input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode) + w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number + h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number + input_image = input_image.resize([w_resize_new, h_resize_new], mode) + + if pad_to_max_side: + res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 + offset_x = (max_side - w_resize_new) // 2 + offset_y = (max_side - h_resize_new) // 2 + res[ + offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new + ] = np.array(input_image) + input_image = Image.fromarray(res) + return input_image + + def apply_style( + style_name: str, positive: str, negative: str = "" + ) -> Tuple[str, str]: + p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) + return p.replace("{prompt}", positive), n + " " + negative + + def generate_image( + face_image_path, + pose_image_path, + prompt, + negative_prompt, + style_name, + num_steps, + identitynet_strength_ratio, + adapter_strength_ratio, + pose_strength, + canny_strength, + depth_strength, + controlnet_selection, + guidance_scale, + seed, + scheduler, + enable_LCM, + enhance_face_region, + progress=gr.Progress(track_tqdm=True), + ): + + if enable_LCM: + pipe.scheduler = diffusers.LCMScheduler.from_config(pipe.scheduler.config) + pipe.enable_lora() + else: + pipe.disable_lora() + scheduler_class_name = scheduler.split("-")[0] + + add_kwargs = {} + if len(scheduler.split("-")) > 1: + add_kwargs["use_karras_sigmas"] = True + if len(scheduler.split("-")) > 2: + add_kwargs["algorithm_type"] = "sde-dpmsolver++" + scheduler = getattr(diffusers, scheduler_class_name) + pipe.scheduler = scheduler.from_config(pipe.scheduler.config, **add_kwargs) + + if face_image_path is None: + raise gr.Error( + f"Cannot find any input face image! Please upload the face image" + ) + + if prompt is None: + prompt = "a person" + + # apply the style template + prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt) + + face_image = load_image(face_image_path) + face_image = resize_img(face_image, max_side=1024) + face_image_cv2 = convert_from_image_to_cv2(face_image) + height, width, _ = face_image_cv2.shape + + # Extract face features + face_info = app.get(face_image_cv2) + + if len(face_info) == 0: + raise gr.Error( + f"Unable to detect a face in the image. Please upload a different photo with a clear face." + ) + + face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face + face_emb = face_info["embedding"] + face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"]) + img_controlnet = face_image + if pose_image_path is not None: + pose_image = load_image(pose_image_path) + pose_image = resize_img(pose_image, max_side=1024) + img_controlnet = pose_image + pose_image_cv2 = convert_from_image_to_cv2(pose_image) + + face_info = app.get(pose_image_cv2) + + if len(face_info) == 0: + raise gr.Error( + f"Cannot find any face in the reference image! Please upload another person image" + ) + + face_info = face_info[-1] + face_kps = draw_kps(pose_image, face_info["kps"]) + + width, height = face_kps.size + + if enhance_face_region: + control_mask = np.zeros([height, width, 3]) + x1, y1, x2, y2 = face_info["bbox"] + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + control_mask[y1:y2, x1:x2] = 255 + control_mask = Image.fromarray(control_mask.astype(np.uint8)) + else: + control_mask = None + + if len(controlnet_selection) > 0: + controlnet_scales = { + "pose": pose_strength, + "canny": canny_strength, + "depth": depth_strength, + } + pipe.controlnet = MultiControlNetModel( + [controlnet_identitynet] + + [controlnet_map[s] for s in controlnet_selection] + ) + control_scales = [float(identitynet_strength_ratio)] + [ + controlnet_scales[s] for s in controlnet_selection + ] + control_images = [face_kps] + [ + controlnet_map_fn[s](img_controlnet).resize((width, height)) + for s in controlnet_selection + ] + else: + pipe.controlnet = controlnet_identitynet + control_scales = float(identitynet_strength_ratio) + control_images = face_kps + + generator = torch.Generator(device=device).manual_seed(seed) + + print("Start inference...") + print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}") + + pipe.set_ip_adapter_scale(adapter_strength_ratio) + images = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + image_embeds=face_emb, + image=control_images, + control_mask=control_mask, + controlnet_conditioning_scale=control_scales, + num_inference_steps=num_steps, + guidance_scale=guidance_scale, + height=height, + width=width, + generator=generator, + ).images + + return images[0], gr.update(visible=True) + + # Description + title = r""" +

InstantID: Zero-shot Identity-Preserving Generation in Seconds

+ """ + + description = r""" + Official 🤗 Gradio demo for InstantID: Zero-shot Identity-Preserving Generation in Seconds.
+ + How to use:
+ 1. Upload an image with a face. For images with multiple faces, we will only detect the largest face. Ensure the face is not too small and is clearly visible without significant obstructions or blurring. + 2. (Optional) You can upload another image as a reference for the face pose. If you don't, we will use the first detected face image to extract facial landmarks. If you use a cropped face at step 1, it is recommended to upload it to define a new face pose. + 3. (Optional) You can select multiple ControlNet models to control the generation process. The default is to use the IdentityNet only. The ControlNet models include pose skeleton, canny, and depth. You can adjust the strength of each ControlNet model to control the generation process. + 4. Enter a text prompt, as done in normal text-to-image models. + 5. Click the Submit button to begin customization. + 6. Share your customized photo with your friends and enjoy! 😊""" + + article = r""" + --- + 📝 **Citation** +
+ If our work is helpful for your research or applications, please cite us via: + ```bibtex + @article{wang2024instantid, + title={InstantID: Zero-shot Identity-Preserving Generation in Seconds}, + author={Wang, Qixun and Bai, Xu and Wang, Haofan and Qin, Zekui and Chen, Anthony}, + journal={arXiv preprint arXiv:2401.07519}, + year={2024} + } + ``` + 📧 **Contact** +
+ If you have any questions, please feel free to open an issue or directly reach us out at haofanwang.ai@gmail.com. + """ + + tips = r""" + ### Usage tips of InstantID + 1. If you're not satisfied with the similarity, try increasing the weight of "IdentityNet Strength" and "Adapter Strength." + 2. If you feel that the saturation is too high, first decrease the Adapter strength. If it remains too high, then decrease the IdentityNet strength. + 3. If you find that text control is not as expected, decrease Adapter strength. + 4. If you find that realistic style is not good enough, go for our Github repo and use a more realistic base model. + """ + + css = """ + .gradio-container {width: 85% !important} + """ + with gr.Blocks(css=css) as demo: + # description + gr.Markdown(title) + gr.Markdown(description) + + with gr.Row(): + with gr.Column(): + with gr.Row(equal_height=True): + # upload face image + face_file = gr.Image( + label="Upload a photo of your face", type="filepath" + ) + # optional: upload a reference pose image + pose_file = gr.Image( + label="Upload a reference pose image (Optional)", + type="filepath", + ) + + # prompt + prompt = gr.Textbox( + label="Prompt", + info="Give simple prompt is enough to achieve good face fidelity", + placeholder="A photo of a person", + value="", + ) + + submit = gr.Button("Submit", variant="primary") + enable_LCM = gr.Checkbox( + label="Enable Fast Inference with LCM", value=enable_lcm_arg, + info="LCM speeds up the inference step, the trade-off is the quality of the generated image. It performs better with portrait face images rather than distant faces", + ) + style = gr.Dropdown( + label="Style template", + choices=STYLE_NAMES, + value=DEFAULT_STYLE_NAME, + ) + + # strength + identitynet_strength_ratio = gr.Slider( + label="IdentityNet strength (for fidelity)", + minimum=0, + maximum=1.5, + step=0.05, + value=0.80, + ) + adapter_strength_ratio = gr.Slider( + label="Image adapter strength (for detail)", + minimum=0, + maximum=1.5, + step=0.05, + value=0.80, + ) + with gr.Accordion("Controlnet"): + controlnet_selection = gr.CheckboxGroup( + ["pose", "canny", "depth"], label="Controlnet", value=["pose"], + info="Use pose for skeleton inference, canny for edge detection, and depth for depth map estimation. You can try all three to control the generation process" + ) + pose_strength = gr.Slider( + label="Pose strength", + minimum=0, + maximum=1.5, + step=0.05, + value=0.40, + ) + canny_strength = gr.Slider( + label="Canny strength", + minimum=0, + maximum=1.5, + step=0.05, + value=0.40, + ) + depth_strength = gr.Slider( + label="Depth strength", + minimum=0, + maximum=1.5, + step=0.05, + value=0.40, + ) + with gr.Accordion(open=False, label="Advanced Options"): + negative_prompt = gr.Textbox( + label="Negative Prompt", + placeholder="low quality", + value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", + ) + num_steps = gr.Slider( + label="Number of sample steps", + minimum=1, + maximum=100, + step=1, + value=5 if enable_lcm_arg else 30, + ) + guidance_scale = gr.Slider( + label="Guidance scale", + minimum=0.1, + maximum=20.0, + step=0.1, + value=0.0 if enable_lcm_arg else 5.0, + ) + seed = gr.Slider( + label="Seed", + minimum=0, + maximum=MAX_SEED, + step=1, + value=42, + ) + schedulers = [ + "DEISMultistepScheduler", + "HeunDiscreteScheduler", + "EulerDiscreteScheduler", + "DPMSolverMultistepScheduler", + "DPMSolverMultistepScheduler-Karras", + "DPMSolverMultistepScheduler-Karras-SDE", + ] + scheduler = gr.Dropdown( + label="Schedulers", + choices=schedulers, + value="EulerDiscreteScheduler", + ) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + enhance_face_region = gr.Checkbox(label="Enhance non-face region", value=True) + + with gr.Column(scale=1): + gallery = gr.Image(label="Generated Images") + usage_tips = gr.Markdown( + label="InstantID Usage Tips", value=tips, visible=False + ) + + submit.click( + fn=remove_tips, + outputs=usage_tips, + ).then( + fn=randomize_seed_fn, + inputs=[seed, randomize_seed], + outputs=seed, + queue=False, + api_name=False, + ).then( + fn=generate_image, + inputs=[ + face_file, + pose_file, + prompt, + negative_prompt, + style, + num_steps, + identitynet_strength_ratio, + adapter_strength_ratio, + pose_strength, + canny_strength, + depth_strength, + controlnet_selection, + guidance_scale, + seed, + scheduler, + enable_LCM, + enhance_face_region, + ], + outputs=[gallery, usage_tips], + ) + + enable_LCM.input( + fn=toggle_lcm_ui, + inputs=[enable_LCM], + outputs=[num_steps, guidance_scale], + queue=False, + ) + + gr.Examples( + examples=get_example(), + inputs=[face_file, pose_file, prompt, style, negative_prompt], + fn=run_for_examples, + outputs=[gallery, usage_tips], + cache_examples=True, + ) + + gr.Markdown(article) + + demo.launch() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--pretrained_model_name_or_path", type=str, default="wangqixun/YamerMIX_v8" + ) + parser.add_argument( + "--enable_LCM", type=bool, default=os.environ.get("ENABLE_LCM", False) + ) + args = parser.parse_args() + + main(args.pretrained_model_name_or_path, args.enable_LCM) \ No newline at end of file diff --git a/gradio_demo/app.py b/gradio_demo/app.py new file mode 100644 index 0000000000000000000000000000000000000000..112b366797c61baf11512578cc8286d196fbbcf5 --- /dev/null +++ b/gradio_demo/app.py @@ -0,0 +1,656 @@ +import sys +sys.path.append('./') + +from typing import Tuple + +import os +import cv2 +import math +import torch +import random +import numpy as np +import argparse +import pandas as pd + +import PIL +from PIL import Image + +import diffusers +from diffusers.utils import load_image +from diffusers.models import ControlNetModel +from diffusers import LCMScheduler + +from huggingface_hub import hf_hub_download + +import insightface +from insightface.app import FaceAnalysis + +from style_template import styles +from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline +from model_util import load_models_xl, get_torch_device, torch_gc + + +# global variable +MAX_SEED = np.iinfo(np.int32).max +device = get_torch_device() +dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32 +STYLE_NAMES = list(styles.keys()) +DEFAULT_STYLE_NAME = "Watercolor" + +# Load face encoder +app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) +app.prepare(ctx_id=0, det_size=(640, 640)) + +# Path to InstantID models +face_adapter = f'./checkpoints/ip-adapter.bin' +controlnet_path = f'./checkpoints/ControlNetModel' + +# Load pipeline +controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=dtype) + +logo = Image.open("./gradio_demo/watermark.png") +logo = logo.resize((100, 70)) + +from cv2 import imencode +import base64 + +# def encode_pil_to_base64_new(pil_image): +# print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") +# image_arr = np.asarray(pil_image)[:,:,::-1] +# _, byte_data = imencode('.png', image_arr) +# base64_data = base64.b64encode(byte_data) +# base64_string_opencv = base64_data.decode("utf-8") +# return "data:image/png;base64," + base64_string_opencv + +import gradio as gr + +# gr.processing_utils.encode_pil_to_base64 = encode_pil_to_base64_new + +def main(pretrained_model_name_or_path="wangqixun/YamerMIX_v8", enable_lcm_arg=False): + + if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + scheduler_kwargs = hf_hub_download( + repo_id="wangqixun/YamerMIX_v8", + subfolder="scheduler", + filename="scheduler_config.json", + ) + + (tokenizers, text_encoders, unet, _, vae) = load_models_xl( + pretrained_model_name_or_path=pretrained_model_name_or_path, + scheduler_name=None, + weight_dtype=dtype, + ) + + scheduler = diffusers.EulerDiscreteScheduler.from_config(scheduler_kwargs) + pipe = StableDiffusionXLInstantIDPipeline( + vae=vae, + text_encoder=text_encoders[0], + text_encoder_2=text_encoders[1], + tokenizer=tokenizers[0], + tokenizer_2=tokenizers[1], + unet=unet, + scheduler=scheduler, + controlnet=controlnet, + ).to(device) + + else: + pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( + pretrained_model_name_or_path, + controlnet=controlnet, + torch_dtype=dtype, + safety_checker=None, + feature_extractor=None, + ).to(device) + + pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config) + + pipe.load_ip_adapter_instantid(face_adapter) + # load and disable LCM + pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl") + pipe.disable_lora() + + def remove_tips(): + print("GG") + return gr.update(visible=False) + + + # prompts = [ + # ["superman","Vibrant Color"], ["japanese anime character with white/neon hair","Watercolor"], + # # ["Suited professional","(No style)"], + # ["Scooba diver","Line art"], ["eskimo","Snow"] + # ] + + def convert_from_cv2_to_image(img: np.ndarray) -> Image: + return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + + def convert_from_image_to_cv2(img: Image) -> np.ndarray: + return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + + def run_for_prompts1(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[0], n) + # else: + # raise gr.Error("Email ID is compulsory") + def run_for_prompts2(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[1], n) + def run_for_prompts3(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[2], n) + def run_for_prompts4(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[3], n) + +# def validate_and_process(face_file, style, email): + +# # Your processing logic here +# gallery1, gallery2, gallery3, gallery4 = run_for_prompts1(face_file, style), run_for_prompts2(face_file, style), run_for_prompts3(face_file, style), run_for_prompts4(face_file, style) +# return gallery1, gallery2, gallery3, gallery4 + + def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]): + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + + def resize_img(input_image, max_side=1280, min_side=1280, size=None, + pad_to_max_side=True, mode=PIL.Image.BILINEAR, base_pixel_number=64): + + w, h = input_image.size + print(f"Original Size --> {input_image.size}") + if size is not None: + w_resize_new, h_resize_new = size + else: + ratio = min_side / min(h, w) + w, h = round(ratio*w), round(ratio*h) + ratio = max_side / max(h, w) + input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) + w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number + h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number + input_image = input_image.resize([w_resize_new, h_resize_new], mode) + + if pad_to_max_side: + res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 + offset_x = (max_side - w_resize_new) // 2 + offset_y = (max_side - h_resize_new) // 2 + res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) + input_image = Image.fromarray(res) + + print(f"Final modified image size --> {input_image.size}") + return input_image + + # def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]: + # p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) + # return p.replace("{prompt}", positive), n + ' ' + negative + + def store_images(email, gallery1, gallery2, gallery3, gallery4,consent): + if not consent: + raise gr.Error("Consent not provided") + galleries = [] + for i, img in enumerate([gallery1, gallery2, gallery3, gallery4], start=1): + if isinstance(img, np.ndarray): + img = Image.fromarray(img) + print(f"Gallery {i} type after conversion: {type(img)}") + galleries.append(img) + # Create the images directory if it doesn't exist + if not os.path.exists('images'): + os.makedirs('images') + + # Define image file paths + image_paths = [] + for i, img in enumerate(galleries, start=1): + img_path = f'images/{email}_gallery{i}.png' + img.save(img_path) + image_paths.append(img_path) + + # Define the CSV file path + csv_file_path = 'image_data.csv' + + # Create a DataFrame for the email and image paths + df = pd.DataFrame({ + 'email': [email], + 'img1_path': [image_paths[0]], + 'img2_path': [image_paths[1]], + 'img3_path': [image_paths[2]], + 'img4_path': [image_paths[3]], + }) + + # Write to CSV (append if the file exists, create a new one if it doesn't) + if not os.path.isfile(csv_file_path): + df.to_csv(csv_file_path, index=False) + else: + df.to_csv(csv_file_path, mode='a', header=False, index=False) + + gr.Info("Thankyou!! Your avatar is on the way to your inbox") + + def add_watermark(image, watermark=logo, opacity=128, position="bottom_right", padding=10): + # Convert NumPy array to PIL Image if needed + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + + if isinstance(watermark, np.ndarray): + watermark = Image.fromarray(watermark) + + # Convert images to 'RGBA' mode to handle transparency + image = image.convert("RGBA") + watermark = watermark.convert("RGBA") + + # Adjust the watermark opacity + watermark = watermark.copy() + watermark.putalpha(opacity) + + # Calculate the position for the watermark + if position == "bottom_right": + x = image.width - watermark.width - padding + y = image.height - watermark.height - padding + elif position == "bottom_left": + x = padding + y = image.height - watermark.height - padding + elif position == "top_right": + x = image.width - watermark.width - padding + y = padding + elif position == "top_left": + x = padding + y = padding + else: + raise ValueError("Unsupported position. Choose from 'bottom_right', 'bottom_left', 'top_right', 'top_left'.") + + # Paste the watermark onto the image + image.paste(watermark, (x, y), watermark) + + # Convert back to 'RGB' if the original image was not 'RGBA' + if image.mode != "RGBA": + image = image.convert("RGB") + + # return resize_img(image) + return image + + def generate_image(face_image,prompt,negative_prompt): + pose_image_path = None + # prompt = "superman" + enable_LCM = False + identitynet_strength_ratio = 0.90 + adapter_strength_ratio = 0.60 + num_steps = 15 + guidance_scale = 5 + seed = random.randint(0, MAX_SEED) + print(f"Seed --> {seed}") + + # negative_prompt = "" + # negative_prompt += neg + enhance_face_region = True + if enable_LCM: + pipe.enable_lora() + pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) + else: + pipe.disable_lora() + pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config) + + if face_image is None: + raise gr.Error(f"Cannot find any input face image! Please upload the face image") + + # if prompt is None: + # prompt = "a person" + + # apply the style template + # prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt) + + # face_image = load_image(face_image_path) + face_image = resize_img(face_image) + face_image_cv2 = convert_from_image_to_cv2(face_image) + height, width, _ = face_image_cv2.shape + + # Extract face features + face_info = app.get(face_image_cv2) + + if len(face_info) == 0: + raise gr.Error(f"Cannot find any face in the image! Please upload another person image") + + face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face + face_emb = face_info['embedding'] + face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info['kps']) + + if pose_image_path is not None: + pose_image = load_image(pose_image_path) + pose_image = resize_img(pose_image) + pose_image_cv2 = convert_from_image_to_cv2(pose_image) + + face_info = app.get(pose_image_cv2) + + if len(face_info) == 0: + raise gr.Error(f"Cannot find any face in the reference image! Please upload another person image") + + face_info = face_info[-1] + face_kps = draw_kps(pose_image, face_info['kps']) + + width, height = face_kps.size + + if enhance_face_region: + control_mask = np.zeros([height, width, 3]) + x1, y1, x2, y2 = face_info["bbox"] + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + control_mask[y1:y2, x1:x2] = 255 + control_mask = Image.fromarray(control_mask.astype(np.uint8)) + else: + control_mask = None + + generator = torch.Generator(device=device).manual_seed(seed) + + print("Start inference...") + print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}") + + pipe.set_ip_adapter_scale(adapter_strength_ratio) + images = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + image_embeds=face_emb, + image=face_kps, + control_mask=control_mask, + controlnet_conditioning_scale=float(identitynet_strength_ratio), + num_inference_steps=num_steps, + guidance_scale=guidance_scale, + height=height, + width=width, + generator=generator, + # num_images_per_prompt = 4 + ).images + + watermarked_image = add_watermark(images[0]) + + # return images[0] + return watermarked_image + + ### Description + title = r""" +

Choose your AVATAR

+ """ + + description = r""" +

Powered by IDfy

""" + + article = r"""""" + + tips = r"""""" + # css = ''' + # .gradio-container { + # width: 95% !important; + # background-image: url('./InstantID/gradio_demo/logo.png'); + # background-size: cover; + # background-position: center; + # } + # .image-gallery { + # height: 100vh !important; + # overflow: auto; + # } + # .gradio-row .gradio-element { + # margin: 0 !important; + # } + # ''' + css = ''' + .gradio-container {width: 100% !important; color: white; background: linear-gradient(135deg, #1C43B9, #254977, #343434);} + .gradio-row .gradio-element { margin: 0 !important; } + .centered-column { + display: flex; + justify-content: center; + align-items: center; + width: 100%;} + #store-btn { + background: #f2bb13 !important; + color: white !important; + } + ''' + with gr.Blocks(css=css) as demo: + + # description + gr.Markdown(title) + with gr.Column(): + with gr.Row(): + gr.Image("./gradio_demo/logo.png", scale=0, min_width=50, show_label=False, show_download_button=False) + gr.Markdown(description) + style = gr.Dropdown(label="Choose your STYLE", choices=STYLE_NAMES) + with gr.Row(equal_height=True): # Center the face file + with gr.Column(elem_id="centered-face", elem_classes=["centered-column"]): # Use CSS class for centering + face_file = gr.Image(label="Upload a photo of your face", type="pil", sources="webcam", height=400, width=500) + # submit = gr.Button("Submit", variant="primary") + with gr.Column(): + with gr.Row(): + gallery1 = gr.Image(label="Generated Images") + gallery2 = gr.Image(label="Generated Images") + with gr.Row(): + gallery3 = gr.Image(label="Generated Images") + gallery4 = gr.Image(label="Generated Images") + email = gr.Textbox(label="Email", info="Enter your email address", value="") + consent = gr.Checkbox(label="I am giving my consent to use my data to share my AI Avtar and IDfy relevant information from time to time") + submit1 = gr.Button("STORE",elem_id="store-btn") +# with gr.Blocks(css=css) as demo: + +# # description +# gr.Markdown(title) +# with gr.Column(): +# with gr.Row(): +# gr.Image("./gradio_demo/logo.png",scale=0,min_width=50,show_label=False,show_download_button=False) +# gr.Markdown(description) +# style = gr.Dropdown(label="Choose your STYLE", choices=STYLE_NAMES) +# face_file = gr.Image(label="Upload a photo of your face", type="pil",sources="webcam", height=400, width=500) +# submit = gr.Button("Submit", variant="primary") +# with gr.Column(): +# with gr.Row(): +# gallery1 = gr.Image(label="Generated Images") +# gallery2 = gr.Image(label="Generated Images") +# with gr.Row(): +# gallery3 = gr.Image(label="Generated Images") +# gallery4 = gr.Image(label="Generated Images") +# email = gr.Textbox(label="Email", +# info="Enter your email address", +# value="") +# consent = gr.Checkbox(label="I am giving my consent to use my data to share my AI Avtar and IDfy relevant information from time to time") +# submit1 = gr.Button("STORE", variant="primary") +# # submit1 = gr.Button("Store") + usage_tips = gr.Markdown(label="Usage tips of InstantID", value=tips ,visible=False) + + face_file.upload( + fn=remove_tips, + outputs=usage_tips, + queue=True, + api_name=False, + show_progress = "full" + ).then( + fn=run_for_prompts1, + inputs=[face_file,style], + outputs=[gallery1] + ).then( + fn=run_for_prompts2, + inputs=[face_file,style], + outputs=[gallery2] + ).then( + fn=run_for_prompts3, + inputs=[face_file,style], + outputs=[gallery3] + ).then( + fn=run_for_prompts4, + inputs=[face_file,style], + outputs=[gallery4] + ) +# submit.click( +# fn=remove_tips, +# outputs=usage_tips, +# queue=True, +# api_name=False, +# show_progress = "full" +# ).then( +# fn=run_for_prompts1, +# inputs=[face_file,style], +# outputs=[gallery1] +# ).then( +# fn=run_for_prompts2, +# inputs=[face_file,style], +# outputs=[gallery2] +# ).then( +# fn=run_for_prompts3, +# inputs=[face_file,style], +# outputs=[gallery3] +# ).then( +# fn=run_for_prompts4, +# inputs=[face_file,style], +# outputs=[gallery4] +# ) + +# submit1.click( +# fn=store_images, +# inputs=[email,gallery1,gallery2,gallery3,gallery4,consent], +# outputs=None) + + + + gr.Markdown(article) + + demo.launch(share=True) + +# with gr.Blocks(css=css, js=js) as demo: + +# # description +# gr.Markdown(title) +# with gr.Row(): +# gr.Image("./gradio_demo/logo.png",scale=0,min_width=50,show_label=False,show_download_button=False) +# gr.Markdown(description) +# with gr.Row(): +# with gr.Column(): +# style = gr.Dropdown(label="Choose your STYLE", choices=STYLE_NAMES) +# face_file = gr.Image(label="Upload a photo of your face", type="pil",sources="webcam") +# submit = gr.Button("Submit", variant="primary") +# with gr.Column(): +# with gr.Row(): +# gallery1 = gr.Image(label="Generated Images") +# gallery2 = gr.Image(label="Generated Images") +# with gr.Row(): +# gallery3 = gr.Image(label="Generated Images") +# gallery4 = gr.Image(label="Generated Images") +# email = gr.Textbox(label="Email", +# info="Enter your email address", +# value="") + +# usage_tips = gr.Markdown(label="Usage tips of InstantID", value=tips ,visible=False) +# # identitynet_strength_ratio = gr.Slider( +# # label="IdentityNet strength (for fidelity)", +# # minimum=0, +# # maximum=1.5, +# # step=0.05, +# # value=0.95, +# # ) +# # adapter_strength_ratio = gr.Slider( +# # label="Image adapter strength (for detail)", +# # minimum=0, +# # maximum=1.5, +# # step=0.05, +# # value=0.60, +# # ) +# # negative_prompt = gr.Textbox( +# # label="Negative Prompt", +# # placeholder="low quality", +# # value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", +# # ) +# # num_steps = gr.Slider( +# # label="Number of sample steps", +# # minimum=15, +# # maximum=100, +# # step=1, +# # value=5 if enable_lcm_arg else 15, +# # ) +# # guidance_scale = gr.Slider( +# # label="Guidance scale", +# # minimum=0.1, +# # maximum=10.0, +# # step=0.1, +# # value=0 if enable_lcm_arg else 8.5, +# # ) +# # if email is None: +# # print("STOPPPP") +# # raise gr.Error("Email ID is compulsory") +# face_file.upload( +# fn=remove_tips, +# outputs=usage_tips, +# queue=True, +# api_name=False, +# show_progress = "full" +# ).then( +# fn=run_for_prompts1, +# inputs=[face_file,style], +# outputs=[gallery1] +# ).then( +# fn=run_for_prompts2, +# inputs=[face_file,style], +# outputs=[gallery2] +# ).then( +# fn=run_for_prompts3, +# inputs=[face_file,style], +# outputs=[gallery3] +# ).then( +# fn=run_for_prompts4, +# inputs=[face_file,style], +# outputs=[gallery4] +# ) +# submit.click( +# fn=remove_tips, +# outputs=usage_tips, +# queue=True, +# api_name=False, +# show_progress = "full" +# ).then( +# fn=run_for_prompts1, +# inputs=[face_file,style], +# outputs=[gallery1] +# ).then( +# fn=run_for_prompts2, +# inputs=[face_file,style], +# outputs=[gallery2] +# ).then( +# fn=run_for_prompts3, +# inputs=[face_file,style], +# outputs=[gallery3] +# ).then( +# fn=run_for_prompts4, +# inputs=[face_file,style], +# outputs=[gallery4] +# ) + + +# gr.Markdown(article) + +# demo.launch(share=True) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained_model_name_or_path", type=str, default="wangqixun/YamerMIX_v8") + args = parser.parse_args() + + main(args.pretrained_model_name_or_path, False) + + diff --git a/gradio_demo/app1.py b/gradio_demo/app1.py new file mode 100644 index 0000000000000000000000000000000000000000..e3d31aa956276dbbf16304548fe0f59c005007df --- /dev/null +++ b/gradio_demo/app1.py @@ -0,0 +1,434 @@ +import sys +sys.path.append('./') + +from typing import Tuple + +import os +import cv2 +import math +import torch +import random +import numpy as np +import argparse + +import PIL +from PIL import Image + +import diffusers +from diffusers.utils import load_image +from diffusers.models import ControlNetModel +from diffusers import LCMScheduler + +from huggingface_hub import hf_hub_download + +import insightface +from insightface.app import FaceAnalysis + +from style_template import styles +from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline +from model_util import load_models_xl, get_torch_device, torch_gc + + +# global variable +MAX_SEED = np.iinfo(np.int32).max +device = get_torch_device() +dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32 +STYLE_NAMES = list(styles.keys()) +DEFAULT_STYLE_NAME = "Watercolor" + +# Load face encoder +app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) +app.prepare(ctx_id=0, det_size=(320, 320)) + +# Path to InstantID models +face_adapter = f'./checkpoints/ip-adapter.bin' +controlnet_path = f'./checkpoints/ControlNetModel' + +# Load pipeline +controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=dtype) + +logo = Image.open("./gradio_demo/logo.png") + +from cv2 import imencode +import base64 + +# def encode_pil_to_base64_new(pil_image): +# print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") +# image_arr = np.asarray(pil_image)[:,:,::-1] +# _, byte_data = imencode('.png', image_arr) +# base64_data = base64.b64encode(byte_data) +# base64_string_opencv = base64_data.decode("utf-8") +# return "data:image/png;base64," + base64_string_opencv + +import gradio as gr + +# gr.processing_utils.encode_pil_to_base64 = encode_pil_to_base64_new + +def main(pretrained_model_name_or_path="wangqixun/YamerMIX_v8", enable_lcm_arg=False): + + if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + scheduler_kwargs = hf_hub_download( + repo_id="wangqixun/YamerMIX_v8", + subfolder="scheduler", + filename="scheduler_config.json", + ) + + (tokenizers, text_encoders, unet, _, vae) = load_models_xl( + pretrained_model_name_or_path=pretrained_model_name_or_path, + scheduler_name=None, + weight_dtype=dtype, + ) + + scheduler = diffusers.EulerDiscreteScheduler.from_config(scheduler_kwargs) + pipe = StableDiffusionXLInstantIDPipeline( + vae=vae, + text_encoder=text_encoders[0], + text_encoder_2=text_encoders[1], + tokenizer=tokenizers[0], + tokenizer_2=tokenizers[1], + unet=unet, + scheduler=scheduler, + controlnet=controlnet, + ).to(device) + + else: + pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( + pretrained_model_name_or_path, + controlnet=controlnet, + torch_dtype=dtype, + safety_checker=None, + feature_extractor=None, + ).to(device) + + pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config) + + pipe.load_ip_adapter_instantid(face_adapter) + # load and disable LCM + pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl") + pipe.disable_lora() + + def remove_tips(): + return gr.update(visible=False) + + + # prompts = [ + # ["superman","Vibrant Color"], ["japanese anime character with white/neon hair","Watercolor"], + # # ["Suited professional","(No style)"], + # ["Scooba diver","Line art"], ["eskimo","Snow"] + # ] + + def convert_from_cv2_to_image(img: np.ndarray) -> Image: + return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + + def convert_from_image_to_cv2(img: Image) -> np.ndarray: + return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + + def run_for_prompts1(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[0], n) + # else: + # raise gr.Error("Email ID is compulsory") + def run_for_prompts2(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[1], n) + def run_for_prompts3(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[2], n) + def run_for_prompts4(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[3], n) + +# def validate_and_process(face_file, style, email): + +# # Your processing logic here +# gallery1, gallery2, gallery3, gallery4 = run_for_prompts1(face_file, style), run_for_prompts2(face_file, style), run_for_prompts3(face_file, style), run_for_prompts4(face_file, style) +# return gallery1, gallery2, gallery3, gallery4 + + def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]): + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + + def resize_img(input_image, max_side=640, min_side=640, size=None, + pad_to_max_side=True, mode=PIL.Image.BILINEAR, base_pixel_number=64): + + w, h = input_image.size + print(w) + print(h) + if size is not None: + w_resize_new, h_resize_new = size + else: + ratio = min_side / min(h, w) + w, h = round(ratio*w), round(ratio*h) + ratio = max_side / max(h, w) + input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) + w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number + h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number + input_image = input_image.resize([w_resize_new, h_resize_new], mode) + + if pad_to_max_side: + res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 + offset_x = (max_side - w_resize_new) // 2 + offset_y = (max_side - h_resize_new) // 2 + res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) + input_image = Image.fromarray(res) + return input_image + + # def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]: + # p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) + # return p.replace("{prompt}", positive), n + ' ' + negative + + def generate_image(face_image,prompt,negative_prompt): + pose_image_path = None + # prompt = "superman" + enable_LCM = False + identitynet_strength_ratio = 0.95 + adapter_strength_ratio = 0.60 + num_steps = 15 + guidance_scale = 8.5 + seed = random.randint(0, MAX_SEED) + # negative_prompt = "" + # negative_prompt += neg + enhance_face_region = True + if enable_LCM: + pipe.enable_lora() + pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) + else: + pipe.disable_lora() + pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config) + + if face_image is None: + raise gr.Error(f"Cannot find any input face image! Please upload the face image") + + # if prompt is None: + # prompt = "a person" + + # apply the style template + # prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt) + + # face_image = load_image(face_image_path) + face_image = resize_img(face_image) + face_image_cv2 = convert_from_image_to_cv2(face_image) + height, width, _ = face_image_cv2.shape + + # Extract face features + face_info = app.get(face_image_cv2) + + if len(face_info) == 0: + raise gr.Error(f"Cannot find any face in the image! Please upload another person image") + + face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face + face_emb = face_info['embedding'] + face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info['kps']) + + if pose_image_path is not None: + pose_image = load_image(pose_image_path) + pose_image = resize_img(pose_image) + pose_image_cv2 = convert_from_image_to_cv2(pose_image) + + face_info = app.get(pose_image_cv2) + + if len(face_info) == 0: + raise gr.Error(f"Cannot find any face in the reference image! Please upload another person image") + + face_info = face_info[-1] + face_kps = draw_kps(pose_image, face_info['kps']) + + width, height = face_kps.size + + if enhance_face_region: + control_mask = np.zeros([height, width, 3]) + x1, y1, x2, y2 = face_info["bbox"] + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + control_mask[y1:y2, x1:x2] = 255 + control_mask = Image.fromarray(control_mask.astype(np.uint8)) + else: + control_mask = None + + generator = torch.Generator(device=device).manual_seed(seed) + + print("Start inference...") + print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}") + + pipe.set_ip_adapter_scale(adapter_strength_ratio) + images = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + image_embeds=face_emb, + image=face_kps, + control_mask=control_mask, + controlnet_conditioning_scale=float(identitynet_strength_ratio), + num_inference_steps=num_steps, + guidance_scale=guidance_scale, + height=height, + width=width, + generator=generator, + # num_images_per_prompt = 4 + ).images + + return images[0] + + ### Description + title = r""" +

Choose your AVATAR

+ """ + + description = r""" +

Powered by IDfy

""" + + article = r"""""" + + tips = r"""""" + + css = ''' + .gradio-container {width: 95% !important; background-color: #E6F3FF;} + .image-gallery {height: 100vh !important; overflow: auto;} + .gradio-row .gradio-element { margin: 0 !important; } + ''' + with gr.Blocks(css=css) as demo: + + # description + gr.Markdown(title) + with gr.Row(): + gr.Image("./gradio_demo/logo.png",scale=0,min_width=50,show_label=False,show_download_button=False) + gr.Markdown(description) + with gr.Row(): + with gr.Column(): + style = gr.Dropdown(label="Choose your STYLE", choices=STYLE_NAMES) + face_file = gr.Image(label="Upload a photo of your face", type="pil",sources="webcam") + submit = gr.Button("Submit", variant="primary") + with gr.Column(): + with gr.Row(): + gallery1 = gr.Image(label="Generated Images") + gallery2 = gr.Image(label="Generated Images") + with gr.Row(): + gallery3 = gr.Image(label="Generated Images") + gallery4 = gr.Image(label="Generated Images") + email = gr.Textbox(label="Email", + info="Enter your email address", + value="") + + usage_tips = gr.Markdown(label="Usage tips of InstantID", value=tips ,visible=False) + # identitynet_strength_ratio = gr.Slider( + # label="IdentityNet strength (for fidelity)", + # minimum=0, + # maximum=1.5, + # step=0.05, + # value=0.95, + # ) + # adapter_strength_ratio = gr.Slider( + # label="Image adapter strength (for detail)", + # minimum=0, + # maximum=1.5, + # step=0.05, + # value=0.60, + # ) + # negative_prompt = gr.Textbox( + # label="Negative Prompt", + # placeholder="low quality", + # value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", + # ) + # num_steps = gr.Slider( + # label="Number of sample steps", + # minimum=15, + # maximum=100, + # step=1, + # value=5 if enable_lcm_arg else 15, + # ) + # guidance_scale = gr.Slider( + # label="Guidance scale", + # minimum=0.1, + # maximum=10.0, + # step=0.1, + # value=0 if enable_lcm_arg else 8.5, + # ) + # if email is None: + # print("STOPPPP") + # raise gr.Error("Email ID is compulsory") + face_file.upload( + fn=remove_tips, + outputs=usage_tips, + queue=True, + api_name=False, + show_progress = "full" + ).then( + fn=run_for_prompts1, + inputs=[face_file,style], + outputs=[gallery1] + ).then( + fn=run_for_prompts2, + inputs=[face_file,style], + outputs=[gallery2] + ).then( + fn=run_for_prompts3, + inputs=[face_file,style], + outputs=[gallery3] + ).then( + fn=run_for_prompts4, + inputs=[face_file,style], + outputs=[gallery4] + ) + submit.click( + fn=remove_tips, + outputs=usage_tips, + queue=True, + api_name=False, + show_progress = "full" + ).then( + fn=run_for_prompts1, + inputs=[face_file,style], + outputs=[gallery1] + ).then( + fn=run_for_prompts2, + inputs=[face_file,style], + outputs=[gallery2] + ).then( + fn=run_for_prompts3, + inputs=[face_file,style], + outputs=[gallery3] + ).then( + fn=run_for_prompts4, + inputs=[face_file,style], + outputs=[gallery4] + ) + + + gr.Markdown(article) + + demo.launch(share=True) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained_model_name_or_path", type=str, default="wangqixun/YamerMIX_v8") + args = parser.parse_args() + + main(args.pretrained_model_name_or_path, False) \ No newline at end of file diff --git a/gradio_demo/background.jpg b/gradio_demo/background.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a996db243164e6bc7251bebc9c24133ad15d5d46 Binary files /dev/null and b/gradio_demo/background.jpg differ diff --git a/gradio_demo/controlnet_util.py b/gradio_demo/controlnet_util.py new file mode 100644 index 0000000000000000000000000000000000000000..5c9f04147617408c6d91e6a7e5b1da2b5a95090c --- /dev/null +++ b/gradio_demo/controlnet_util.py @@ -0,0 +1,39 @@ +import torch +import numpy as np +from PIL import Image +from controlnet_aux import OpenposeDetector +from model_util import get_torch_device +import cv2 + + +from transformers import DPTImageProcessor, DPTForDepthEstimation + +device = get_torch_device() +depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device) +feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") +openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") + +def get_depth_map(image): + image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") + with torch.no_grad(), torch.autocast("cuda"): + depth_map = depth_estimator(image).predicted_depth + + depth_map = torch.nn.functional.interpolate( + depth_map.unsqueeze(1), + size=(1024, 1024), + mode="bicubic", + align_corners=False, + ) + depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) + depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) + depth_map = (depth_map - depth_min) / (depth_max - depth_min) + image = torch.cat([depth_map] * 3, dim=1) + + image = image.permute(0, 2, 3, 1).cpu().numpy()[0] + image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) + return image + +def get_canny_image(image, t1=100, t2=200): + image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + edges = cv2.Canny(image, t1, t2) + return Image.fromarray(edges, "L") \ No newline at end of file diff --git a/gradio_demo/demo.py b/gradio_demo/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..856a857e262ee8c20856c74d6dbe4e48a0ec94f7 --- /dev/null +++ b/gradio_demo/demo.py @@ -0,0 +1,369 @@ +import sys +sys.path.append('./') + +from typing import Tuple + +import os +import cv2 +import math +import torch +import random +import numpy as np +import argparse + +import PIL +from PIL import Image + +import diffusers +from diffusers.utils import load_image +from diffusers.models import ControlNetModel +from diffusers import LCMScheduler + +from huggingface_hub import hf_hub_download + +import insightface +from insightface.app import FaceAnalysis + +from style_template import styles +from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline +from model_util import load_models_xl, get_torch_device, torch_gc + +from cv2 import imencode +import base64 + +# def encode_pil_to_base64_new(pil_image): +# print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") +# image_arr = np.asarray(pil_image)[:,:,::-1] +# _, byte_data = imencode('.png', image_arr) +# base64_data = base64.b64encode(byte_data) +# base64_string_opencv = base64_data.decode("utf-8") +# return "data:image/png;base64," + base64_string_opencv + +import gradio as gr + + +# global variable +MAX_SEED = np.iinfo(np.int32).max +device = get_torch_device() +dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32 +STYLE_NAMES = list(styles.keys()) +DEFAULT_STYLE_NAME = "Watercolor" + +# Load face encoder +app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) +app.prepare(ctx_id=0, det_size=(320, 320)) + +# Path to InstantID models +face_adapter = f'./checkpoints/ip-adapter.bin' +controlnet_path = f'./checkpoints/ControlNetModel' + +# Load pipeline +controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=dtype) + +logo = Image.open("./gradio_demo/logo.png") + +pretrained_model_name_or_path="wangqixun/YamerMIX_v8" + + +if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + scheduler_kwargs = hf_hub_download( + repo_id="wangqixun/YamerMIX_v8", + subfolder="scheduler", + filename="scheduler_config.json", + ) + + (tokenizers, text_encoders, unet, _, vae) = load_models_xl( + pretrained_model_name_or_path=pretrained_model_name_or_path, + scheduler_name=None, + weight_dtype=dtype, + ) + + scheduler = diffusers.EulerDiscreteScheduler.from_config(scheduler_kwargs) + pipe = StableDiffusionXLInstantIDPipeline( + vae=vae, + text_encoder=text_encoders[0], + text_encoder_2=text_encoders[1], + tokenizer=tokenizers[0], + tokenizer_2=tokenizers[1], + unet=unet, + scheduler=scheduler, + controlnet=controlnet, + ).to(device) + +else: + pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( + pretrained_model_name_or_path, + controlnet=controlnet, + torch_dtype=dtype, + safety_checker=None, + feature_extractor=None, + ).to(device) + + pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config) + +pipe.load_ip_adapter_instantid(face_adapter) +# load and disable LCM +pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl") +pipe.disable_lora() + +# gr.processing_utils.encode_pil_to_base64 = encode_pil_to_base64_new +def remove_tips(): + print("GG") + return gr.update(visible=False) + +def convert_from_cv2_to_image(img: np.ndarray) -> Image: + return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + +def convert_from_image_to_cv2(img: Image) -> np.ndarray: + return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + +def run_for_prompts1(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[0], n) + # else: + # raise gr.Error("Email ID is compulsory") +def run_for_prompts2(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[1], n) + +def run_for_prompts3(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[2], n) + +def run_for_prompts4(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[3], n) + + +def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]): + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + +def resize_img(input_image, max_side=640, min_side=640, size=None, + pad_to_max_side=True, mode=PIL.Image.BILINEAR, base_pixel_number=64): + + w, h = input_image.size + print(w) + print(h) + if size is not None: + w_resize_new, h_resize_new = size + else: + ratio = min_side / min(h, w) + w, h = round(ratio*w), round(ratio*h) + ratio = max_side / max(h, w) + input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) + w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number + h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number + input_image = input_image.resize([w_resize_new, h_resize_new], mode) + + if pad_to_max_side: + res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 + offset_x = (max_side - w_resize_new) // 2 + offset_y = (max_side - h_resize_new) // 2 + res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) + input_image = Image.fromarray(res) + return input_image + + +def generate_image(face_image,prompt,negative_prompt): + pose_image_path = None + # prompt = "superman" + enable_LCM = False + identitynet_strength_ratio = 0.95 + adapter_strength_ratio = 0.60 + num_steps = 15 + guidance_scale = 8.5 + seed = random.randint(0, MAX_SEED) + # negative_prompt = "" + # negative_prompt += neg + enhance_face_region = True + if enable_LCM: + pipe.enable_lora() + pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) + else: + pipe.disable_lora() + pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config) + + if face_image is None: + raise gr.Error(f"Cannot find any input face image! Please upload the face image") + + # if prompt is None: + # prompt = "a person" + + # apply the style template + # prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt) + + # face_image = load_image(face_image_path) + face_image = resize_img(face_image) + face_image_cv2 = convert_from_image_to_cv2(face_image) + height, width, _ = face_image_cv2.shape + + # Extract face features + face_info = app.get(face_image_cv2) + + if len(face_info) == 0: + raise gr.Error(f"Cannot find any face in the image! Please upload another person image") + + face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face + face_emb = face_info['embedding'] + face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info['kps']) + + if pose_image_path is not None: + pose_image = load_image(pose_image_path) + pose_image = resize_img(pose_image) + pose_image_cv2 = convert_from_image_to_cv2(pose_image) + + face_info = app.get(pose_image_cv2) + + if len(face_info) == 0: + raise gr.Error(f"Cannot find any face in the reference image! Please upload another person image") + + face_info = face_info[-1] + face_kps = draw_kps(pose_image, face_info['kps']) + + width, height = face_kps.size + + if enhance_face_region: + control_mask = np.zeros([height, width, 3]) + x1, y1, x2, y2 = face_info["bbox"] + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + control_mask[y1:y2, x1:x2] = 255 + control_mask = Image.fromarray(control_mask.astype(np.uint8)) + else: + control_mask = None + + generator = torch.Generator(device=device).manual_seed(seed) + + print("Start inference...") + print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}") + + pipe.set_ip_adapter_scale(adapter_strength_ratio) + images = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + image_embeds=face_emb, + image=face_kps, + control_mask=control_mask, + controlnet_conditioning_scale=float(identitynet_strength_ratio), + num_inference_steps=num_steps, + guidance_scale=guidance_scale, + height=height, + width=width, + generator=generator, + # num_images_per_prompt = 4 + ).images + + return images[0] + +def main(pretrained_model_name_or_path="wangqixun/YamerMIX_v8", enable_lcm_arg=False): + + + + + ### Description + title = r""" +

Choose your AVATAR

+ """ + + description = r""" +

Powered by IDfy

""" + + article = r"""""" + + tips = r"""""" + + js = ''' ''' + + css = ''' + .gradio-container {width: 95% !important; background-color: #E6F3FF;} + .image-gallery {height: 100vh !important; overflow: auto;} + .gradio-row .gradio-element { margin: 0 !important; } + ''' + + + with gr.Blocks(css=css, js=js) as demo: + + # description + gr.Markdown(title) + with gr.Row(): + gr.Image("./gradio_demo/logo.png",scale=0,min_width=50,show_label=False,show_download_button=False) + gr.Markdown(description) + with gr.Row(): + with gr.Column(): + style = gr.Dropdown(label="Choose your STYLE", choices=STYLE_NAMES) + face_file = gr.Image(label="Upload a photo of your face", type="pil",sources="webcam") + submit = gr.Button("Submit", variant="primary") + with gr.Column(): + with gr.Row(): + gallery1 = gr.Image(label="Generated Images") + gallery2 = gr.Image(label="Generated Images") + with gr.Row(): + gallery3 = gr.Image(label="Generated Images") + gallery4 = gr.Image(label="Generated Images") + email = gr.Textbox(label="Email", + info="Enter your email address", + value="") + + usage_tips = gr.Markdown(label="Usage tips of InstantID", value=tips ,visible=False) + + face_file.upload( + fn=remove_tips, + outputs=usage_tips, + queue=True, + api_name=False, + show_progress = "full" + ) + + submit.click( + fn=remove_tips, + outputs=usage_tips, + queue=True, + api_name=False, + show_progress = "full" + ).then( + fn=run_for_prompts1, + inputs=[face_file,style], + outputs=[gallery1] + ) + + + gr.Markdown(article) + + demo.launch(share=True) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained_model_name_or_path", type=str, default="wangqixun/YamerMIX_v8") + args = parser.parse_args() + + main(args.pretrained_model_name_or_path, False) \ No newline at end of file diff --git a/gradio_demo/download_models.py b/gradio_demo/download_models.py new file mode 100644 index 0000000000000000000000000000000000000000..ca1f0780a83e367667a41ee6029a1a7edaa22e02 --- /dev/null +++ b/gradio_demo/download_models.py @@ -0,0 +1,27 @@ +from huggingface_hub import hf_hub_download +import gdown +import os + +# download models +hf_hub_download( + repo_id="InstantX/InstantID", + filename="ControlNetModel/config.json", + local_dir="./checkpoints", +) +hf_hub_download( + repo_id="InstantX/InstantID", + filename="ControlNetModel/diffusion_pytorch_model.safetensors", + local_dir="./checkpoints", +) +hf_hub_download( + repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints" +) +hf_hub_download( + repo_id="latent-consistency/lcm-lora-sdxl", + filename="pytorch_lora_weights.safetensors", + local_dir="./checkpoints", +) +# download antelopev2 +gdown.download(url="https://drive.google.com/file/d/18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8/view?usp=sharing", output="./models/", quiet=False, fuzzy=True) +# unzip antelopev2.zip +os.system("unzip ./models/antelopev2.zip -d ./models/") \ No newline at end of file diff --git a/gradio_demo/logo.png b/gradio_demo/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..ba0d13e4f52bc5bf51f45efc185c5208a122c567 Binary files /dev/null and b/gradio_demo/logo.png differ diff --git a/gradio_demo/logo1.png b/gradio_demo/logo1.png new file mode 100644 index 0000000000000000000000000000000000000000..d33f17760db871dd28ea33a132a9afda92089a6e Binary files /dev/null and b/gradio_demo/logo1.png differ diff --git a/gradio_demo/model_util.py b/gradio_demo/model_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0cee9dc2bbaa3356f24a8e5ca103d8eb6a599317 --- /dev/null +++ b/gradio_demo/model_util.py @@ -0,0 +1,472 @@ +from typing import Literal, Union, Optional, Tuple, List + +import torch +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection +from diffusers import ( + UNet2DConditionModel, + SchedulerMixin, + StableDiffusionPipeline, + StableDiffusionXLPipeline, + AutoencoderKL, +) +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + convert_ldm_unet_checkpoint, +) +from safetensors.torch import load_file +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + UniPCMultistepScheduler, +) + +from omegaconf import OmegaConf + +# DiffUsers版StableDiffusionのモデルパラメータ +NUM_TRAIN_TIMESTEPS = 1000 +BETA_START = 0.00085 +BETA_END = 0.0120 + +UNET_PARAMS_MODEL_CHANNELS = 320 +UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] +UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] +UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32` +UNET_PARAMS_IN_CHANNELS = 4 +UNET_PARAMS_OUT_CHANNELS = 4 +UNET_PARAMS_NUM_RES_BLOCKS = 2 +UNET_PARAMS_CONTEXT_DIM = 768 +UNET_PARAMS_NUM_HEADS = 8 +# UNET_PARAMS_USE_LINEAR_PROJECTION = False + +VAE_PARAMS_Z_CHANNELS = 4 +VAE_PARAMS_RESOLUTION = 256 +VAE_PARAMS_IN_CHANNELS = 3 +VAE_PARAMS_OUT_CH = 3 +VAE_PARAMS_CH = 128 +VAE_PARAMS_CH_MULT = [1, 2, 4, 4] +VAE_PARAMS_NUM_RES_BLOCKS = 2 + +# V2 +V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] +V2_UNET_PARAMS_CONTEXT_DIM = 1024 +# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True + +TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" +TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" + +AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a", "euler", "uniPC"] + +SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] + +DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this + + +def load_checkpoint_with_text_encoder_conversion(ckpt_path: str, device="cpu"): + # text encoderの格納形式が違うモデルに対応する ('text_model'がない) + TEXT_ENCODER_KEY_REPLACEMENTS = [ + ( + "cond_stage_model.transformer.embeddings.", + "cond_stage_model.transformer.text_model.embeddings.", + ), + ( + "cond_stage_model.transformer.encoder.", + "cond_stage_model.transformer.text_model.encoder.", + ), + ( + "cond_stage_model.transformer.final_layer_norm.", + "cond_stage_model.transformer.text_model.final_layer_norm.", + ), + ] + + if ckpt_path.endswith(".safetensors"): + checkpoint = None + state_dict = load_file(ckpt_path) # , device) # may causes error + else: + checkpoint = torch.load(ckpt_path, map_location=device) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + checkpoint = None + + key_reps = [] + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + for key in state_dict.keys(): + if key.startswith(rep_from): + new_key = rep_to + key[len(rep_from) :] + key_reps.append((key, new_key)) + + for key, new_key in key_reps: + state_dict[new_key] = state_dict[key] + del state_dict[key] + + return checkpoint, state_dict + + +def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # unet_params = original_config.model.params.unet_config.params + + block_out_channels = [ + UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT + ] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnDownBlock2D" + if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS + else "DownBlock2D" + ) + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnUpBlock2D" + if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS + else "UpBlock2D" + ) + up_block_types.append(block_type) + resolution //= 2 + + config = dict( + sample_size=UNET_PARAMS_IMAGE_SIZE, + in_channels=UNET_PARAMS_IN_CHANNELS, + out_channels=UNET_PARAMS_OUT_CHANNELS, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, + cross_attention_dim=UNET_PARAMS_CONTEXT_DIM + if not v2 + else V2_UNET_PARAMS_CONTEXT_DIM, + attention_head_dim=UNET_PARAMS_NUM_HEADS + if not v2 + else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, + # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION, + ) + if v2 and use_linear_projection_in_v2: + config["use_linear_projection"] = True + + return config + + +def load_diffusers_model( + pretrained_model_name_or_path: str, + v2: bool = False, + clip_skip: Optional[int] = None, + weight_dtype: torch.dtype = torch.float32, +) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: + if v2: + tokenizer = CLIPTokenizer.from_pretrained( + TOKENIZER_V2_MODEL_NAME, + subfolder="tokenizer", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + # default is clip skip 2 + num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + else: + tokenizer = CLIPTokenizer.from_pretrained( + TOKENIZER_V1_MODEL_NAME, + subfolder="tokenizer", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="unet", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + + return tokenizer, text_encoder, unet, vae + + +def load_checkpoint_model( + checkpoint_path: str, + v2: bool = False, + clip_skip: Optional[int] = None, + weight_dtype: torch.dtype = torch.float32, +) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: + pipe = StableDiffusionPipeline.from_single_file( + checkpoint_path, + upcast_attention=True if v2 else False, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + _, state_dict = load_checkpoint_with_text_encoder_conversion(checkpoint_path) + unet_config = create_unet_diffusers_config(v2, use_linear_projection_in_v2=v2) + unet_config["class_embed_type"] = None + unet_config["addition_embed_type"] = None + converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config) + unet = UNet2DConditionModel(**unet_config) + unet.load_state_dict(converted_unet_checkpoint) + + tokenizer = pipe.tokenizer + text_encoder = pipe.text_encoder + vae = pipe.vae + if clip_skip is not None: + if v2: + text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1) + else: + text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1) + + del pipe + + return tokenizer, text_encoder, unet, vae + + +def load_models( + pretrained_model_name_or_path: str, + scheduler_name: str, + v2: bool = False, + v_pred: bool = False, + weight_dtype: torch.dtype = torch.float32, +) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]: + if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + tokenizer, text_encoder, unet, vae = load_checkpoint_model( + pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype + ) + else: # diffusers + tokenizer, text_encoder, unet, vae = load_diffusers_model( + pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype + ) + + if scheduler_name: + scheduler = create_noise_scheduler( + scheduler_name, + prediction_type="v_prediction" if v_pred else "epsilon", + ) + else: + scheduler = None + + return tokenizer, text_encoder, unet, scheduler, vae + + +def load_diffusers_model_xl( + pretrained_model_name_or_path: str, + weight_dtype: torch.dtype = torch.float32, +) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: + # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet + + tokenizers = [ + CLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ), + CLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer_2", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + pad_token_id=0, # same as open clip + ), + ] + + text_encoders = [ + CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ), + CLIPTextModelWithProjection.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder_2", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ), + ] + + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="unet", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + return tokenizers, text_encoders, unet, vae + + +def load_checkpoint_model_xl( + checkpoint_path: str, + weight_dtype: torch.dtype = torch.float32, +) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: + pipe = StableDiffusionXLPipeline.from_single_file( + checkpoint_path, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + unet = pipe.unet + vae = pipe.vae + tokenizers = [pipe.tokenizer, pipe.tokenizer_2] + text_encoders = [pipe.text_encoder, pipe.text_encoder_2] + if len(text_encoders) == 2: + text_encoders[1].pad_token_id = 0 + + del pipe + + return tokenizers, text_encoders, unet, vae + + +def load_models_xl( + pretrained_model_name_or_path: str, + scheduler_name: str, + weight_dtype: torch.dtype = torch.float32, + noise_scheduler_kwargs=None, +) -> Tuple[ + List[CLIPTokenizer], + List[SDXL_TEXT_ENCODER_TYPE], + UNet2DConditionModel, + SchedulerMixin, +]: + if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + (tokenizers, text_encoders, unet, vae) = load_checkpoint_model_xl( + pretrained_model_name_or_path, weight_dtype + ) + else: # diffusers + (tokenizers, text_encoders, unet, vae) = load_diffusers_model_xl( + pretrained_model_name_or_path, weight_dtype + ) + if scheduler_name: + scheduler = create_noise_scheduler(scheduler_name, noise_scheduler_kwargs) + else: + scheduler = None + + return tokenizers, text_encoders, unet, scheduler, vae + +def create_noise_scheduler( + scheduler_name: AVAILABLE_SCHEDULERS = "ddpm", + noise_scheduler_kwargs=None, + prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", +) -> SchedulerMixin: + name = scheduler_name.lower().replace(" ", "_") + if name.lower() == "ddim": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim + scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) + elif name.lower() == "ddpm": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm + scheduler = DDPMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) + elif name.lower() == "lms": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete + scheduler = LMSDiscreteScheduler( + **OmegaConf.to_container(noise_scheduler_kwargs) + ) + elif name.lower() == "euler_a": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral + scheduler = EulerAncestralDiscreteScheduler( + **OmegaConf.to_container(noise_scheduler_kwargs) + ) + elif name.lower() == "euler": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral + scheduler = EulerDiscreteScheduler( + **OmegaConf.to_container(noise_scheduler_kwargs) + ) + elif name.lower() == "unipc": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/unipc + scheduler = UniPCMultistepScheduler( + **OmegaConf.to_container(noise_scheduler_kwargs) + ) + else: + raise ValueError(f"Unknown scheduler name: {name}") + + return scheduler + + +def torch_gc(): + import gc + + gc.collect() + if torch.cuda.is_available(): + with torch.cuda.device("cuda"): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +from enum import Enum + + +class CPUState(Enum): + GPU = 0 + CPU = 1 + MPS = 2 + + +cpu_state = CPUState.GPU +xpu_available = False +directml_enabled = False + + +def is_intel_xpu(): + global cpu_state + global xpu_available + if cpu_state == CPUState.GPU: + if xpu_available: + return True + return False + + +try: + import intel_extension_for_pytorch as ipex + + if torch.xpu.is_available(): + xpu_available = True +except: + pass + +try: + if torch.backends.mps.is_available(): + cpu_state = CPUState.MPS + import torch.mps +except: + pass + + +def get_torch_device(): + global directml_enabled + global cpu_state + if directml_enabled: + global directml_device + return directml_device + if cpu_state == CPUState.MPS: + return torch.device("mps") + if cpu_state == CPUState.CPU: + return torch.device("cpu") + else: + if is_intel_xpu(): + return torch.device("xpu") + else: + return torch.device(torch.cuda.current_device()) diff --git a/gradio_demo/preprocess.py b/gradio_demo/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..70a598744e8c501fc0386792f58639826582eda9 --- /dev/null +++ b/gradio_demo/preprocess.py @@ -0,0 +1,232 @@ +import os +import random +import csv +import gc +import glob +from datetime import datetime +import time +from pathlib import Path +from style_template import style_list +from PIL import Image, ImageOps + +# Default Configuration variables +INPUT_FOLDER_NAME = 'examples' +OUTPUT_FOLDER_NAME = 'generated_images' +LOG_FILENAME = 'generation_log.csv' +logfile_path = os.path.join(os.getcwd(), LOG_FILENAME) + +PROMPT = "human, sharp focus" +NEGATIVE_PROMPT = "(blurry, blur, text, abstract, glitch, lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured" +IDENTITYNET_STRENGTH_RATIO_RANGE = (1.0, 1.5) +ADAPTER_STRENGTH_RATIO_RANGE = (0.7, 1.0) +NUM_INFERENCE_STEPS_RANGE = (40, 60) +GUIDANCE_SCALE_RANGE = (7.0, 12.0) +MAX_SIDE = 1280 +MIN_SIDE = 1024 +NUMBER_OF_LOOPS = 1 + +# Dynamically create the STYLES list from imported style_list +STYLES = [style["name"] for style in style_list] +USE_RANDOM_STYLE = False + +def choose_random_style(): + return random.choice(STYLES) + +def get_random_image_file(input_folder): + valid_extensions = [".jpg", ".jpeg", ".png"] + files = [file for file in Path(input_folder).glob("*") if file.suffix.lower() in valid_extensions] + if not files: + raise FileNotFoundError(f"No images found in directory {input_folder}") + return str(random.choice(files)) + +def resize_and_pad_image(image_path, max_side, min_side, pad_color=(255, 255, 255)): + # Open an image using PIL + image = Image.open(image_path) + + # Calculate the scale and new size + ratio = min(min_side / min(image.size), max_side / max(image.size)) + new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) + + # Resize the image + image = image.resize(new_size, Image.BILINEAR) + + # Calculate padding + delta_w = max_side - new_size[0] + delta_h = max_side - new_size[1] + + # Pad the resized image to make it square + padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) + image = ImageOps.expand(image, padding, pad_color) + + return image + +def log_to_csv(logfile_path, image_name, new_file_name='Unknown', identitynet_strength_ratio=0.0, adapter_strength_ratio=0.0, num_inference_steps=0, guidance_scale=0.0, seed=0, success=True, error_message='', style_name="", prompt="", negative_prompt="", time_taken=0.0, current_timestamp=""): + os.makedirs(os.path.dirname(logfile_path), exist_ok=True) + file_exists = os.path.isfile(logfile_path) + + with open(logfile_path, 'a', newline='', encoding='utf-8') as csvfile: + fieldnames = ['image_name', 'new_file_name', 'identitynet_strength_ratio', 'adapter_strength_ratio', 'num_inference_steps', 'guidance_scale', 'seed', 'success', 'error_message', 'style_name', 'prompt', 'negative_prompt', 'time_taken', 'current_timestamp'] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + if not file_exists: + writer.writeheader() + writer.writerow({ + 'image_name': image_name, + 'new_file_name': new_file_name, + 'identitynet_strength_ratio': identitynet_strength_ratio, + 'adapter_strength_ratio': adapter_strength_ratio, + 'num_inference_steps': num_inference_steps, + 'guidance_scale': guidance_scale, + 'seed': seed, + 'success': success, + 'error_message': error_message, + 'style_name': style_name, + 'prompt': prompt, + 'negative_prompt': negative_prompt, + 'time_taken': time_taken, + 'current_timestamp': current_timestamp + }) + +def initial_image(generate_image_func): + overall_start_time = time.time() + total_time_taken = 0.0 + + # Initialize a counter for processed images at the beginning of the function + processed_images_count = 0 + + # List all image files in the `INPUT_FOLDER_NAME` + image_files = glob.glob(f'{INPUT_FOLDER_NAME}/*.png') + \ + glob.glob(f'{INPUT_FOLDER_NAME}/*.jpg') + \ + glob.glob(f'{INPUT_FOLDER_NAME}/*.jpeg') + + # Check if we found any images + if not image_files: + raise FileNotFoundError(f"No images found in directory {INPUT_FOLDER_NAME}") + + # Print the count of detected image files + print(f"Processing a total of {len(image_files)} image(s) in '{INPUT_FOLDER_NAME}'") + + # Shuffle the image files randomly + random.shuffle(image_files) + + total_images = len(image_files) # Get the total number of images to process + + for loop in range(NUMBER_OF_LOOPS): + print(f"Starting loop {loop+1} of {NUMBER_OF_LOOPS}") + + for image_number, face_image_path in enumerate(image_files, start=1): + loop_start_time = datetime.now() + face_image = [face_image_path] + basename = os.path.basename(face_image_path) + processed_images_count += 1 + + # Resize and pad the image before processing + processed_image = resize_and_pad_image( + image_path=face_image_path, + max_side=MAX_SIDE, + min_side=MIN_SIDE + ) + + if USE_RANDOM_STYLE: + style_name = choose_random_style() + else: + style_name = "(No style)" + + identitynet_strength_ratio = random.uniform(*IDENTITYNET_STRENGTH_RATIO_RANGE) + adapter_strength_ratio = random.uniform(*ADAPTER_STRENGTH_RATIO_RANGE) + num_inference_steps = random.randint(*NUM_INFERENCE_STEPS_RANGE) + guidance_scale = random.uniform(*GUIDANCE_SCALE_RANGE) + seed = random.randint(0, 2**32 - 1) + + # Print settings for the current image BEFORE processing it + print_generation_settings(basename, style_name, identitynet_strength_ratio, + adapter_strength_ratio, num_inference_steps, guidance_scale, seed, + image_number, total_images) + + # Here, the generate_image_func is supposedly called and image processing happens + _, _, generated_file_paths = generate_image_func( + face_image=face_image, + pose_image=None, + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + style_name=style_name, + enhance_face_region=True, + num_steps=num_inference_steps, + identitynet_strength_ratio=identitynet_strength_ratio, + adapter_strength_ratio=adapter_strength_ratio, + guidance_scale=guidance_scale, + seed=seed + ) + + loop_end_time = datetime.now() + loop_time_taken = (loop_end_time - loop_start_time).total_seconds() + + # Immediately print the time taken and current time. + print(f"Time taken to process image: {loop_time_taken:.2f} seconds") + + # Update the total time taken with this image's processing time + total_time_taken += loop_time_taken + + # Calculate the average time taken per image + average_time_per_image = total_time_taken / image_number + + current_timestamp = loop_end_time.strftime("%Y-%m-%d %H:%M:%S") # Current time after processing + print(f"Current timestamp: {current_timestamp}") + + # Calculate estimated remaining time considering the images left in this loop and the additional loops + remaining_images_this_loop = total_images - image_number + remaining_images_in_additional_loops = (NUMBER_OF_LOOPS - (loop + 1)) * total_images + total_remaining_images = remaining_images_this_loop + remaining_images_in_additional_loops + estimated_time_remaining = average_time_per_image * total_remaining_images + + # Display the estimated time remaining including remaining loops + print(f"Estimated time remaining (including loops): {estimated_time_remaining // 60:.0f} minutes, {estimated_time_remaining % 60:.0f} seconds") + + # Display the overall average time per image in seconds + print(f"Overall average time per image: {average_time_per_image:.2f} seconds") + + # Display the total number of remaining images to process including looping + print(f"Total remaining images to process (including loops): {total_remaining_images}") + + + success = True # Assuming generation was successful. + error_message = "" # Assuming no error. + + # Log to CSV after the image generation. + for generated_file_path in generated_file_paths: + new_file_name = os.path.basename(generated_file_path) + log_to_csv(logfile_path, basename, new_file_name, identitynet_strength_ratio, + adapter_strength_ratio, num_inference_steps, guidance_scale, seed, success, + error_message, style_name, PROMPT, NEGATIVE_PROMPT, loop_time_taken, current_timestamp) + + + del generated_file_paths # Explicitly delete large variables + gc.collect() # Call garbage collection + + + # At the end of the initial_image() function, add: + total_elapsed_time = time.time() - overall_start_time + print("\n===FINAL SUMMARY===") + print(f"Total loops completed: {NUMBER_OF_LOOPS}") + print(f"Total images processed per loop: {len(image_files)}") + print(f"Overall total images processed: {NUMBER_OF_LOOPS * len(image_files)}") # Multiplied by the number of loops + print(f"Overall total time: {total_elapsed_time / 60:.2f} minutes") + + +def print_generation_settings(basename, style_name, identitynet_strength_ratio, adapter_strength_ratio, num_inference_steps, guidance_scale, seed, image_number, total_images): + print("===IMAGE GENERATION DATA SUMMARY===") + # Print settings for the current image + print(f"- Image {image_number} of {total_images}\n" + f"- Filename: {basename}\n" + f"- Style: {style_name}\n" + f"- IdentityNet strength ratio: {identitynet_strength_ratio:0.2f}\n" + f"- Adapter strength ratio: {adapter_strength_ratio:0.2f}\n" + f"- Number of inference steps: {num_inference_steps}\n" + f"- Guidance scale: {guidance_scale:0.2f}\n" + f"- Seed: {seed}\n" + f"- Input folder name: {INPUT_FOLDER_NAME}\n" + f"- Output folder name: {OUTPUT_FOLDER_NAME}\n" + f"- Prompt: {PROMPT}\n" + f"- Negative prompt: {NEGATIVE_PROMPT}\n" + f"- Number of loops: {NUMBER_OF_LOOPS}\n" + f"- Use random style: {USE_RANDOM_STYLE}\n") + print("===DEFINING COMPLETE, GENERATING IMAGE...===") \ No newline at end of file diff --git a/gradio_demo/requirements.txt b/gradio_demo/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b592707c824e9e1521bcb3fcb4ded0e8629d7f5f --- /dev/null +++ b/gradio_demo/requirements.txt @@ -0,0 +1,19 @@ +diffusers==0.25.1 +torch==2.0.0 +torchvision==0.15.1 +transformers==4.37.1 +accelerate==0.25.0 +safetensors==0.4.3 +einops==0.7.0 +onnxruntime-gpu==1.18.1 +spaces==0.19.4 +omegaconf==2.3.0 +peft==0.11.1 +huggingface-hub==0.23.4 +opencv-python==4.10.0.84 +insightface==0.7.3 +gradio==4.38.1 +controlnet_aux==0.0.9 +gdown==5.2.0 +peft==0.11.1 +setuptools=71.1.0 \ No newline at end of file diff --git a/gradio_demo/style_template.py b/gradio_demo/style_template.py new file mode 100644 index 0000000000000000000000000000000000000000..3e7874d5414d4c50cbf47654f567bb54a7764933 --- /dev/null +++ b/gradio_demo/style_template.py @@ -0,0 +1,136 @@ +# style_list = [ +# { +# "name": "Professional", +# "prompt": ["Minimalist style, Simple, clean, uncluttered, modern, elegant, white background, professional photo, linkedin profile photo, formal attire, professional look","Minimalist style, Simple, clean, uncluttered, modern, elegant, white background, professional photo, linkedin profile photo, formal attire, professional look","Minimalist style, Simple, clean, uncluttered, modern, elegant, white background, professional photo, linkedin profile photo, formal attire, professional look","Minimalist style, Simple, clean, uncluttered, modern, elegant, white background, professional photo, linkedin profile photo, formal attire, professional look"], +# # "prompt": ["Minimalist style, Simple, clean, uncluttered, modern, elegant, white background, professional photo, linkedin profile photo, formal attire, professional look,highly detailed, sharp focus","Minimalist style, Simple, clean, uncluttered, modern, elegant, white background, professional photo, linkedin profile photo, formal attire, professional look","Minimalist style, Simple, clean, uncluttered, modern, elegant, white background, professional photo, linkedin profile photo, formal attire, professional look","Minimalist style, Simple, clean, uncluttered, modern, elegant, white background, professional photo, linkedin profile photo, formal attire, professional look"], +# "negative_prompt": +# # "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, blurry, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white, multiple people, green, deformed" +# # "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly" + +# # "lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed" + +# "high saturation, multiple people, two people, patchy, photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch,blurred, blurry, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white" +# }, +# { +# "name": "Quirky", +# "prompt": ["vibrant colorful, ink sketch|vector|2d colors, sharp focus, superman/wonderwoman, highly detailed, the clouds,colorful,ultra sharpness,4k","watercolor painting, japanese anime character with white/neon hair. vibrant, beautiful, painterly, detailed, textural, artistic","vibrant colorful, ink sketch|vector|2d colors, sharp focus, scooba diver, highly detailed, the ocean,fishes,colorful,ultra sharpness,4k","individual dressed as an eskimo, surrounded by snowy mountains and igloo, snow crystals, cold, windy background, frozen natural landscape in background,highly detailed, sharp focus, intricate design, 4k resolution"], +# "negative_prompt": "saturation, highly saturated,(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, multiple people, buildings in background, green" +# }, +# { +# "name": "Sci-fi", +# "prompt": ["ethereal fantasy concept art individual, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy","Dystopian style cyborg. Bleak, post-apocalyptic, somber, dramatic, highly detailed","Alien-themed, Extraterrestrial, cosmic, otherworldly, mysterious, sci-fi, highly detailed", "Legend of Zelda style . Vibrant, fantasy, detailed, epic, heroic, reminiscent of The Legend of Zelda series"], +# "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white, multiple people, green, deformed", +# } +# ] + +# styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} + +# # lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed + +# style_list = [ +# { +# "name": "(No style)", +# "prompt": "Realistic, 4k resolution, ultra sharpness, {prompt} sitiing at a desk, office environment, professional photoshoot", +# "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", +# }, +# { +# "name": "Watercolor", +# "prompt": "watercolor painting, japanese anime character with white/neon hair. vibrant, beautiful, painterly, detailed, textural, artistic", +# "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy", +# }, +# { +# "name": "Film Noir", +# "prompt": "film noir style, ink sketch|vector, {prompt} highly detailed, sharp focus, ultra sharpness, monochrome, high contrast, dramatic shadows, 1940s style, mysterious, cinematic", +# "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", +# }, +# { +# "name": "Neon", +# "prompt": "masterpiece painting, buildings in the backdrop, kaleidoscope, lilac orange blue cream fuchsia bright vivid gradient colors, the scene is cinematic, {prompt}, emotional realism, double exposure, watercolor ink pencil, graded wash, color layering, magic realism, figurative painting, intricate motifs, organic tracery, polished", +# "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", +# }, +# { +# "name": "Jungle", +# "prompt": 'waist-up "{prompt} in a Jungle" by Syd Mead, tangerine cold color palette, muted colors, detailed, 8k,photo r3al,dripping paint,3d toon style,3d style,Movie Still', +# "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", +# }, +# { +# "name": "Mars", +# "prompt": "{prompt}, Post-apocalyptic. Mars Colony, Scavengers roam the wastelands searching for valuable resources, rovers, bright morning sunlight shining, (detailed) (intricate) (8k) (HDR) (cinematic lighting) (sharp focus)", +# "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", +# }, +# { +# "name": "Vibrant Color", +# "prompt": "vibrant colorful, ink sketch|vector|2d colors, sharp focus, {prompt}, highly detailed, the clouds,colorful,ultra sharpness,4k", +# "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly,distorted, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", +# }, +# { +# "name": "Snow", +# "prompt": "individual dressed as an {prompt}, high contrast, surrounded by snowy mountains and igloo, snow crystals, cold, windy background, frozen natural landscape in background,highly detailed, sharp focus, intricate design, 4k resolution", +# "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", +# }, +# { +# "name": "Line art", +# "prompt": "vibrant colorful, sharp focus,individual wearing {prompt} costume, highly detailed, sharp focus, the ocean, fishes swimming in the background,coral reef behind, ocean landscape, 4k, colorful", +# "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", +# }, +# ] + +# styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} + +style_list = [ + { + "name": "Professional", + # "prompt": ["Minimalist style, Simple, clean, uncluttered, modern, elegant, White background, suit and tie, professional look","Minimalist style, Simple, clean, uncluttered, modern, elegant, white background, professional photo, linkedin profile photo, formal attire, professional look","Minimalist style, Simple, clean, uncluttered, modern, elegant, formal attire, sitting on a chair, professional look","Minimalist style, Simple, clean, uncluttered, modern, elegant, formal attire, individual sitiing at a desk, office environment, professional look"], + "prompt": ["professional portrait, gender-aligned, natural skin tones, cinematic lighting, highly detailed, well-composed, professional photography, subtle background blur","Minimalist portrait, clean lines, soft colors, simple background, modern, elegant, subtle details, focus on facial features","Professional, Corporate, formal attire, polished, sharp features, clean background, high clarity, refined, business style","LinkedIn professional, business attire, neutral background, sharp focus, approachable, polished, suited for professional networking"], + "negative_prompt": "oversaturated, unnatural skin tones, deformed, disfigured, low resolution, cartoonish, unrealistic" + }, + # { + # "name": "Watercolor", + # "prompt": "watercolor painting, {prompt}. vibrant, beautiful, painterly, detailed, textural, artistic", + # "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy", + # }, + # { + # "name": "Film Noir", + # "prompt": "film noir style, ink sketch|vector, {prompt} highly detailed, sharp focus, ultra sharpness, monochrome, high contrast, dramatic shadows, 1940s style, mysterious, cinematic", + # "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", + # }, + # { + # "name": "Neon", + # "prompt": "masterpiece painting, buildings in the backdrop, kaleidoscope, lilac orange blue cream fuchsia bright vivid gradient colors, the scene is cinematic, {prompt}, emotional realism, double exposure, watercolor ink pencil, graded wash, color layering, magic realism, figurative painting, intricate motifs, organic tracery, polished", + # "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", + # }, + # { + # "name": "Jungle", + # "prompt": 'waist-up "{prompt} in a Jungle" by Syd Mead, tangerine cold color palette, muted colors, detailed, 8k,photo r3al,dripping paint,3d toon style,3d style,Movie Still', + # "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", + # }, + # { + # "name": "Mars", + # "prompt": "{prompt}, Post-apocalyptic. Mars Colony, Scavengers roam the wastelands searching for valuable resources, rovers, bright morning sunlight shining, (detailed) (intricate) (8k) (HDR) (cinematic lighting) (sharp focus)", + # "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", + # }, + # { + # "name": "Vibrant Color", + # "prompt": "vibrant colorful, ink sketch|vector|2d colors, sharp focus, {prompt}, highly detailed, the clouds,colorful,ultra sharpness,4k", + # "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly,distorted, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", + # }, + # { + # "name": "Snow", + # "prompt": "individual dressed as an {prompt}, high contrast, surrounded by snowy mountains and igloo, snow crystals, cold, windy background, frozen natural landscape in background,highly detailed, sharp focus, intricate design, 4k resolution", + # "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green", + # }, + { + "name": "Quirky", + "prompt": ["vibrant colorful, ink sketch|vector|2d colors, sharp focus, superman/wonderwoman, highly detailed, the clouds,colorful,ultra sharpness,4k","watercolor painting, japanese anime character with white/neon hair. vibrant, beautiful, painterly, detailed, textural, artistic","vibrant colorful, ink sketch|vector|2d colors, sharp focus, scooba diver, highly detailed, the ocean,fishes,colorful,ultra sharpness,4k","individual dressed as an eskimo, high contrast, surrounded by snowy mountains and igloo, snow crystals, cold, windy background, frozen natural landscape in background,highly detailed, sharp focus, intricate design, 4k resolution"], + "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green" + }, + { + "name": "Sci-fi", + "prompt": ["ethereal fantasy concept art individual, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy","Dystopian style cyborg. Bleak, post-apocalyptic, somber, dramatic, highly detailed","Alien-themed, Extraterrestrial, cosmic, otherworldly, mysterious, sci-fi, highly detailed", "Legend of Zelda style . Vibrant, fantasy, detailed, epic, heroic, reminiscent of The Legend of Zelda series"], + "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white, multiple people", + } +] + +styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} + +# lowres \ No newline at end of file diff --git a/gradio_demo/test.py b/gradio_demo/test.py new file mode 100644 index 0000000000000000000000000000000000000000..324a80d3f9ba9b7bd76723914b1b4600b08f26e6 --- /dev/null +++ b/gradio_demo/test.py @@ -0,0 +1,400 @@ +import sys +sys.path.append('./') + +from typing import Tuple + +import os +import cv2 +import math +import torch +import random +import numpy as np +import argparse +import pandas as pd + +import PIL +from PIL import Image + +import diffusers +from diffusers.utils import load_image +from diffusers.models import ControlNetModel +from diffusers import LCMScheduler + +from huggingface_hub import hf_hub_download + +import insightface +from insightface.app import FaceAnalysis + +from style_template import styles +from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline +from model_util import load_models_xl, get_torch_device, torch_gc + + +# global variable +MAX_SEED = np.iinfo(np.int32).max +device = get_torch_device() +dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32 +STYLE_NAMES = list(styles.keys()) +DEFAULT_STYLE_NAME = "Watercolor" + +# Load face encoder +app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) +app.prepare(ctx_id=0, det_size=(320, 320)) + +# Path to InstantID models +face_adapter = f'./checkpoints/ip-adapter.bin' +controlnet_path = f'./checkpoints/ControlNetModel' + +# Load pipeline +controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=dtype) + +logo = Image.open("./gradio_demo/logo.png") + +from cv2 import imencode +import base64 + +# def encode_pil_to_base64_new(pil_image): +# print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") +# image_arr = np.asarray(pil_image)[:,:,::-1] +# _, byte_data = imencode('.png', image_arr) +# base64_data = base64.b64encode(byte_data) +# base64_string_opencv = base64_data.decode("utf-8") +# return "data:image/png;base64," + base64_string_opencv + +import gradio as gr + +# gr.processing_utils.encode_pil_to_base64 = encode_pil_to_base64_new + +def main(pretrained_model_name_or_path="wangqixun/YamerMIX_v8", enable_lcm_arg=False): + + if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + scheduler_kwargs = hf_hub_download( + repo_id="wangqixun/YamerMIX_v8", + subfolder="scheduler", + filename="scheduler_config.json", + ) + + (tokenizers, text_encoders, unet, _, vae) = load_models_xl( + pretrained_model_name_or_path=pretrained_model_name_or_path, + scheduler_name=None, + weight_dtype=dtype, + ) + + scheduler = diffusers.EulerDiscreteScheduler.from_config(scheduler_kwargs) + pipe = StableDiffusionXLInstantIDPipeline( + vae=vae, + text_encoder=text_encoders[0], + text_encoder_2=text_encoders[1], + tokenizer=tokenizers[0], + tokenizer_2=tokenizers[1], + unet=unet, + scheduler=scheduler, + controlnet=controlnet, + ).to(device) + + else: + pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( + pretrained_model_name_or_path, + controlnet=controlnet, + torch_dtype=dtype, + safety_checker=None, + feature_extractor=None, + ).to(device) + + pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config) + + pipe.load_ip_adapter_instantid(face_adapter) + # load and disable LCM + pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl") + pipe.disable_lora() + + def remove_tips(): + return gr.update(visible=False) + + + # prompts = [ + # ["superman","Vibrant Color"], ["japanese anime character with white/neon hair","Watercolor"], + # # ["Suited professional","(No style)"], + # ["Scooba diver","Line art"], ["eskimo","Snow"] + # ] + + def convert_from_cv2_to_image(img: np.ndarray) -> Image: + return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + + def convert_from_image_to_cv2(img: Image) -> np.ndarray: + return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + + def run_for_prompts1(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[0], n) + # else: + # raise gr.Error("Email ID is compulsory") + def run_for_prompts2(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[1], n) + def run_for_prompts3(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[2], n) + def run_for_prompts4(face_file,style,progress=gr.Progress(track_tqdm=True)): + # if email != "": + p,n = styles.get(style, styles.get(STYLE_NAMES[1])) + return generate_image(face_file, p[3], n) + +# def validate_and_process(face_file, style, email): + +# # Your processing logic here +# gallery1, gallery2, gallery3, gallery4 = run_for_prompts1(face_file, style), run_for_prompts2(face_file, style), run_for_prompts3(face_file, style), run_for_prompts4(face_file, style) +# return gallery1, gallery2, gallery3, gallery4 + + def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]): + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + + def resize_img(input_image, max_side=640, min_side=640, size=None, + pad_to_max_side=True, mode=PIL.Image.BILINEAR, base_pixel_number=64): + + w, h = input_image.size + print(w) + print(h) + if size is not None: + w_resize_new, h_resize_new = size + else: + ratio = min_side / min(h, w) + w, h = round(ratio*w), round(ratio*h) + ratio = max_side / max(h, w) + input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) + w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number + h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number + input_image = input_image.resize([w_resize_new, h_resize_new], mode) + + if pad_to_max_side: + res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 + offset_x = (max_side - w_resize_new) // 2 + offset_y = (max_side - h_resize_new) // 2 + res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) + input_image = Image.fromarray(res) + return input_image + + def store_images(email, gallery1, gallery2, gallery3, gallery4): + galleries = [] + for i, img in enumerate([gallery1, gallery2, gallery3, gallery4], start=1): + if isinstance(img, np.ndarray): + img = Image.fromarray(img) + print(f"Gallery {i} type after conversion: {type(img)}") + galleries.append(img) + # Create the images directory if it doesn't exist + if not os.path.exists('images'): + os.makedirs('images') + + # Define image file paths + image_paths = [] + for i, img in enumerate(galleries, start=1): + img_path = f'images/{email}_gallery{i}.png' + img.save(img_path) + image_paths.append(img_path) + + # Define the CSV file path + csv_file_path = 'image_data.csv' + + # Create a DataFrame for the email and image paths + df = pd.DataFrame({ + 'email': [email], + 'img1_path': [image_paths[0]], + 'img2_path': [image_paths[1]], + 'img3_path': [image_paths[2]], + 'img4_path': [image_paths[3]], + }) + + # Write to CSV (append if the file exists, create a new one if it doesn't) + if not os.path.isfile(csv_file_path): + df.to_csv(csv_file_path, index=False) + else: + df.to_csv(csv_file_path, mode='a', header=False, index=False) + + + def generate_image(face_image,prompt,negative_prompt): + pose_image_path = None + # prompt = "superman" + enable_LCM = False + identitynet_strength_ratio = 0.95 + adapter_strength_ratio = 0.60 + num_steps = 15 + guidance_scale = 8.5 + seed = random.randint(0, MAX_SEED) + # negative_prompt = "" + # negative_prompt += neg + enhance_face_region = True + if enable_LCM: + pipe.enable_lora() + pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) + else: + pipe.disable_lora() + pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config) + + if face_image is None: + raise gr.Error(f"Cannot find any input face image! Please upload the face image") + + # if prompt is None: + # prompt = "a person" + + # apply the style template + # prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt) + + # face_image = load_image(face_image_path) + face_image = resize_img(face_image) + face_image_cv2 = convert_from_image_to_cv2(face_image) + height, width, _ = face_image_cv2.shape + + # Extract face features + face_info = app.get(face_image_cv2) + + if len(face_info) == 0: + raise gr.Error(f"Cannot find any face in the image! Please upload another person image") + + face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face + face_emb = face_info['embedding'] + face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info['kps']) + + if pose_image_path is not None: + pose_image = load_image(pose_image_path) + pose_image = resize_img(pose_image) + pose_image_cv2 = convert_from_image_to_cv2(pose_image) + + face_info = app.get(pose_image_cv2) + + if len(face_info) == 0: + raise gr.Error(f"Cannot find any face in the reference image! Please upload another person image") + + face_info = face_info[-1] + face_kps = draw_kps(pose_image, face_info['kps']) + + width, height = face_kps.size + + if enhance_face_region: + control_mask = np.zeros([height, width, 3]) + x1, y1, x2, y2 = face_info["bbox"] + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + control_mask[y1:y2, x1:x2] = 255 + control_mask = Image.fromarray(control_mask.astype(np.uint8)) + else: + control_mask = None + + generator = torch.Generator(device=device).manual_seed(seed) + + print("Start inference...") + print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}") + + pipe.set_ip_adapter_scale(adapter_strength_ratio) + images = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + image_embeds=face_emb, + image=face_kps, + control_mask=control_mask, + controlnet_conditioning_scale=float(identitynet_strength_ratio), + num_inference_steps=num_steps, + guidance_scale=guidance_scale, + height=height, + width=width, + generator=generator, + # num_images_per_prompt = 4 + ).images + + print(images[0]) + + return images[0] + + + ### Description + title = r""" +

Choose your AVATAR

+ """ + + description = r""" +

Powered by IDfy

""" + + article = r"""""" + + tips = r"""""" + + css = ''' + .gradio-container {width: 95% !important; background-color: #E6F3FF;} + .image-gallery {height: 100vh !important; overflow: auto;} + .gradio-row .gradio-element { margin: 0 !important; } + ''' + with gr.Blocks(css=css) as demo: + title = "

Choose your AVATAR

" + description = "

Powered by IDfy

" + + # Description + gr.Markdown(title) + with gr.Row(): + gr.Image("./gradio_demo/logo.png",scale=0,min_width=50,show_label=False,show_download_button=False) + gr.Markdown(description) + with gr.Row(): + with gr.Column(): + style = gr.Dropdown(label="Choose your STYLE", choices=STYLE_NAMES) + face_file = gr.Image(label="Upload a photo of your face", type="pil") + submit = gr.Button("Submit", variant="primary") + with gr.Column(): + with gr.Row(): + gallery1 = gr.Image(label="Generated Images") + gallery2 = gr.Image(label="Generated Images") + with gr.Row(): + gallery3 = gr.Image(label="Generated Images") + gallery4 = gr.Image(label="Generated Images") + email = gr.Textbox(label="Email", + info="Enter your email address", + value="") + submit1 = gr.Button("STORE", variant="primary") + usage_tips = gr.Markdown(label="Usage tips of InstantID", value="", visible=False) + + # Image upload and processing chain + face_file.upload(remove_tips, outputs=usage_tips).then(run_for_prompts1, inputs=[face_file, style], outputs=[gallery1]).then(run_for_prompts2, inputs=[face_file, style], outputs=[gallery2]).then(run_for_prompts3, inputs=[face_file, style], outputs=[gallery3]).then(run_for_prompts4, inputs=[face_file, style], outputs=[gallery4]) + submit.click(remove_tips, outputs=usage_tips).then(run_for_prompts1, inputs=[face_file, style], outputs=[gallery1]).then(run_for_prompts2, inputs=[face_file, style], outputs=[gallery2]).then(run_for_prompts3, inputs=[face_file, style], outputs=[gallery3]).then(run_for_prompts4, inputs=[face_file, style], outputs=[gallery4]) + + # Store data on button click + submit1.click( + fn=store_images, + inputs=[email,gallery1,gallery2,gallery3,gallery4], + outputs=None) + + gr.Markdown("") + + demo.launch(share=True) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained_model_name_or_path", type=str, default="wangqixun/YamerMIX_v8") + args = parser.parse_args() + + main(args.pretrained_model_name_or_path, False) \ No newline at end of file diff --git a/gradio_demo/watermark.png b/gradio_demo/watermark.png new file mode 100644 index 0000000000000000000000000000000000000000..c91f5ef36321fa98b1396d12dbca24c27a8bdf2e Binary files /dev/null and b/gradio_demo/watermark.png differ diff --git a/image_data.csv b/image_data.csv new file mode 100644 index 0000000000000000000000000000000000000000..13b6d49598ed8dda0849ced833df9a6e72345a23 --- /dev/null +++ b/image_data.csv @@ -0,0 +1,15 @@ +email,img1_path,img2_path,img3_path,img4_path +ll,images/ll_gallery1.png,images/ll_gallery2.png,images/ll_gallery3.png,images/ll_gallery4.png +kajal@img,images/kajal@img_gallery1.png,images/kajal@img_gallery2.png,images/kajal@img_gallery3.png,images/kajal@img_gallery4.png +heeral@img,images/heeral@img_gallery1.png,images/heeral@img_gallery2.png,images/heeral@img_gallery3.png,images/heeral@img_gallery4.png +sanskruti@img-scifi,images/sanskruti@img-scifi_gallery1.png,images/sanskruti@img-scifi_gallery2.png,images/sanskruti@img-scifi_gallery3.png,images/sanskruti@img-scifi_gallery4.png +sanskruti@img-scifi,images/sanskruti@img-scifi_gallery1.png,images/sanskruti@img-scifi_gallery2.png,images/sanskruti@img-scifi_gallery3.png,images/sanskruti@img-scifi_gallery4.png +sanskruti@img-quirky,images/sanskruti@img-quirky_gallery1.png,images/sanskruti@img-quirky_gallery2.png,images/sanskruti@img-quirky_gallery3.png,images/sanskruti@img-quirky_gallery4.png +kajal@quirky,images/kajal@quirky_gallery1.png,images/kajal@quirky_gallery2.png,images/kajal@quirky_gallery3.png,images/kajal@quirky_gallery4.png +kajal@prof,images/kajal@prof_gallery1.png,images/kajal@prof_gallery2.png,images/kajal@prof_gallery3.png,images/kajal@prof_gallery4.png +kajal@quirky,images/kajal@quirky_gallery1.png,images/kajal@quirky_gallery2.png,images/kajal@quirky_gallery3.png,images/kajal@quirky_gallery4.png +kajal@sci-fi,images/kajal@sci-fi_gallery1.png,images/kajal@sci-fi_gallery2.png,images/kajal@sci-fi_gallery3.png,images/kajal@sci-fi_gallery4.png +yashvi,images/yashvi_gallery1.png,images/yashvi_gallery2.png,images/yashvi_gallery3.png,images/yashvi_gallery4.png +yashviwhy@instantid.com,images/yashviwhy@instantid.com_gallery1.png,images/yashviwhy@instantid.com_gallery2.png,images/yashviwhy@instantid.com_gallery3.png,images/yashviwhy@instantid.com_gallery4.png +kartik@prof,images/kartik@prof_gallery1.png,images/kartik@prof_gallery2.png,images/kartik@prof_gallery3.png,images/kartik@prof_gallery4.png +yashvii@proffffff,images/yashvii@proffffff_gallery1.png,images/yashvii@proffffff_gallery2.png,images/yashvii@proffffff_gallery3.png,images/yashvii@proffffff_gallery4.png diff --git a/images/aa.ll_gallery1.png b/images/aa.ll_gallery1.png new file mode 100644 index 0000000000000000000000000000000000000000..55ea2ff36a4524754e08fdb7c11432233fe363eb --- /dev/null +++ b/images/aa.ll_gallery1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c08937723c681094ef3befc375dfe1ae7f67fa4b8f89325df494861a5091a20 +size 1128016 diff --git a/images/aa.ll_gallery2.png b/images/aa.ll_gallery2.png new file mode 100644 index 0000000000000000000000000000000000000000..00b5f9d2b21a4c8fa62f61f7ae8ca5192c7ae42b Binary files /dev/null and b/images/aa.ll_gallery2.png differ diff --git a/images/aa.ll_gallery3.png b/images/aa.ll_gallery3.png new file mode 100644 index 0000000000000000000000000000000000000000..2f5441f19451d3aea9bb97de4065e4a65555912a Binary files /dev/null and b/images/aa.ll_gallery3.png differ diff --git a/images/aa.ll_gallery4.png b/images/aa.ll_gallery4.png new file mode 100644 index 0000000000000000000000000000000000000000..1b7ffd084c4887f2c8ab5144146370e57bddedb7 Binary files /dev/null and b/images/aa.ll_gallery4.png differ diff --git a/images/heeral@img_gallery1.png b/images/heeral@img_gallery1.png new file mode 100644 index 0000000000000000000000000000000000000000..bf0d6d8b3b82575426f122639ca46c7ca186c3bd Binary files /dev/null and b/images/heeral@img_gallery1.png differ diff --git a/images/heeral@img_gallery2.png b/images/heeral@img_gallery2.png new file mode 100644 index 0000000000000000000000000000000000000000..c00ac1aecf43241762e7f7bfd10da971aca44fd5 Binary files /dev/null and b/images/heeral@img_gallery2.png differ diff --git a/images/heeral@img_gallery3.png b/images/heeral@img_gallery3.png new file mode 100644 index 0000000000000000000000000000000000000000..4eefd17ffee12bb24deeb2d1474445f171a764d4 Binary files /dev/null and b/images/heeral@img_gallery3.png differ diff --git a/images/heeral@img_gallery4.png b/images/heeral@img_gallery4.png new file mode 100644 index 0000000000000000000000000000000000000000..dd4ae20606115e0d972fe6450f4b0d904c4fd97f Binary files /dev/null and b/images/heeral@img_gallery4.png differ diff --git a/images/kajal@img_gallery1.png b/images/kajal@img_gallery1.png new file mode 100644 index 0000000000000000000000000000000000000000..5aca4f772d9c5cc0f22df5f5ad92cf5938bfbb96 Binary files /dev/null and b/images/kajal@img_gallery1.png differ diff --git a/images/kajal@img_gallery2.png b/images/kajal@img_gallery2.png new file mode 100644 index 0000000000000000000000000000000000000000..15d2413b67a704711445e2164921b8de075d5e47 Binary files /dev/null and b/images/kajal@img_gallery2.png differ diff --git a/images/kajal@img_gallery3.png b/images/kajal@img_gallery3.png new file mode 100644 index 0000000000000000000000000000000000000000..126271b0b5d06a5fe563e8a1cd0e451e9222bd79 Binary files /dev/null and b/images/kajal@img_gallery3.png differ diff --git a/images/kajal@img_gallery4.png b/images/kajal@img_gallery4.png new file mode 100644 index 0000000000000000000000000000000000000000..1e506615468625e165165e3f21b954538af79f2b Binary files /dev/null and b/images/kajal@img_gallery4.png differ diff --git a/images/kajal@prof_gallery1.png b/images/kajal@prof_gallery1.png new file mode 100644 index 0000000000000000000000000000000000000000..bebcc1f5cb33c3c151d68804fb2dd8f59ad515cf Binary files /dev/null and b/images/kajal@prof_gallery1.png differ diff --git a/images/kajal@prof_gallery2.png b/images/kajal@prof_gallery2.png new file mode 100644 index 0000000000000000000000000000000000000000..4fae6e71d0b9349341c6a4e7cc38764913dd3233 Binary files /dev/null and b/images/kajal@prof_gallery2.png differ diff --git a/images/kajal@prof_gallery3.png b/images/kajal@prof_gallery3.png new file mode 100644 index 0000000000000000000000000000000000000000..4601e3ff79df46e19e715182873f2538ca4ed1c5 Binary files /dev/null and b/images/kajal@prof_gallery3.png differ diff --git a/images/kajal@prof_gallery4.png b/images/kajal@prof_gallery4.png new file mode 100644 index 0000000000000000000000000000000000000000..3476bc456c28b0c358c46f19fa26bdc566629789 Binary files /dev/null and b/images/kajal@prof_gallery4.png differ diff --git a/images/kajal@quirky_gallery1.png b/images/kajal@quirky_gallery1.png new file mode 100644 index 0000000000000000000000000000000000000000..2fce57144b0f46750496938dd10270ecc98e3fe5 Binary files /dev/null and b/images/kajal@quirky_gallery1.png differ diff --git a/images/kajal@quirky_gallery2.png b/images/kajal@quirky_gallery2.png new file mode 100644 index 0000000000000000000000000000000000000000..029ec01ab659e5d474cce515102d0a7f699ac915 Binary files /dev/null and b/images/kajal@quirky_gallery2.png differ diff --git a/images/kajal@quirky_gallery3.png b/images/kajal@quirky_gallery3.png new file mode 100644 index 0000000000000000000000000000000000000000..83744b851f15b9817b6810c00933a14f32911c72 Binary files /dev/null and b/images/kajal@quirky_gallery3.png differ diff --git a/images/kajal@quirky_gallery4.png b/images/kajal@quirky_gallery4.png new file mode 100644 index 0000000000000000000000000000000000000000..58630ce3a31018870e4678e6d514a77fac023c3c Binary files /dev/null and b/images/kajal@quirky_gallery4.png differ diff --git a/images/kajal@sci-fi_gallery1.png b/images/kajal@sci-fi_gallery1.png new file mode 100644 index 0000000000000000000000000000000000000000..23af741a0a520644d18c36751ea152ab016de372 Binary files /dev/null and b/images/kajal@sci-fi_gallery1.png differ diff --git a/images/kajal@sci-fi_gallery2.png b/images/kajal@sci-fi_gallery2.png new file mode 100644 index 0000000000000000000000000000000000000000..88e03a6036d5311683890470f7235f2ea8e8e62a Binary files /dev/null and b/images/kajal@sci-fi_gallery2.png differ diff --git a/images/kajal@sci-fi_gallery3.png b/images/kajal@sci-fi_gallery3.png new file mode 100644 index 0000000000000000000000000000000000000000..99016c731ed662e6ea3fc74974ba343e81a90770 Binary files /dev/null and b/images/kajal@sci-fi_gallery3.png differ diff --git a/images/kajal@sci-fi_gallery4.png b/images/kajal@sci-fi_gallery4.png new file mode 100644 index 0000000000000000000000000000000000000000..c60c0637411cf2c1e0b9085c8520a69464c746c9 Binary files /dev/null and b/images/kajal@sci-fi_gallery4.png differ diff --git a/images/kartik@prof_gallery1.png b/images/kartik@prof_gallery1.png new file mode 100644 index 0000000000000000000000000000000000000000..e5fb6203bd4842fc087ff36247dc1495f99d6cdb Binary files /dev/null and b/images/kartik@prof_gallery1.png differ diff --git a/images/kartik@prof_gallery2.png b/images/kartik@prof_gallery2.png new file mode 100644 index 0000000000000000000000000000000000000000..f8d10c9c682d5e2a1136bde96aeabf4efd4064ee Binary files /dev/null and b/images/kartik@prof_gallery2.png differ diff --git a/images/kartik@prof_gallery3.png b/images/kartik@prof_gallery3.png new file mode 100644 index 0000000000000000000000000000000000000000..f2ffa9c6d00d4d4d8ac523ee4d9daa311b5ae528 Binary files /dev/null and b/images/kartik@prof_gallery3.png differ diff --git a/images/kartik@prof_gallery4.png b/images/kartik@prof_gallery4.png new file mode 100644 index 0000000000000000000000000000000000000000..37170ef3b2daae08975ac91ae01a3c6670ac85ed Binary files /dev/null and b/images/kartik@prof_gallery4.png differ diff --git a/images/lkjb_gallery1.png b/images/lkjb_gallery1.png new file mode 100644 index 0000000000000000000000000000000000000000..d378ee864a100f1f7bce952108ecb6a6f3c77b84 Binary files /dev/null and b/images/lkjb_gallery1.png differ diff --git a/images/lkjb_gallery2.png b/images/lkjb_gallery2.png new file mode 100644 index 0000000000000000000000000000000000000000..cdce5960f63e88c8e8d4a020f674067bd7476fc7 Binary files /dev/null and b/images/lkjb_gallery2.png differ diff --git a/images/lkjb_gallery3.png b/images/lkjb_gallery3.png new file mode 100644 index 0000000000000000000000000000000000000000..873c85003c33ce4e30d5b1dd971b37cee77a5a78 Binary files /dev/null and b/images/lkjb_gallery3.png differ diff --git a/images/lkjb_gallery4.png b/images/lkjb_gallery4.png new file mode 100644 index 0000000000000000000000000000000000000000..59a3608df95c35180514bdf866799e759c65f6dd Binary files /dev/null and b/images/lkjb_gallery4.png differ diff --git a/images/ll_gallery1.png b/images/ll_gallery1.png new file mode 100644 index 0000000000000000000000000000000000000000..2ac1fbccc27da3d8cb6924629d79746bb299f986 Binary files /dev/null and b/images/ll_gallery1.png differ diff --git a/images/ll_gallery2.png b/images/ll_gallery2.png new file mode 100644 index 0000000000000000000000000000000000000000..ff414c5bc090604f649b9cb0a0bedfbe2e898fe8 Binary files /dev/null and b/images/ll_gallery2.png differ diff --git a/images/ll_gallery3.png b/images/ll_gallery3.png new file mode 100644 index 0000000000000000000000000000000000000000..008d53d6cecec368e4fb21c1852c967d85c88bc2 Binary files /dev/null and b/images/ll_gallery3.png differ diff --git a/images/ll_gallery4.png b/images/ll_gallery4.png new file mode 100644 index 0000000000000000000000000000000000000000..9de8eb5ddbb0d02c0826fe259ba899f00912b937 Binary files /dev/null and b/images/ll_gallery4.png differ diff --git a/images/sanskruti@img-quirky_gallery1.png b/images/sanskruti@img-quirky_gallery1.png new file mode 100644 index 0000000000000000000000000000000000000000..8b82952530304bd7cbdb9fd8924e76a61aca84bf Binary files /dev/null and b/images/sanskruti@img-quirky_gallery1.png differ diff --git a/images/sanskruti@img-quirky_gallery2.png b/images/sanskruti@img-quirky_gallery2.png new file mode 100644 index 0000000000000000000000000000000000000000..590f9d7058b9809cb3db288c5c151e85867eced2 Binary files /dev/null and b/images/sanskruti@img-quirky_gallery2.png differ diff --git a/images/sanskruti@img-quirky_gallery3.png b/images/sanskruti@img-quirky_gallery3.png new file mode 100644 index 0000000000000000000000000000000000000000..db03afbd7232390213d0233bb82da41ff6b6b9f3 Binary files /dev/null and b/images/sanskruti@img-quirky_gallery3.png differ diff --git a/images/sanskruti@img-quirky_gallery4.png b/images/sanskruti@img-quirky_gallery4.png new file mode 100644 index 0000000000000000000000000000000000000000..4e071d6f87942b2b220990418629f72a0832b56e Binary files /dev/null and b/images/sanskruti@img-quirky_gallery4.png differ diff --git a/images/sanskruti@img-scifi_gallery1.png b/images/sanskruti@img-scifi_gallery1.png new file mode 100644 index 0000000000000000000000000000000000000000..8b82952530304bd7cbdb9fd8924e76a61aca84bf Binary files /dev/null and b/images/sanskruti@img-scifi_gallery1.png differ diff --git a/images/sanskruti@img-scifi_gallery2.png b/images/sanskruti@img-scifi_gallery2.png new file mode 100644 index 0000000000000000000000000000000000000000..590f9d7058b9809cb3db288c5c151e85867eced2 Binary files /dev/null and b/images/sanskruti@img-scifi_gallery2.png differ diff --git a/images/sanskruti@img-scifi_gallery3.png b/images/sanskruti@img-scifi_gallery3.png new file mode 100644 index 0000000000000000000000000000000000000000..db03afbd7232390213d0233bb82da41ff6b6b9f3 Binary files /dev/null and b/images/sanskruti@img-scifi_gallery3.png differ diff --git a/images/sanskruti@img-scifi_gallery4.png b/images/sanskruti@img-scifi_gallery4.png new file mode 100644 index 0000000000000000000000000000000000000000..4e071d6f87942b2b220990418629f72a0832b56e Binary files /dev/null and b/images/sanskruti@img-scifi_gallery4.png differ diff --git a/images/yashvi_gallery1.png b/images/yashvi_gallery1.png new file mode 100644 index 0000000000000000000000000000000000000000..c122d4b912eb4c61f722fd63687020614216ffe7 --- /dev/null +++ b/images/yashvi_gallery1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b607afe8ae4404c1776517bf953d72835ddf99d2fe96b8e3b627822db7524f42 +size 1078744 diff --git a/images/yashvi_gallery2.png b/images/yashvi_gallery2.png new file mode 100644 index 0000000000000000000000000000000000000000..c3be49c1418a3732438459af829186bb7b4b2023 Binary files /dev/null and b/images/yashvi_gallery2.png differ diff --git a/images/yashvi_gallery3.png b/images/yashvi_gallery3.png new file mode 100644 index 0000000000000000000000000000000000000000..3b56d706f1664f43b83629f8a782994d136a2c81 Binary files /dev/null and b/images/yashvi_gallery3.png differ diff --git a/images/yashvi_gallery4.png b/images/yashvi_gallery4.png new file mode 100644 index 0000000000000000000000000000000000000000..643670b0f408042460699e712381b25558e43372 --- /dev/null +++ b/images/yashvi_gallery4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b28eb906303c0bb938739cd2d8248d2536bb02a0f437619a2cc6aa5de02e85a +size 1022424 diff --git a/images/yashvii@proffffff_gallery1.png b/images/yashvii@proffffff_gallery1.png new file mode 100644 index 0000000000000000000000000000000000000000..e0ad2dd1219ed815ce69d6af6f6acf375e5b13f9 Binary files /dev/null and b/images/yashvii@proffffff_gallery1.png differ diff --git a/images/yashvii@proffffff_gallery2.png b/images/yashvii@proffffff_gallery2.png new file mode 100644 index 0000000000000000000000000000000000000000..7fc8a5b1dfec3ee6bd3ee32b690d7d9f7940e51b Binary files /dev/null and b/images/yashvii@proffffff_gallery2.png differ diff --git a/images/yashvii@proffffff_gallery3.png b/images/yashvii@proffffff_gallery3.png new file mode 100644 index 0000000000000000000000000000000000000000..7978fea64383fa4fc68d8248da35135758d076c0 Binary files /dev/null and b/images/yashvii@proffffff_gallery3.png differ diff --git a/images/yashvii@proffffff_gallery4.png b/images/yashvii@proffffff_gallery4.png new file mode 100644 index 0000000000000000000000000000000000000000..495fad3b0f7614b189cdb1e218ce450c63b4524c Binary files /dev/null and b/images/yashvii@proffffff_gallery4.png differ diff --git a/images/yashviwhy@instantid.com_gallery1.png b/images/yashviwhy@instantid.com_gallery1.png new file mode 100644 index 0000000000000000000000000000000000000000..341dfd4d5fd0453f193128eb45c9eb38b954d0ec --- /dev/null +++ b/images/yashviwhy@instantid.com_gallery1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f9a04d7ab6dafffec9c77abb3c73cc3b58ecc70f5d6fe88a84c99c53c824e39 +size 2283227 diff --git a/images/yashviwhy@instantid.com_gallery2.png b/images/yashviwhy@instantid.com_gallery2.png new file mode 100644 index 0000000000000000000000000000000000000000..62cd8f45b6fc229a0bcf8e66ab696305bdb4321a --- /dev/null +++ b/images/yashviwhy@instantid.com_gallery2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4825e2ffe6eae19c7701250fcd03c2527fccc6ec8f822e406f07df8f2a902b98 +size 2212836 diff --git a/images/yashviwhy@instantid.com_gallery3.png b/images/yashviwhy@instantid.com_gallery3.png new file mode 100644 index 0000000000000000000000000000000000000000..0611526e4620ea5e113d7e6173a6ea4fa7229283 --- /dev/null +++ b/images/yashviwhy@instantid.com_gallery3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:661bc8301a5980660476e4209f8bc982be58782f2312ea3a8d98e090bf60ac34 +size 3264133 diff --git a/images/yashviwhy@instantid.com_gallery4.png b/images/yashviwhy@instantid.com_gallery4.png new file mode 100644 index 0000000000000000000000000000000000000000..a8e7cf25ebb690d3a2481b03d4ac74582ec4664d --- /dev/null +++ b/images/yashviwhy@instantid.com_gallery4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b4aac2e9e1929afa86ac754b7946496a013a229ddf4a9c86da26db227094b67 +size 1576451 diff --git a/infer.py b/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..a54a0b2a82f6fe60cfaed24b760351a20a340693 --- /dev/null +++ b/infer.py @@ -0,0 +1,82 @@ +import cv2 +import torch +import numpy as np +from PIL import Image + +from diffusers.utils import load_image +from diffusers.models import ControlNetModel + +from insightface.app import FaceAnalysis +from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps + +def resize_img(input_image, max_side=1280, min_side=1024, size=None, + pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): + + w, h = input_image.size + if size is not None: + w_resize_new, h_resize_new = size + else: + ratio = min_side / min(h, w) + w, h = round(ratio*w), round(ratio*h) + ratio = max_side / max(h, w) + input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) + w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number + h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number + input_image = input_image.resize([w_resize_new, h_resize_new], mode) + + if pad_to_max_side: + res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 + offset_x = (max_side - w_resize_new) // 2 + offset_y = (max_side - h_resize_new) // 2 + res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) + input_image = Image.fromarray(res) + return input_image + + +if __name__ == "__main__": + + # Load face encoder + app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + app.prepare(ctx_id=0, det_size=(640, 640)) + + # Path to InstantID models + face_adapter = f'./checkpoints/ip-adapter.bin' + controlnet_path = f'./checkpoints/ControlNetModel' + + # Load pipeline + controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) + + base_model_path = 'stabilityai/stable-diffusion-xl-base-1.0' + + pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( + base_model_path, + controlnet=controlnet, + torch_dtype=torch.float16, + ) + pipe.cuda() + pipe.load_ip_adapter_instantid(face_adapter) + + # Infer setting + prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality" + n_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured" + + face_image = load_image("./examples/yann-lecun_resize.jpg") + face_image = resize_img(face_image) + + face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)) + face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face + face_emb = face_info['embedding'] + face_kps = draw_kps(face_image, face_info['kps']) + + image = pipe( + prompt=prompt, + negative_prompt=n_prompt, + image_embeds=face_emb, + image=face_kps, + controlnet_conditioning_scale=0.8, + ip_adapter_scale=0.8, + num_inference_steps=30, + guidance_scale=5, + ).images[0] + + image.save('result.jpg') \ No newline at end of file diff --git a/infer_full.py b/infer_full.py new file mode 100644 index 0000000000000000000000000000000000000000..fe26b998834405dcf77df5c41f9dcbdf029c37fe --- /dev/null +++ b/infer_full.py @@ -0,0 +1,119 @@ +import cv2 +import torch +import numpy as np +from PIL import Image + +from diffusers.utils import load_image +from diffusers.models import ControlNetModel +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel + +from insightface.app import FaceAnalysis +from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps + +from controlnet_aux import MidasDetector + +def convert_from_image_to_cv2(img: Image) -> np.ndarray: + return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + +def resize_img(input_image, max_side=1280, min_side=1024, size=None, + pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): + + w, h = input_image.size + if size is not None: + w_resize_new, h_resize_new = size + else: + ratio = min_side / min(h, w) + w, h = round(ratio*w), round(ratio*h) + ratio = max_side / max(h, w) + input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) + w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number + h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number + input_image = input_image.resize([w_resize_new, h_resize_new], mode) + + if pad_to_max_side: + res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 + offset_x = (max_side - w_resize_new) // 2 + offset_y = (max_side - h_resize_new) // 2 + res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) + input_image = Image.fromarray(res) + return input_image + + +if __name__ == "__main__": + + # Load face encoder + app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + app.prepare(ctx_id=0, det_size=(640, 640)) + + # Path to InstantID models + face_adapter = f'./checkpoints/ip-adapter.bin' + controlnet_path = f'./checkpoints/ControlNetModel' + controlnet_depth_path = f'diffusers/controlnet-depth-sdxl-1.0-small' + + # Load depth detector + midas = MidasDetector.from_pretrained("lllyasviel/Annotators") + + # Load pipeline + controlnet_list = [controlnet_path, controlnet_depth_path] + controlnet_model_list = [] + for controlnet_path in controlnet_list: + controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) + controlnet_model_list.append(controlnet) + controlnet = MultiControlNetModel(controlnet_model_list) + + base_model_path = 'stabilityai/stable-diffusion-xl-base-1.0' + + pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( + base_model_path, + controlnet=controlnet, + torch_dtype=torch.float16, + ) + pipe.cuda() + pipe.load_ip_adapter_instantid(face_adapter) + + # Infer setting + prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality" + n_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured" + + face_image = load_image("./examples/yann-lecun_resize.jpg") + face_image = resize_img(face_image) + + face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)) + face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face + face_emb = face_info['embedding'] + + # use another reference image + pose_image = load_image("./examples/poses/pose.jpg") + pose_image = resize_img(pose_image) + + face_info = app.get(cv2.cvtColor(np.array(pose_image), cv2.COLOR_RGB2BGR)) + pose_image_cv2 = convert_from_image_to_cv2(pose_image) + face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face + face_kps = draw_kps(pose_image, face_info['kps']) + + width, height = face_kps.size + + # use depth control + processed_image_midas = midas(pose_image) + processed_image_midas = processed_image_midas.resize(pose_image.size) + + # enhance face region + control_mask = np.zeros([height, width, 3]) + x1, y1, x2, y2 = face_info["bbox"] + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + control_mask[y1:y2, x1:x2] = 255 + control_mask = Image.fromarray(control_mask.astype(np.uint8)) + + image = pipe( + prompt=prompt, + negative_prompt=n_prompt, + image_embeds=face_emb, + control_mask=control_mask, + image=[face_kps, processed_image_midas], + controlnet_conditioning_scale=[0.8,0.8], + ip_adapter_scale=0.8, + num_inference_steps=30, + guidance_scale=5, + ).images[0] + + image.save('result.jpg') \ No newline at end of file diff --git a/infer_img2img.py b/infer_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..16d9942ac09d003783752a292dcd7f9a336da59b --- /dev/null +++ b/infer_img2img.py @@ -0,0 +1,84 @@ +import cv2 +import torch +import numpy as np +from PIL import Image + +from diffusers.utils import load_image +from diffusers.models import ControlNetModel + +from insightface.app import FaceAnalysis +from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps + +def resize_img(input_image, max_side=1280, min_side=1024, size=None, + pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): + + w, h = input_image.size + if size is not None: + w_resize_new, h_resize_new = size + else: + ratio = min_side / min(h, w) + w, h = round(ratio*w), round(ratio*h) + ratio = max_side / max(h, w) + input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) + w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number + h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number + input_image = input_image.resize([w_resize_new, h_resize_new], mode) + + if pad_to_max_side: + res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 + offset_x = (max_side - w_resize_new) // 2 + offset_y = (max_side - h_resize_new) // 2 + res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) + input_image = Image.fromarray(res) + return input_image + + +if __name__ == "__main__": + + # Load face encoder + app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + app.prepare(ctx_id=0, det_size=(640, 640)) + + # Path to InstantID models + face_adapter = f'./checkpoints/ip-adapter.bin' + controlnet_path = f'./checkpoints/ControlNetModel' + + # Load pipeline + controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) + + base_model_path = 'stabilityai/stable-diffusion-xl-base-1.0' + + pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained( + base_model_path, + controlnet=controlnet, + torch_dtype=torch.float16, + ) + pipe.cuda() + pipe.load_ip_adapter_instantid(face_adapter) + + # Infer setting + prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality" + n_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured" + + face_image = load_image("./examples/yann-lecun_resize.jpg") + face_image = resize_img(face_image) + + face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)) + face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face + face_emb = face_info['embedding'] + face_kps = draw_kps(face_image, face_info['kps']) + + image = pipe( + prompt=prompt, + negative_prompt=n_prompt, + image=face_image, + image_embeds=face_emb, + control_image=face_kps, + controlnet_conditioning_scale=0.8, + ip_adapter_scale=0.8, + num_inference_steps=30, + guidance_scale=5, + strength=0.85 + ).images[0] + + image.save('result.jpg') \ No newline at end of file diff --git a/ip_adapter/attention_processor.py b/ip_adapter/attention_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..bfbb50d934a6aad038630a5e8be9d5887ca8f07b --- /dev/null +++ b/ip_adapter/attention_processor.py @@ -0,0 +1,447 @@ +# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + import xformers + import xformers.ops + xformers_available = True +except Exception as e: + xformers_available = False + +class RegionControler(object): + def __init__(self) -> None: + self.prompt_image_conditioning = [] +region_control = RegionControler() + +class AttnProcessor(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + + def forward( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def forward( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :] + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + if xformers_available: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + else: + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + if xformers_available: + ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None) + else: + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + # region control + if len(region_control.prompt_image_conditioning) == 1: + region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None) + if region_mask is not None: + h, w = region_mask.shape[:2] + ratio = (h * w / query.shape[1]) ** 0.5 + mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1]) + else: + mask = torch.ones_like(ip_hidden_states) + ip_hidden_states = ip_hidden_states * mask + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + # hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def forward( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class IPAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def forward( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + with torch.no_grad(): + self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) + #print(self.attn_map.shape) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + # region control + if len(region_control.prompt_image_conditioning) == 1: + region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None) + if region_mask is not None: + query = query.reshape([-1, query.shape[-2], query.shape[-1]]) + h, w = region_mask.shape[:2] + ratio = (h * w / query.shape[1]) ** 0.5 + mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1]) + else: + mask = torch.ones_like(ip_hidden_states) + ip_hidden_states = ip_hidden_states * mask + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/ip_adapter/resampler.py b/ip_adapter/resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..4521c8c3e6f17caf4547c3dd84118da760e5179f --- /dev/null +++ b/ip_adapter/resampler.py @@ -0,0 +1,121 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +import math + +import torch +import torch.nn as nn + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) \ No newline at end of file diff --git a/ip_adapter/utils.py b/ip_adapter/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a105f3701c15e8d3bbf838d79bacc51e91d0696 --- /dev/null +++ b/ip_adapter/utils.py @@ -0,0 +1,5 @@ +import torch.nn.functional as F + + +def is_torch2_available(): + return hasattr(F, "scaled_dot_product_attention") diff --git a/models/antelopev2/1k3d68.onnx b/models/antelopev2/1k3d68.onnx new file mode 100644 index 0000000000000000000000000000000000000000..221aa2f02a6faccddb2723529e1f93c7db2edbdc --- /dev/null +++ b/models/antelopev2/1k3d68.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc +size 143607619 diff --git a/models/antelopev2/2d106det.onnx b/models/antelopev2/2d106det.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cdb163d88b5f51396855ebc795e0114322c98b6b --- /dev/null +++ b/models/antelopev2/2d106det.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf +size 5030888 diff --git a/models/antelopev2/genderage.onnx b/models/antelopev2/genderage.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fcf638481cea978e99ddabd914ccd3b70c8401cb --- /dev/null +++ b/models/antelopev2/genderage.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb +size 1322532 diff --git a/models/antelopev2/glintr100.onnx b/models/antelopev2/glintr100.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9d221846df998a9c85239fd74a9fe5685193775f --- /dev/null +++ b/models/antelopev2/glintr100.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4ab1d6435d639628a6f3e5008dd4f929edf4c4124b1a7169e1048f9fef534cdf +size 260665334 diff --git a/models/antelopev2/scrfd_10g_bnkps.onnx b/models/antelopev2/scrfd_10g_bnkps.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa586e034379fa5ea5babc8aa73d47afcd0fa6c2 --- /dev/null +++ b/models/antelopev2/scrfd_10g_bnkps.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91 +size 16923827 diff --git a/models/buffalo_m/1k3d68.onnx b/models/buffalo_m/1k3d68.onnx new file mode 100644 index 0000000000000000000000000000000000000000..221aa2f02a6faccddb2723529e1f93c7db2edbdc --- /dev/null +++ b/models/buffalo_m/1k3d68.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc +size 143607619 diff --git a/models/buffalo_m/2d106det.onnx b/models/buffalo_m/2d106det.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cdb163d88b5f51396855ebc795e0114322c98b6b --- /dev/null +++ b/models/buffalo_m/2d106det.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf +size 5030888 diff --git a/models/buffalo_m/det_2.5g.onnx b/models/buffalo_m/det_2.5g.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e58896988e89af02f348dd0b8046de94170fa36f --- /dev/null +++ b/models/buffalo_m/det_2.5g.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:041f73f47371333d1d17a6fee6c8ab4e6aecabefe398ff32cca4e2d5eaee0af9 +size 3292009 diff --git a/models/buffalo_m/genderage.onnx b/models/buffalo_m/genderage.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fcf638481cea978e99ddabd914ccd3b70c8401cb --- /dev/null +++ b/models/buffalo_m/genderage.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb +size 1322532 diff --git a/models/buffalo_m/w600k_r50.onnx b/models/buffalo_m/w600k_r50.onnx new file mode 100644 index 0000000000000000000000000000000000000000..571d2bb9ffd76399b23260620b9101b20bcc4e99 --- /dev/null +++ b/models/buffalo_m/w600k_r50.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43 +size 174383860 diff --git a/models/buffalo_sc/det_500m.onnx b/models/buffalo_sc/det_500m.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a8a550ab6452f2029c977686cf89f48d53400959 --- /dev/null +++ b/models/buffalo_sc/det_500m.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e4447f50245bbd7966bd6c0fa52938c61474a04ec7def48753668a9d8b4ea3a +size 2524817 diff --git a/models/buffalo_sc/w600k_mbf.onnx b/models/buffalo_sc/w600k_mbf.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d81d20d6d55cf5671b30f7e1a811a46858ebc618 --- /dev/null +++ b/models/buffalo_sc/w600k_mbf.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9cc6e4a75f0e2bf0b1aed94578f144d15175f357bdc05e815e5c4a02b319eb4f +size 13616099 diff --git a/pipeline_stable_diffusion_xl_instantid.py b/pipeline_stable_diffusion_xl_instantid.py new file mode 100644 index 0000000000000000000000000000000000000000..dd6a9a2193c3137a95a1d34903457f8af4054f6a --- /dev/null +++ b/pipeline_stable_diffusion_xl_instantid.py @@ -0,0 +1,787 @@ +# Copyright 2024 The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import cv2 +import math + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F + +from diffusers.image_processor import PipelineImageInput + +from diffusers.models import ControlNetModel + +from diffusers.utils import ( + deprecate, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput + +from diffusers import StableDiffusionXLControlNetPipeline +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.utils.import_utils import is_xformers_available + +from ip_adapter.resampler import Resampler +from ip_adapter.utils import is_torch2_available + +if is_torch2_available(): + from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor +else: + from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate insightface + >>> import diffusers + >>> from diffusers.utils import load_image + >>> from diffusers.models import ControlNetModel + + >>> import cv2 + >>> import torch + >>> import numpy as np + >>> from PIL import Image + + >>> from insightface.app import FaceAnalysis + >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps + + >>> # download 'antelopev2' under ./models + >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + >>> app.prepare(ctx_id=0, det_size=(640, 640)) + + >>> # download models under ./checkpoints + >>> face_adapter = f'./checkpoints/ip-adapter.bin' + >>> controlnet_path = f'./checkpoints/ControlNetModel' + + >>> # load IdentityNet + >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) + + >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.cuda() + + >>> # load adapter + >>> pipe.load_ip_adapter_instantid(face_adapter) + + >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality" + >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured" + + >>> # load an image + >>> image = load_image("your-example.jpg") + + >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1] + >>> face_emb = face_info['embedding'] + >>> face_kps = draw_kps(face_image, face_info['kps']) + + >>> pipe.set_ip_adapter_scale(0.8) + + >>> # generate image + >>> image = pipe( + ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8 + ... ).images[0] + ``` +""" + +def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]): + + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + +class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): + + def cuda(self, dtype=torch.float16, use_xformers=False): + self.to('cuda', dtype) + + if hasattr(self, 'image_proj_model'): + self.image_proj_model.to(self.unet.device).to(self.unet.dtype) + + if use_xformers: + if is_xformers_available(): + import xformers + from packaging import version + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + self.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5): + self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens) + self.set_ip_adapter(model_ckpt, num_tokens, scale) + + def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16): + + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=num_tokens, + embedding_dim=image_emb_dim, + output_dim=self.unet.config.cross_attention_dim, + ff_mult=4, + ) + + image_proj_model.eval() + + self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype) + state_dict = torch.load(model_ckpt, map_location="cpu") + if 'image_proj' in state_dict: + state_dict = state_dict["image_proj"] + self.image_proj_model.load_state_dict(state_dict) + + self.image_proj_model_in_features = image_emb_dim + + def set_ip_adapter(self, model_ckpt, num_tokens, scale): + + unet = self.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype) + else: + attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=scale, + num_tokens=num_tokens).to(unet.device, dtype=unet.dtype) + unet.set_attn_processor(attn_procs) + + state_dict = torch.load(model_ckpt, map_location="cpu") + ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values()) + if 'ip_adapter' in state_dict: + state_dict = state_dict['ip_adapter'] + ip_layers.load_state_dict(state_dict) + + def set_ip_adapter_scale(self, scale): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for attn_processor in unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def _encode_prompt_image_emb(self, prompt_image_emb, device, num_images_per_prompt, dtype, do_classifier_free_guidance): + + if isinstance(prompt_image_emb, torch.Tensor): + prompt_image_emb = prompt_image_emb.clone().detach() + else: + prompt_image_emb = torch.tensor(prompt_image_emb) + + prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features]) + + if do_classifier_free_guidance: + prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0) + else: + prompt_image_emb = torch.cat([prompt_image_emb], dim=0) + + prompt_image_emb = prompt_image_emb.to(device=self.image_proj_model.latents.device, + dtype=self.image_proj_model.latents.dtype) + prompt_image_emb = self.image_proj_model(prompt_image_emb) + + bs_embed, seq_len, _ = prompt_image_emb.shape + prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1) + prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1) + + return prompt_image_emb.to(device=device, dtype=dtype) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + + # IP adapter + ip_adapter_scale=None, + + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + image_embeds (`torch.FloatTensor`, *optional*): + Pre-generated image embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 0. set ip_adapter_scale + if ip_adapter_scale is not None: + self.set_ip_adapter_scale(ip_adapter_scale) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + prompt_2=prompt_2, + image=image, + callback_steps=callback_steps, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode image prompt + prompt_image_emb = self._encode_prompt_image_emb(image_embeds, + device, + num_images_per_prompt, + self.unet.dtype, + self.do_classifier_free_guidance) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=prompt_image_emb, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=encoder_hidden_states, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file diff --git a/pipeline_stable_diffusion_xl_instantid_full.py b/pipeline_stable_diffusion_xl_instantid_full.py new file mode 100644 index 0000000000000000000000000000000000000000..1d10294bfd8f7b1ff371dd270432d8ead5069c0b --- /dev/null +++ b/pipeline_stable_diffusion_xl_instantid_full.py @@ -0,0 +1,1224 @@ +# Copyright 2024 The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import cv2 +import math + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F + +from diffusers.image_processor import PipelineImageInput + +from diffusers.models import ControlNetModel + +from diffusers.utils import ( + deprecate, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput + +from diffusers import StableDiffusionXLControlNetPipeline +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.utils.import_utils import is_xformers_available + +from ip_adapter.resampler import Resampler +from ip_adapter.utils import is_torch2_available + +if is_torch2_available(): + from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor +else: + from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor +from ip_adapter.attention_processor import region_control + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate insightface + >>> import diffusers + >>> from diffusers.utils import load_image + >>> from diffusers.models import ControlNetModel + + >>> import cv2 + >>> import torch + >>> import numpy as np + >>> from PIL import Image + + >>> from insightface.app import FaceAnalysis + >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps + + >>> # download 'antelopev2' under ./models + >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + >>> app.prepare(ctx_id=0, det_size=(640, 640)) + + >>> # download models under ./checkpoints + >>> face_adapter = f'./checkpoints/ip-adapter.bin' + >>> controlnet_path = f'./checkpoints/ControlNetModel' + + >>> # load IdentityNet + >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) + + >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.cuda() + + >>> # load adapter + >>> pipe.load_ip_adapter_instantid(face_adapter) + + >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality" + >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured" + + >>> # load an image + >>> image = load_image("your-example.jpg") + + >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1] + >>> face_emb = face_info['embedding'] + >>> face_kps = draw_kps(face_image, face_info['kps']) + + >>> pipe.set_ip_adapter_scale(0.8) + + >>> # generate image + >>> image = pipe( + ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8 + ... ).images[0] + ``` +""" + +from transformers import CLIPTokenizer +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline +class LongPromptWeight(object): + + """ + Copied from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion_xl.py + """ + + def __init__(self) -> None: + pass + + def parse_prompt_attention(self, text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + import re + + re_attention = re.compile( + r""" + \\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)| + \)|]|[^\\()\[\]:]+|: + """, + re.X, + ) + + re_break = re.compile(r"\s*\bBREAK\b\s*", re.S) + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + parts = re.split(re_break, text) + for i, part in enumerate(parts): + if i > 0: + res.append(["BREAK", -1]) + res.append([part, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + def get_prompts_tokens_with_weights(self, clip_tokenizer: CLIPTokenizer, prompt: str): + """ + Get prompt token ids and weights, this function works for both prompt and negative prompt + + Args: + pipe (CLIPTokenizer) + A CLIPTokenizer + prompt (str) + A prompt string with weights + + Returns: + text_tokens (list) + A list contains token ids + text_weight (list) + A list contains the correspodent weight of token ids + + Example: + import torch + from transformers import CLIPTokenizer + + clip_tokenizer = CLIPTokenizer.from_pretrained( + "stablediffusionapi/deliberate-v2" + , subfolder = "tokenizer" + , dtype = torch.float16 + ) + + token_id_list, token_weight_list = get_prompts_tokens_with_weights( + clip_tokenizer = clip_tokenizer + ,prompt = "a (red:1.5) cat"*70 + ) + """ + texts_and_weights = self.parse_prompt_attention(prompt) + text_tokens, text_weights = [], [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = clip_tokenizer(word, truncation=False).input_ids[1:-1] # so that tokenize whatever length prompt + # the returned token is a 1d list: [320, 1125, 539, 320] + + # merge the new tokens to the all tokens holder: text_tokens + text_tokens = [*text_tokens, *token] + + # each token chunk will come with one weight, like ['red cat', 2.0] + # need to expand weight for each token. + chunk_weights = [weight] * len(token) + + # append the weight back to the weight holder: text_weights + text_weights = [*text_weights, *chunk_weights] + return text_tokens, text_weights + + def group_tokens_and_weights(self, token_ids: list, weights: list, pad_last_block=False): + """ + Produce tokens and weights in groups and pad the missing tokens + + Args: + token_ids (list) + The token ids from tokenizer + weights (list) + The weights list from function get_prompts_tokens_with_weights + pad_last_block (bool) + Control if fill the last token list to 75 tokens with eos + Returns: + new_token_ids (2d list) + new_weights (2d list) + + Example: + token_groups,weight_groups = group_tokens_and_weights( + token_ids = token_id_list + , weights = token_weight_list + ) + """ + bos, eos = 49406, 49407 + + # this will be a 2d list + new_token_ids = [] + new_weights = [] + while len(token_ids) >= 75: + # get the first 75 tokens + head_75_tokens = [token_ids.pop(0) for _ in range(75)] + head_75_weights = [weights.pop(0) for _ in range(75)] + + # extract token ids and weights + temp_77_token_ids = [bos] + head_75_tokens + [eos] + temp_77_weights = [1.0] + head_75_weights + [1.0] + + # add 77 token and weights chunk to the holder list + new_token_ids.append(temp_77_token_ids) + new_weights.append(temp_77_weights) + + # padding the left + if len(token_ids) >= 0: + padding_len = 75 - len(token_ids) if pad_last_block else 0 + + temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos] + new_token_ids.append(temp_77_token_ids) + + temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0] + new_weights.append(temp_77_weights) + + return new_token_ids, new_weights + + def get_weighted_text_embeddings_sdxl( + self, + pipe: StableDiffusionXLPipeline, + prompt: str = "", + prompt_2: str = None, + neg_prompt: str = "", + neg_prompt_2: str = None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + extra_emb=None, + extra_emb_alpha=0.6, + ): + """ + This function can process long prompt with weights, no length limitation + for Stable Diffusion XL + + Args: + pipe (StableDiffusionPipeline) + prompt (str) + prompt_2 (str) + neg_prompt (str) + neg_prompt_2 (str) + Returns: + prompt_embeds (torch.Tensor) + neg_prompt_embeds (torch.Tensor) + """ + # + if prompt_embeds is not None and \ + negative_prompt_embeds is not None and \ + pooled_prompt_embeds is not None and \ + negative_pooled_prompt_embeds is not None: + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + if prompt_2: + prompt = f"{prompt} {prompt_2}" + + if neg_prompt_2: + neg_prompt = f"{neg_prompt} {neg_prompt_2}" + + eos = pipe.tokenizer.eos_token_id + + # tokenizer 1 + prompt_tokens, prompt_weights = self.get_prompts_tokens_with_weights(pipe.tokenizer, prompt) + neg_prompt_tokens, neg_prompt_weights = self.get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt) + + # tokenizer 2 + # prompt_tokens_2, prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt) + # neg_prompt_tokens_2, neg_prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt) + # tokenizer 2 遇到 !! !!!! 等多感叹号和tokenizer 1的效果不一致 + prompt_tokens_2, prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer, prompt) + neg_prompt_tokens_2, neg_prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt) + + # padding the shorter one for prompt set 1 + prompt_token_len = len(prompt_tokens) + neg_prompt_token_len = len(neg_prompt_tokens) + + if prompt_token_len > neg_prompt_token_len: + # padding the neg_prompt with eos token + neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len) + neg_prompt_weights = neg_prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len) + else: + # padding the prompt + prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len) + prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len) + + # padding the shorter one for token set 2 + prompt_token_len_2 = len(prompt_tokens_2) + neg_prompt_token_len_2 = len(neg_prompt_tokens_2) + + if prompt_token_len_2 > neg_prompt_token_len_2: + # padding the neg_prompt with eos token + neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2) + neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2) + else: + # padding the prompt + prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2) + prompt_weights_2 = prompt_weights + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2) + + embeds = [] + neg_embeds = [] + + prompt_token_groups, prompt_weight_groups = self.group_tokens_and_weights(prompt_tokens.copy(), prompt_weights.copy()) + + neg_prompt_token_groups, neg_prompt_weight_groups = self.group_tokens_and_weights( + neg_prompt_tokens.copy(), neg_prompt_weights.copy() + ) + + prompt_token_groups_2, prompt_weight_groups_2 = self.group_tokens_and_weights( + prompt_tokens_2.copy(), prompt_weights_2.copy() + ) + + neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = self.group_tokens_and_weights( + neg_prompt_tokens_2.copy(), neg_prompt_weights_2.copy() + ) + + # get prompt embeddings one by one is not working. + for i in range(len(prompt_token_groups)): + # get positive prompt embeddings with weights + token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device) + weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device) + + token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device) + + # use first text encoder + prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True) + prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2] + + # use second text encoder + prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True) + prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2] + pooled_prompt_embeds = prompt_embeds_2[0] + + prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states] + token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0) + + for j in range(len(weight_tensor)): + if weight_tensor[j] != 1.0: + token_embedding[j] = ( + token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j] + ) + + token_embedding = token_embedding.unsqueeze(0) + embeds.append(token_embedding) + + # get negative prompt embeddings with weights + neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device) + neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device) + neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device) + + # use first text encoder + neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True) + neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2] + + # use second text encoder + neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True) + neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2] + negative_pooled_prompt_embeds = neg_prompt_embeds_2[0] + + neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states] + neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0) + + for z in range(len(neg_weight_tensor)): + if neg_weight_tensor[z] != 1.0: + neg_token_embedding[z] = ( + neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z] + ) + + neg_token_embedding = neg_token_embedding.unsqueeze(0) + neg_embeds.append(neg_token_embedding) + + prompt_embeds = torch.cat(embeds, dim=1) + negative_prompt_embeds = torch.cat(neg_embeds, dim=1) + + if extra_emb is not None: + extra_emb = extra_emb.to(prompt_embeds.device, dtype=prompt_embeds.dtype) * extra_emb_alpha + prompt_embeds = torch.cat([prompt_embeds, extra_emb], 1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, torch.zeros_like(extra_emb)], 1) + print(f'fix prompt_embeds, extra_emb_alpha={extra_emb_alpha}') + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def get_prompt_embeds(self, *args, **kwargs): + prompt_embeds, negative_prompt_embeds, _, _ = self.get_weighted_text_embeddings_sdxl(*args, **kwargs) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + return prompt_embeds + +def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]): + + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + +class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): + + def cuda(self, dtype=torch.float16, use_xformers=False): + self.to('cuda', dtype) + + if hasattr(self, 'image_proj_model'): + self.image_proj_model.to(self.unet.device).to(self.unet.dtype) + + if use_xformers: + if is_xformers_available(): + import xformers + from packaging import version + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + self.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5): + self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens) + self.set_ip_adapter(model_ckpt, num_tokens, scale) + + def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16): + + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=num_tokens, + embedding_dim=image_emb_dim, + output_dim=self.unet.config.cross_attention_dim, + ff_mult=4, + ) + + image_proj_model.eval() + + self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype) + state_dict = torch.load(model_ckpt, map_location="cpu") + if 'image_proj' in state_dict: + state_dict = state_dict["image_proj"] + self.image_proj_model.load_state_dict(state_dict) + + self.image_proj_model_in_features = image_emb_dim + + def set_ip_adapter(self, model_ckpt, num_tokens, scale): + + unet = self.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype) + else: + attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=scale, + num_tokens=num_tokens).to(unet.device, dtype=unet.dtype) + unet.set_attn_processor(attn_procs) + + state_dict = torch.load(model_ckpt, map_location="cpu") + ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values()) + if 'ip_adapter' in state_dict: + state_dict = state_dict['ip_adapter'] + ip_layers.load_state_dict(state_dict) + + def set_ip_adapter_scale(self, scale): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for attn_processor in unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def _encode_prompt_image_emb(self, prompt_image_emb, device, num_images_per_prompt, dtype, do_classifier_free_guidance): + + if isinstance(prompt_image_emb, torch.Tensor): + prompt_image_emb = prompt_image_emb.clone().detach() + else: + prompt_image_emb = torch.tensor(prompt_image_emb) + + prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features]) + + if do_classifier_free_guidance: + prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0) + else: + prompt_image_emb = torch.cat([prompt_image_emb], dim=0) + + prompt_image_emb = prompt_image_emb.to(device=self.image_proj_model.latents.device, + dtype=self.image_proj_model.latents.dtype) + prompt_image_emb = self.image_proj_model(prompt_image_emb) + + bs_embed, seq_len, _ = prompt_image_emb.shape + prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1) + prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1) + + return prompt_image_emb.to(device=device, dtype=dtype) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + + # IP adapter + ip_adapter_scale=None, + + # Enhance Face Region + control_mask = None, + + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + image_embeds (`torch.FloatTensor`, *optional*): + Pre-generated image embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + + lpw = LongPromptWeight() + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 0. set ip_adapter_scale + if ip_adapter_scale is not None: + self.set_ip_adapter_scale(ip_adapter_scale) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + prompt_2=prompt_2, + image=image, + callback_steps=callback_steps, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1 Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = lpw.get_weighted_text_embeddings_sdxl( + pipe=self, + prompt=prompt, + neg_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 3.2 Encode image prompt + prompt_image_emb = self._encode_prompt_image_emb(image_embeds, + device, + num_images_per_prompt, + self.unet.dtype, + self.do_classifier_free_guidance) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 4.1 Region control + if control_mask is not None: + mask_weight_image = control_mask + mask_weight_image = np.array(mask_weight_image) + mask_weight_image_tensor = torch.from_numpy(mask_weight_image).to(device=device, dtype=prompt_embeds.dtype) + mask_weight_image_tensor = mask_weight_image_tensor[:, :, 0] / 255. + mask_weight_image_tensor = mask_weight_image_tensor[None, None] + h, w = mask_weight_image_tensor.shape[-2:] + control_mask_wight_image_list = [] + for scale in [8, 8, 8, 16, 16, 16, 32, 32, 32]: + scale_mask_weight_image_tensor = F.interpolate( + mask_weight_image_tensor,(h // scale, w // scale), mode='bilinear') + control_mask_wight_image_list.append(scale_mask_weight_image_tensor) + region_mask = torch.from_numpy(np.array(control_mask)[:, :, 0]).to(self.unet.device, dtype=self.unet.dtype) / 255. + region_control.prompt_image_conditioning = [dict(region_mask=region_mask)] + else: + control_mask_wight_image_list = None + region_control.prompt_image_conditioning = [dict(region_mask=None)] + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + if isinstance(self.controlnet, MultiControlNetModel): + down_block_res_samples_list, mid_block_res_sample_list = [], [] + for control_index in range(len(self.controlnet.nets)): + controlnet = self.controlnet.nets[control_index] + if control_index == 0: + # assume fhe first controlnet is IdentityNet + controlnet_prompt_embeds = prompt_image_emb + else: + controlnet_prompt_embeds = prompt_embeds + down_block_res_samples, mid_block_res_sample = controlnet(control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image[control_index], + conditioning_scale=cond_scale[control_index], + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False) + + # controlnet mask + if control_index == 0 and control_mask_wight_image_list is not None: + down_block_res_samples = [ + down_block_res_sample * mask_weight + for down_block_res_sample, mask_weight in zip(down_block_res_samples, control_mask_wight_image_list) + ] + mid_block_res_sample *= control_mask_wight_image_list[-1] + + down_block_res_samples_list.append(down_block_res_samples) + mid_block_res_sample_list.append(mid_block_res_sample) + + mid_block_res_sample = torch.stack(mid_block_res_sample_list).sum(dim=0) + down_block_res_samples = [torch.stack(down_block_res_samples).sum(dim=0) for down_block_res_samples in + zip(*down_block_res_samples_list)] + else: + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=prompt_image_emb, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + # controlnet mask + if control_mask_wight_image_list is not None: + down_block_res_samples = [ + down_block_res_sample * mask_weight + for down_block_res_sample, mask_weight in zip(down_block_res_samples, control_mask_wight_image_list) + ] + mid_block_res_sample *= control_mask_wight_image_list[-1] + + if guess_mode and self.do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=encoder_hidden_states, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/pipeline_stable_diffusion_xl_instantid_img2img.py b/pipeline_stable_diffusion_xl_instantid_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..1bc2f02f67ca24c3a6bc24994550d638acaf7d00 --- /dev/null +++ b/pipeline_stable_diffusion_xl_instantid_img2img.py @@ -0,0 +1,1072 @@ +# Copyright 2024 The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +import PIL.Image +import torch +import torch.nn as nn + +from diffusers import StableDiffusionXLControlNetImg2ImgPipeline +from diffusers.image_processor import PipelineImageInput +from diffusers.models import ControlNetModel +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +from diffusers.utils import ( + deprecate, + logging, + replace_example_docstring, +) +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version + + +try: + import xformers + import xformers.ops + + xformers_available = True +except Exception: + xformers_available = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + latents = self.latents.repeat(x.size(0), 1, 1) + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class AttnProcessor(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + if xformers_available: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + else: + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + if xformers_available: + ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None) + else: + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + return hidden_states + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate insightface + >>> import diffusers + >>> from diffusers.utils import load_image + >>> from diffusers.models import ControlNetModel + + >>> import cv2 + >>> import torch + >>> import numpy as np + >>> from PIL import Image + + >>> from insightface.app import FaceAnalysis + >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps + + >>> # download 'antelopev2' under ./models + >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + >>> app.prepare(ctx_id=0, det_size=(640, 640)) + + >>> # download models under ./checkpoints + >>> face_adapter = f'./checkpoints/ip-adapter.bin' + >>> controlnet_path = f'./checkpoints/ControlNetModel' + + >>> # load IdentityNet + >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) + + >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.cuda() + + >>> # load adapter + >>> pipe.load_ip_adapter_instantid(face_adapter) + + >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality" + >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured" + + >>> # load an image + >>> image = load_image("your-example.jpg") + + >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1] + >>> face_emb = face_info['embedding'] + >>> face_kps = draw_kps(face_image, face_info['kps']) + + >>> pipe.set_ip_adapter_scale(0.8) + + >>> # generate image + >>> image = pipe( + ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8 + ... ).images[0] + ``` +""" + + +def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]): + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly( + (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1 + ) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + + +class StableDiffusionXLInstantIDImg2ImgPipeline(StableDiffusionXLControlNetImg2ImgPipeline): + def cuda(self, dtype=torch.float16, use_xformers=False): + self.to("cuda", dtype) + + if hasattr(self, "image_proj_model"): + self.image_proj_model.to(self.unet.device).to(self.unet.dtype) + + if use_xformers: + if is_xformers_available(): + import xformers + from packaging import version + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + self.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5): + self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens) + self.set_ip_adapter(model_ckpt, num_tokens, scale) + + def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16): + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=num_tokens, + embedding_dim=image_emb_dim, + output_dim=self.unet.config.cross_attention_dim, + ff_mult=4, + ) + + image_proj_model.eval() + + self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype) + state_dict = torch.load(model_ckpt, map_location="cpu") + if "image_proj" in state_dict: + state_dict = state_dict["image_proj"] + self.image_proj_model.load_state_dict(state_dict) + + self.image_proj_model_in_features = image_emb_dim + + def set_ip_adapter(self, model_ckpt, num_tokens, scale): + unet = self.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype) + else: + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=scale, + num_tokens=num_tokens, + ).to(unet.device, dtype=unet.dtype) + unet.set_attn_processor(attn_procs) + + state_dict = torch.load(model_ckpt, map_location="cpu") + ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values()) + if "ip_adapter" in state_dict: + state_dict = state_dict["ip_adapter"] + ip_layers.load_state_dict(state_dict) + + def set_ip_adapter_scale(self, scale): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for attn_processor in unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance): + if isinstance(prompt_image_emb, torch.Tensor): + prompt_image_emb = prompt_image_emb.clone().detach() + else: + prompt_image_emb = torch.tensor(prompt_image_emb) + + prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype) + prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features]) + + if do_classifier_free_guidance: + prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0) + else: + prompt_image_emb = torch.cat([prompt_image_emb], dim=0) + image_proj_model_device = self.image_proj_model.to(device) + prompt_image_emb = image_proj_model_device(prompt_image_emb) + return prompt_image_emb + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + strength: float = 0.8, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + image_embeds (`torch.FloatTensor`, *optional*): + Pre-generated image embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + control_image, + strength, + num_inference_steps, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + None, + None, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode image prompt + prompt_image_emb = self._encode_prompt_image_emb( + image_embeds, device, self.unet.dtype, self.do_classifier_free_guidance + ) + bs_embed, seq_len, _ = prompt_image_emb.shape + prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1) + prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # 4. Prepare image and controlnet_conditioning_image + image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = control_image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + height, width = control_image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + True, + ) + + # # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(control_image, list): + original_size = original_size or control_image[0].shape[-2:] + else: + original_size = original_size or control_image.shape[-2:] + target_size = target_size or (height, width) + + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + add_text_embeds = pooled_prompt_embeds + + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=prompt_image_emb, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=encoder_hidden_states, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..da4b4a452854b80969b0b40e51370918571f1eeb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +diffusers==0.25.1 +torch==2.0.0 +torchvision==0.15.1 +transformers==4.37.1 +accelerate==0.25.0 +safetensors==0.4.3 +einops==0.7.0 +onnxruntime-gpu==1.18.1 +spaces==0.19.4 +omegaconf==2.3.0 +peft==0.11.1 +huggingface-hub==0.23.4 +opencv-python==4.10.0.84 +insightface==0.7.3 +gradio==4.38.1 +controlnet_aux==0.0.9 +gdown==5.2.0 +peft==0.11.1 +setuptools==71.1.0 \ No newline at end of file diff --git a/scripts/download_safety_checker.py b/scripts/download_safety_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0bf3ca051a8030e85b8a21a8fe7cc5cd9ac764 --- /dev/null +++ b/scripts/download_safety_checker.py @@ -0,0 +1,22 @@ +# Run this before you deploy it on replicate, because if you don't +# whenever you run the model, it will download the weights from the +# internet, which will take a long time. + +import torch +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) +from transformers import CLIPFeatureExtractor +from transformers import CLIPImageProcessor + + +safety = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", + torch_dtype=torch.float16, +) +safety.save_pretrained("./safety-cache") + +fe = feature_extractor = CLIPImageProcessor.from_pretrained( + "openai/clip-vit-base-patch32", +) +fe.save_pretrained("./feature-extractor") diff --git a/scripts/push_to_gcp.py b/scripts/push_to_gcp.py new file mode 100644 index 0000000000000000000000000000000000000000..a169a08d6bd9373177810cbca95dec156e9b081c --- /dev/null +++ b/scripts/push_to_gcp.py @@ -0,0 +1,67 @@ +# for checkpoints d, cd in, tar it up, then move up a d, then push to gcp +# e.g. +# cd checkpoints/models--stablediffusionapi--nightvision-xl-0791 +# sudo tar -cvf ../models--stablediffusionapi--nightvision-xl-0791.tar * +# cd .. +# gcloud storage cp models--stablediffusionapi--nightvision-xl-0791.tar gs://replicate-weights/InstantID/models--stablediffusionapi--nightvision-xl-0791.tar + +# TODO + +import os +import subprocess + +# Get the list of directories in the checkpoints directory +dirs = [ + # "checkpoints/models--stablediffusionapi--juggernaut-xl-v8", + # "checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0", + # "checkpoints/models--stablediffusionapi--afrodite-xl-v2", + # "checkpoints/models--stablediffusionapi--albedobase-xl-20", + # "checkpoints/models--stablediffusionapi--albedobase-xl-v13", + # "checkpoints/models--stablediffusionapi--animagine-xl-30", + # "checkpoints/models--stablediffusionapi--anime-art-diffusion-xl", + # "checkpoints/models--stablediffusionapi--anime-illust-diffusion-xl", + # "checkpoints/models--stablediffusionapi--dreamshaper-xl", + # "checkpoints/models--stablediffusionapi--duchaiten-real3d-nsfw-xl", + # "checkpoints/models--stablediffusionapi--dynavision-xl-v0610", + # "checkpoints/models--stablediffusionapi--guofeng4-xl", + # "checkpoints/models--stablediffusionapi--hentai-mix-xl", + # "checkpoints/models--stablediffusionapi--juggernaut-xl-v8", + # "checkpoints/models--stablediffusionapi--nightvision-xl-0791", + # "checkpoints/models--stablediffusionapi--omnigen-xl", + # "checkpoints/models--stablediffusionapi--pony-diffusion-v6-xl", + # "checkpoints/models--stablediffusionapi--protovision-xl-high-fidel", + "checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0", +] + +# Iterate over each directory +for d in dirs: + # Construct the tar file name + tar_file_name = f"{d}.tar" + print(f"[!] Starting the process for directory: {d}") + print(f"[!] Step 1: Constructing tar file name as '{tar_file_name}'") + + # Construct the full path to the tar file + full_tar_path = os.path.join( + "..", tar_file_name + ) # Adjusted to account for script's new location + print(f"[!] Step 2: The full path for the tar file is '{full_tar_path}'") + + # Remove 'checkpoints/' from tar_file_name for gcloud destination + gcloud_tar_file_name = tar_file_name.replace("checkpoints/", "") + # Construct the gcloud destination + gcloud_destination = f"gs://replicate-weights/InstantID/{gcloud_tar_file_name}" + print( + f"[!] Step 3: The destination path on GCloud is set to '{gcloud_destination}'" + ) + + # Adjust the shell command string to account for the script's new location + cmd = f"cd ../{d} && tar -cvf ../../{tar_file_name} * && gcloud storage cp ../../{tar_file_name} {gcloud_destination}" + print( + f"[!] Step 4: The shell command constructed to perform the operations is: {cmd}" + ) + + # Run the shell command + print(f"[!] Step 5: Executing the shell command for directory: {d}") + subprocess.run(cmd, shell=True) + print(f"[!] Step 6: The shell command execution for directory '{d}' has completed.") + print(f"[!] Process completed for directory: {d}") diff --git a/tests/assets/out.png b/tests/assets/out.png new file mode 100644 index 0000000000000000000000000000000000000000..b437f965656d137a2ef25a446d1a07c21c7aebc6 Binary files /dev/null and b/tests/assets/out.png differ diff --git a/tests/run_tests.sh b/tests/run_tests.sh new file mode 100644 index 0000000000000000000000000000000000000000..f47058d882693f633d94f6e1a5bc43d84940ecaf --- /dev/null +++ b/tests/run_tests.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +echo "Starting the script to run tests..." + +# Build the model +echo "Building the model with cog..." +sudo cog build -t test-model --use-cog-base-image +echo "Model build completed." + +# Stop and remove the existing container if it's running +container_name='cog-test' +echo "Checking if the container '$container_name' is already running..." +if sudo docker inspect --format="{{.State.Running}}" "$container_name" &> /dev/null; then + echo "Container '$container_name' is running. Stopping and removing..." + sudo docker stop "$container_name" + sudo docker rm "$container_name" + echo "Container '$container_name' stopped and removed successfully." +else + echo "Container '$container_name' not found or not running. Proceeding to run a new instance." +fi + +# Run the container +echo "Running the container '$container_name'..." +sudo docker run -d -p 5000:5000 --gpus all --name "$container_name" test-model +echo "Container '$container_name' is now running." + +# Wait for the server to be ready +echo "Waiting for the server to be ready..." +sleep 10 +echo "Server should be ready now." + +# Set the environment variable for local testing +echo "Setting environment variable for local testing..." +export TEST_ENV=local +echo "Environment variable set." + +# Run the specific test case +echo "Running the test case: test_seeded_prediction..." +pytest -vv tests/test_predict.py::test_seeded_prediction +echo "Test case execution completed." + +# Stop the container +echo "Stopping the container '$container_name'..." +sudo docker stop "$container_name" +echo "Container '$container_name' stopped. Script execution completed." \ No newline at end of file diff --git a/tests/test_predict.py b/tests/test_predict.py new file mode 100644 index 0000000000000000000000000000000000000000..04dff43699617d53f01df66aea8be89e38233620 --- /dev/null +++ b/tests/test_predict.py @@ -0,0 +1,108 @@ +import base64 +from io import BytesIO +import os +import time + +import numpy as np +from PIL import Image, ImageChops +import pytest +import requests + + +def local_run(model_endpoint: str, model_input: dict): + # Maximum wait time in seconds + max_wait_time = 1000 + # Interval between status checks in seconds + retry_interval = 100 + + total_wait_time = 0 + while total_wait_time < max_wait_time: + response = requests.post(model_endpoint, json={"input": model_input}) + data = response.json() + + if "output" in data: + try: + datauri = data["output"][0] + base64_encoded_data = datauri.split(",")[1] + decoded_data = base64.b64decode(base64_encoded_data) + return Image.open(BytesIO(decoded_data)) + except Exception as e: + print("Error while processing output:") + print("input:", model_input) + print(data) + raise e + elif "detail" in data and data["detail"] == "Already running a prediction": + print(f"Prediction in progress, waited {total_wait_time}s, waiting more...") + time.sleep(retry_interval) + total_wait_time += retry_interval + else: + print("Unexpected response data:", data) + break + else: + raise Exception("Max wait time exceeded, unable to get valid response") + + +def image_equal_fuzzy(img_expected, img_actual, test_name="default", tol=20): + """ + Assert that average pixel values differ by less than tol across image + Tol determined empirically - holding everything else equal but varying seed + generates images that vary by at least 50 + """ + img1 = np.array(img_expected, dtype=np.int32) + img2 = np.array(img_actual, dtype=np.int32) + + mean_delta = np.mean(np.abs(img1 - img2)) + imgs_equal = mean_delta < tol + if not imgs_equal: + # save failures for quick inspection + save_dir = f"/tmp/{test_name}" + if not os.path.exists(save_dir): + os.makedirs(save_dir) + img_expected.save(os.path.join(save_dir, "expected.png")) + img_actual.save(os.path.join(save_dir, "actual.png")) + difference = ImageChops.difference(img_expected, img_actual) + difference.save(os.path.join(save_dir, "delta.png")) + + return imgs_equal + + +@pytest.fixture +def expected_image(): + return Image.open("tests/assets/out.png") + + +def test_seeded_prediction(expected_image): + data = { + "image": "https://replicate.delivery/pbxt/KIIutO7jIleskKaWebhvurgBUlHR6M6KN7KHaMMWSt4OnVrF/musk_resize.jpeg", + "prompt": "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality", + "scheduler": "EulerDiscreteScheduler", + "enable_lcm": False, + "pose_image": "https://replicate.delivery/pbxt/KJmFdQRQVDXGDVdVXftLvFrrvgOPXXRXbzIVEyExPYYOFPyF/80048a6e6586759dbcb529e74a9042ca.jpeg", + "sdxl_weights": "protovision-xl-high-fidel", + "pose_strength": 0.4, + "canny_strength": 0.3, + "depth_strength": 0.5, + "guidance_scale": 5, + "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured", + "ip_adapter_scale": 0.8, + "lcm_guidance_scale": 1.5, + "num_inference_steps": 30, + "enable_pose_controlnet": True, + "enhance_nonface_region": True, + "enable_canny_controlnet": False, + "enable_depth_controlnet": False, + "lcm_num_inference_steps": 5, + "controlnet_conditioning_scale": 0.8, + "seed": 1337, + } + + actual_image = local_run("http://localhost:5000/predictions", data) + expected_image = Image.open("tests/assets/out.png") + test_result = image_equal_fuzzy( + actual_image, expected_image, test_name="test_seeded_prediction" + ) + if test_result: + print("Test passed successfully.") + else: + print("Test failed.") + assert test_result