nasnet.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains the definition for the NASNet classification networks.
  16. Paper: https://arxiv.org/abs/1707.07012
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import tensorflow as tf
  22. from nets.nasnet import nasnet_utils
  23. arg_scope = tf.contrib.framework.arg_scope
  24. slim = tf.contrib.slim
  25. # Notes for training NASNet Cifar Model
  26. # -------------------------------------
  27. # batch_size: 32
  28. # learning rate: 0.025
  29. # cosine (single period) learning rate decay
  30. # auxiliary head loss weighting: 0.4
  31. # clip global norm of all gradients by 5
  32. def _cifar_config(is_training=True):
  33. drop_path_keep_prob = 1.0 if not is_training else 0.6
  34. return tf.contrib.training.HParams(
  35. stem_multiplier=3.0,
  36. drop_path_keep_prob=drop_path_keep_prob,
  37. num_cells=18,
  38. use_aux_head=1,
  39. num_conv_filters=32,
  40. dense_dropout_keep_prob=1.0,
  41. filter_scaling_rate=2.0,
  42. num_reduction_layers=2,
  43. data_format='NHWC',
  44. skip_reduction_layer_input=0,
  45. # 600 epochs with a batch size of 32
  46. # This is used for the drop path probabilities since it needs to increase
  47. # the drop out probability over the course of training.
  48. total_training_steps=937500,
  49. )
  50. # Notes for training large NASNet model on ImageNet
  51. # -------------------------------------
  52. # batch size (per replica): 16
  53. # learning rate: 0.015 * 100
  54. # learning rate decay factor: 0.97
  55. # num epochs per decay: 2.4
  56. # sync sgd with 100 replicas
  57. # auxiliary head loss weighting: 0.4
  58. # label smoothing: 0.1
  59. # clip global norm of all gradients by 10
  60. def _large_imagenet_config(is_training=True):
  61. drop_path_keep_prob = 1.0 if not is_training else 0.7
  62. return tf.contrib.training.HParams(
  63. stem_multiplier=3.0,
  64. dense_dropout_keep_prob=0.5,
  65. num_cells=18,
  66. filter_scaling_rate=2.0,
  67. num_conv_filters=168,
  68. drop_path_keep_prob=drop_path_keep_prob,
  69. use_aux_head=1,
  70. num_reduction_layers=2,
  71. data_format='NHWC',
  72. skip_reduction_layer_input=1,
  73. total_training_steps=250000,
  74. )
  75. # Notes for training the mobile NASNet ImageNet model
  76. # -------------------------------------
  77. # batch size (per replica): 32
  78. # learning rate: 0.04 * 50
  79. # learning rate scaling factor: 0.97
  80. # num epochs per decay: 2.4
  81. # sync sgd with 50 replicas
  82. # auxiliary head weighting: 0.4
  83. # label smoothing: 0.1
  84. # clip global norm of all gradients by 10
  85. def _mobile_imagenet_config():
  86. return tf.contrib.training.HParams(
  87. stem_multiplier=1.0,
  88. dense_dropout_keep_prob=0.5,
  89. num_cells=12,
  90. filter_scaling_rate=2.0,
  91. drop_path_keep_prob=1.0,
  92. num_conv_filters=44,
  93. use_aux_head=1,
  94. num_reduction_layers=2,
  95. data_format='NHWC',
  96. skip_reduction_layer_input=0,
  97. total_training_steps=250000,
  98. )
  99. def nasnet_cifar_arg_scope(weight_decay=5e-4,
  100. batch_norm_decay=0.9,
  101. batch_norm_epsilon=1e-5):
  102. """Defines the default arg scope for the NASNet-A Cifar model.
  103. Args:
  104. weight_decay: The weight decay to use for regularizing the model.
  105. batch_norm_decay: Decay for batch norm moving average.
  106. batch_norm_epsilon: Small float added to variance to avoid dividing by zero
  107. in batch norm.
  108. Returns:
  109. An `arg_scope` to use for the NASNet Cifar Model.
  110. """
  111. batch_norm_params = {
  112. # Decay for the moving averages.
  113. 'decay': batch_norm_decay,
  114. # epsilon to prevent 0s in variance.
  115. 'epsilon': batch_norm_epsilon,
  116. 'scale': True,
  117. 'fused': True,
  118. }
  119. weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
  120. weights_initializer = tf.contrib.layers.variance_scaling_initializer(
  121. mode='FAN_OUT')
  122. with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
  123. weights_regularizer=weights_regularizer,
  124. weights_initializer=weights_initializer):
  125. with arg_scope([slim.fully_connected],
  126. activation_fn=None, scope='FC'):
  127. with arg_scope([slim.conv2d, slim.separable_conv2d],
  128. activation_fn=None, biases_initializer=None):
  129. with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
  130. return sc
  131. def nasnet_mobile_arg_scope(weight_decay=4e-5,
  132. batch_norm_decay=0.9997,
  133. batch_norm_epsilon=1e-3):
  134. """Defines the default arg scope for the NASNet-A Mobile ImageNet model.
  135. Args:
  136. weight_decay: The weight decay to use for regularizing the model.
  137. batch_norm_decay: Decay for batch norm moving average.
  138. batch_norm_epsilon: Small float added to variance to avoid dividing by zero
  139. in batch norm.
  140. Returns:
  141. An `arg_scope` to use for the NASNet Mobile Model.
  142. """
  143. batch_norm_params = {
  144. # Decay for the moving averages.
  145. 'decay': batch_norm_decay,
  146. # epsilon to prevent 0s in variance.
  147. 'epsilon': batch_norm_epsilon,
  148. 'scale': True,
  149. 'fused': True,
  150. }
  151. weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
  152. weights_initializer = tf.contrib.layers.variance_scaling_initializer(
  153. mode='FAN_OUT')
  154. with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
  155. weights_regularizer=weights_regularizer,
  156. weights_initializer=weights_initializer):
  157. with arg_scope([slim.fully_connected],
  158. activation_fn=None, scope='FC'):
  159. with arg_scope([slim.conv2d, slim.separable_conv2d],
  160. activation_fn=None, biases_initializer=None):
  161. with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
  162. return sc
  163. def nasnet_large_arg_scope(weight_decay=5e-5,
  164. batch_norm_decay=0.9997,
  165. batch_norm_epsilon=1e-3):
  166. """Defines the default arg scope for the NASNet-A Large ImageNet model.
  167. Args:
  168. weight_decay: The weight decay to use for regularizing the model.
  169. batch_norm_decay: Decay for batch norm moving average.
  170. batch_norm_epsilon: Small float added to variance to avoid dividing by zero
  171. in batch norm.
  172. Returns:
  173. An `arg_scope` to use for the NASNet Large Model.
  174. """
  175. batch_norm_params = {
  176. # Decay for the moving averages.
  177. 'decay': batch_norm_decay,
  178. # epsilon to prevent 0s in variance.
  179. 'epsilon': batch_norm_epsilon,
  180. 'scale': True,
  181. 'fused': True,
  182. }
  183. weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
  184. weights_initializer = tf.contrib.layers.variance_scaling_initializer(
  185. mode='FAN_OUT')
  186. with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
  187. weights_regularizer=weights_regularizer,
  188. weights_initializer=weights_initializer):
  189. with arg_scope([slim.fully_connected],
  190. activation_fn=None, scope='FC'):
  191. with arg_scope([slim.conv2d, slim.separable_conv2d],
  192. activation_fn=None, biases_initializer=None):
  193. with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
  194. return sc
  195. def _build_aux_head(net, end_points, num_classes, hparams, scope):
  196. """Auxiliary head used for all models across all datasets."""
  197. with tf.variable_scope(scope):
  198. aux_logits = tf.identity(net)
  199. with tf.variable_scope('aux_logits'):
  200. aux_logits = slim.avg_pool2d(
  201. aux_logits, [5, 5], stride=3, padding='VALID')
  202. aux_logits = slim.conv2d(aux_logits, 128, [1, 1], scope='proj')
  203. aux_logits = slim.batch_norm(aux_logits, scope='aux_bn0')
  204. aux_logits = tf.nn.relu(aux_logits)
  205. # Shape of feature map before the final layer.
  206. shape = aux_logits.shape
  207. if hparams.data_format == 'NHWC':
  208. shape = shape[1:3]
  209. else:
  210. shape = shape[2:4]
  211. aux_logits = slim.conv2d(aux_logits, 768, shape, padding='VALID')
  212. aux_logits = slim.batch_norm(aux_logits, scope='aux_bn1')
  213. aux_logits = tf.nn.relu(aux_logits)
  214. aux_logits = tf.contrib.layers.flatten(aux_logits)
  215. aux_logits = slim.fully_connected(aux_logits, num_classes)
  216. end_points['AuxLogits'] = aux_logits
  217. def _imagenet_stem(inputs, hparams, stem_cell):
  218. """Stem used for models trained on ImageNet."""
  219. num_stem_cells = 2
  220. # 149 x 149 x 32
  221. num_stem_filters = int(32 * hparams.stem_multiplier)
  222. net = slim.conv2d(
  223. inputs, num_stem_filters, [3, 3], stride=2, scope='conv0',
  224. padding='VALID')
  225. net = slim.batch_norm(net, scope='conv0_bn')
  226. # Run the reduction cells
  227. cell_outputs = [None, net]
  228. filter_scaling = 1.0 / (hparams.filter_scaling_rate**num_stem_cells)
  229. for cell_num in range(num_stem_cells):
  230. net = stem_cell(
  231. net,
  232. scope='cell_stem_{}'.format(cell_num),
  233. filter_scaling=filter_scaling,
  234. stride=2,
  235. prev_layer=cell_outputs[-2],
  236. cell_num=cell_num)
  237. cell_outputs.append(net)
  238. filter_scaling *= hparams.filter_scaling_rate
  239. return net, cell_outputs
  240. def _cifar_stem(inputs, hparams):
  241. """Stem used for models trained on Cifar."""
  242. num_stem_filters = int(hparams.num_conv_filters * hparams.stem_multiplier)
  243. net = slim.conv2d(
  244. inputs,
  245. num_stem_filters,
  246. 3,
  247. scope='l1_stem_3x3')
  248. net = slim.batch_norm(net, scope='l1_stem_bn')
  249. return net, [None, net]
  250. def build_nasnet_cifar(
  251. images, num_classes, is_training=True):
  252. """Build NASNet model for the Cifar Dataset."""
  253. hparams = _cifar_config(is_training=is_training)
  254. if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
  255. tf.logging.info('A GPU is available on the machine, consider using NCHW '
  256. 'data format for increased speed on GPU.')
  257. if hparams.data_format == 'NCHW':
  258. images = tf.transpose(images, [0, 3, 1, 2])
  259. # Calculate the total number of cells in the network
  260. # Add 2 for the reduction cells
  261. total_num_cells = hparams.num_cells + 2
  262. normal_cell = nasnet_utils.NasNetANormalCell(
  263. hparams.num_conv_filters, hparams.drop_path_keep_prob,
  264. total_num_cells, hparams.total_training_steps)
  265. reduction_cell = nasnet_utils.NasNetAReductionCell(
  266. hparams.num_conv_filters, hparams.drop_path_keep_prob,
  267. total_num_cells, hparams.total_training_steps)
  268. with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
  269. is_training=is_training):
  270. with arg_scope([slim.avg_pool2d,
  271. slim.max_pool2d,
  272. slim.conv2d,
  273. slim.batch_norm,
  274. slim.separable_conv2d,
  275. nasnet_utils.factorized_reduction,
  276. nasnet_utils.global_avg_pool,
  277. nasnet_utils.get_channel_index,
  278. nasnet_utils.get_channel_dim],
  279. data_format=hparams.data_format):
  280. return _build_nasnet_base(images,
  281. normal_cell=normal_cell,
  282. reduction_cell=reduction_cell,
  283. num_classes=num_classes,
  284. hparams=hparams,
  285. is_training=is_training,
  286. stem_type='cifar')
  287. build_nasnet_cifar.default_image_size = 32
  288. def build_nasnet_mobile(images, num_classes,
  289. is_training=True,
  290. final_endpoint=None):
  291. """Build NASNet Mobile model for the ImageNet Dataset."""
  292. hparams = _mobile_imagenet_config()
  293. if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
  294. tf.logging.info('A GPU is available on the machine, consider using NCHW '
  295. 'data format for increased speed on GPU.')
  296. if hparams.data_format == 'NCHW':
  297. images = tf.transpose(images, [0, 3, 1, 2])
  298. # Calculate the total number of cells in the network
  299. # Add 2 for the reduction cells
  300. total_num_cells = hparams.num_cells + 2
  301. # If ImageNet, then add an additional two for the stem cells
  302. total_num_cells += 2
  303. normal_cell = nasnet_utils.NasNetANormalCell(
  304. hparams.num_conv_filters, hparams.drop_path_keep_prob,
  305. total_num_cells, hparams.total_training_steps)
  306. reduction_cell = nasnet_utils.NasNetAReductionCell(
  307. hparams.num_conv_filters, hparams.drop_path_keep_prob,
  308. total_num_cells, hparams.total_training_steps)
  309. with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
  310. is_training=is_training):
  311. with arg_scope([slim.avg_pool2d,
  312. slim.max_pool2d,
  313. slim.conv2d,
  314. slim.batch_norm,
  315. slim.separable_conv2d,
  316. nasnet_utils.factorized_reduction,
  317. nasnet_utils.global_avg_pool,
  318. nasnet_utils.get_channel_index,
  319. nasnet_utils.get_channel_dim],
  320. data_format=hparams.data_format):
  321. return _build_nasnet_base(images,
  322. normal_cell=normal_cell,
  323. reduction_cell=reduction_cell,
  324. num_classes=num_classes,
  325. hparams=hparams,
  326. is_training=is_training,
  327. stem_type='imagenet',
  328. final_endpoint=final_endpoint)
  329. build_nasnet_mobile.default_image_size = 224
  330. def build_nasnet_large(images, num_classes,
  331. is_training=True,
  332. final_endpoint=None):
  333. """Build NASNet Large model for the ImageNet Dataset."""
  334. hparams = _large_imagenet_config(is_training=is_training)
  335. if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
  336. tf.logging.info('A GPU is available on the machine, consider using NCHW '
  337. 'data format for increased speed on GPU.')
  338. if hparams.data_format == 'NCHW':
  339. images = tf.transpose(images, [0, 3, 1, 2])
  340. # Calculate the total number of cells in the network
  341. # Add 2 for the reduction cells
  342. total_num_cells = hparams.num_cells + 2
  343. # If ImageNet, then add an additional two for the stem cells
  344. total_num_cells += 2
  345. normal_cell = nasnet_utils.NasNetANormalCell(
  346. hparams.num_conv_filters, hparams.drop_path_keep_prob,
  347. total_num_cells, hparams.total_training_steps)
  348. reduction_cell = nasnet_utils.NasNetAReductionCell(
  349. hparams.num_conv_filters, hparams.drop_path_keep_prob,
  350. total_num_cells, hparams.total_training_steps)
  351. with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
  352. is_training=is_training):
  353. with arg_scope([slim.avg_pool2d,
  354. slim.max_pool2d,
  355. slim.conv2d,
  356. slim.batch_norm,
  357. slim.separable_conv2d,
  358. nasnet_utils.factorized_reduction,
  359. nasnet_utils.global_avg_pool,
  360. nasnet_utils.get_channel_index,
  361. nasnet_utils.get_channel_dim],
  362. data_format=hparams.data_format):
  363. return _build_nasnet_base(images,
  364. normal_cell=normal_cell,
  365. reduction_cell=reduction_cell,
  366. num_classes=num_classes,
  367. hparams=hparams,
  368. is_training=is_training,
  369. stem_type='imagenet',
  370. final_endpoint=final_endpoint)
  371. build_nasnet_large.default_image_size = 331
  372. def _build_nasnet_base(images,
  373. normal_cell,
  374. reduction_cell,
  375. num_classes,
  376. hparams,
  377. is_training,
  378. stem_type,
  379. final_endpoint=None):
  380. """Constructs a NASNet image model."""
  381. end_points = {}
  382. def add_and_check_endpoint(endpoint_name, net):
  383. end_points[endpoint_name] = net
  384. return final_endpoint and (endpoint_name == final_endpoint)
  385. # Find where to place the reduction cells or stride normal cells
  386. reduction_indices = nasnet_utils.calc_reduction_layers(
  387. hparams.num_cells, hparams.num_reduction_layers)
  388. stem_cell = reduction_cell
  389. if stem_type == 'imagenet':
  390. stem = lambda: _imagenet_stem(images, hparams, stem_cell)
  391. elif stem_type == 'cifar':
  392. stem = lambda: _cifar_stem(images, hparams)
  393. else:
  394. raise ValueError('Unknown stem_type: ', stem_type)
  395. net, cell_outputs = stem()
  396. if add_and_check_endpoint('Stem', net): return net, end_points
  397. # Setup for building in the auxiliary head.
  398. aux_head_cell_idxes = []
  399. if len(reduction_indices) >= 2:
  400. aux_head_cell_idxes.append(reduction_indices[1] - 1)
  401. # Run the cells
  402. filter_scaling = 1.0
  403. # true_cell_num accounts for the stem cells
  404. true_cell_num = 2 if stem_type == 'imagenet' else 0
  405. for cell_num in range(hparams.num_cells):
  406. stride = 1
  407. if hparams.skip_reduction_layer_input:
  408. prev_layer = cell_outputs[-2]
  409. if cell_num in reduction_indices:
  410. filter_scaling *= hparams.filter_scaling_rate
  411. net = reduction_cell(
  412. net,
  413. scope='reduction_cell_{}'.format(reduction_indices.index(cell_num)),
  414. filter_scaling=filter_scaling,
  415. stride=2,
  416. prev_layer=cell_outputs[-2],
  417. cell_num=true_cell_num)
  418. if add_and_check_endpoint(
  419. 'Reduction_Cell_{}'.format(reduction_indices.index(cell_num)), net):
  420. return net, end_points
  421. true_cell_num += 1
  422. cell_outputs.append(net)
  423. if not hparams.skip_reduction_layer_input:
  424. prev_layer = cell_outputs[-2]
  425. net = normal_cell(
  426. net,
  427. scope='cell_{}'.format(cell_num),
  428. filter_scaling=filter_scaling,
  429. stride=stride,
  430. prev_layer=prev_layer,
  431. cell_num=true_cell_num)
  432. if add_and_check_endpoint('Cell_{}'.format(cell_num), net):
  433. return net, end_points
  434. true_cell_num += 1
  435. if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and
  436. num_classes and is_training):
  437. aux_net = tf.nn.relu(net)
  438. _build_aux_head(aux_net, end_points, num_classes, hparams,
  439. scope='aux_{}'.format(cell_num))
  440. cell_outputs.append(net)
  441. # Final softmax layer
  442. with tf.variable_scope('final_layer'):
  443. net = tf.nn.relu(net)
  444. net = nasnet_utils.global_avg_pool(net)
  445. if add_and_check_endpoint('global_pool', net) or num_classes is None:
  446. return net, end_points
  447. net = slim.dropout(net, hparams.dense_dropout_keep_prob, scope='dropout')
  448. logits = slim.fully_connected(net, num_classes)
  449. if add_and_check_endpoint('Logits', logits):
  450. return net, end_points
  451. predictions = tf.nn.softmax(logits, name='predictions')
  452. if add_and_check_endpoint('Predictions', predictions):
  453. return net, end_points
  454. return logits, end_points