123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268 |
- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Defines the CycleGAN generator and discriminator networks."""
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import numpy as np
- from six.moves import xrange
- import tensorflow as tf
- layers = tf.contrib.layers
- def cyclegan_arg_scope(instance_norm_center=True,
- instance_norm_scale=True,
- instance_norm_epsilon=0.001,
- weights_init_stddev=0.02,
- weight_decay=0.0):
- """Returns a default argument scope for all generators and discriminators.
- Args:
- instance_norm_center: Whether instance normalization applies centering.
- instance_norm_scale: Whether instance normalization applies scaling.
- instance_norm_epsilon: Small float added to the variance in the instance
- normalization to avoid dividing by zero.
- weights_init_stddev: Standard deviation of the random values to initialize
- the convolution kernels with.
- weight_decay: Magnitude of weight decay applied to all convolution kernel
- variables of the generator.
- Returns:
- An arg-scope.
- """
- instance_norm_params = {
- 'center': instance_norm_center,
- 'scale': instance_norm_scale,
- 'epsilon': instance_norm_epsilon,
- }
- weights_regularizer = None
- if weight_decay and weight_decay > 0.0:
- weights_regularizer = layers.l2_regularizer(weight_decay)
- with tf.contrib.framework.arg_scope(
- [layers.conv2d],
- normalizer_fn=layers.instance_norm,
- normalizer_params=instance_norm_params,
- weights_initializer=tf.random_normal_initializer(0, weights_init_stddev),
- weights_regularizer=weights_regularizer) as sc:
- return sc
- def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose'):
- """Upsamples the given inputs.
- Args:
- net: A Tensor of size [batch_size, height, width, filters].
- num_outputs: The number of output filters.
- stride: A list of 2 scalars or a 1x2 Tensor indicating the scale,
- relative to the inputs, of the output dimensions. For example, if kernel
- size is [2, 3], then the output height and width will be twice and three
- times the input size.
- method: The upsampling method: 'nn_upsample_conv', 'bilinear_upsample_conv',
- or 'conv2d_transpose'.
- Returns:
- A Tensor which was upsampled using the specified method.
- Raises:
- ValueError: if `method` is not recognized.
- """
- with tf.variable_scope('upconv'):
- net_shape = tf.shape(net)
- height = net_shape[1]
- width = net_shape[2]
- # Reflection pad by 1 in spatial dimensions (axes 1, 2 = h, w) to make a 3x3
- # 'valid' convolution produce an output with the same dimension as the
- # input.
- spatial_pad_1 = np.array([[0, 0], [1, 1], [1, 1], [0, 0]])
- if method == 'nn_upsample_conv':
- net = tf.image.resize_nearest_neighbor(
- net, [stride[0] * height, stride[1] * width])
- net = tf.pad(net, spatial_pad_1, 'REFLECT')
- net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid')
- if method == 'bilinear_upsample_conv':
- net = tf.image.resize_bilinear(
- net, [stride[0] * height, stride[1] * width])
- net = tf.pad(net, spatial_pad_1, 'REFLECT')
- net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid')
- elif method == 'conv2d_transpose':
- net = layers.conv2d_transpose(
- net, num_outputs, kernel_size=[3, 3], stride=stride, padding='same')
- else:
- raise ValueError('Unknown method: [%s]', method)
- return net
- def _dynamic_or_static_shape(tensor):
- shape = tf.shape(tensor)
- static_shape = tf.contrib.util.constant_value(shape)
- return static_shape if static_shape is not None else shape
- def cyclegan_generator_resnet(images,
- arg_scope_fn=cyclegan_arg_scope,
- num_resnet_blocks=6,
- num_filters=64,
- upsample_fn=cyclegan_upsample,
- kernel_size=3,
- num_outputs=3,
- tanh_linear_slope=0.0,
- is_training=False):
- """Defines the cyclegan resnet network architecture.
- As closely as possible following
- https://github.com/junyanz/CycleGAN/blob/master/models/architectures.lua#L232
- FYI: This network requires input height and width to be divisible by 4 in
- order to generate an output with shape equal to input shape. Assertions will
- catch this if input dimensions are known at graph construction time, but
- there's no protection if unknown at graph construction time (you'll see an
- error).
- Args:
- images: Input image tensor of shape [batch_size, h, w, 3].
- arg_scope_fn: Function to create the global arg_scope for the network.
- num_resnet_blocks: Number of ResNet blocks in the middle of the generator.
- num_filters: Number of filters of the first hidden layer.
- upsample_fn: Upsampling function for the decoder part of the generator.
- kernel_size: Size w or list/tuple [h, w] of the filter kernels for all inner
- layers.
- num_outputs: Number of output layers. Defaults to 3 for RGB.
- tanh_linear_slope: Slope of the linear function to add to the tanh over the
- logits.
- is_training: Whether the network is created in training mode or inference
- only mode. Not actually needed, just for compliance with other generator
- network functions.
- Returns:
- A `Tensor` representing the model output and a dictionary of model end
- points.
- Raises:
- ValueError: If the input height or width is known at graph construction time
- and not a multiple of 4.
- """
- # Neither dropout nor batch norm -> dont need is_training
- del is_training
- end_points = {}
- input_size = images.shape.as_list()
- height, width = input_size[1], input_size[2]
- if height and height % 4 != 0:
- raise ValueError('The input height must be a multiple of 4.')
- if width and width % 4 != 0:
- raise ValueError('The input width must be a multiple of 4.')
- if not isinstance(kernel_size, (list, tuple)):
- kernel_size = [kernel_size, kernel_size]
- kernel_height = kernel_size[0]
- kernel_width = kernel_size[1]
- pad_top = (kernel_height - 1) // 2
- pad_bottom = kernel_height // 2
- pad_left = (kernel_width - 1) // 2
- pad_right = kernel_width // 2
- paddings = np.array(
- [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]],
- dtype=np.int32)
- spatial_pad_3 = np.array([[0, 0], [3, 3], [3, 3], [0, 0]])
- with tf.contrib.framework.arg_scope(arg_scope_fn()):
- ###########
- # Encoder #
- ###########
- with tf.variable_scope('input'):
- # 7x7 input stage
- net = tf.pad(images, spatial_pad_3, 'REFLECT')
- net = layers.conv2d(net, num_filters, kernel_size=[7, 7], padding='VALID')
- end_points['encoder_0'] = net
- with tf.variable_scope('encoder'):
- with tf.contrib.framework.arg_scope(
- [layers.conv2d],
- kernel_size=kernel_size,
- stride=2,
- activation_fn=tf.nn.relu,
- padding='VALID'):
- net = tf.pad(net, paddings, 'REFLECT')
- net = layers.conv2d(net, num_filters * 2)
- end_points['encoder_1'] = net
- net = tf.pad(net, paddings, 'REFLECT')
- net = layers.conv2d(net, num_filters * 4)
- end_points['encoder_2'] = net
- ###################
- # Residual Blocks #
- ###################
- with tf.variable_scope('residual_blocks'):
- with tf.contrib.framework.arg_scope(
- [layers.conv2d],
- kernel_size=kernel_size,
- stride=1,
- activation_fn=tf.nn.relu,
- padding='VALID'):
- for block_id in xrange(num_resnet_blocks):
- with tf.variable_scope('block_{}'.format(block_id)):
- res_net = tf.pad(net, paddings, 'REFLECT')
- res_net = layers.conv2d(res_net, num_filters * 4)
- res_net = tf.pad(res_net, paddings, 'REFLECT')
- res_net = layers.conv2d(res_net, num_filters * 4,
- activation_fn=None)
- net += res_net
- end_points['resnet_block_%d' % block_id] = net
- ###########
- # Decoder #
- ###########
- with tf.variable_scope('decoder'):
- with tf.contrib.framework.arg_scope(
- [layers.conv2d],
- kernel_size=kernel_size,
- stride=1,
- activation_fn=tf.nn.relu):
- with tf.variable_scope('decoder1'):
- net = upsample_fn(net, num_outputs=num_filters * 2, stride=[2, 2])
- end_points['decoder1'] = net
- with tf.variable_scope('decoder2'):
- net = upsample_fn(net, num_outputs=num_filters, stride=[2, 2])
- end_points['decoder2'] = net
- with tf.variable_scope('output'):
- net = tf.pad(net, spatial_pad_3, 'REFLECT')
- logits = layers.conv2d(
- net,
- num_outputs, [7, 7],
- activation_fn=None,
- normalizer_fn=None,
- padding='valid')
- logits = tf.reshape(logits, _dynamic_or_static_shape(images))
- end_points['logits'] = logits
- end_points['predictions'] = tf.tanh(logits) + logits * tanh_linear_slope
- return end_points['predictions'], end_points