alexnet_test.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for slim.nets.alexnet."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from nets import alexnet
  21. slim = tf.contrib.slim
  22. class AlexnetV2Test(tf.test.TestCase):
  23. def testBuild(self):
  24. batch_size = 5
  25. height, width = 224, 224
  26. num_classes = 1000
  27. with self.test_session():
  28. inputs = tf.random_uniform((batch_size, height, width, 3))
  29. logits, _ = alexnet.alexnet_v2(inputs, num_classes)
  30. self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed')
  31. self.assertListEqual(logits.get_shape().as_list(),
  32. [batch_size, num_classes])
  33. def testFullyConvolutional(self):
  34. batch_size = 1
  35. height, width = 300, 400
  36. num_classes = 1000
  37. with self.test_session():
  38. inputs = tf.random_uniform((batch_size, height, width, 3))
  39. logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False)
  40. self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd')
  41. self.assertListEqual(logits.get_shape().as_list(),
  42. [batch_size, 4, 7, num_classes])
  43. def testGlobalPool(self):
  44. batch_size = 1
  45. height, width = 300, 400
  46. num_classes = 1000
  47. with self.test_session():
  48. inputs = tf.random_uniform((batch_size, height, width, 3))
  49. logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False,
  50. global_pool=True)
  51. self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd')
  52. self.assertListEqual(logits.get_shape().as_list(),
  53. [batch_size, 1, 1, num_classes])
  54. def testEndPoints(self):
  55. batch_size = 5
  56. height, width = 224, 224
  57. num_classes = 1000
  58. with self.test_session():
  59. inputs = tf.random_uniform((batch_size, height, width, 3))
  60. _, end_points = alexnet.alexnet_v2(inputs, num_classes)
  61. expected_names = ['alexnet_v2/conv1',
  62. 'alexnet_v2/pool1',
  63. 'alexnet_v2/conv2',
  64. 'alexnet_v2/pool2',
  65. 'alexnet_v2/conv3',
  66. 'alexnet_v2/conv4',
  67. 'alexnet_v2/conv5',
  68. 'alexnet_v2/pool5',
  69. 'alexnet_v2/fc6',
  70. 'alexnet_v2/fc7',
  71. 'alexnet_v2/fc8'
  72. ]
  73. self.assertSetEqual(set(end_points.keys()), set(expected_names))
  74. def testNoClasses(self):
  75. batch_size = 5
  76. height, width = 224, 224
  77. num_classes = None
  78. with self.test_session():
  79. inputs = tf.random_uniform((batch_size, height, width, 3))
  80. net, end_points = alexnet.alexnet_v2(inputs, num_classes)
  81. expected_names = ['alexnet_v2/conv1',
  82. 'alexnet_v2/pool1',
  83. 'alexnet_v2/conv2',
  84. 'alexnet_v2/pool2',
  85. 'alexnet_v2/conv3',
  86. 'alexnet_v2/conv4',
  87. 'alexnet_v2/conv5',
  88. 'alexnet_v2/pool5',
  89. 'alexnet_v2/fc6',
  90. 'alexnet_v2/fc7'
  91. ]
  92. self.assertSetEqual(set(end_points.keys()), set(expected_names))
  93. self.assertTrue(net.op.name.startswith('alexnet_v2/fc7'))
  94. self.assertListEqual(net.get_shape().as_list(),
  95. [batch_size, 1, 1, 4096])
  96. def testModelVariables(self):
  97. batch_size = 5
  98. height, width = 224, 224
  99. num_classes = 1000
  100. with self.test_session():
  101. inputs = tf.random_uniform((batch_size, height, width, 3))
  102. alexnet.alexnet_v2(inputs, num_classes)
  103. expected_names = ['alexnet_v2/conv1/weights',
  104. 'alexnet_v2/conv1/biases',
  105. 'alexnet_v2/conv2/weights',
  106. 'alexnet_v2/conv2/biases',
  107. 'alexnet_v2/conv3/weights',
  108. 'alexnet_v2/conv3/biases',
  109. 'alexnet_v2/conv4/weights',
  110. 'alexnet_v2/conv4/biases',
  111. 'alexnet_v2/conv5/weights',
  112. 'alexnet_v2/conv5/biases',
  113. 'alexnet_v2/fc6/weights',
  114. 'alexnet_v2/fc6/biases',
  115. 'alexnet_v2/fc7/weights',
  116. 'alexnet_v2/fc7/biases',
  117. 'alexnet_v2/fc8/weights',
  118. 'alexnet_v2/fc8/biases',
  119. ]
  120. model_variables = [v.op.name for v in slim.get_model_variables()]
  121. self.assertSetEqual(set(model_variables), set(expected_names))
  122. def testEvaluation(self):
  123. batch_size = 2
  124. height, width = 224, 224
  125. num_classes = 1000
  126. with self.test_session():
  127. eval_inputs = tf.random_uniform((batch_size, height, width, 3))
  128. logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False)
  129. self.assertListEqual(logits.get_shape().as_list(),
  130. [batch_size, num_classes])
  131. predictions = tf.argmax(logits, 1)
  132. self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
  133. def testTrainEvalWithReuse(self):
  134. train_batch_size = 2
  135. eval_batch_size = 1
  136. train_height, train_width = 224, 224
  137. eval_height, eval_width = 300, 400
  138. num_classes = 1000
  139. with self.test_session():
  140. train_inputs = tf.random_uniform(
  141. (train_batch_size, train_height, train_width, 3))
  142. logits, _ = alexnet.alexnet_v2(train_inputs)
  143. self.assertListEqual(logits.get_shape().as_list(),
  144. [train_batch_size, num_classes])
  145. tf.get_variable_scope().reuse_variables()
  146. eval_inputs = tf.random_uniform(
  147. (eval_batch_size, eval_height, eval_width, 3))
  148. logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False,
  149. spatial_squeeze=False)
  150. self.assertListEqual(logits.get_shape().as_list(),
  151. [eval_batch_size, 4, 7, num_classes])
  152. logits = tf.reduce_mean(logits, [1, 2])
  153. predictions = tf.argmax(logits, 1)
  154. self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
  155. def testForward(self):
  156. batch_size = 1
  157. height, width = 224, 224
  158. with self.test_session() as sess:
  159. inputs = tf.random_uniform((batch_size, height, width, 3))
  160. logits, _ = alexnet.alexnet_v2(inputs)
  161. sess.run(tf.global_variables_initializer())
  162. output = sess.run(logits)
  163. self.assertTrue(output.any())
  164. if __name__ == '__main__':
  165. tf.test.main()