pix2pix_test.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # =============================================================================
  15. """Tests for pix2pix."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from nets import pix2pix
  21. class GeneratorTest(tf.test.TestCase):
  22. def test_nonsquare_inputs_raise_exception(self):
  23. batch_size = 2
  24. height, width = 240, 320
  25. num_outputs = 4
  26. images = tf.ones((batch_size, height, width, 3))
  27. with self.assertRaises(ValueError):
  28. with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
  29. pix2pix.pix2pix_generator(
  30. images, num_outputs, upsample_method='nn_upsample_conv')
  31. def _reduced_default_blocks(self):
  32. """Returns the default blocks, scaled down to make test run faster."""
  33. return [pix2pix.Block(b.num_filters // 32, b.decoder_keep_prob)
  34. for b in pix2pix._default_generator_blocks()]
  35. def test_output_size_nn_upsample_conv(self):
  36. batch_size = 2
  37. height, width = 256, 256
  38. num_outputs = 4
  39. images = tf.ones((batch_size, height, width, 3))
  40. with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
  41. logits, _ = pix2pix.pix2pix_generator(
  42. images, num_outputs, blocks=self._reduced_default_blocks(),
  43. upsample_method='nn_upsample_conv')
  44. with self.test_session() as session:
  45. session.run(tf.global_variables_initializer())
  46. np_outputs = session.run(logits)
  47. self.assertListEqual([batch_size, height, width, num_outputs],
  48. list(np_outputs.shape))
  49. def test_output_size_conv2d_transpose(self):
  50. batch_size = 2
  51. height, width = 256, 256
  52. num_outputs = 4
  53. images = tf.ones((batch_size, height, width, 3))
  54. with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
  55. logits, _ = pix2pix.pix2pix_generator(
  56. images, num_outputs, blocks=self._reduced_default_blocks(),
  57. upsample_method='conv2d_transpose')
  58. with self.test_session() as session:
  59. session.run(tf.global_variables_initializer())
  60. np_outputs = session.run(logits)
  61. self.assertListEqual([batch_size, height, width, num_outputs],
  62. list(np_outputs.shape))
  63. def test_block_number_dictates_number_of_layers(self):
  64. batch_size = 2
  65. height, width = 256, 256
  66. num_outputs = 4
  67. images = tf.ones((batch_size, height, width, 3))
  68. blocks = [
  69. pix2pix.Block(64, 0.5),
  70. pix2pix.Block(128, 0),
  71. ]
  72. with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
  73. _, end_points = pix2pix.pix2pix_generator(
  74. images, num_outputs, blocks)
  75. num_encoder_layers = 0
  76. num_decoder_layers = 0
  77. for end_point in end_points:
  78. if end_point.startswith('encoder'):
  79. num_encoder_layers += 1
  80. elif end_point.startswith('decoder'):
  81. num_decoder_layers += 1
  82. self.assertEqual(num_encoder_layers, len(blocks))
  83. self.assertEqual(num_decoder_layers, len(blocks))
  84. class DiscriminatorTest(tf.test.TestCase):
  85. def _layer_output_size(self, input_size, kernel_size=4, stride=2, pad=2):
  86. return (input_size + pad * 2 - kernel_size) // stride + 1
  87. def test_four_layers(self):
  88. batch_size = 2
  89. input_size = 256
  90. output_size = self._layer_output_size(input_size)
  91. output_size = self._layer_output_size(output_size)
  92. output_size = self._layer_output_size(output_size)
  93. output_size = self._layer_output_size(output_size, stride=1)
  94. output_size = self._layer_output_size(output_size, stride=1)
  95. images = tf.ones((batch_size, input_size, input_size, 3))
  96. with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
  97. logits, end_points = pix2pix.pix2pix_discriminator(
  98. images, num_filters=[64, 128, 256, 512])
  99. self.assertListEqual([batch_size, output_size, output_size, 1],
  100. logits.shape.as_list())
  101. self.assertListEqual([batch_size, output_size, output_size, 1],
  102. end_points['predictions'].shape.as_list())
  103. def test_four_layers_no_padding(self):
  104. batch_size = 2
  105. input_size = 256
  106. output_size = self._layer_output_size(input_size, pad=0)
  107. output_size = self._layer_output_size(output_size, pad=0)
  108. output_size = self._layer_output_size(output_size, pad=0)
  109. output_size = self._layer_output_size(output_size, stride=1, pad=0)
  110. output_size = self._layer_output_size(output_size, stride=1, pad=0)
  111. images = tf.ones((batch_size, input_size, input_size, 3))
  112. with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
  113. logits, end_points = pix2pix.pix2pix_discriminator(
  114. images, num_filters=[64, 128, 256, 512], padding=0)
  115. self.assertListEqual([batch_size, output_size, output_size, 1],
  116. logits.shape.as_list())
  117. self.assertListEqual([batch_size, output_size, output_size, 1],
  118. end_points['predictions'].shape.as_list())
  119. def test_four_layers_wrog_paddig(self):
  120. batch_size = 2
  121. input_size = 256
  122. images = tf.ones((batch_size, input_size, input_size, 3))
  123. with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
  124. with self.assertRaises(TypeError):
  125. pix2pix.pix2pix_discriminator(
  126. images, num_filters=[64, 128, 256, 512], padding=1.5)
  127. def test_four_layers_negative_padding(self):
  128. batch_size = 2
  129. input_size = 256
  130. images = tf.ones((batch_size, input_size, input_size, 3))
  131. with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
  132. with self.assertRaises(ValueError):
  133. pix2pix.pix2pix_discriminator(
  134. images, num_filters=[64, 128, 256, 512], padding=-1)
  135. if __name__ == '__main__':
  136. tf.test.main()