123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292 |
- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # =============================================================================
- """Implementation of the Image-to-Image Translation model.
- This network represents a port of the following work:
- Image-to-Image Translation with Conditional Adversarial Networks
- Phillip Isola, Jun-Yan Zhu, Tinghui Zhou and Alexei A. Efros
- Arxiv, 2017
- https://phillipi.github.io/pix2pix/
- A reference implementation written in Lua can be found at:
- https://github.com/phillipi/pix2pix/blob/master/models.lua
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import collections
- import functools
- import tensorflow as tf
- layers = tf.contrib.layers
- def pix2pix_arg_scope():
- """Returns a default argument scope for isola_net.
- Returns:
- An arg scope.
- """
- # These parameters come from the online port, which don't necessarily match
- # those in the paper.
- # TODO(nsilberman): confirm these values with Philip.
- instance_norm_params = {
- 'center': True,
- 'scale': True,
- 'epsilon': 0.00001,
- }
- with tf.contrib.framework.arg_scope(
- [layers.conv2d, layers.conv2d_transpose],
- normalizer_fn=layers.instance_norm,
- normalizer_params=instance_norm_params,
- weights_initializer=tf.random_normal_initializer(0, 0.02)) as sc:
- return sc
- def upsample(net, num_outputs, kernel_size, method='nn_upsample_conv'):
- """Upsamples the given inputs.
- Args:
- net: A `Tensor` of size [batch_size, height, width, filters].
- num_outputs: The number of output filters.
- kernel_size: A list of 2 scalars or a 1x2 `Tensor` indicating the scale,
- relative to the inputs, of the output dimensions. For example, if kernel
- size is [2, 3], then the output height and width will be twice and three
- times the input size.
- method: The upsampling method.
- Returns:
- An `Tensor` which was upsampled using the specified method.
- Raises:
- ValueError: if `method` is not recognized.
- """
- net_shape = tf.shape(net)
- height = net_shape[1]
- width = net_shape[2]
- if method == 'nn_upsample_conv':
- net = tf.image.resize_nearest_neighbor(
- net, [kernel_size[0] * height, kernel_size[1] * width])
- net = layers.conv2d(net, num_outputs, [4, 4], activation_fn=None)
- elif method == 'conv2d_transpose':
- net = layers.conv2d_transpose(
- net, num_outputs, [4, 4], stride=kernel_size, activation_fn=None)
- else:
- raise ValueError('Unknown method: [%s]', method)
- return net
- class Block(
- collections.namedtuple('Block', ['num_filters', 'decoder_keep_prob'])):
- """Represents a single block of encoder and decoder processing.
- The Image-to-Image translation paper works a bit differently than the original
- U-Net model. In particular, each block represents a single operation in the
- encoder which is concatenated with the corresponding decoder representation.
- A dropout layer follows the concatenation and convolution of the concatenated
- features.
- """
- pass
- def _default_generator_blocks():
- """Returns the default generator block definitions.
- Returns:
- A list of generator blocks.
- """
- return [
- Block(64, 0.5),
- Block(128, 0.5),
- Block(256, 0.5),
- Block(512, 0),
- Block(512, 0),
- Block(512, 0),
- Block(512, 0),
- ]
- def pix2pix_generator(net,
- num_outputs,
- blocks=None,
- upsample_method='nn_upsample_conv',
- is_training=False): # pylint: disable=unused-argument
- """Defines the network architecture.
- Args:
- net: A `Tensor` of size [batch, height, width, channels]. Note that the
- generator currently requires square inputs (e.g. height=width).
- num_outputs: The number of (per-pixel) outputs.
- blocks: A list of generator blocks or `None` to use the default generator
- definition.
- upsample_method: The method of upsampling images, one of 'nn_upsample_conv'
- or 'conv2d_transpose'
- is_training: Whether or not we're in training or testing mode.
- Returns:
- A `Tensor` representing the model output and a dictionary of model end
- points.
- Raises:
- ValueError: if the input heights do not match their widths.
- """
- end_points = {}
- blocks = blocks or _default_generator_blocks()
- input_size = net.get_shape().as_list()
- height, width = input_size[1], input_size[2]
- if height != width:
- raise ValueError('The input height must match the input width.')
- input_size[3] = num_outputs
- upsample_fn = functools.partial(upsample, method=upsample_method)
- encoder_activations = []
- ###########
- # Encoder #
- ###########
- with tf.variable_scope('encoder'):
- with tf.contrib.framework.arg_scope(
- [layers.conv2d],
- kernel_size=[4, 4],
- stride=2,
- activation_fn=tf.nn.leaky_relu):
- for block_id, block in enumerate(blocks):
- # No normalizer for the first encoder layers as per 'Image-to-Image',
- # Section 5.1.1
- if block_id == 0:
- # First layer doesn't use normalizer_fn
- net = layers.conv2d(net, block.num_filters, normalizer_fn=None)
- elif block_id < len(blocks) - 1:
- net = layers.conv2d(net, block.num_filters)
- else:
- # Last layer doesn't use activation_fn nor normalizer_fn
- net = layers.conv2d(
- net, block.num_filters, activation_fn=None, normalizer_fn=None)
- encoder_activations.append(net)
- end_points['encoder%d' % block_id] = net
- ###########
- # Decoder #
- ###########
- reversed_blocks = list(blocks)
- reversed_blocks.reverse()
- with tf.variable_scope('decoder'):
- # Dropout is used at both train and test time as per 'Image-to-Image',
- # Section 2.1 (last paragraph).
- with tf.contrib.framework.arg_scope([layers.dropout], is_training=True):
- for block_id, block in enumerate(reversed_blocks):
- if block_id > 0:
- net = tf.concat([net, encoder_activations[-block_id - 1]], axis=3)
- # The Relu comes BEFORE the upsample op:
- net = tf.nn.relu(net)
- net = upsample_fn(net, block.num_filters, [2, 2])
- if block.decoder_keep_prob > 0:
- net = layers.dropout(net, keep_prob=block.decoder_keep_prob)
- end_points['decoder%d' % block_id] = net
- with tf.variable_scope('output'):
- logits = layers.conv2d(net, num_outputs, [4, 4], activation_fn=None)
- logits = tf.reshape(logits, input_size)
- end_points['logits'] = logits
- end_points['predictions'] = tf.tanh(logits)
- return logits, end_points
- def pix2pix_discriminator(net, num_filters, padding=2, is_training=False):
- """Creates the Image2Image Translation Discriminator.
- Args:
- net: A `Tensor` of size [batch_size, height, width, channels] representing
- the input.
- num_filters: A list of the filters in the discriminator. The length of the
- list determines the number of layers in the discriminator.
- padding: Amount of reflection padding applied before each convolution.
- is_training: Whether or not the model is training or testing.
- Returns:
- A logits `Tensor` of size [batch_size, N, N, 1] where N is the number of
- 'patches' we're attempting to discriminate and a dictionary of model end
- points.
- """
- del is_training
- end_points = {}
- num_layers = len(num_filters)
- def padded(net, scope):
- if padding:
- with tf.variable_scope(scope):
- spatial_pad = tf.constant(
- [[0, 0], [padding, padding], [padding, padding], [0, 0]],
- dtype=tf.int32)
- return tf.pad(net, spatial_pad, 'REFLECT')
- else:
- return net
- with tf.contrib.framework.arg_scope(
- [layers.conv2d],
- kernel_size=[4, 4],
- stride=2,
- padding='valid',
- activation_fn=tf.nn.leaky_relu):
- # No normalization on the input layer.
- net = layers.conv2d(
- padded(net, 'conv0'), num_filters[0], normalizer_fn=None, scope='conv0')
- end_points['conv0'] = net
- for i in range(1, num_layers - 1):
- net = layers.conv2d(
- padded(net, 'conv%d' % i), num_filters[i], scope='conv%d' % i)
- end_points['conv%d' % i] = net
- # Stride 1 on the last layer.
- net = layers.conv2d(
- padded(net, 'conv%d' % (num_layers - 1)),
- num_filters[-1],
- stride=1,
- scope='conv%d' % (num_layers - 1))
- end_points['conv%d' % (num_layers - 1)] = net
- # 1-dim logits, stride 1, no activation, no normalization.
- logits = layers.conv2d(
- padded(net, 'conv%d' % num_layers),
- 1,
- stride=1,
- activation_fn=None,
- normalizer_fn=None,
- scope='conv%d' % num_layers)
- end_points['logits'] = logits
- end_points['predictions'] = tf.sigmoid(logits)
- return logits, end_points
|