nasnet_utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A custom module for some common operations used by NASNet.
  16. Functions exposed in this file:
  17. - calc_reduction_layers
  18. - get_channel_index
  19. - get_channel_dim
  20. - global_avg_pool
  21. - factorized_reduction
  22. - drop_path
  23. Classes exposed in this file:
  24. - NasNetABaseCell
  25. - NasNetANormalCell
  26. - NasNetAReductionCell
  27. """
  28. from __future__ import absolute_import
  29. from __future__ import division
  30. from __future__ import print_function
  31. import tensorflow as tf
  32. arg_scope = tf.contrib.framework.arg_scope
  33. slim = tf.contrib.slim
  34. DATA_FORMAT_NCHW = 'NCHW'
  35. DATA_FORMAT_NHWC = 'NHWC'
  36. INVALID = 'null'
  37. def calc_reduction_layers(num_cells, num_reduction_layers):
  38. """Figure out what layers should have reductions."""
  39. reduction_layers = []
  40. for pool_num in range(1, num_reduction_layers + 1):
  41. layer_num = (float(pool_num) / (num_reduction_layers + 1)) * num_cells
  42. layer_num = int(layer_num)
  43. reduction_layers.append(layer_num)
  44. return reduction_layers
  45. @tf.contrib.framework.add_arg_scope
  46. def get_channel_index(data_format=INVALID):
  47. assert data_format != INVALID
  48. axis = 3 if data_format == 'NHWC' else 1
  49. return axis
  50. @tf.contrib.framework.add_arg_scope
  51. def get_channel_dim(shape, data_format=INVALID):
  52. assert data_format != INVALID
  53. assert len(shape) == 4
  54. if data_format == 'NHWC':
  55. return int(shape[3])
  56. elif data_format == 'NCHW':
  57. return int(shape[1])
  58. else:
  59. raise ValueError('Not a valid data_format', data_format)
  60. @tf.contrib.framework.add_arg_scope
  61. def global_avg_pool(x, data_format=INVALID):
  62. """Average pool away the height and width spatial dimensions of x."""
  63. assert data_format != INVALID
  64. assert data_format in ['NHWC', 'NCHW']
  65. assert x.shape.ndims == 4
  66. if data_format == 'NHWC':
  67. return tf.reduce_mean(x, [1, 2])
  68. else:
  69. return tf.reduce_mean(x, [2, 3])
  70. @tf.contrib.framework.add_arg_scope
  71. def factorized_reduction(net, output_filters, stride, data_format=INVALID):
  72. """Reduces the shape of net without information loss due to striding."""
  73. assert output_filters % 2 == 0, (
  74. 'Need even number of filters when using this factorized reduction.')
  75. assert data_format != INVALID
  76. if stride == 1:
  77. net = slim.conv2d(net, output_filters, 1, scope='path_conv')
  78. net = slim.batch_norm(net, scope='path_bn')
  79. return net
  80. if data_format == 'NHWC':
  81. stride_spec = [1, stride, stride, 1]
  82. else:
  83. stride_spec = [1, 1, stride, stride]
  84. # Skip path 1
  85. path1 = tf.nn.avg_pool(
  86. net, [1, 1, 1, 1], stride_spec, 'VALID', data_format=data_format)
  87. path1 = slim.conv2d(path1, int(output_filters / 2), 1, scope='path1_conv')
  88. # Skip path 2
  89. # First pad with 0's on the right and bottom, then shift the filter to
  90. # include those 0's that were added.
  91. if data_format == 'NHWC':
  92. pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]]
  93. path2 = tf.pad(net, pad_arr)[:, 1:, 1:, :]
  94. concat_axis = 3
  95. else:
  96. pad_arr = [[0, 0], [0, 0], [0, 1], [0, 1]]
  97. path2 = tf.pad(net, pad_arr)[:, :, 1:, 1:]
  98. concat_axis = 1
  99. path2 = tf.nn.avg_pool(
  100. path2, [1, 1, 1, 1], stride_spec, 'VALID', data_format=data_format)
  101. path2 = slim.conv2d(path2, int(output_filters / 2), 1, scope='path2_conv')
  102. # Concat and apply BN
  103. final_path = tf.concat(values=[path1, path2], axis=concat_axis)
  104. final_path = slim.batch_norm(final_path, scope='final_path_bn')
  105. return final_path
  106. @tf.contrib.framework.add_arg_scope
  107. def drop_path(net, keep_prob, is_training=True):
  108. """Drops out a whole example hiddenstate with the specified probability."""
  109. if is_training:
  110. batch_size = tf.shape(net)[0]
  111. noise_shape = [batch_size, 1, 1, 1]
  112. random_tensor = keep_prob
  113. random_tensor += tf.random_uniform(noise_shape, dtype=tf.float32)
  114. binary_tensor = tf.floor(random_tensor)
  115. net = tf.div(net, keep_prob) * binary_tensor
  116. return net
  117. def _operation_to_filter_shape(operation):
  118. splitted_operation = operation.split('x')
  119. filter_shape = int(splitted_operation[0][-1])
  120. assert filter_shape == int(
  121. splitted_operation[1][0]), 'Rectangular filters not supported.'
  122. return filter_shape
  123. def _operation_to_num_layers(operation):
  124. splitted_operation = operation.split('_')
  125. if 'x' in splitted_operation[-1]:
  126. return 1
  127. return int(splitted_operation[-1])
  128. def _operation_to_info(operation):
  129. """Takes in operation name and returns meta information.
  130. An example would be 'separable_3x3_4' -> (3, 4).
  131. Args:
  132. operation: String that corresponds to convolution operation.
  133. Returns:
  134. Tuple of (filter shape, num layers).
  135. """
  136. num_layers = _operation_to_num_layers(operation)
  137. filter_shape = _operation_to_filter_shape(operation)
  138. return num_layers, filter_shape
  139. def _stacked_separable_conv(net, stride, operation, filter_size):
  140. """Takes in an operations and parses it to the correct sep operation."""
  141. num_layers, kernel_size = _operation_to_info(operation)
  142. for layer_num in range(num_layers - 1):
  143. net = tf.nn.relu(net)
  144. net = slim.separable_conv2d(
  145. net,
  146. filter_size,
  147. kernel_size,
  148. depth_multiplier=1,
  149. scope='separable_{0}x{0}_{1}'.format(kernel_size, layer_num + 1),
  150. stride=stride)
  151. net = slim.batch_norm(
  152. net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, layer_num + 1))
  153. stride = 1
  154. net = tf.nn.relu(net)
  155. net = slim.separable_conv2d(
  156. net,
  157. filter_size,
  158. kernel_size,
  159. depth_multiplier=1,
  160. scope='separable_{0}x{0}_{1}'.format(kernel_size, num_layers),
  161. stride=stride)
  162. net = slim.batch_norm(
  163. net, scope='bn_sep_{0}x{0}_{1}'.format(kernel_size, num_layers))
  164. return net
  165. def _operation_to_pooling_type(operation):
  166. """Takes in the operation string and returns the pooling type."""
  167. splitted_operation = operation.split('_')
  168. return splitted_operation[0]
  169. def _operation_to_pooling_shape(operation):
  170. """Takes in the operation string and returns the pooling kernel shape."""
  171. splitted_operation = operation.split('_')
  172. shape = splitted_operation[-1]
  173. assert 'x' in shape
  174. filter_height, filter_width = shape.split('x')
  175. assert filter_height == filter_width
  176. return int(filter_height)
  177. def _operation_to_pooling_info(operation):
  178. """Parses the pooling operation string to return its type and shape."""
  179. pooling_type = _operation_to_pooling_type(operation)
  180. pooling_shape = _operation_to_pooling_shape(operation)
  181. return pooling_type, pooling_shape
  182. def _pooling(net, stride, operation):
  183. """Parses operation and performs the correct pooling operation on net."""
  184. padding = 'SAME'
  185. pooling_type, pooling_shape = _operation_to_pooling_info(operation)
  186. if pooling_type == 'avg':
  187. net = slim.avg_pool2d(net, pooling_shape, stride=stride, padding=padding)
  188. elif pooling_type == 'max':
  189. net = slim.max_pool2d(net, pooling_shape, stride=stride, padding=padding)
  190. else:
  191. raise NotImplementedError('Unimplemented pooling type: ', pooling_type)
  192. return net
  193. class NasNetABaseCell(object):
  194. """NASNet Cell class that is used as a 'layer' in image architectures.
  195. Args:
  196. num_conv_filters: The number of filters for each convolution operation.
  197. operations: List of operations that are performed in the NASNet Cell in
  198. order.
  199. used_hiddenstates: Binary array that signals if the hiddenstate was used
  200. within the cell. This is used to determine what outputs of the cell
  201. should be concatenated together.
  202. hiddenstate_indices: Determines what hiddenstates should be combined
  203. together with the specified operations to create the NASNet cell.
  204. """
  205. def __init__(self, num_conv_filters, operations, used_hiddenstates,
  206. hiddenstate_indices, drop_path_keep_prob, total_num_cells,
  207. total_training_steps):
  208. self._num_conv_filters = num_conv_filters
  209. self._operations = operations
  210. self._used_hiddenstates = used_hiddenstates
  211. self._hiddenstate_indices = hiddenstate_indices
  212. self._drop_path_keep_prob = drop_path_keep_prob
  213. self._total_num_cells = total_num_cells
  214. self._total_training_steps = total_training_steps
  215. def _reduce_prev_layer(self, prev_layer, curr_layer):
  216. """Matches dimension of prev_layer to the curr_layer."""
  217. # Set the prev layer to the current layer if it is none
  218. if prev_layer is None:
  219. return curr_layer
  220. curr_num_filters = self._filter_size
  221. prev_num_filters = get_channel_dim(prev_layer.shape)
  222. curr_filter_shape = int(curr_layer.shape[2])
  223. prev_filter_shape = int(prev_layer.shape[2])
  224. if curr_filter_shape != prev_filter_shape:
  225. prev_layer = tf.nn.relu(prev_layer)
  226. prev_layer = factorized_reduction(
  227. prev_layer, curr_num_filters, stride=2)
  228. elif curr_num_filters != prev_num_filters:
  229. prev_layer = tf.nn.relu(prev_layer)
  230. prev_layer = slim.conv2d(
  231. prev_layer, curr_num_filters, 1, scope='prev_1x1')
  232. prev_layer = slim.batch_norm(prev_layer, scope='prev_bn')
  233. return prev_layer
  234. def _cell_base(self, net, prev_layer):
  235. """Runs the beginning of the conv cell before the predicted ops are run."""
  236. num_filters = self._filter_size
  237. # Check to be sure prev layer stuff is setup correctly
  238. prev_layer = self._reduce_prev_layer(prev_layer, net)
  239. net = tf.nn.relu(net)
  240. net = slim.conv2d(net, num_filters, 1, scope='1x1')
  241. net = slim.batch_norm(net, scope='beginning_bn')
  242. split_axis = get_channel_index()
  243. net = tf.split(
  244. axis=split_axis, num_or_size_splits=1, value=net)
  245. for split in net:
  246. assert int(split.shape[split_axis] == int(self._num_conv_filters *
  247. self._filter_scaling))
  248. net.append(prev_layer)
  249. return net
  250. def __call__(self, net, scope=None, filter_scaling=1, stride=1,
  251. prev_layer=None, cell_num=-1):
  252. """Runs the conv cell."""
  253. self._cell_num = cell_num
  254. self._filter_scaling = filter_scaling
  255. self._filter_size = int(self._num_conv_filters * filter_scaling)
  256. i = 0
  257. with tf.variable_scope(scope):
  258. net = self._cell_base(net, prev_layer)
  259. for iteration in range(5):
  260. with tf.variable_scope('comb_iter_{}'.format(iteration)):
  261. left_hiddenstate_idx, right_hiddenstate_idx = (
  262. self._hiddenstate_indices[i],
  263. self._hiddenstate_indices[i + 1])
  264. original_input_left = left_hiddenstate_idx < 2
  265. original_input_right = right_hiddenstate_idx < 2
  266. h1 = net[left_hiddenstate_idx]
  267. h2 = net[right_hiddenstate_idx]
  268. operation_left = self._operations[i]
  269. operation_right = self._operations[i+1]
  270. i += 2
  271. # Apply conv operations
  272. with tf.variable_scope('left'):
  273. h1 = self._apply_conv_operation(h1, operation_left,
  274. stride, original_input_left)
  275. with tf.variable_scope('right'):
  276. h2 = self._apply_conv_operation(h2, operation_right,
  277. stride, original_input_right)
  278. # Combine hidden states using 'add'.
  279. with tf.variable_scope('combine'):
  280. h = h1 + h2
  281. # Add hiddenstate to the list of hiddenstates we can choose from
  282. net.append(h)
  283. with tf.variable_scope('cell_output'):
  284. net = self._combine_unused_states(net)
  285. return net
  286. def _apply_conv_operation(self, net, operation,
  287. stride, is_from_original_input):
  288. """Applies the predicted conv operation to net."""
  289. # Dont stride if this is not one of the original hiddenstates
  290. if stride > 1 and not is_from_original_input:
  291. stride = 1
  292. input_filters = get_channel_dim(net.shape)
  293. filter_size = self._filter_size
  294. if 'separable' in operation:
  295. net = _stacked_separable_conv(net, stride, operation, filter_size)
  296. elif operation in ['none']:
  297. # Check if a stride is needed, then use a strided 1x1 here
  298. if stride > 1 or (input_filters != filter_size):
  299. net = tf.nn.relu(net)
  300. net = slim.conv2d(net, filter_size, 1, stride=stride, scope='1x1')
  301. net = slim.batch_norm(net, scope='bn_1')
  302. elif 'pool' in operation:
  303. net = _pooling(net, stride, operation)
  304. if input_filters != filter_size:
  305. net = slim.conv2d(net, filter_size, 1, stride=1, scope='1x1')
  306. net = slim.batch_norm(net, scope='bn_1')
  307. else:
  308. raise ValueError('Unimplemented operation', operation)
  309. if operation != 'none':
  310. net = self._apply_drop_path(net)
  311. return net
  312. def _combine_unused_states(self, net):
  313. """Concatenate the unused hidden states of the cell."""
  314. used_hiddenstates = self._used_hiddenstates
  315. final_height = int(net[-1].shape[2])
  316. final_num_filters = get_channel_dim(net[-1].shape)
  317. assert len(used_hiddenstates) == len(net)
  318. for idx, used_h in enumerate(used_hiddenstates):
  319. curr_height = int(net[idx].shape[2])
  320. curr_num_filters = get_channel_dim(net[idx].shape)
  321. # Determine if a reduction should be applied to make the number of
  322. # filters match.
  323. should_reduce = final_num_filters != curr_num_filters
  324. should_reduce = (final_height != curr_height) or should_reduce
  325. should_reduce = should_reduce and not used_h
  326. if should_reduce:
  327. stride = 2 if final_height != curr_height else 1
  328. with tf.variable_scope('reduction_{}'.format(idx)):
  329. net[idx] = factorized_reduction(
  330. net[idx], final_num_filters, stride)
  331. states_to_combine = (
  332. [h for h, is_used in zip(net, used_hiddenstates) if not is_used])
  333. # Return the concat of all the states
  334. concat_axis = get_channel_index()
  335. net = tf.concat(values=states_to_combine, axis=concat_axis)
  336. return net
  337. def _apply_drop_path(self, net):
  338. """Apply drop_path regularization to net."""
  339. drop_path_keep_prob = self._drop_path_keep_prob
  340. if drop_path_keep_prob < 1.0:
  341. # Scale keep prob by layer number
  342. assert self._cell_num != -1
  343. # The added 2 is for the reduction cells
  344. num_cells = self._total_num_cells
  345. layer_ratio = (self._cell_num + 1)/float(num_cells)
  346. with tf.device('/cpu:0'):
  347. tf.summary.scalar('layer_ratio', layer_ratio)
  348. drop_path_keep_prob = 1 - layer_ratio * (1 - drop_path_keep_prob)
  349. # Decrease the keep probability over time
  350. current_step = tf.cast(tf.train.get_or_create_global_step(),
  351. tf.float32)
  352. drop_path_burn_in_steps = self._total_training_steps
  353. current_ratio = (
  354. current_step / drop_path_burn_in_steps)
  355. current_ratio = tf.minimum(1.0, current_ratio)
  356. with tf.device('/cpu:0'):
  357. tf.summary.scalar('current_ratio', current_ratio)
  358. drop_path_keep_prob = (
  359. 1 - current_ratio * (1 - drop_path_keep_prob))
  360. with tf.device('/cpu:0'):
  361. tf.summary.scalar('drop_path_keep_prob', drop_path_keep_prob)
  362. net = drop_path(net, drop_path_keep_prob)
  363. return net
  364. class NasNetANormalCell(NasNetABaseCell):
  365. """NASNetA Normal Cell."""
  366. def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
  367. total_training_steps):
  368. operations = ['separable_5x5_2',
  369. 'separable_3x3_2',
  370. 'separable_5x5_2',
  371. 'separable_3x3_2',
  372. 'avg_pool_3x3',
  373. 'none',
  374. 'avg_pool_3x3',
  375. 'avg_pool_3x3',
  376. 'separable_3x3_2',
  377. 'none']
  378. used_hiddenstates = [1, 0, 0, 0, 0, 0, 0]
  379. hiddenstate_indices = [0, 1, 1, 1, 0, 1, 1, 1, 0, 0]
  380. super(NasNetANormalCell, self).__init__(num_conv_filters, operations,
  381. used_hiddenstates,
  382. hiddenstate_indices,
  383. drop_path_keep_prob,
  384. total_num_cells,
  385. total_training_steps)
  386. class NasNetAReductionCell(NasNetABaseCell):
  387. """NASNetA Reduction Cell."""
  388. def __init__(self, num_conv_filters, drop_path_keep_prob, total_num_cells,
  389. total_training_steps):
  390. operations = ['separable_5x5_2',
  391. 'separable_7x7_2',
  392. 'max_pool_3x3',
  393. 'separable_7x7_2',
  394. 'avg_pool_3x3',
  395. 'separable_5x5_2',
  396. 'none',
  397. 'avg_pool_3x3',
  398. 'separable_3x3_2',
  399. 'max_pool_3x3']
  400. used_hiddenstates = [1, 1, 1, 0, 0, 0, 0]
  401. hiddenstate_indices = [0, 1, 0, 1, 0, 1, 3, 2, 2, 0]
  402. super(NasNetAReductionCell, self).__init__(num_conv_filters, operations,
  403. used_hiddenstates,
  404. hiddenstate_indices,
  405. drop_path_keep_prob,
  406. total_num_cells,
  407. total_training_steps)