wahaha commited on
Commit
c9dfbd7
1 Parent(s): 41224f4
Files changed (1) hide show
  1. test1.py +5 -1
test1.py CHANGED
@@ -11,7 +11,7 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
11
  class ImportGraph:
12
  def __init__(self, checkpoint_dir):
13
  self.graph = tf.Graph()
14
- self.sess = tf.Session(graph=self.graph, config=tf.ConfigProto(allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True)))
15
  with self.graph.as_default():
16
 
17
  test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
@@ -20,6 +20,10 @@ class ImportGraph:
20
 
21
  self.saver = tf.train.Saver()
22
 
 
 
 
 
23
  ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information
24
  if ckpt and ckpt.model_checkpoint_path:
25
  ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # first line
 
11
  class ImportGraph:
12
  def __init__(self, checkpoint_dir):
13
  self.graph = tf.Graph()
14
+
15
  with self.graph.as_default():
16
 
17
  test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
 
20
 
21
  self.saver = tf.train.Saver()
22
 
23
+ self.sess = tf.Session(graph=self.graph, config=tf.ConfigProto(allow_soft_placement=True,
24
+ gpu_options=tf.GPUOptions(
25
+ allow_growth=True)))
26
+
27
  ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information
28
  if ckpt and ckpt.model_checkpoint_path:
29
  ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # first line