# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Contains the definition for the NASNet classification networks. Paper: https://arxiv.org/abs/1707.07012 """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from nets.nasnet import nasnet_utils arg_scope = tf.contrib.framework.arg_scope slim = tf.contrib.slim # Notes for training NASNet Cifar Model # ------------------------------------- # batch_size: 32 # learning rate: 0.025 # cosine (single period) learning rate decay # auxiliary head loss weighting: 0.4 # clip global norm of all gradients by 5 def _cifar_config(is_training=True): drop_path_keep_prob = 1.0 if not is_training else 0.6 return tf.contrib.training.HParams( stem_multiplier=3.0, drop_path_keep_prob=drop_path_keep_prob, num_cells=18, use_aux_head=1, num_conv_filters=32, dense_dropout_keep_prob=1.0, filter_scaling_rate=2.0, num_reduction_layers=2, data_format='NHWC', skip_reduction_layer_input=0, # 600 epochs with a batch size of 32 # This is used for the drop path probabilities since it needs to increase # the drop out probability over the course of training. total_training_steps=937500, ) # Notes for training large NASNet model on ImageNet # ------------------------------------- # batch size (per replica): 16 # learning rate: 0.015 * 100 # learning rate decay factor: 0.97 # num epochs per decay: 2.4 # sync sgd with 100 replicas # auxiliary head loss weighting: 0.4 # label smoothing: 0.1 # clip global norm of all gradients by 10 def _large_imagenet_config(is_training=True): drop_path_keep_prob = 1.0 if not is_training else 0.7 return tf.contrib.training.HParams( stem_multiplier=3.0, dense_dropout_keep_prob=0.5, num_cells=18, filter_scaling_rate=2.0, num_conv_filters=168, drop_path_keep_prob=drop_path_keep_prob, use_aux_head=1, num_reduction_layers=2, data_format='NHWC', skip_reduction_layer_input=1, total_training_steps=250000, ) # Notes for training the mobile NASNet ImageNet model # ------------------------------------- # batch size (per replica): 32 # learning rate: 0.04 * 50 # learning rate scaling factor: 0.97 # num epochs per decay: 2.4 # sync sgd with 50 replicas # auxiliary head weighting: 0.4 # label smoothing: 0.1 # clip global norm of all gradients by 10 def _mobile_imagenet_config(): return tf.contrib.training.HParams( stem_multiplier=1.0, dense_dropout_keep_prob=0.5, num_cells=12, filter_scaling_rate=2.0, drop_path_keep_prob=1.0, num_conv_filters=44, use_aux_head=1, num_reduction_layers=2, data_format='NHWC', skip_reduction_layer_input=0, total_training_steps=250000, ) def nasnet_cifar_arg_scope(weight_decay=5e-4, batch_norm_decay=0.9, batch_norm_epsilon=1e-5): """Defines the default arg scope for the NASNet-A Cifar model. Args: weight_decay: The weight decay to use for regularizing the model. batch_norm_decay: Decay for batch norm moving average. batch_norm_epsilon: Small float added to variance to avoid dividing by zero in batch norm. Returns: An `arg_scope` to use for the NASNet Cifar Model. """ batch_norm_params = { # Decay for the moving averages. 'decay': batch_norm_decay, # epsilon to prevent 0s in variance. 'epsilon': batch_norm_epsilon, 'scale': True, 'fused': True, } weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay) weights_initializer = tf.contrib.layers.variance_scaling_initializer( mode='FAN_OUT') with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d], weights_regularizer=weights_regularizer, weights_initializer=weights_initializer): with arg_scope([slim.fully_connected], activation_fn=None, scope='FC'): with arg_scope([slim.conv2d, slim.separable_conv2d], activation_fn=None, biases_initializer=None): with arg_scope([slim.batch_norm], **batch_norm_params) as sc: return sc def nasnet_mobile_arg_scope(weight_decay=4e-5, batch_norm_decay=0.9997, batch_norm_epsilon=1e-3): """Defines the default arg scope for the NASNet-A Mobile ImageNet model. Args: weight_decay: The weight decay to use for regularizing the model. batch_norm_decay: Decay for batch norm moving average. batch_norm_epsilon: Small float added to variance to avoid dividing by zero in batch norm. Returns: An `arg_scope` to use for the NASNet Mobile Model. """ batch_norm_params = { # Decay for the moving averages. 'decay': batch_norm_decay, # epsilon to prevent 0s in variance. 'epsilon': batch_norm_epsilon, 'scale': True, 'fused': True, } weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay) weights_initializer = tf.contrib.layers.variance_scaling_initializer( mode='FAN_OUT') with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d], weights_regularizer=weights_regularizer, weights_initializer=weights_initializer): with arg_scope([slim.fully_connected], activation_fn=None, scope='FC'): with arg_scope([slim.conv2d, slim.separable_conv2d], activation_fn=None, biases_initializer=None): with arg_scope([slim.batch_norm], **batch_norm_params) as sc: return sc def nasnet_large_arg_scope(weight_decay=5e-5, batch_norm_decay=0.9997, batch_norm_epsilon=1e-3): """Defines the default arg scope for the NASNet-A Large ImageNet model. Args: weight_decay: The weight decay to use for regularizing the model. batch_norm_decay: Decay for batch norm moving average. batch_norm_epsilon: Small float added to variance to avoid dividing by zero in batch norm. Returns: An `arg_scope` to use for the NASNet Large Model. """ batch_norm_params = { # Decay for the moving averages. 'decay': batch_norm_decay, # epsilon to prevent 0s in variance. 'epsilon': batch_norm_epsilon, 'scale': True, 'fused': True, } weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay) weights_initializer = tf.contrib.layers.variance_scaling_initializer( mode='FAN_OUT') with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d], weights_regularizer=weights_regularizer, weights_initializer=weights_initializer): with arg_scope([slim.fully_connected], activation_fn=None, scope='FC'): with arg_scope([slim.conv2d, slim.separable_conv2d], activation_fn=None, biases_initializer=None): with arg_scope([slim.batch_norm], **batch_norm_params) as sc: return sc def _build_aux_head(net, end_points, num_classes, hparams, scope): """Auxiliary head used for all models across all datasets.""" with tf.variable_scope(scope): aux_logits = tf.identity(net) with tf.variable_scope('aux_logits'): aux_logits = slim.avg_pool2d( aux_logits, [5, 5], stride=3, padding='VALID') aux_logits = slim.conv2d(aux_logits, 128, [1, 1], scope='proj') aux_logits = slim.batch_norm(aux_logits, scope='aux_bn0') aux_logits = tf.nn.relu(aux_logits) # Shape of feature map before the final layer. shape = aux_logits.shape if hparams.data_format == 'NHWC': shape = shape[1:3] else: shape = shape[2:4] aux_logits = slim.conv2d(aux_logits, 768, shape, padding='VALID') aux_logits = slim.batch_norm(aux_logits, scope='aux_bn1') aux_logits = tf.nn.relu(aux_logits) aux_logits = tf.contrib.layers.flatten(aux_logits) aux_logits = slim.fully_connected(aux_logits, num_classes) end_points['AuxLogits'] = aux_logits def _imagenet_stem(inputs, hparams, stem_cell): """Stem used for models trained on ImageNet.""" num_stem_cells = 2 # 149 x 149 x 32 num_stem_filters = int(32 * hparams.stem_multiplier) net = slim.conv2d( inputs, num_stem_filters, [3, 3], stride=2, scope='conv0', padding='VALID') net = slim.batch_norm(net, scope='conv0_bn') # Run the reduction cells cell_outputs = [None, net] filter_scaling = 1.0 / (hparams.filter_scaling_rate**num_stem_cells) for cell_num in range(num_stem_cells): net = stem_cell( net, scope='cell_stem_{}'.format(cell_num), filter_scaling=filter_scaling, stride=2, prev_layer=cell_outputs[-2], cell_num=cell_num) cell_outputs.append(net) filter_scaling *= hparams.filter_scaling_rate return net, cell_outputs def _cifar_stem(inputs, hparams): """Stem used for models trained on Cifar.""" num_stem_filters = int(hparams.num_conv_filters * hparams.stem_multiplier) net = slim.conv2d( inputs, num_stem_filters, 3, scope='l1_stem_3x3') net = slim.batch_norm(net, scope='l1_stem_bn') return net, [None, net] def build_nasnet_cifar( images, num_classes, is_training=True): """Build NASNet model for the Cifar Dataset.""" hparams = _cifar_config(is_training=is_training) if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network # Add 2 for the reduction cells total_num_cells = hparams.num_cells + 2 normal_cell = nasnet_utils.NasNetANormalCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) reduction_cell = nasnet_utils.NasNetAReductionCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_nasnet_base(images, normal_cell=normal_cell, reduction_cell=reduction_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, stem_type='cifar') build_nasnet_cifar.default_image_size = 32 def build_nasnet_mobile(images, num_classes, is_training=True, final_endpoint=None): """Build NASNet Mobile model for the ImageNet Dataset.""" hparams = _mobile_imagenet_config() if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network # Add 2 for the reduction cells total_num_cells = hparams.num_cells + 2 # If ImageNet, then add an additional two for the stem cells total_num_cells += 2 normal_cell = nasnet_utils.NasNetANormalCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) reduction_cell = nasnet_utils.NasNetAReductionCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_nasnet_base(images, normal_cell=normal_cell, reduction_cell=reduction_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, stem_type='imagenet', final_endpoint=final_endpoint) build_nasnet_mobile.default_image_size = 224 def build_nasnet_large(images, num_classes, is_training=True, final_endpoint=None): """Build NASNet Large model for the ImageNet Dataset.""" hparams = _large_imagenet_config(is_training=is_training) if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network # Add 2 for the reduction cells total_num_cells = hparams.num_cells + 2 # If ImageNet, then add an additional two for the stem cells total_num_cells += 2 normal_cell = nasnet_utils.NasNetANormalCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) reduction_cell = nasnet_utils.NasNetAReductionCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_nasnet_base(images, normal_cell=normal_cell, reduction_cell=reduction_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, stem_type='imagenet', final_endpoint=final_endpoint) build_nasnet_large.default_image_size = 331 def _build_nasnet_base(images, normal_cell, reduction_cell, num_classes, hparams, is_training, stem_type, final_endpoint=None): """Constructs a NASNet image model.""" end_points = {} def add_and_check_endpoint(endpoint_name, net): end_points[endpoint_name] = net return final_endpoint and (endpoint_name == final_endpoint) # Find where to place the reduction cells or stride normal cells reduction_indices = nasnet_utils.calc_reduction_layers( hparams.num_cells, hparams.num_reduction_layers) stem_cell = reduction_cell if stem_type == 'imagenet': stem = lambda: _imagenet_stem(images, hparams, stem_cell) elif stem_type == 'cifar': stem = lambda: _cifar_stem(images, hparams) else: raise ValueError('Unknown stem_type: ', stem_type) net, cell_outputs = stem() if add_and_check_endpoint('Stem', net): return net, end_points # Setup for building in the auxiliary head. aux_head_cell_idxes = [] if len(reduction_indices) >= 2: aux_head_cell_idxes.append(reduction_indices[1] - 1) # Run the cells filter_scaling = 1.0 # true_cell_num accounts for the stem cells true_cell_num = 2 if stem_type == 'imagenet' else 0 for cell_num in range(hparams.num_cells): stride = 1 if hparams.skip_reduction_layer_input: prev_layer = cell_outputs[-2] if cell_num in reduction_indices: filter_scaling *= hparams.filter_scaling_rate net = reduction_cell( net, scope='reduction_cell_{}'.format(reduction_indices.index(cell_num)), filter_scaling=filter_scaling, stride=2, prev_layer=cell_outputs[-2], cell_num=true_cell_num) if add_and_check_endpoint( 'Reduction_Cell_{}'.format(reduction_indices.index(cell_num)), net): return net, end_points true_cell_num += 1 cell_outputs.append(net) if not hparams.skip_reduction_layer_input: prev_layer = cell_outputs[-2] net = normal_cell( net, scope='cell_{}'.format(cell_num), filter_scaling=filter_scaling, stride=stride, prev_layer=prev_layer, cell_num=true_cell_num) if add_and_check_endpoint('Cell_{}'.format(cell_num), net): return net, end_points true_cell_num += 1 if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and num_classes and is_training): aux_net = tf.nn.relu(net) _build_aux_head(aux_net, end_points, num_classes, hparams, scope='aux_{}'.format(cell_num)) cell_outputs.append(net) # Final softmax layer with tf.variable_scope('final_layer'): net = tf.nn.relu(net) net = nasnet_utils.global_avg_pool(net) if add_and_check_endpoint('global_pool', net) or num_classes is None: return net, end_points net = slim.dropout(net, hparams.dense_dropout_keep_prob, scope='dropout') logits = slim.fully_connected(net, num_classes) if add_and_check_endpoint('Logits', logits): return net, end_points predictions = tf.nn.softmax(logits, name='predictions') if add_and_check_endpoint('Predictions', predictions): return net, end_points return logits, end_points