pix2pix.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # =============================================================================
  15. """Implementation of the Image-to-Image Translation model.
  16. This network represents a port of the following work:
  17. Image-to-Image Translation with Conditional Adversarial Networks
  18. Phillip Isola, Jun-Yan Zhu, Tinghui Zhou and Alexei A. Efros
  19. Arxiv, 2017
  20. https://phillipi.github.io/pix2pix/
  21. A reference implementation written in Lua can be found at:
  22. https://github.com/phillipi/pix2pix/blob/master/models.lua
  23. """
  24. from __future__ import absolute_import
  25. from __future__ import division
  26. from __future__ import print_function
  27. import collections
  28. import functools
  29. import tensorflow as tf
  30. layers = tf.contrib.layers
  31. def pix2pix_arg_scope():
  32. """Returns a default argument scope for isola_net.
  33. Returns:
  34. An arg scope.
  35. """
  36. # These parameters come from the online port, which don't necessarily match
  37. # those in the paper.
  38. # TODO(nsilberman): confirm these values with Philip.
  39. instance_norm_params = {
  40. 'center': True,
  41. 'scale': True,
  42. 'epsilon': 0.00001,
  43. }
  44. with tf.contrib.framework.arg_scope(
  45. [layers.conv2d, layers.conv2d_transpose],
  46. normalizer_fn=layers.instance_norm,
  47. normalizer_params=instance_norm_params,
  48. weights_initializer=tf.random_normal_initializer(0, 0.02)) as sc:
  49. return sc
  50. def upsample(net, num_outputs, kernel_size, method='nn_upsample_conv'):
  51. """Upsamples the given inputs.
  52. Args:
  53. net: A `Tensor` of size [batch_size, height, width, filters].
  54. num_outputs: The number of output filters.
  55. kernel_size: A list of 2 scalars or a 1x2 `Tensor` indicating the scale,
  56. relative to the inputs, of the output dimensions. For example, if kernel
  57. size is [2, 3], then the output height and width will be twice and three
  58. times the input size.
  59. method: The upsampling method.
  60. Returns:
  61. An `Tensor` which was upsampled using the specified method.
  62. Raises:
  63. ValueError: if `method` is not recognized.
  64. """
  65. net_shape = tf.shape(net)
  66. height = net_shape[1]
  67. width = net_shape[2]
  68. if method == 'nn_upsample_conv':
  69. net = tf.image.resize_nearest_neighbor(
  70. net, [kernel_size[0] * height, kernel_size[1] * width])
  71. net = layers.conv2d(net, num_outputs, [4, 4], activation_fn=None)
  72. elif method == 'conv2d_transpose':
  73. net = layers.conv2d_transpose(
  74. net, num_outputs, [4, 4], stride=kernel_size, activation_fn=None)
  75. else:
  76. raise ValueError('Unknown method: [%s]', method)
  77. return net
  78. class Block(
  79. collections.namedtuple('Block', ['num_filters', 'decoder_keep_prob'])):
  80. """Represents a single block of encoder and decoder processing.
  81. The Image-to-Image translation paper works a bit differently than the original
  82. U-Net model. In particular, each block represents a single operation in the
  83. encoder which is concatenated with the corresponding decoder representation.
  84. A dropout layer follows the concatenation and convolution of the concatenated
  85. features.
  86. """
  87. pass
  88. def _default_generator_blocks():
  89. """Returns the default generator block definitions.
  90. Returns:
  91. A list of generator blocks.
  92. """
  93. return [
  94. Block(64, 0.5),
  95. Block(128, 0.5),
  96. Block(256, 0.5),
  97. Block(512, 0),
  98. Block(512, 0),
  99. Block(512, 0),
  100. Block(512, 0),
  101. ]
  102. def pix2pix_generator(net,
  103. num_outputs,
  104. blocks=None,
  105. upsample_method='nn_upsample_conv',
  106. is_training=False): # pylint: disable=unused-argument
  107. """Defines the network architecture.
  108. Args:
  109. net: A `Tensor` of size [batch, height, width, channels]. Note that the
  110. generator currently requires square inputs (e.g. height=width).
  111. num_outputs: The number of (per-pixel) outputs.
  112. blocks: A list of generator blocks or `None` to use the default generator
  113. definition.
  114. upsample_method: The method of upsampling images, one of 'nn_upsample_conv'
  115. or 'conv2d_transpose'
  116. is_training: Whether or not we're in training or testing mode.
  117. Returns:
  118. A `Tensor` representing the model output and a dictionary of model end
  119. points.
  120. Raises:
  121. ValueError: if the input heights do not match their widths.
  122. """
  123. end_points = {}
  124. blocks = blocks or _default_generator_blocks()
  125. input_size = net.get_shape().as_list()
  126. height, width = input_size[1], input_size[2]
  127. if height != width:
  128. raise ValueError('The input height must match the input width.')
  129. input_size[3] = num_outputs
  130. upsample_fn = functools.partial(upsample, method=upsample_method)
  131. encoder_activations = []
  132. ###########
  133. # Encoder #
  134. ###########
  135. with tf.variable_scope('encoder'):
  136. with tf.contrib.framework.arg_scope(
  137. [layers.conv2d],
  138. kernel_size=[4, 4],
  139. stride=2,
  140. activation_fn=tf.nn.leaky_relu):
  141. for block_id, block in enumerate(blocks):
  142. # No normalizer for the first encoder layers as per 'Image-to-Image',
  143. # Section 5.1.1
  144. if block_id == 0:
  145. # First layer doesn't use normalizer_fn
  146. net = layers.conv2d(net, block.num_filters, normalizer_fn=None)
  147. elif block_id < len(blocks) - 1:
  148. net = layers.conv2d(net, block.num_filters)
  149. else:
  150. # Last layer doesn't use activation_fn nor normalizer_fn
  151. net = layers.conv2d(
  152. net, block.num_filters, activation_fn=None, normalizer_fn=None)
  153. encoder_activations.append(net)
  154. end_points['encoder%d' % block_id] = net
  155. ###########
  156. # Decoder #
  157. ###########
  158. reversed_blocks = list(blocks)
  159. reversed_blocks.reverse()
  160. with tf.variable_scope('decoder'):
  161. # Dropout is used at both train and test time as per 'Image-to-Image',
  162. # Section 2.1 (last paragraph).
  163. with tf.contrib.framework.arg_scope([layers.dropout], is_training=True):
  164. for block_id, block in enumerate(reversed_blocks):
  165. if block_id > 0:
  166. net = tf.concat([net, encoder_activations[-block_id - 1]], axis=3)
  167. # The Relu comes BEFORE the upsample op:
  168. net = tf.nn.relu(net)
  169. net = upsample_fn(net, block.num_filters, [2, 2])
  170. if block.decoder_keep_prob > 0:
  171. net = layers.dropout(net, keep_prob=block.decoder_keep_prob)
  172. end_points['decoder%d' % block_id] = net
  173. with tf.variable_scope('output'):
  174. logits = layers.conv2d(net, num_outputs, [4, 4], activation_fn=None)
  175. logits = tf.reshape(logits, input_size)
  176. end_points['logits'] = logits
  177. end_points['predictions'] = tf.tanh(logits)
  178. return logits, end_points
  179. def pix2pix_discriminator(net, num_filters, padding=2, is_training=False):
  180. """Creates the Image2Image Translation Discriminator.
  181. Args:
  182. net: A `Tensor` of size [batch_size, height, width, channels] representing
  183. the input.
  184. num_filters: A list of the filters in the discriminator. The length of the
  185. list determines the number of layers in the discriminator.
  186. padding: Amount of reflection padding applied before each convolution.
  187. is_training: Whether or not the model is training or testing.
  188. Returns:
  189. A logits `Tensor` of size [batch_size, N, N, 1] where N is the number of
  190. 'patches' we're attempting to discriminate and a dictionary of model end
  191. points.
  192. """
  193. del is_training
  194. end_points = {}
  195. num_layers = len(num_filters)
  196. def padded(net, scope):
  197. if padding:
  198. with tf.variable_scope(scope):
  199. spatial_pad = tf.constant(
  200. [[0, 0], [padding, padding], [padding, padding], [0, 0]],
  201. dtype=tf.int32)
  202. return tf.pad(net, spatial_pad, 'REFLECT')
  203. else:
  204. return net
  205. with tf.contrib.framework.arg_scope(
  206. [layers.conv2d],
  207. kernel_size=[4, 4],
  208. stride=2,
  209. padding='valid',
  210. activation_fn=tf.nn.leaky_relu):
  211. # No normalization on the input layer.
  212. net = layers.conv2d(
  213. padded(net, 'conv0'), num_filters[0], normalizer_fn=None, scope='conv0')
  214. end_points['conv0'] = net
  215. for i in range(1, num_layers - 1):
  216. net = layers.conv2d(
  217. padded(net, 'conv%d' % i), num_filters[i], scope='conv%d' % i)
  218. end_points['conv%d' % i] = net
  219. # Stride 1 on the last layer.
  220. net = layers.conv2d(
  221. padded(net, 'conv%d' % (num_layers - 1)),
  222. num_filters[-1],
  223. stride=1,
  224. scope='conv%d' % (num_layers - 1))
  225. end_points['conv%d' % (num_layers - 1)] = net
  226. # 1-dim logits, stride 1, no activation, no normalization.
  227. logits = layers.conv2d(
  228. padded(net, 'conv%d' % num_layers),
  229. 1,
  230. stride=1,
  231. activation_fn=None,
  232. normalizer_fn=None,
  233. scope='conv%d' % num_layers)
  234. end_points['logits'] = logits
  235. end_points['predictions'] = tf.sigmoid(logits)
  236. return logits, end_points