123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513 |
- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Contains the definition for the NASNet classification networks.
- Paper: https://arxiv.org/abs/1707.07012
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import tensorflow as tf
- from nets.nasnet import nasnet_utils
- arg_scope = tf.contrib.framework.arg_scope
- slim = tf.contrib.slim
- # Notes for training NASNet Cifar Model
- # -------------------------------------
- # batch_size: 32
- # learning rate: 0.025
- # cosine (single period) learning rate decay
- # auxiliary head loss weighting: 0.4
- # clip global norm of all gradients by 5
- def _cifar_config(is_training=True):
- drop_path_keep_prob = 1.0 if not is_training else 0.6
- return tf.contrib.training.HParams(
- stem_multiplier=3.0,
- drop_path_keep_prob=drop_path_keep_prob,
- num_cells=18,
- use_aux_head=1,
- num_conv_filters=32,
- dense_dropout_keep_prob=1.0,
- filter_scaling_rate=2.0,
- num_reduction_layers=2,
- data_format='NHWC',
- skip_reduction_layer_input=0,
- # 600 epochs with a batch size of 32
- # This is used for the drop path probabilities since it needs to increase
- # the drop out probability over the course of training.
- total_training_steps=937500,
- )
- # Notes for training large NASNet model on ImageNet
- # -------------------------------------
- # batch size (per replica): 16
- # learning rate: 0.015 * 100
- # learning rate decay factor: 0.97
- # num epochs per decay: 2.4
- # sync sgd with 100 replicas
- # auxiliary head loss weighting: 0.4
- # label smoothing: 0.1
- # clip global norm of all gradients by 10
- def _large_imagenet_config(is_training=True):
- drop_path_keep_prob = 1.0 if not is_training else 0.7
- return tf.contrib.training.HParams(
- stem_multiplier=3.0,
- dense_dropout_keep_prob=0.5,
- num_cells=18,
- filter_scaling_rate=2.0,
- num_conv_filters=168,
- drop_path_keep_prob=drop_path_keep_prob,
- use_aux_head=1,
- num_reduction_layers=2,
- data_format='NHWC',
- skip_reduction_layer_input=1,
- total_training_steps=250000,
- )
- # Notes for training the mobile NASNet ImageNet model
- # -------------------------------------
- # batch size (per replica): 32
- # learning rate: 0.04 * 50
- # learning rate scaling factor: 0.97
- # num epochs per decay: 2.4
- # sync sgd with 50 replicas
- # auxiliary head weighting: 0.4
- # label smoothing: 0.1
- # clip global norm of all gradients by 10
- def _mobile_imagenet_config():
- return tf.contrib.training.HParams(
- stem_multiplier=1.0,
- dense_dropout_keep_prob=0.5,
- num_cells=12,
- filter_scaling_rate=2.0,
- drop_path_keep_prob=1.0,
- num_conv_filters=44,
- use_aux_head=1,
- num_reduction_layers=2,
- data_format='NHWC',
- skip_reduction_layer_input=0,
- total_training_steps=250000,
- )
- def nasnet_cifar_arg_scope(weight_decay=5e-4,
- batch_norm_decay=0.9,
- batch_norm_epsilon=1e-5):
- """Defines the default arg scope for the NASNet-A Cifar model.
- Args:
- weight_decay: The weight decay to use for regularizing the model.
- batch_norm_decay: Decay for batch norm moving average.
- batch_norm_epsilon: Small float added to variance to avoid dividing by zero
- in batch norm.
- Returns:
- An `arg_scope` to use for the NASNet Cifar Model.
- """
- batch_norm_params = {
- # Decay for the moving averages.
- 'decay': batch_norm_decay,
- # epsilon to prevent 0s in variance.
- 'epsilon': batch_norm_epsilon,
- 'scale': True,
- 'fused': True,
- }
- weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
- weights_initializer = tf.contrib.layers.variance_scaling_initializer(
- mode='FAN_OUT')
- with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
- weights_regularizer=weights_regularizer,
- weights_initializer=weights_initializer):
- with arg_scope([slim.fully_connected],
- activation_fn=None, scope='FC'):
- with arg_scope([slim.conv2d, slim.separable_conv2d],
- activation_fn=None, biases_initializer=None):
- with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
- return sc
- def nasnet_mobile_arg_scope(weight_decay=4e-5,
- batch_norm_decay=0.9997,
- batch_norm_epsilon=1e-3):
- """Defines the default arg scope for the NASNet-A Mobile ImageNet model.
- Args:
- weight_decay: The weight decay to use for regularizing the model.
- batch_norm_decay: Decay for batch norm moving average.
- batch_norm_epsilon: Small float added to variance to avoid dividing by zero
- in batch norm.
- Returns:
- An `arg_scope` to use for the NASNet Mobile Model.
- """
- batch_norm_params = {
- # Decay for the moving averages.
- 'decay': batch_norm_decay,
- # epsilon to prevent 0s in variance.
- 'epsilon': batch_norm_epsilon,
- 'scale': True,
- 'fused': True,
- }
- weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
- weights_initializer = tf.contrib.layers.variance_scaling_initializer(
- mode='FAN_OUT')
- with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
- weights_regularizer=weights_regularizer,
- weights_initializer=weights_initializer):
- with arg_scope([slim.fully_connected],
- activation_fn=None, scope='FC'):
- with arg_scope([slim.conv2d, slim.separable_conv2d],
- activation_fn=None, biases_initializer=None):
- with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
- return sc
- def nasnet_large_arg_scope(weight_decay=5e-5,
- batch_norm_decay=0.9997,
- batch_norm_epsilon=1e-3):
- """Defines the default arg scope for the NASNet-A Large ImageNet model.
- Args:
- weight_decay: The weight decay to use for regularizing the model.
- batch_norm_decay: Decay for batch norm moving average.
- batch_norm_epsilon: Small float added to variance to avoid dividing by zero
- in batch norm.
- Returns:
- An `arg_scope` to use for the NASNet Large Model.
- """
- batch_norm_params = {
- # Decay for the moving averages.
- 'decay': batch_norm_decay,
- # epsilon to prevent 0s in variance.
- 'epsilon': batch_norm_epsilon,
- 'scale': True,
- 'fused': True,
- }
- weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
- weights_initializer = tf.contrib.layers.variance_scaling_initializer(
- mode='FAN_OUT')
- with arg_scope([slim.fully_connected, slim.conv2d, slim.separable_conv2d],
- weights_regularizer=weights_regularizer,
- weights_initializer=weights_initializer):
- with arg_scope([slim.fully_connected],
- activation_fn=None, scope='FC'):
- with arg_scope([slim.conv2d, slim.separable_conv2d],
- activation_fn=None, biases_initializer=None):
- with arg_scope([slim.batch_norm], **batch_norm_params) as sc:
- return sc
- def _build_aux_head(net, end_points, num_classes, hparams, scope):
- """Auxiliary head used for all models across all datasets."""
- with tf.variable_scope(scope):
- aux_logits = tf.identity(net)
- with tf.variable_scope('aux_logits'):
- aux_logits = slim.avg_pool2d(
- aux_logits, [5, 5], stride=3, padding='VALID')
- aux_logits = slim.conv2d(aux_logits, 128, [1, 1], scope='proj')
- aux_logits = slim.batch_norm(aux_logits, scope='aux_bn0')
- aux_logits = tf.nn.relu(aux_logits)
- # Shape of feature map before the final layer.
- shape = aux_logits.shape
- if hparams.data_format == 'NHWC':
- shape = shape[1:3]
- else:
- shape = shape[2:4]
- aux_logits = slim.conv2d(aux_logits, 768, shape, padding='VALID')
- aux_logits = slim.batch_norm(aux_logits, scope='aux_bn1')
- aux_logits = tf.nn.relu(aux_logits)
- aux_logits = tf.contrib.layers.flatten(aux_logits)
- aux_logits = slim.fully_connected(aux_logits, num_classes)
- end_points['AuxLogits'] = aux_logits
- def _imagenet_stem(inputs, hparams, stem_cell):
- """Stem used for models trained on ImageNet."""
- num_stem_cells = 2
- # 149 x 149 x 32
- num_stem_filters = int(32 * hparams.stem_multiplier)
- net = slim.conv2d(
- inputs, num_stem_filters, [3, 3], stride=2, scope='conv0',
- padding='VALID')
- net = slim.batch_norm(net, scope='conv0_bn')
- # Run the reduction cells
- cell_outputs = [None, net]
- filter_scaling = 1.0 / (hparams.filter_scaling_rate**num_stem_cells)
- for cell_num in range(num_stem_cells):
- net = stem_cell(
- net,
- scope='cell_stem_{}'.format(cell_num),
- filter_scaling=filter_scaling,
- stride=2,
- prev_layer=cell_outputs[-2],
- cell_num=cell_num)
- cell_outputs.append(net)
- filter_scaling *= hparams.filter_scaling_rate
- return net, cell_outputs
- def _cifar_stem(inputs, hparams):
- """Stem used for models trained on Cifar."""
- num_stem_filters = int(hparams.num_conv_filters * hparams.stem_multiplier)
- net = slim.conv2d(
- inputs,
- num_stem_filters,
- 3,
- scope='l1_stem_3x3')
- net = slim.batch_norm(net, scope='l1_stem_bn')
- return net, [None, net]
- def build_nasnet_cifar(
- images, num_classes, is_training=True):
- """Build NASNet model for the Cifar Dataset."""
- hparams = _cifar_config(is_training=is_training)
- if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
- tf.logging.info('A GPU is available on the machine, consider using NCHW '
- 'data format for increased speed on GPU.')
- if hparams.data_format == 'NCHW':
- images = tf.transpose(images, [0, 3, 1, 2])
- # Calculate the total number of cells in the network
- # Add 2 for the reduction cells
- total_num_cells = hparams.num_cells + 2
- normal_cell = nasnet_utils.NasNetANormalCell(
- hparams.num_conv_filters, hparams.drop_path_keep_prob,
- total_num_cells, hparams.total_training_steps)
- reduction_cell = nasnet_utils.NasNetAReductionCell(
- hparams.num_conv_filters, hparams.drop_path_keep_prob,
- total_num_cells, hparams.total_training_steps)
- with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
- is_training=is_training):
- with arg_scope([slim.avg_pool2d,
- slim.max_pool2d,
- slim.conv2d,
- slim.batch_norm,
- slim.separable_conv2d,
- nasnet_utils.factorized_reduction,
- nasnet_utils.global_avg_pool,
- nasnet_utils.get_channel_index,
- nasnet_utils.get_channel_dim],
- data_format=hparams.data_format):
- return _build_nasnet_base(images,
- normal_cell=normal_cell,
- reduction_cell=reduction_cell,
- num_classes=num_classes,
- hparams=hparams,
- is_training=is_training,
- stem_type='cifar')
- build_nasnet_cifar.default_image_size = 32
- def build_nasnet_mobile(images, num_classes,
- is_training=True,
- final_endpoint=None):
- """Build NASNet Mobile model for the ImageNet Dataset."""
- hparams = _mobile_imagenet_config()
- if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
- tf.logging.info('A GPU is available on the machine, consider using NCHW '
- 'data format for increased speed on GPU.')
- if hparams.data_format == 'NCHW':
- images = tf.transpose(images, [0, 3, 1, 2])
- # Calculate the total number of cells in the network
- # Add 2 for the reduction cells
- total_num_cells = hparams.num_cells + 2
- # If ImageNet, then add an additional two for the stem cells
- total_num_cells += 2
- normal_cell = nasnet_utils.NasNetANormalCell(
- hparams.num_conv_filters, hparams.drop_path_keep_prob,
- total_num_cells, hparams.total_training_steps)
- reduction_cell = nasnet_utils.NasNetAReductionCell(
- hparams.num_conv_filters, hparams.drop_path_keep_prob,
- total_num_cells, hparams.total_training_steps)
- with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
- is_training=is_training):
- with arg_scope([slim.avg_pool2d,
- slim.max_pool2d,
- slim.conv2d,
- slim.batch_norm,
- slim.separable_conv2d,
- nasnet_utils.factorized_reduction,
- nasnet_utils.global_avg_pool,
- nasnet_utils.get_channel_index,
- nasnet_utils.get_channel_dim],
- data_format=hparams.data_format):
- return _build_nasnet_base(images,
- normal_cell=normal_cell,
- reduction_cell=reduction_cell,
- num_classes=num_classes,
- hparams=hparams,
- is_training=is_training,
- stem_type='imagenet',
- final_endpoint=final_endpoint)
- build_nasnet_mobile.default_image_size = 224
- def build_nasnet_large(images, num_classes,
- is_training=True,
- final_endpoint=None):
- """Build NASNet Large model for the ImageNet Dataset."""
- hparams = _large_imagenet_config(is_training=is_training)
- if tf.test.is_gpu_available() and hparams.data_format == 'NHWC':
- tf.logging.info('A GPU is available on the machine, consider using NCHW '
- 'data format for increased speed on GPU.')
- if hparams.data_format == 'NCHW':
- images = tf.transpose(images, [0, 3, 1, 2])
- # Calculate the total number of cells in the network
- # Add 2 for the reduction cells
- total_num_cells = hparams.num_cells + 2
- # If ImageNet, then add an additional two for the stem cells
- total_num_cells += 2
- normal_cell = nasnet_utils.NasNetANormalCell(
- hparams.num_conv_filters, hparams.drop_path_keep_prob,
- total_num_cells, hparams.total_training_steps)
- reduction_cell = nasnet_utils.NasNetAReductionCell(
- hparams.num_conv_filters, hparams.drop_path_keep_prob,
- total_num_cells, hparams.total_training_steps)
- with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
- is_training=is_training):
- with arg_scope([slim.avg_pool2d,
- slim.max_pool2d,
- slim.conv2d,
- slim.batch_norm,
- slim.separable_conv2d,
- nasnet_utils.factorized_reduction,
- nasnet_utils.global_avg_pool,
- nasnet_utils.get_channel_index,
- nasnet_utils.get_channel_dim],
- data_format=hparams.data_format):
- return _build_nasnet_base(images,
- normal_cell=normal_cell,
- reduction_cell=reduction_cell,
- num_classes=num_classes,
- hparams=hparams,
- is_training=is_training,
- stem_type='imagenet',
- final_endpoint=final_endpoint)
- build_nasnet_large.default_image_size = 331
- def _build_nasnet_base(images,
- normal_cell,
- reduction_cell,
- num_classes,
- hparams,
- is_training,
- stem_type,
- final_endpoint=None):
- """Constructs a NASNet image model."""
- end_points = {}
- def add_and_check_endpoint(endpoint_name, net):
- end_points[endpoint_name] = net
- return final_endpoint and (endpoint_name == final_endpoint)
- # Find where to place the reduction cells or stride normal cells
- reduction_indices = nasnet_utils.calc_reduction_layers(
- hparams.num_cells, hparams.num_reduction_layers)
- stem_cell = reduction_cell
- if stem_type == 'imagenet':
- stem = lambda: _imagenet_stem(images, hparams, stem_cell)
- elif stem_type == 'cifar':
- stem = lambda: _cifar_stem(images, hparams)
- else:
- raise ValueError('Unknown stem_type: ', stem_type)
- net, cell_outputs = stem()
- if add_and_check_endpoint('Stem', net): return net, end_points
- # Setup for building in the auxiliary head.
- aux_head_cell_idxes = []
- if len(reduction_indices) >= 2:
- aux_head_cell_idxes.append(reduction_indices[1] - 1)
- # Run the cells
- filter_scaling = 1.0
- # true_cell_num accounts for the stem cells
- true_cell_num = 2 if stem_type == 'imagenet' else 0
- for cell_num in range(hparams.num_cells):
- stride = 1
- if hparams.skip_reduction_layer_input:
- prev_layer = cell_outputs[-2]
- if cell_num in reduction_indices:
- filter_scaling *= hparams.filter_scaling_rate
- net = reduction_cell(
- net,
- scope='reduction_cell_{}'.format(reduction_indices.index(cell_num)),
- filter_scaling=filter_scaling,
- stride=2,
- prev_layer=cell_outputs[-2],
- cell_num=true_cell_num)
- if add_and_check_endpoint(
- 'Reduction_Cell_{}'.format(reduction_indices.index(cell_num)), net):
- return net, end_points
- true_cell_num += 1
- cell_outputs.append(net)
- if not hparams.skip_reduction_layer_input:
- prev_layer = cell_outputs[-2]
- net = normal_cell(
- net,
- scope='cell_{}'.format(cell_num),
- filter_scaling=filter_scaling,
- stride=stride,
- prev_layer=prev_layer,
- cell_num=true_cell_num)
- if add_and_check_endpoint('Cell_{}'.format(cell_num), net):
- return net, end_points
- true_cell_num += 1
- if (hparams.use_aux_head and cell_num in aux_head_cell_idxes and
- num_classes and is_training):
- aux_net = tf.nn.relu(net)
- _build_aux_head(aux_net, end_points, num_classes, hparams,
- scope='aux_{}'.format(cell_num))
- cell_outputs.append(net)
- # Final softmax layer
- with tf.variable_scope('final_layer'):
- net = tf.nn.relu(net)
- net = nasnet_utils.global_avg_pool(net)
- if add_and_check_endpoint('global_pool', net) or num_classes is None:
- return net, end_points
- net = slim.dropout(net, hparams.dense_dropout_keep_prob, scope='dropout')
- logits = slim.fully_connected(net, num_classes)
- if add_and_check_endpoint('Logits', logits):
- return net, end_points
- predictions = tf.nn.softmax(logits, name='predictions')
- if add_and_check_endpoint('Predictions', predictions):
- return net, end_points
- return logits, end_points
|