inception_v1_test.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for nets.inception_v1."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import numpy as np
  20. import tensorflow as tf
  21. from nets import inception
  22. slim = tf.contrib.slim
  23. class InceptionV1Test(tf.test.TestCase):
  24. def testBuildClassificationNetwork(self):
  25. batch_size = 5
  26. height, width = 224, 224
  27. num_classes = 1000
  28. inputs = tf.random_uniform((batch_size, height, width, 3))
  29. logits, end_points = inception.inception_v1(inputs, num_classes)
  30. self.assertTrue(logits.op.name.startswith(
  31. 'InceptionV1/Logits/SpatialSqueeze'))
  32. self.assertListEqual(logits.get_shape().as_list(),
  33. [batch_size, num_classes])
  34. self.assertTrue('Predictions' in end_points)
  35. self.assertListEqual(end_points['Predictions'].get_shape().as_list(),
  36. [batch_size, num_classes])
  37. def testBuildPreLogitsNetwork(self):
  38. batch_size = 5
  39. height, width = 224, 224
  40. num_classes = None
  41. inputs = tf.random_uniform((batch_size, height, width, 3))
  42. net, end_points = inception.inception_v1(inputs, num_classes)
  43. self.assertTrue(net.op.name.startswith('InceptionV1/Logits/AvgPool'))
  44. self.assertListEqual(net.get_shape().as_list(), [batch_size, 1, 1, 1024])
  45. self.assertFalse('Logits' in end_points)
  46. self.assertFalse('Predictions' in end_points)
  47. def testBuildBaseNetwork(self):
  48. batch_size = 5
  49. height, width = 224, 224
  50. inputs = tf.random_uniform((batch_size, height, width, 3))
  51. mixed_6c, end_points = inception.inception_v1_base(inputs)
  52. self.assertTrue(mixed_6c.op.name.startswith('InceptionV1/Mixed_5c'))
  53. self.assertListEqual(mixed_6c.get_shape().as_list(),
  54. [batch_size, 7, 7, 1024])
  55. expected_endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
  56. 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b',
  57. 'Mixed_3c', 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c',
  58. 'Mixed_4d', 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2',
  59. 'Mixed_5b', 'Mixed_5c']
  60. self.assertItemsEqual(end_points.keys(), expected_endpoints)
  61. def testBuildOnlyUptoFinalEndpoint(self):
  62. batch_size = 5
  63. height, width = 224, 224
  64. endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
  65. 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
  66. 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d',
  67. 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b',
  68. 'Mixed_5c']
  69. for index, endpoint in enumerate(endpoints):
  70. with tf.Graph().as_default():
  71. inputs = tf.random_uniform((batch_size, height, width, 3))
  72. out_tensor, end_points = inception.inception_v1_base(
  73. inputs, final_endpoint=endpoint)
  74. self.assertTrue(out_tensor.op.name.startswith(
  75. 'InceptionV1/' + endpoint))
  76. self.assertItemsEqual(endpoints[:index+1], end_points)
  77. def testBuildAndCheckAllEndPointsUptoMixed5c(self):
  78. batch_size = 5
  79. height, width = 224, 224
  80. inputs = tf.random_uniform((batch_size, height, width, 3))
  81. _, end_points = inception.inception_v1_base(inputs,
  82. final_endpoint='Mixed_5c')
  83. endpoints_shapes = {'Conv2d_1a_7x7': [5, 112, 112, 64],
  84. 'MaxPool_2a_3x3': [5, 56, 56, 64],
  85. 'Conv2d_2b_1x1': [5, 56, 56, 64],
  86. 'Conv2d_2c_3x3': [5, 56, 56, 192],
  87. 'MaxPool_3a_3x3': [5, 28, 28, 192],
  88. 'Mixed_3b': [5, 28, 28, 256],
  89. 'Mixed_3c': [5, 28, 28, 480],
  90. 'MaxPool_4a_3x3': [5, 14, 14, 480],
  91. 'Mixed_4b': [5, 14, 14, 512],
  92. 'Mixed_4c': [5, 14, 14, 512],
  93. 'Mixed_4d': [5, 14, 14, 512],
  94. 'Mixed_4e': [5, 14, 14, 528],
  95. 'Mixed_4f': [5, 14, 14, 832],
  96. 'MaxPool_5a_2x2': [5, 7, 7, 832],
  97. 'Mixed_5b': [5, 7, 7, 832],
  98. 'Mixed_5c': [5, 7, 7, 1024]}
  99. self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
  100. for endpoint_name in endpoints_shapes:
  101. expected_shape = endpoints_shapes[endpoint_name]
  102. self.assertTrue(endpoint_name in end_points)
  103. self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
  104. expected_shape)
  105. def testModelHasExpectedNumberOfParameters(self):
  106. batch_size = 5
  107. height, width = 224, 224
  108. inputs = tf.random_uniform((batch_size, height, width, 3))
  109. with slim.arg_scope(inception.inception_v1_arg_scope()):
  110. inception.inception_v1_base(inputs)
  111. total_params, _ = slim.model_analyzer.analyze_vars(
  112. slim.get_model_variables())
  113. self.assertAlmostEqual(5607184, total_params)
  114. def testHalfSizeImages(self):
  115. batch_size = 5
  116. height, width = 112, 112
  117. inputs = tf.random_uniform((batch_size, height, width, 3))
  118. mixed_5c, _ = inception.inception_v1_base(inputs)
  119. self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c'))
  120. self.assertListEqual(mixed_5c.get_shape().as_list(),
  121. [batch_size, 4, 4, 1024])
  122. def testUnknownImageShape(self):
  123. tf.reset_default_graph()
  124. batch_size = 2
  125. height, width = 224, 224
  126. num_classes = 1000
  127. input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
  128. with self.test_session() as sess:
  129. inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
  130. logits, end_points = inception.inception_v1(inputs, num_classes)
  131. self.assertTrue(logits.op.name.startswith('InceptionV1/Logits'))
  132. self.assertListEqual(logits.get_shape().as_list(),
  133. [batch_size, num_classes])
  134. pre_pool = end_points['Mixed_5c']
  135. feed_dict = {inputs: input_np}
  136. tf.global_variables_initializer().run()
  137. pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
  138. self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024])
  139. def testGlobalPoolUnknownImageShape(self):
  140. tf.reset_default_graph()
  141. batch_size = 2
  142. height, width = 300, 400
  143. num_classes = 1000
  144. input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
  145. with self.test_session() as sess:
  146. inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
  147. logits, end_points = inception.inception_v1(inputs, num_classes,
  148. global_pool=True)
  149. self.assertTrue(logits.op.name.startswith('InceptionV1/Logits'))
  150. self.assertListEqual(logits.get_shape().as_list(),
  151. [batch_size, num_classes])
  152. pre_pool = end_points['Mixed_5c']
  153. feed_dict = {inputs: input_np}
  154. tf.global_variables_initializer().run()
  155. pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
  156. self.assertListEqual(list(pre_pool_out.shape), [batch_size, 10, 13, 1024])
  157. def testUnknowBatchSize(self):
  158. batch_size = 1
  159. height, width = 224, 224
  160. num_classes = 1000
  161. inputs = tf.placeholder(tf.float32, (None, height, width, 3))
  162. logits, _ = inception.inception_v1(inputs, num_classes)
  163. self.assertTrue(logits.op.name.startswith('InceptionV1/Logits'))
  164. self.assertListEqual(logits.get_shape().as_list(),
  165. [None, num_classes])
  166. images = tf.random_uniform((batch_size, height, width, 3))
  167. with self.test_session() as sess:
  168. sess.run(tf.global_variables_initializer())
  169. output = sess.run(logits, {inputs: images.eval()})
  170. self.assertEquals(output.shape, (batch_size, num_classes))
  171. def testEvaluation(self):
  172. batch_size = 2
  173. height, width = 224, 224
  174. num_classes = 1000
  175. eval_inputs = tf.random_uniform((batch_size, height, width, 3))
  176. logits, _ = inception.inception_v1(eval_inputs, num_classes,
  177. is_training=False)
  178. predictions = tf.argmax(logits, 1)
  179. with self.test_session() as sess:
  180. sess.run(tf.global_variables_initializer())
  181. output = sess.run(predictions)
  182. self.assertEquals(output.shape, (batch_size,))
  183. def testTrainEvalWithReuse(self):
  184. train_batch_size = 5
  185. eval_batch_size = 2
  186. height, width = 224, 224
  187. num_classes = 1000
  188. train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
  189. inception.inception_v1(train_inputs, num_classes)
  190. eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
  191. logits, _ = inception.inception_v1(eval_inputs, num_classes, reuse=True)
  192. predictions = tf.argmax(logits, 1)
  193. with self.test_session() as sess:
  194. sess.run(tf.global_variables_initializer())
  195. output = sess.run(predictions)
  196. self.assertEquals(output.shape, (eval_batch_size,))
  197. def testLogitsNotSqueezed(self):
  198. num_classes = 25
  199. images = tf.random_uniform([1, 224, 224, 3])
  200. logits, _ = inception.inception_v1(images,
  201. num_classes=num_classes,
  202. spatial_squeeze=False)
  203. with self.test_session() as sess:
  204. tf.global_variables_initializer().run()
  205. logits_out = sess.run(logits)
  206. self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
  207. if __name__ == '__main__':
  208. tf.test.main()