overfeat.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains the model definition for the OverFeat network.
  16. The definition for the network was obtained from:
  17. OverFeat: Integrated Recognition, Localization and Detection using
  18. Convolutional Networks
  19. Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and
  20. Yann LeCun, 2014
  21. http://arxiv.org/abs/1312.6229
  22. Usage:
  23. with slim.arg_scope(overfeat.overfeat_arg_scope()):
  24. outputs, end_points = overfeat.overfeat(inputs)
  25. @@overfeat
  26. """
  27. from __future__ import absolute_import
  28. from __future__ import division
  29. from __future__ import print_function
  30. import tensorflow as tf
  31. slim = tf.contrib.slim
  32. trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
  33. def overfeat_arg_scope(weight_decay=0.0005):
  34. with slim.arg_scope([slim.conv2d, slim.fully_connected],
  35. activation_fn=tf.nn.relu,
  36. weights_regularizer=slim.l2_regularizer(weight_decay),
  37. biases_initializer=tf.zeros_initializer()):
  38. with slim.arg_scope([slim.conv2d], padding='SAME'):
  39. with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
  40. return arg_sc
  41. def overfeat(inputs,
  42. num_classes=1000,
  43. is_training=True,
  44. dropout_keep_prob=0.5,
  45. spatial_squeeze=True,
  46. scope='overfeat',
  47. global_pool=False):
  48. """Contains the model definition for the OverFeat network.
  49. The definition for the network was obtained from:
  50. OverFeat: Integrated Recognition, Localization and Detection using
  51. Convolutional Networks
  52. Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and
  53. Yann LeCun, 2014
  54. http://arxiv.org/abs/1312.6229
  55. Note: All the fully_connected layers have been transformed to conv2d layers.
  56. To use in classification mode, resize input to 231x231. To use in fully
  57. convolutional mode, set spatial_squeeze to false.
  58. Args:
  59. inputs: a tensor of size [batch_size, height, width, channels].
  60. num_classes: number of predicted classes. If 0 or None, the logits layer is
  61. omitted and the input features to the logits layer are returned instead.
  62. is_training: whether or not the model is being trained.
  63. dropout_keep_prob: the probability that activations are kept in the dropout
  64. layers during training.
  65. spatial_squeeze: whether or not should squeeze the spatial dimensions of the
  66. outputs. Useful to remove unnecessary dimensions for classification.
  67. scope: Optional scope for the variables.
  68. global_pool: Optional boolean flag. If True, the input to the classification
  69. layer is avgpooled to size 1x1, for any input size. (This is not part
  70. of the original OverFeat.)
  71. Returns:
  72. net: the output of the logits layer (if num_classes is a non-zero integer),
  73. or the non-dropped-out input to the logits layer (if num_classes is 0 or
  74. None).
  75. end_points: a dict of tensors with intermediate activations.
  76. """
  77. with tf.variable_scope(scope, 'overfeat', [inputs]) as sc:
  78. end_points_collection = sc.original_name_scope + '_end_points'
  79. # Collect outputs for conv2d, fully_connected and max_pool2d
  80. with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
  81. outputs_collections=end_points_collection):
  82. net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
  83. scope='conv1')
  84. net = slim.max_pool2d(net, [2, 2], scope='pool1')
  85. net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2')
  86. net = slim.max_pool2d(net, [2, 2], scope='pool2')
  87. net = slim.conv2d(net, 512, [3, 3], scope='conv3')
  88. net = slim.conv2d(net, 1024, [3, 3], scope='conv4')
  89. net = slim.conv2d(net, 1024, [3, 3], scope='conv5')
  90. net = slim.max_pool2d(net, [2, 2], scope='pool5')
  91. # Use conv2d instead of fully_connected layers.
  92. with slim.arg_scope([slim.conv2d],
  93. weights_initializer=trunc_normal(0.005),
  94. biases_initializer=tf.constant_initializer(0.1)):
  95. net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6')
  96. net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
  97. scope='dropout6')
  98. net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
  99. # Convert end_points_collection into a end_point dict.
  100. end_points = slim.utils.convert_collection_to_dict(
  101. end_points_collection)
  102. if global_pool:
  103. net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
  104. end_points['global_pool'] = net
  105. if num_classes:
  106. net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
  107. scope='dropout7')
  108. net = slim.conv2d(net, num_classes, [1, 1],
  109. activation_fn=None,
  110. normalizer_fn=None,
  111. biases_initializer=tf.zeros_initializer(),
  112. scope='fc8')
  113. if spatial_squeeze:
  114. net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
  115. end_points[sc.name + '/fc8'] = net
  116. return net, end_points
  117. overfeat.default_image_size = 231