cyclegan.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Defines the CycleGAN generator and discriminator networks."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import numpy as np
  20. from six.moves import xrange
  21. import tensorflow as tf
  22. layers = tf.contrib.layers
  23. def cyclegan_arg_scope(instance_norm_center=True,
  24. instance_norm_scale=True,
  25. instance_norm_epsilon=0.001,
  26. weights_init_stddev=0.02,
  27. weight_decay=0.0):
  28. """Returns a default argument scope for all generators and discriminators.
  29. Args:
  30. instance_norm_center: Whether instance normalization applies centering.
  31. instance_norm_scale: Whether instance normalization applies scaling.
  32. instance_norm_epsilon: Small float added to the variance in the instance
  33. normalization to avoid dividing by zero.
  34. weights_init_stddev: Standard deviation of the random values to initialize
  35. the convolution kernels with.
  36. weight_decay: Magnitude of weight decay applied to all convolution kernel
  37. variables of the generator.
  38. Returns:
  39. An arg-scope.
  40. """
  41. instance_norm_params = {
  42. 'center': instance_norm_center,
  43. 'scale': instance_norm_scale,
  44. 'epsilon': instance_norm_epsilon,
  45. }
  46. weights_regularizer = None
  47. if weight_decay and weight_decay > 0.0:
  48. weights_regularizer = layers.l2_regularizer(weight_decay)
  49. with tf.contrib.framework.arg_scope(
  50. [layers.conv2d],
  51. normalizer_fn=layers.instance_norm,
  52. normalizer_params=instance_norm_params,
  53. weights_initializer=tf.random_normal_initializer(0, weights_init_stddev),
  54. weights_regularizer=weights_regularizer) as sc:
  55. return sc
  56. def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose'):
  57. """Upsamples the given inputs.
  58. Args:
  59. net: A Tensor of size [batch_size, height, width, filters].
  60. num_outputs: The number of output filters.
  61. stride: A list of 2 scalars or a 1x2 Tensor indicating the scale,
  62. relative to the inputs, of the output dimensions. For example, if kernel
  63. size is [2, 3], then the output height and width will be twice and three
  64. times the input size.
  65. method: The upsampling method: 'nn_upsample_conv', 'bilinear_upsample_conv',
  66. or 'conv2d_transpose'.
  67. Returns:
  68. A Tensor which was upsampled using the specified method.
  69. Raises:
  70. ValueError: if `method` is not recognized.
  71. """
  72. with tf.variable_scope('upconv'):
  73. net_shape = tf.shape(net)
  74. height = net_shape[1]
  75. width = net_shape[2]
  76. # Reflection pad by 1 in spatial dimensions (axes 1, 2 = h, w) to make a 3x3
  77. # 'valid' convolution produce an output with the same dimension as the
  78. # input.
  79. spatial_pad_1 = np.array([[0, 0], [1, 1], [1, 1], [0, 0]])
  80. if method == 'nn_upsample_conv':
  81. net = tf.image.resize_nearest_neighbor(
  82. net, [stride[0] * height, stride[1] * width])
  83. net = tf.pad(net, spatial_pad_1, 'REFLECT')
  84. net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid')
  85. if method == 'bilinear_upsample_conv':
  86. net = tf.image.resize_bilinear(
  87. net, [stride[0] * height, stride[1] * width])
  88. net = tf.pad(net, spatial_pad_1, 'REFLECT')
  89. net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid')
  90. elif method == 'conv2d_transpose':
  91. net = layers.conv2d_transpose(
  92. net, num_outputs, kernel_size=[3, 3], stride=stride, padding='same')
  93. else:
  94. raise ValueError('Unknown method: [%s]', method)
  95. return net
  96. def _dynamic_or_static_shape(tensor):
  97. shape = tf.shape(tensor)
  98. static_shape = tf.contrib.util.constant_value(shape)
  99. return static_shape if static_shape is not None else shape
  100. def cyclegan_generator_resnet(images,
  101. arg_scope_fn=cyclegan_arg_scope,
  102. num_resnet_blocks=6,
  103. num_filters=64,
  104. upsample_fn=cyclegan_upsample,
  105. kernel_size=3,
  106. num_outputs=3,
  107. tanh_linear_slope=0.0,
  108. is_training=False):
  109. """Defines the cyclegan resnet network architecture.
  110. As closely as possible following
  111. https://github.com/junyanz/CycleGAN/blob/master/models/architectures.lua#L232
  112. FYI: This network requires input height and width to be divisible by 4 in
  113. order to generate an output with shape equal to input shape. Assertions will
  114. catch this if input dimensions are known at graph construction time, but
  115. there's no protection if unknown at graph construction time (you'll see an
  116. error).
  117. Args:
  118. images: Input image tensor of shape [batch_size, h, w, 3].
  119. arg_scope_fn: Function to create the global arg_scope for the network.
  120. num_resnet_blocks: Number of ResNet blocks in the middle of the generator.
  121. num_filters: Number of filters of the first hidden layer.
  122. upsample_fn: Upsampling function for the decoder part of the generator.
  123. kernel_size: Size w or list/tuple [h, w] of the filter kernels for all inner
  124. layers.
  125. num_outputs: Number of output layers. Defaults to 3 for RGB.
  126. tanh_linear_slope: Slope of the linear function to add to the tanh over the
  127. logits.
  128. is_training: Whether the network is created in training mode or inference
  129. only mode. Not actually needed, just for compliance with other generator
  130. network functions.
  131. Returns:
  132. A `Tensor` representing the model output and a dictionary of model end
  133. points.
  134. Raises:
  135. ValueError: If the input height or width is known at graph construction time
  136. and not a multiple of 4.
  137. """
  138. # Neither dropout nor batch norm -> dont need is_training
  139. del is_training
  140. end_points = {}
  141. input_size = images.shape.as_list()
  142. height, width = input_size[1], input_size[2]
  143. if height and height % 4 != 0:
  144. raise ValueError('The input height must be a multiple of 4.')
  145. if width and width % 4 != 0:
  146. raise ValueError('The input width must be a multiple of 4.')
  147. if not isinstance(kernel_size, (list, tuple)):
  148. kernel_size = [kernel_size, kernel_size]
  149. kernel_height = kernel_size[0]
  150. kernel_width = kernel_size[1]
  151. pad_top = (kernel_height - 1) // 2
  152. pad_bottom = kernel_height // 2
  153. pad_left = (kernel_width - 1) // 2
  154. pad_right = kernel_width // 2
  155. paddings = np.array(
  156. [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]],
  157. dtype=np.int32)
  158. spatial_pad_3 = np.array([[0, 0], [3, 3], [3, 3], [0, 0]])
  159. with tf.contrib.framework.arg_scope(arg_scope_fn()):
  160. ###########
  161. # Encoder #
  162. ###########
  163. with tf.variable_scope('input'):
  164. # 7x7 input stage
  165. net = tf.pad(images, spatial_pad_3, 'REFLECT')
  166. net = layers.conv2d(net, num_filters, kernel_size=[7, 7], padding='VALID')
  167. end_points['encoder_0'] = net
  168. with tf.variable_scope('encoder'):
  169. with tf.contrib.framework.arg_scope(
  170. [layers.conv2d],
  171. kernel_size=kernel_size,
  172. stride=2,
  173. activation_fn=tf.nn.relu,
  174. padding='VALID'):
  175. net = tf.pad(net, paddings, 'REFLECT')
  176. net = layers.conv2d(net, num_filters * 2)
  177. end_points['encoder_1'] = net
  178. net = tf.pad(net, paddings, 'REFLECT')
  179. net = layers.conv2d(net, num_filters * 4)
  180. end_points['encoder_2'] = net
  181. ###################
  182. # Residual Blocks #
  183. ###################
  184. with tf.variable_scope('residual_blocks'):
  185. with tf.contrib.framework.arg_scope(
  186. [layers.conv2d],
  187. kernel_size=kernel_size,
  188. stride=1,
  189. activation_fn=tf.nn.relu,
  190. padding='VALID'):
  191. for block_id in xrange(num_resnet_blocks):
  192. with tf.variable_scope('block_{}'.format(block_id)):
  193. res_net = tf.pad(net, paddings, 'REFLECT')
  194. res_net = layers.conv2d(res_net, num_filters * 4)
  195. res_net = tf.pad(res_net, paddings, 'REFLECT')
  196. res_net = layers.conv2d(res_net, num_filters * 4,
  197. activation_fn=None)
  198. net += res_net
  199. end_points['resnet_block_%d' % block_id] = net
  200. ###########
  201. # Decoder #
  202. ###########
  203. with tf.variable_scope('decoder'):
  204. with tf.contrib.framework.arg_scope(
  205. [layers.conv2d],
  206. kernel_size=kernel_size,
  207. stride=1,
  208. activation_fn=tf.nn.relu):
  209. with tf.variable_scope('decoder1'):
  210. net = upsample_fn(net, num_outputs=num_filters * 2, stride=[2, 2])
  211. end_points['decoder1'] = net
  212. with tf.variable_scope('decoder2'):
  213. net = upsample_fn(net, num_outputs=num_filters, stride=[2, 2])
  214. end_points['decoder2'] = net
  215. with tf.variable_scope('output'):
  216. net = tf.pad(net, spatial_pad_3, 'REFLECT')
  217. logits = layers.conv2d(
  218. net,
  219. num_outputs, [7, 7],
  220. activation_fn=None,
  221. normalizer_fn=None,
  222. padding='valid')
  223. logits = tf.reshape(logits, _dynamic_or_static_shape(images))
  224. end_points['logits'] = logits
  225. end_points['predictions'] = tf.tanh(logits) + logits * tanh_linear_slope
  226. return end_points['predictions'], end_points