cyclegan_test.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for tensorflow.contrib.slim.nets.cyclegan."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from nets import cyclegan
  21. # TODO(joelshor): Add a test to check generator endpoints.
  22. class CycleganTest(tf.test.TestCase):
  23. def test_generator_inference(self):
  24. """Check one inference step."""
  25. img_batch = tf.zeros([2, 32, 32, 3])
  26. model_output, _ = cyclegan.cyclegan_generator_resnet(img_batch)
  27. with self.test_session() as sess:
  28. sess.run(tf.global_variables_initializer())
  29. sess.run(model_output)
  30. def _test_generator_graph_helper(self, shape):
  31. """Check that generator can take small and non-square inputs."""
  32. output_imgs, _ = cyclegan.cyclegan_generator_resnet(tf.ones(shape))
  33. self.assertAllEqual(shape, output_imgs.shape.as_list())
  34. def test_generator_graph_small(self):
  35. self._test_generator_graph_helper([4, 32, 32, 3])
  36. def test_generator_graph_medium(self):
  37. self._test_generator_graph_helper([3, 128, 128, 3])
  38. def test_generator_graph_nonsquare(self):
  39. self._test_generator_graph_helper([2, 80, 400, 3])
  40. def test_generator_unknown_batch_dim(self):
  41. """Check that generator can take unknown batch dimension inputs."""
  42. img = tf.placeholder(tf.float32, shape=[None, 32, None, 3])
  43. output_imgs, _ = cyclegan.cyclegan_generator_resnet(img)
  44. self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list())
  45. def _input_and_output_same_shape_helper(self, kernel_size):
  46. img_batch = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
  47. output_img_batch, _ = cyclegan.cyclegan_generator_resnet(
  48. img_batch, kernel_size=kernel_size)
  49. self.assertAllEqual(img_batch.shape.as_list(),
  50. output_img_batch.shape.as_list())
  51. def input_and_output_same_shape_kernel3(self):
  52. self._input_and_output_same_shape_helper(3)
  53. def input_and_output_same_shape_kernel4(self):
  54. self._input_and_output_same_shape_helper(4)
  55. def input_and_output_same_shape_kernel5(self):
  56. self._input_and_output_same_shape_helper(5)
  57. def input_and_output_same_shape_kernel6(self):
  58. self._input_and_output_same_shape_helper(6)
  59. def _error_if_height_not_multiple_of_four_helper(self, height):
  60. self.assertRaisesRegexp(
  61. ValueError,
  62. 'The input height must be a multiple of 4.',
  63. cyclegan.cyclegan_generator_resnet,
  64. tf.placeholder(tf.float32, shape=[None, height, 32, 3]))
  65. def test_error_if_height_not_multiple_of_four_height29(self):
  66. self._error_if_height_not_multiple_of_four_helper(29)
  67. def test_error_if_height_not_multiple_of_four_height30(self):
  68. self._error_if_height_not_multiple_of_four_helper(30)
  69. def test_error_if_height_not_multiple_of_four_height31(self):
  70. self._error_if_height_not_multiple_of_four_helper(31)
  71. def _error_if_width_not_multiple_of_four_helper(self, width):
  72. self.assertRaisesRegexp(
  73. ValueError,
  74. 'The input width must be a multiple of 4.',
  75. cyclegan.cyclegan_generator_resnet,
  76. tf.placeholder(tf.float32, shape=[None, 32, width, 3]))
  77. def test_error_if_width_not_multiple_of_four_width29(self):
  78. self._error_if_width_not_multiple_of_four_helper(29)
  79. def test_error_if_width_not_multiple_of_four_width30(self):
  80. self._error_if_width_not_multiple_of_four_helper(30)
  81. def test_error_if_width_not_multiple_of_four_width31(self):
  82. self._error_if_width_not_multiple_of_four_helper(31)
  83. if __name__ == '__main__':
  84. tf.test.main()