overfeat_test.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for slim.nets.overfeat."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from nets import overfeat
  21. slim = tf.contrib.slim
  22. class OverFeatTest(tf.test.TestCase):
  23. def testBuild(self):
  24. batch_size = 5
  25. height, width = 231, 231
  26. num_classes = 1000
  27. with self.test_session():
  28. inputs = tf.random_uniform((batch_size, height, width, 3))
  29. logits, _ = overfeat.overfeat(inputs, num_classes)
  30. self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed')
  31. self.assertListEqual(logits.get_shape().as_list(),
  32. [batch_size, num_classes])
  33. def testFullyConvolutional(self):
  34. batch_size = 1
  35. height, width = 281, 281
  36. num_classes = 1000
  37. with self.test_session():
  38. inputs = tf.random_uniform((batch_size, height, width, 3))
  39. logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False)
  40. self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd')
  41. self.assertListEqual(logits.get_shape().as_list(),
  42. [batch_size, 2, 2, num_classes])
  43. def testGlobalPool(self):
  44. batch_size = 1
  45. height, width = 281, 281
  46. num_classes = 1000
  47. with self.test_session():
  48. inputs = tf.random_uniform((batch_size, height, width, 3))
  49. logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False,
  50. global_pool=True)
  51. self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd')
  52. self.assertListEqual(logits.get_shape().as_list(),
  53. [batch_size, 1, 1, num_classes])
  54. def testEndPoints(self):
  55. batch_size = 5
  56. height, width = 231, 231
  57. num_classes = 1000
  58. with self.test_session():
  59. inputs = tf.random_uniform((batch_size, height, width, 3))
  60. _, end_points = overfeat.overfeat(inputs, num_classes)
  61. expected_names = ['overfeat/conv1',
  62. 'overfeat/pool1',
  63. 'overfeat/conv2',
  64. 'overfeat/pool2',
  65. 'overfeat/conv3',
  66. 'overfeat/conv4',
  67. 'overfeat/conv5',
  68. 'overfeat/pool5',
  69. 'overfeat/fc6',
  70. 'overfeat/fc7',
  71. 'overfeat/fc8'
  72. ]
  73. self.assertSetEqual(set(end_points.keys()), set(expected_names))
  74. def testNoClasses(self):
  75. batch_size = 5
  76. height, width = 231, 231
  77. num_classes = None
  78. with self.test_session():
  79. inputs = tf.random_uniform((batch_size, height, width, 3))
  80. net, end_points = overfeat.overfeat(inputs, num_classes)
  81. expected_names = ['overfeat/conv1',
  82. 'overfeat/pool1',
  83. 'overfeat/conv2',
  84. 'overfeat/pool2',
  85. 'overfeat/conv3',
  86. 'overfeat/conv4',
  87. 'overfeat/conv5',
  88. 'overfeat/pool5',
  89. 'overfeat/fc6',
  90. 'overfeat/fc7'
  91. ]
  92. self.assertSetEqual(set(end_points.keys()), set(expected_names))
  93. self.assertTrue(net.op.name.startswith('overfeat/fc7'))
  94. def testModelVariables(self):
  95. batch_size = 5
  96. height, width = 231, 231
  97. num_classes = 1000
  98. with self.test_session():
  99. inputs = tf.random_uniform((batch_size, height, width, 3))
  100. overfeat.overfeat(inputs, num_classes)
  101. expected_names = ['overfeat/conv1/weights',
  102. 'overfeat/conv1/biases',
  103. 'overfeat/conv2/weights',
  104. 'overfeat/conv2/biases',
  105. 'overfeat/conv3/weights',
  106. 'overfeat/conv3/biases',
  107. 'overfeat/conv4/weights',
  108. 'overfeat/conv4/biases',
  109. 'overfeat/conv5/weights',
  110. 'overfeat/conv5/biases',
  111. 'overfeat/fc6/weights',
  112. 'overfeat/fc6/biases',
  113. 'overfeat/fc7/weights',
  114. 'overfeat/fc7/biases',
  115. 'overfeat/fc8/weights',
  116. 'overfeat/fc8/biases',
  117. ]
  118. model_variables = [v.op.name for v in slim.get_model_variables()]
  119. self.assertSetEqual(set(model_variables), set(expected_names))
  120. def testEvaluation(self):
  121. batch_size = 2
  122. height, width = 231, 231
  123. num_classes = 1000
  124. with self.test_session():
  125. eval_inputs = tf.random_uniform((batch_size, height, width, 3))
  126. logits, _ = overfeat.overfeat(eval_inputs, is_training=False)
  127. self.assertListEqual(logits.get_shape().as_list(),
  128. [batch_size, num_classes])
  129. predictions = tf.argmax(logits, 1)
  130. self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
  131. def testTrainEvalWithReuse(self):
  132. train_batch_size = 2
  133. eval_batch_size = 1
  134. train_height, train_width = 231, 231
  135. eval_height, eval_width = 281, 281
  136. num_classes = 1000
  137. with self.test_session():
  138. train_inputs = tf.random_uniform(
  139. (train_batch_size, train_height, train_width, 3))
  140. logits, _ = overfeat.overfeat(train_inputs)
  141. self.assertListEqual(logits.get_shape().as_list(),
  142. [train_batch_size, num_classes])
  143. tf.get_variable_scope().reuse_variables()
  144. eval_inputs = tf.random_uniform(
  145. (eval_batch_size, eval_height, eval_width, 3))
  146. logits, _ = overfeat.overfeat(eval_inputs, is_training=False,
  147. spatial_squeeze=False)
  148. self.assertListEqual(logits.get_shape().as_list(),
  149. [eval_batch_size, 2, 2, num_classes])
  150. logits = tf.reduce_mean(logits, [1, 2])
  151. predictions = tf.argmax(logits, 1)
  152. self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
  153. def testForward(self):
  154. batch_size = 1
  155. height, width = 231, 231
  156. with self.test_session() as sess:
  157. inputs = tf.random_uniform((batch_size, height, width, 3))
  158. logits, _ = overfeat.overfeat(inputs)
  159. sess.run(tf.global_variables_initializer())
  160. output = sess.run(logits)
  161. self.assertTrue(output.any())
  162. if __name__ == '__main__':
  163. tf.test.main()