nasnet_test.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for slim.nasnet."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from nets.nasnet import nasnet
  21. slim = tf.contrib.slim
  22. class NASNetTest(tf.test.TestCase):
  23. def testBuildLogitsCifarModel(self):
  24. batch_size = 5
  25. height, width = 32, 32
  26. num_classes = 10
  27. inputs = tf.random_uniform((batch_size, height, width, 3))
  28. tf.train.create_global_step()
  29. with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
  30. logits, end_points = nasnet.build_nasnet_cifar(inputs, num_classes)
  31. auxlogits = end_points['AuxLogits']
  32. predictions = end_points['Predictions']
  33. self.assertListEqual(auxlogits.get_shape().as_list(),
  34. [batch_size, num_classes])
  35. self.assertListEqual(logits.get_shape().as_list(),
  36. [batch_size, num_classes])
  37. self.assertListEqual(predictions.get_shape().as_list(),
  38. [batch_size, num_classes])
  39. def testBuildLogitsMobileModel(self):
  40. batch_size = 5
  41. height, width = 224, 224
  42. num_classes = 1000
  43. inputs = tf.random_uniform((batch_size, height, width, 3))
  44. tf.train.create_global_step()
  45. with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
  46. logits, end_points = nasnet.build_nasnet_mobile(inputs, num_classes)
  47. auxlogits = end_points['AuxLogits']
  48. predictions = end_points['Predictions']
  49. self.assertListEqual(auxlogits.get_shape().as_list(),
  50. [batch_size, num_classes])
  51. self.assertListEqual(logits.get_shape().as_list(),
  52. [batch_size, num_classes])
  53. self.assertListEqual(predictions.get_shape().as_list(),
  54. [batch_size, num_classes])
  55. def testBuildLogitsLargeModel(self):
  56. batch_size = 5
  57. height, width = 331, 331
  58. num_classes = 1000
  59. inputs = tf.random_uniform((batch_size, height, width, 3))
  60. tf.train.create_global_step()
  61. with slim.arg_scope(nasnet.nasnet_large_arg_scope()):
  62. logits, end_points = nasnet.build_nasnet_large(inputs, num_classes)
  63. auxlogits = end_points['AuxLogits']
  64. predictions = end_points['Predictions']
  65. self.assertListEqual(auxlogits.get_shape().as_list(),
  66. [batch_size, num_classes])
  67. self.assertListEqual(logits.get_shape().as_list(),
  68. [batch_size, num_classes])
  69. self.assertListEqual(predictions.get_shape().as_list(),
  70. [batch_size, num_classes])
  71. def testBuildPreLogitsCifarModel(self):
  72. batch_size = 5
  73. height, width = 32, 32
  74. num_classes = None
  75. inputs = tf.random_uniform((batch_size, height, width, 3))
  76. tf.train.create_global_step()
  77. with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
  78. net, end_points = nasnet.build_nasnet_cifar(inputs, num_classes)
  79. self.assertFalse('AuxLogits' in end_points)
  80. self.assertFalse('Predictions' in end_points)
  81. self.assertTrue(net.op.name.startswith('final_layer/Mean'))
  82. self.assertListEqual(net.get_shape().as_list(), [batch_size, 768])
  83. def testBuildPreLogitsMobileModel(self):
  84. batch_size = 5
  85. height, width = 224, 224
  86. num_classes = None
  87. inputs = tf.random_uniform((batch_size, height, width, 3))
  88. tf.train.create_global_step()
  89. with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
  90. net, end_points = nasnet.build_nasnet_mobile(inputs, num_classes)
  91. self.assertFalse('AuxLogits' in end_points)
  92. self.assertFalse('Predictions' in end_points)
  93. self.assertTrue(net.op.name.startswith('final_layer/Mean'))
  94. self.assertListEqual(net.get_shape().as_list(), [batch_size, 1056])
  95. def testBuildPreLogitsLargeModel(self):
  96. batch_size = 5
  97. height, width = 331, 331
  98. num_classes = None
  99. inputs = tf.random_uniform((batch_size, height, width, 3))
  100. tf.train.create_global_step()
  101. with slim.arg_scope(nasnet.nasnet_large_arg_scope()):
  102. net, end_points = nasnet.build_nasnet_large(inputs, num_classes)
  103. self.assertFalse('AuxLogits' in end_points)
  104. self.assertFalse('Predictions' in end_points)
  105. self.assertTrue(net.op.name.startswith('final_layer/Mean'))
  106. self.assertListEqual(net.get_shape().as_list(), [batch_size, 4032])
  107. def testAllEndPointsShapesCifarModel(self):
  108. batch_size = 5
  109. height, width = 32, 32
  110. num_classes = 10
  111. inputs = tf.random_uniform((batch_size, height, width, 3))
  112. tf.train.create_global_step()
  113. with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
  114. _, end_points = nasnet.build_nasnet_cifar(inputs, num_classes)
  115. endpoints_shapes = {'Stem': [batch_size, 32, 32, 96],
  116. 'Cell_0': [batch_size, 32, 32, 192],
  117. 'Cell_1': [batch_size, 32, 32, 192],
  118. 'Cell_2': [batch_size, 32, 32, 192],
  119. 'Cell_3': [batch_size, 32, 32, 192],
  120. 'Cell_4': [batch_size, 32, 32, 192],
  121. 'Cell_5': [batch_size, 32, 32, 192],
  122. 'Cell_6': [batch_size, 16, 16, 384],
  123. 'Cell_7': [batch_size, 16, 16, 384],
  124. 'Cell_8': [batch_size, 16, 16, 384],
  125. 'Cell_9': [batch_size, 16, 16, 384],
  126. 'Cell_10': [batch_size, 16, 16, 384],
  127. 'Cell_11': [batch_size, 16, 16, 384],
  128. 'Cell_12': [batch_size, 8, 8, 768],
  129. 'Cell_13': [batch_size, 8, 8, 768],
  130. 'Cell_14': [batch_size, 8, 8, 768],
  131. 'Cell_15': [batch_size, 8, 8, 768],
  132. 'Cell_16': [batch_size, 8, 8, 768],
  133. 'Cell_17': [batch_size, 8, 8, 768],
  134. 'Reduction_Cell_0': [batch_size, 16, 16, 256],
  135. 'Reduction_Cell_1': [batch_size, 8, 8, 512],
  136. 'global_pool': [batch_size, 768],
  137. # Logits and predictions
  138. 'AuxLogits': [batch_size, num_classes],
  139. 'Logits': [batch_size, num_classes],
  140. 'Predictions': [batch_size, num_classes]}
  141. self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
  142. for endpoint_name in endpoints_shapes:
  143. tf.logging.info('Endpoint name: {}'.format(endpoint_name))
  144. expected_shape = endpoints_shapes[endpoint_name]
  145. self.assertTrue(endpoint_name in end_points)
  146. self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
  147. expected_shape)
  148. def testAllEndPointsShapesMobileModel(self):
  149. batch_size = 5
  150. height, width = 224, 224
  151. num_classes = 1000
  152. inputs = tf.random_uniform((batch_size, height, width, 3))
  153. tf.train.create_global_step()
  154. with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
  155. _, end_points = nasnet.build_nasnet_mobile(inputs, num_classes)
  156. endpoints_shapes = {'Stem': [batch_size, 28, 28, 88],
  157. 'Cell_0': [batch_size, 28, 28, 264],
  158. 'Cell_1': [batch_size, 28, 28, 264],
  159. 'Cell_2': [batch_size, 28, 28, 264],
  160. 'Cell_3': [batch_size, 28, 28, 264],
  161. 'Cell_4': [batch_size, 14, 14, 528],
  162. 'Cell_5': [batch_size, 14, 14, 528],
  163. 'Cell_6': [batch_size, 14, 14, 528],
  164. 'Cell_7': [batch_size, 14, 14, 528],
  165. 'Cell_8': [batch_size, 7, 7, 1056],
  166. 'Cell_9': [batch_size, 7, 7, 1056],
  167. 'Cell_10': [batch_size, 7, 7, 1056],
  168. 'Cell_11': [batch_size, 7, 7, 1056],
  169. 'Reduction_Cell_0': [batch_size, 14, 14, 352],
  170. 'Reduction_Cell_1': [batch_size, 7, 7, 704],
  171. 'global_pool': [batch_size, 1056],
  172. # Logits and predictions
  173. 'AuxLogits': [batch_size, num_classes],
  174. 'Logits': [batch_size, num_classes],
  175. 'Predictions': [batch_size, num_classes]}
  176. self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
  177. for endpoint_name in endpoints_shapes:
  178. tf.logging.info('Endpoint name: {}'.format(endpoint_name))
  179. expected_shape = endpoints_shapes[endpoint_name]
  180. self.assertTrue(endpoint_name in end_points)
  181. self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
  182. expected_shape)
  183. def testAllEndPointsShapesLargeModel(self):
  184. batch_size = 5
  185. height, width = 331, 331
  186. num_classes = 1000
  187. inputs = tf.random_uniform((batch_size, height, width, 3))
  188. tf.train.create_global_step()
  189. with slim.arg_scope(nasnet.nasnet_large_arg_scope()):
  190. _, end_points = nasnet.build_nasnet_large(inputs, num_classes)
  191. endpoints_shapes = {'Stem': [batch_size, 42, 42, 336],
  192. 'Cell_0': [batch_size, 42, 42, 1008],
  193. 'Cell_1': [batch_size, 42, 42, 1008],
  194. 'Cell_2': [batch_size, 42, 42, 1008],
  195. 'Cell_3': [batch_size, 42, 42, 1008],
  196. 'Cell_4': [batch_size, 42, 42, 1008],
  197. 'Cell_5': [batch_size, 42, 42, 1008],
  198. 'Cell_6': [batch_size, 21, 21, 2016],
  199. 'Cell_7': [batch_size, 21, 21, 2016],
  200. 'Cell_8': [batch_size, 21, 21, 2016],
  201. 'Cell_9': [batch_size, 21, 21, 2016],
  202. 'Cell_10': [batch_size, 21, 21, 2016],
  203. 'Cell_11': [batch_size, 21, 21, 2016],
  204. 'Cell_12': [batch_size, 11, 11, 4032],
  205. 'Cell_13': [batch_size, 11, 11, 4032],
  206. 'Cell_14': [batch_size, 11, 11, 4032],
  207. 'Cell_15': [batch_size, 11, 11, 4032],
  208. 'Cell_16': [batch_size, 11, 11, 4032],
  209. 'Cell_17': [batch_size, 11, 11, 4032],
  210. 'Reduction_Cell_0': [batch_size, 21, 21, 1344],
  211. 'Reduction_Cell_1': [batch_size, 11, 11, 2688],
  212. 'global_pool': [batch_size, 4032],
  213. # Logits and predictions
  214. 'AuxLogits': [batch_size, num_classes],
  215. 'Logits': [batch_size, num_classes],
  216. 'Predictions': [batch_size, num_classes]}
  217. self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
  218. for endpoint_name in endpoints_shapes:
  219. tf.logging.info('Endpoint name: {}'.format(endpoint_name))
  220. expected_shape = endpoints_shapes[endpoint_name]
  221. self.assertTrue(endpoint_name in end_points)
  222. self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
  223. expected_shape)
  224. def testVariablesSetDeviceMobileModel(self):
  225. batch_size = 5
  226. height, width = 224, 224
  227. num_classes = 1000
  228. inputs = tf.random_uniform((batch_size, height, width, 3))
  229. tf.train.create_global_step()
  230. # Force all Variables to reside on the device.
  231. with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
  232. with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
  233. nasnet.build_nasnet_mobile(inputs, num_classes)
  234. with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
  235. with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
  236. nasnet.build_nasnet_mobile(inputs, num_classes)
  237. for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
  238. self.assertDeviceEqual(v.device, '/cpu:0')
  239. for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
  240. self.assertDeviceEqual(v.device, '/gpu:0')
  241. def testUnknownBatchSizeMobileModel(self):
  242. batch_size = 1
  243. height, width = 224, 224
  244. num_classes = 1000
  245. with self.test_session() as sess:
  246. inputs = tf.placeholder(tf.float32, (None, height, width, 3))
  247. with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
  248. logits, _ = nasnet.build_nasnet_mobile(inputs, num_classes)
  249. self.assertListEqual(logits.get_shape().as_list(),
  250. [None, num_classes])
  251. images = tf.random_uniform((batch_size, height, width, 3))
  252. sess.run(tf.global_variables_initializer())
  253. output = sess.run(logits, {inputs: images.eval()})
  254. self.assertEquals(output.shape, (batch_size, num_classes))
  255. def testEvaluationMobileModel(self):
  256. batch_size = 2
  257. height, width = 224, 224
  258. num_classes = 1000
  259. with self.test_session() as sess:
  260. eval_inputs = tf.random_uniform((batch_size, height, width, 3))
  261. with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
  262. logits, _ = nasnet.build_nasnet_mobile(eval_inputs,
  263. num_classes,
  264. is_training=False)
  265. predictions = tf.argmax(logits, 1)
  266. sess.run(tf.global_variables_initializer())
  267. output = sess.run(predictions)
  268. self.assertEquals(output.shape, (batch_size,))
  269. if __name__ == '__main__':
  270. tf.test.main()