inception_v3_test.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for nets.inception_v1."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import numpy as np
  20. import tensorflow as tf
  21. from nets import inception
  22. slim = tf.contrib.slim
  23. class InceptionV3Test(tf.test.TestCase):
  24. def testBuildClassificationNetwork(self):
  25. batch_size = 5
  26. height, width = 299, 299
  27. num_classes = 1000
  28. inputs = tf.random_uniform((batch_size, height, width, 3))
  29. logits, end_points = inception.inception_v3(inputs, num_classes)
  30. self.assertTrue(logits.op.name.startswith(
  31. 'InceptionV3/Logits/SpatialSqueeze'))
  32. self.assertListEqual(logits.get_shape().as_list(),
  33. [batch_size, num_classes])
  34. self.assertTrue('Predictions' in end_points)
  35. self.assertListEqual(end_points['Predictions'].get_shape().as_list(),
  36. [batch_size, num_classes])
  37. def testBuildPreLogitsNetwork(self):
  38. batch_size = 5
  39. height, width = 299, 299
  40. num_classes = None
  41. inputs = tf.random_uniform((batch_size, height, width, 3))
  42. net, end_points = inception.inception_v3(inputs, num_classes)
  43. self.assertTrue(net.op.name.startswith('InceptionV3/Logits/AvgPool'))
  44. self.assertListEqual(net.get_shape().as_list(), [batch_size, 1, 1, 2048])
  45. self.assertFalse('Logits' in end_points)
  46. self.assertFalse('Predictions' in end_points)
  47. def testBuildBaseNetwork(self):
  48. batch_size = 5
  49. height, width = 299, 299
  50. inputs = tf.random_uniform((batch_size, height, width, 3))
  51. final_endpoint, end_points = inception.inception_v3_base(inputs)
  52. self.assertTrue(final_endpoint.op.name.startswith(
  53. 'InceptionV3/Mixed_7c'))
  54. self.assertListEqual(final_endpoint.get_shape().as_list(),
  55. [batch_size, 8, 8, 2048])
  56. expected_endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
  57. 'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3',
  58. 'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
  59. 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
  60. 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c']
  61. self.assertItemsEqual(end_points.keys(), expected_endpoints)
  62. def testBuildOnlyUptoFinalEndpoint(self):
  63. batch_size = 5
  64. height, width = 299, 299
  65. endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
  66. 'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3',
  67. 'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
  68. 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
  69. 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c']
  70. for index, endpoint in enumerate(endpoints):
  71. with tf.Graph().as_default():
  72. inputs = tf.random_uniform((batch_size, height, width, 3))
  73. out_tensor, end_points = inception.inception_v3_base(
  74. inputs, final_endpoint=endpoint)
  75. self.assertTrue(out_tensor.op.name.startswith(
  76. 'InceptionV3/' + endpoint))
  77. self.assertItemsEqual(endpoints[:index+1], end_points)
  78. def testBuildAndCheckAllEndPointsUptoMixed7c(self):
  79. batch_size = 5
  80. height, width = 299, 299
  81. inputs = tf.random_uniform((batch_size, height, width, 3))
  82. _, end_points = inception.inception_v3_base(
  83. inputs, final_endpoint='Mixed_7c')
  84. endpoints_shapes = {'Conv2d_1a_3x3': [batch_size, 149, 149, 32],
  85. 'Conv2d_2a_3x3': [batch_size, 147, 147, 32],
  86. 'Conv2d_2b_3x3': [batch_size, 147, 147, 64],
  87. 'MaxPool_3a_3x3': [batch_size, 73, 73, 64],
  88. 'Conv2d_3b_1x1': [batch_size, 73, 73, 80],
  89. 'Conv2d_4a_3x3': [batch_size, 71, 71, 192],
  90. 'MaxPool_5a_3x3': [batch_size, 35, 35, 192],
  91. 'Mixed_5b': [batch_size, 35, 35, 256],
  92. 'Mixed_5c': [batch_size, 35, 35, 288],
  93. 'Mixed_5d': [batch_size, 35, 35, 288],
  94. 'Mixed_6a': [batch_size, 17, 17, 768],
  95. 'Mixed_6b': [batch_size, 17, 17, 768],
  96. 'Mixed_6c': [batch_size, 17, 17, 768],
  97. 'Mixed_6d': [batch_size, 17, 17, 768],
  98. 'Mixed_6e': [batch_size, 17, 17, 768],
  99. 'Mixed_7a': [batch_size, 8, 8, 1280],
  100. 'Mixed_7b': [batch_size, 8, 8, 2048],
  101. 'Mixed_7c': [batch_size, 8, 8, 2048]}
  102. self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
  103. for endpoint_name in endpoints_shapes:
  104. expected_shape = endpoints_shapes[endpoint_name]
  105. self.assertTrue(endpoint_name in end_points)
  106. self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
  107. expected_shape)
  108. def testModelHasExpectedNumberOfParameters(self):
  109. batch_size = 5
  110. height, width = 299, 299
  111. inputs = tf.random_uniform((batch_size, height, width, 3))
  112. with slim.arg_scope(inception.inception_v3_arg_scope()):
  113. inception.inception_v3_base(inputs)
  114. total_params, _ = slim.model_analyzer.analyze_vars(
  115. slim.get_model_variables())
  116. self.assertAlmostEqual(21802784, total_params)
  117. def testBuildEndPoints(self):
  118. batch_size = 5
  119. height, width = 299, 299
  120. num_classes = 1000
  121. inputs = tf.random_uniform((batch_size, height, width, 3))
  122. _, end_points = inception.inception_v3(inputs, num_classes)
  123. self.assertTrue('Logits' in end_points)
  124. logits = end_points['Logits']
  125. self.assertListEqual(logits.get_shape().as_list(),
  126. [batch_size, num_classes])
  127. self.assertTrue('AuxLogits' in end_points)
  128. aux_logits = end_points['AuxLogits']
  129. self.assertListEqual(aux_logits.get_shape().as_list(),
  130. [batch_size, num_classes])
  131. self.assertTrue('Mixed_7c' in end_points)
  132. pre_pool = end_points['Mixed_7c']
  133. self.assertListEqual(pre_pool.get_shape().as_list(),
  134. [batch_size, 8, 8, 2048])
  135. self.assertTrue('PreLogits' in end_points)
  136. pre_logits = end_points['PreLogits']
  137. self.assertListEqual(pre_logits.get_shape().as_list(),
  138. [batch_size, 1, 1, 2048])
  139. def testBuildEndPointsWithDepthMultiplierLessThanOne(self):
  140. batch_size = 5
  141. height, width = 299, 299
  142. num_classes = 1000
  143. inputs = tf.random_uniform((batch_size, height, width, 3))
  144. _, end_points = inception.inception_v3(inputs, num_classes)
  145. endpoint_keys = [key for key in end_points.keys()
  146. if key.startswith('Mixed') or key.startswith('Conv')]
  147. _, end_points_with_multiplier = inception.inception_v3(
  148. inputs, num_classes, scope='depth_multiplied_net',
  149. depth_multiplier=0.5)
  150. for key in endpoint_keys:
  151. original_depth = end_points[key].get_shape().as_list()[3]
  152. new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
  153. self.assertEqual(0.5 * original_depth, new_depth)
  154. def testBuildEndPointsWithDepthMultiplierGreaterThanOne(self):
  155. batch_size = 5
  156. height, width = 299, 299
  157. num_classes = 1000
  158. inputs = tf.random_uniform((batch_size, height, width, 3))
  159. _, end_points = inception.inception_v3(inputs, num_classes)
  160. endpoint_keys = [key for key in end_points.keys()
  161. if key.startswith('Mixed') or key.startswith('Conv')]
  162. _, end_points_with_multiplier = inception.inception_v3(
  163. inputs, num_classes, scope='depth_multiplied_net',
  164. depth_multiplier=2.0)
  165. for key in endpoint_keys:
  166. original_depth = end_points[key].get_shape().as_list()[3]
  167. new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
  168. self.assertEqual(2.0 * original_depth, new_depth)
  169. def testRaiseValueErrorWithInvalidDepthMultiplier(self):
  170. batch_size = 5
  171. height, width = 299, 299
  172. num_classes = 1000
  173. inputs = tf.random_uniform((batch_size, height, width, 3))
  174. with self.assertRaises(ValueError):
  175. _ = inception.inception_v3(inputs, num_classes, depth_multiplier=-0.1)
  176. with self.assertRaises(ValueError):
  177. _ = inception.inception_v3(inputs, num_classes, depth_multiplier=0.0)
  178. def testHalfSizeImages(self):
  179. batch_size = 5
  180. height, width = 150, 150
  181. num_classes = 1000
  182. inputs = tf.random_uniform((batch_size, height, width, 3))
  183. logits, end_points = inception.inception_v3(inputs, num_classes)
  184. self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
  185. self.assertListEqual(logits.get_shape().as_list(),
  186. [batch_size, num_classes])
  187. pre_pool = end_points['Mixed_7c']
  188. self.assertListEqual(pre_pool.get_shape().as_list(),
  189. [batch_size, 3, 3, 2048])
  190. def testUnknownImageShape(self):
  191. tf.reset_default_graph()
  192. batch_size = 2
  193. height, width = 299, 299
  194. num_classes = 1000
  195. input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
  196. with self.test_session() as sess:
  197. inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
  198. logits, end_points = inception.inception_v3(inputs, num_classes)
  199. self.assertListEqual(logits.get_shape().as_list(),
  200. [batch_size, num_classes])
  201. pre_pool = end_points['Mixed_7c']
  202. feed_dict = {inputs: input_np}
  203. tf.global_variables_initializer().run()
  204. pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
  205. self.assertListEqual(list(pre_pool_out.shape), [batch_size, 8, 8, 2048])
  206. def testGlobalPoolUnknownImageShape(self):
  207. tf.reset_default_graph()
  208. batch_size = 2
  209. height, width = 400, 600
  210. num_classes = 1000
  211. input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
  212. with self.test_session() as sess:
  213. inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
  214. logits, end_points = inception.inception_v3(inputs, num_classes,
  215. global_pool=True)
  216. self.assertListEqual(logits.get_shape().as_list(),
  217. [batch_size, num_classes])
  218. pre_pool = end_points['Mixed_7c']
  219. feed_dict = {inputs: input_np}
  220. tf.global_variables_initializer().run()
  221. pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
  222. self.assertListEqual(list(pre_pool_out.shape), [batch_size, 11, 17, 2048])
  223. def testUnknowBatchSize(self):
  224. batch_size = 1
  225. height, width = 299, 299
  226. num_classes = 1000
  227. inputs = tf.placeholder(tf.float32, (None, height, width, 3))
  228. logits, _ = inception.inception_v3(inputs, num_classes)
  229. self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
  230. self.assertListEqual(logits.get_shape().as_list(),
  231. [None, num_classes])
  232. images = tf.random_uniform((batch_size, height, width, 3))
  233. with self.test_session() as sess:
  234. sess.run(tf.global_variables_initializer())
  235. output = sess.run(logits, {inputs: images.eval()})
  236. self.assertEquals(output.shape, (batch_size, num_classes))
  237. def testEvaluation(self):
  238. batch_size = 2
  239. height, width = 299, 299
  240. num_classes = 1000
  241. eval_inputs = tf.random_uniform((batch_size, height, width, 3))
  242. logits, _ = inception.inception_v3(eval_inputs, num_classes,
  243. is_training=False)
  244. predictions = tf.argmax(logits, 1)
  245. with self.test_session() as sess:
  246. sess.run(tf.global_variables_initializer())
  247. output = sess.run(predictions)
  248. self.assertEquals(output.shape, (batch_size,))
  249. def testTrainEvalWithReuse(self):
  250. train_batch_size = 5
  251. eval_batch_size = 2
  252. height, width = 150, 150
  253. num_classes = 1000
  254. train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
  255. inception.inception_v3(train_inputs, num_classes)
  256. eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
  257. logits, _ = inception.inception_v3(eval_inputs, num_classes,
  258. is_training=False, reuse=True)
  259. predictions = tf.argmax(logits, 1)
  260. with self.test_session() as sess:
  261. sess.run(tf.global_variables_initializer())
  262. output = sess.run(predictions)
  263. self.assertEquals(output.shape, (eval_batch_size,))
  264. def testLogitsNotSqueezed(self):
  265. num_classes = 25
  266. images = tf.random_uniform([1, 299, 299, 3])
  267. logits, _ = inception.inception_v3(images,
  268. num_classes=num_classes,
  269. spatial_squeeze=False)
  270. with self.test_session() as sess:
  271. tf.global_variables_initializer().run()
  272. logits_out = sess.run(logits)
  273. self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
  274. if __name__ == '__main__':
  275. tf.test.main()