nasnet_utils_test.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for slim.nets.nasnet.nasnet_utils."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from nets.nasnet import nasnet_utils
  21. class NasnetUtilsTest(tf.test.TestCase):
  22. def testCalcReductionLayers(self):
  23. num_cells = 18
  24. num_reduction_layers = 2
  25. reduction_layers = nasnet_utils.calc_reduction_layers(
  26. num_cells, num_reduction_layers)
  27. self.assertEqual(len(reduction_layers), 2)
  28. self.assertEqual(reduction_layers[0], 6)
  29. self.assertEqual(reduction_layers[1], 12)
  30. def testGetChannelIndex(self):
  31. data_formats = ['NHWC', 'NCHW']
  32. for data_format in data_formats:
  33. index = nasnet_utils.get_channel_index(data_format)
  34. correct_index = 3 if data_format == 'NHWC' else 1
  35. self.assertEqual(index, correct_index)
  36. def testGetChannelDim(self):
  37. data_formats = ['NHWC', 'NCHW']
  38. shape = [10, 20, 30, 40]
  39. for data_format in data_formats:
  40. dim = nasnet_utils.get_channel_dim(shape, data_format)
  41. correct_dim = shape[3] if data_format == 'NHWC' else shape[1]
  42. self.assertEqual(dim, correct_dim)
  43. def testGlobalAvgPool(self):
  44. data_formats = ['NHWC', 'NCHW']
  45. inputs = tf.placeholder(tf.float32, (5, 10, 20, 10))
  46. for data_format in data_formats:
  47. output = nasnet_utils.global_avg_pool(
  48. inputs, data_format)
  49. self.assertEqual(output.shape, [5, 10])
  50. if __name__ == '__main__':
  51. tf.test.main()