dcgan_test.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for dcgan."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. from six.moves import xrange
  20. import tensorflow as tf
  21. from nets import dcgan
  22. class DCGANTest(tf.test.TestCase):
  23. def test_generator_run(self):
  24. tf.set_random_seed(1234)
  25. noise = tf.random_normal([100, 64])
  26. image, _ = dcgan.generator(noise)
  27. with self.test_session() as sess:
  28. sess.run(tf.global_variables_initializer())
  29. image.eval()
  30. def test_generator_graph(self):
  31. tf.set_random_seed(1234)
  32. # Check graph construction for a number of image size/depths and batch
  33. # sizes.
  34. for i, batch_size in zip(xrange(3, 7), xrange(3, 8)):
  35. tf.reset_default_graph()
  36. final_size = 2 ** i
  37. noise = tf.random_normal([batch_size, 64])
  38. image, end_points = dcgan.generator(
  39. noise,
  40. depth=32,
  41. final_size=final_size)
  42. self.assertAllEqual([batch_size, final_size, final_size, 3],
  43. image.shape.as_list())
  44. expected_names = ['deconv%i' % j for j in xrange(1, i)] + ['logits']
  45. self.assertSetEqual(set(expected_names), set(end_points.keys()))
  46. # Check layer depths.
  47. for j in range(1, i):
  48. layer = end_points['deconv%i' % j]
  49. self.assertEqual(32 * 2**(i-j-1), layer.get_shape().as_list()[-1])
  50. def test_generator_invalid_input(self):
  51. wrong_dim_input = tf.zeros([5, 32, 32])
  52. with self.assertRaises(ValueError):
  53. dcgan.generator(wrong_dim_input)
  54. correct_input = tf.zeros([3, 2])
  55. with self.assertRaisesRegexp(ValueError, 'must be a power of 2'):
  56. dcgan.generator(correct_input, final_size=30)
  57. with self.assertRaisesRegexp(ValueError, 'must be greater than 8'):
  58. dcgan.generator(correct_input, final_size=4)
  59. def test_discriminator_run(self):
  60. image = tf.random_uniform([5, 32, 32, 3], -1, 1)
  61. output, _ = dcgan.discriminator(image)
  62. with self.test_session() as sess:
  63. sess.run(tf.global_variables_initializer())
  64. output.eval()
  65. def test_discriminator_graph(self):
  66. # Check graph construction for a number of image size/depths and batch
  67. # sizes.
  68. for i, batch_size in zip(xrange(1, 6), xrange(3, 8)):
  69. tf.reset_default_graph()
  70. img_w = 2 ** i
  71. image = tf.random_uniform([batch_size, img_w, img_w, 3], -1, 1)
  72. output, end_points = dcgan.discriminator(
  73. image,
  74. depth=32)
  75. self.assertAllEqual([batch_size, 1], output.get_shape().as_list())
  76. expected_names = ['conv%i' % j for j in xrange(1, i+1)] + ['logits']
  77. self.assertSetEqual(set(expected_names), set(end_points.keys()))
  78. # Check layer depths.
  79. for j in range(1, i+1):
  80. layer = end_points['conv%i' % j]
  81. self.assertEqual(32 * 2**(j-1), layer.get_shape().as_list()[-1])
  82. def test_discriminator_invalid_input(self):
  83. wrong_dim_img = tf.zeros([5, 32, 32])
  84. with self.assertRaises(ValueError):
  85. dcgan.discriminator(wrong_dim_img)
  86. spatially_undefined_shape = tf.placeholder(tf.float32, [5, 32, None, 3])
  87. with self.assertRaises(ValueError):
  88. dcgan.discriminator(spatially_undefined_shape)
  89. not_square = tf.zeros([5, 32, 16, 3])
  90. with self.assertRaisesRegexp(ValueError, 'not have equal width and height'):
  91. dcgan.discriminator(not_square)
  92. not_power_2 = tf.zeros([5, 30, 30, 3])
  93. with self.assertRaisesRegexp(ValueError, 'not a power of 2'):
  94. dcgan.discriminator(not_power_2)
  95. if __name__ == '__main__':
  96. tf.test.main()