Spaces:
Runtime error
Runtime error
init
Browse files
app.py
CHANGED
@@ -45,20 +45,14 @@ def parse_args() -> argparse.Namespace:
|
|
45 |
return parser.parse_args()
|
46 |
|
47 |
|
48 |
-
|
49 |
-
|
50 |
def run(
|
51 |
image,
|
52 |
-
|
53 |
-
hayao: ImportGraph,
|
54 |
-
paprika: ImportGraph,
|
55 |
-
) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
#im3 = paprika.test('paprika', image.name, True)
|
60 |
|
61 |
-
return PIL.Image.open(
|
62 |
|
63 |
|
64 |
def main():
|
@@ -66,13 +60,13 @@ def main():
|
|
66 |
|
67 |
args = parse_args()
|
68 |
|
69 |
-
curPath = os.path.abspath(os.path.dirname(__file__))
|
70 |
#init
|
71 |
-
shinkai = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Shinkai_weight'))
|
72 |
#hayao = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Hayao_weight'))
|
73 |
#paprika = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Paprika_weight'))
|
74 |
|
75 |
-
func = functools.partial(run
|
76 |
func = functools.update_wrapper(func, run)
|
77 |
|
78 |
|
@@ -84,13 +78,8 @@ def main():
|
|
84 |
[
|
85 |
gr.outputs.Image(
|
86 |
type='pil',
|
87 |
-
label='
|
88 |
-
|
89 |
-
type='pil',
|
90 |
-
label='Hayao Result'),
|
91 |
-
gr.outputs.Image(
|
92 |
-
type='pil',
|
93 |
-
label='Paprika Result'),
|
94 |
],
|
95 |
#examples=examples,
|
96 |
theme=args.theme,
|
|
|
45 |
return parser.parse_args()
|
46 |
|
47 |
|
|
|
|
|
48 |
def run(
|
49 |
image,
|
50 |
+
) -> tuple[PIL.Image.Image]:
|
|
|
|
|
|
|
51 |
|
52 |
+
out = test.test(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Shinkai_weight'),
|
53 |
+
style_name='Shinkai', test_file=image.name, if_adjust_brightness=True)
|
|
|
54 |
|
55 |
+
return PIL.Image.open(out)
|
56 |
|
57 |
|
58 |
def main():
|
|
|
60 |
|
61 |
args = parse_args()
|
62 |
|
63 |
+
#curPath = os.path.abspath(os.path.dirname(__file__))
|
64 |
#init
|
65 |
+
#shinkai = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Shinkai_weight'))
|
66 |
#hayao = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Hayao_weight'))
|
67 |
#paprika = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Paprika_weight'))
|
68 |
|
69 |
+
func = functools.partial(run)
|
70 |
func = functools.update_wrapper(func, run)
|
71 |
|
72 |
|
|
|
78 |
[
|
79 |
gr.outputs.Image(
|
80 |
type='pil',
|
81 |
+
label='Result'),
|
82 |
+
|
|
|
|
|
|
|
|
|
|
|
83 |
],
|
84 |
#examples=examples,
|
85 |
theme=args.theme,
|
test1.py
CHANGED
@@ -53,23 +53,27 @@ def stats_graph(graph):
|
|
53 |
# params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
|
54 |
print('FLOPs: {}'.format(flops.total_float_ops))
|
55 |
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
# tf.reset_default_graph()
|
58 |
result_dir = 'results/'+style_name
|
59 |
check_folder(result_dir)
|
60 |
-
test_files = [test_dir]
|
61 |
|
62 |
-
|
|
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
|
68 |
-
|
|
|
69 |
|
70 |
-
gpu_options = tf.GPUOptions(allow_growth=True)
|
71 |
-
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)) as sess:
|
72 |
-
# tf.global_variables_initializer().run()
|
73 |
# load model
|
74 |
ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information
|
75 |
if ckpt and ckpt.model_checkpoint_path:
|
@@ -81,22 +85,20 @@ def test(checkpoint_dir, style_name, test_dir, if_adjust_brightness, img_size=[2
|
|
81 |
return
|
82 |
# stats_graph(tf.get_default_graph())
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
return out_paths
|
100 |
|
101 |
if __name__ == '__main__':
|
102 |
arg = parse_args()
|
|
|
53 |
# params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
|
54 |
print('FLOPs: {}'.format(flops.total_float_ops))
|
55 |
|
56 |
+
g_sess = None
|
57 |
+
test_generated = None
|
58 |
+
|
59 |
+
def test(checkpoint_dir, style_name, test_file, if_adjust_brightness, img_size=[256,256]):
|
60 |
+
global g_sess
|
61 |
+
global test_generated
|
62 |
+
|
63 |
# tf.reset_default_graph()
|
64 |
result_dir = 'results/'+style_name
|
65 |
check_folder(result_dir)
|
|
|
66 |
|
67 |
+
if g_sess is None:
|
68 |
+
test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
|
69 |
|
70 |
+
with tf.variable_scope("generator", reuse=False):
|
71 |
+
test_generated = generator.G_net(test_real).fake
|
72 |
+
saver = tf.train.Saver()
|
73 |
|
74 |
+
gpu_options = tf.GPUOptions(allow_growth=True)
|
75 |
+
g_sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options))
|
76 |
|
|
|
|
|
|
|
77 |
# load model
|
78 |
ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information
|
79 |
if ckpt and ckpt.model_checkpoint_path:
|
|
|
85 |
return
|
86 |
# stats_graph(tf.get_default_graph())
|
87 |
|
88 |
+
begin = time.time()
|
89 |
+
# print('Processing image: ' + sample_file)
|
90 |
+
sample_image = np.asarray(load_test_data(test_file, img_size))
|
91 |
+
image_path = os.path.join(result_dir,'{0}'.format(os.path.basename(test_file)))
|
92 |
+
fake_img = g_sess.run(test_generated, feed_dict = {test_real : sample_image})
|
93 |
+
if if_adjust_brightness:
|
94 |
+
save_images(fake_img, image_path, test_file)
|
95 |
+
else:
|
96 |
+
save_images(fake_img, image_path, None)
|
97 |
+
|
98 |
+
end = time.time()
|
99 |
+
print(f'test-time: {end-begin} s')
|
100 |
+
|
101 |
+
return image_path
|
|
|
|
|
102 |
|
103 |
if __name__ == '__main__':
|
104 |
arg = parse_args()
|