File size: 863 Bytes
cfb7702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from PIL import Image
from pathlib import Path
from omegaconf import OmegaConf

from scripts.demo.streamlit_helpers import (
    load_model_from_config,
    get_sampler,
    get_batch,
    do_sample,
)


def load_config_and_model(ckpt: Path):
    if (ckpt.parent.parent / "configs").exists():
        config_path = list((ckpt.parent.parent / "configs").glob("*-project.yaml"))[0]
    else:
        config_path = list(
            (ckpt.parent.parent.parent / "configs").glob("*-project.yaml")
        )[0]

    config = OmegaConf.load(config_path)

    model, msg = load_model_from_config(config, ckpt)

    return config, model


def load_sampler(sampler_cfg):
    return get_sampler(**sampler_cfg)


def load_batch():
    pass


class DiffusionEngine:
    def __init__(self, cfg) -> None:
        self.cfg = cfg

    def sample(self):
        pass