vgg_test.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for slim.nets.vgg."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from nets import vgg
  21. slim = tf.contrib.slim
  22. class VGGATest(tf.test.TestCase):
  23. def testBuild(self):
  24. batch_size = 5
  25. height, width = 224, 224
  26. num_classes = 1000
  27. with self.test_session():
  28. inputs = tf.random_uniform((batch_size, height, width, 3))
  29. logits, _ = vgg.vgg_a(inputs, num_classes)
  30. self.assertEquals(logits.op.name, 'vgg_a/fc8/squeezed')
  31. self.assertListEqual(logits.get_shape().as_list(),
  32. [batch_size, num_classes])
  33. def testFullyConvolutional(self):
  34. batch_size = 1
  35. height, width = 256, 256
  36. num_classes = 1000
  37. with self.test_session():
  38. inputs = tf.random_uniform((batch_size, height, width, 3))
  39. logits, _ = vgg.vgg_a(inputs, num_classes, spatial_squeeze=False)
  40. self.assertEquals(logits.op.name, 'vgg_a/fc8/BiasAdd')
  41. self.assertListEqual(logits.get_shape().as_list(),
  42. [batch_size, 2, 2, num_classes])
  43. def testGlobalPool(self):
  44. batch_size = 1
  45. height, width = 256, 256
  46. num_classes = 1000
  47. with self.test_session():
  48. inputs = tf.random_uniform((batch_size, height, width, 3))
  49. logits, _ = vgg.vgg_a(inputs, num_classes, spatial_squeeze=False,
  50. global_pool=True)
  51. self.assertEquals(logits.op.name, 'vgg_a/fc8/BiasAdd')
  52. self.assertListEqual(logits.get_shape().as_list(),
  53. [batch_size, 1, 1, num_classes])
  54. def testEndPoints(self):
  55. batch_size = 5
  56. height, width = 224, 224
  57. num_classes = 1000
  58. with self.test_session():
  59. inputs = tf.random_uniform((batch_size, height, width, 3))
  60. _, end_points = vgg.vgg_a(inputs, num_classes)
  61. expected_names = ['vgg_a/conv1/conv1_1',
  62. 'vgg_a/pool1',
  63. 'vgg_a/conv2/conv2_1',
  64. 'vgg_a/pool2',
  65. 'vgg_a/conv3/conv3_1',
  66. 'vgg_a/conv3/conv3_2',
  67. 'vgg_a/pool3',
  68. 'vgg_a/conv4/conv4_1',
  69. 'vgg_a/conv4/conv4_2',
  70. 'vgg_a/pool4',
  71. 'vgg_a/conv5/conv5_1',
  72. 'vgg_a/conv5/conv5_2',
  73. 'vgg_a/pool5',
  74. 'vgg_a/fc6',
  75. 'vgg_a/fc7',
  76. 'vgg_a/fc8'
  77. ]
  78. self.assertSetEqual(set(end_points.keys()), set(expected_names))
  79. def testNoClasses(self):
  80. batch_size = 5
  81. height, width = 224, 224
  82. num_classes = None
  83. with self.test_session():
  84. inputs = tf.random_uniform((batch_size, height, width, 3))
  85. net, end_points = vgg.vgg_a(inputs, num_classes)
  86. expected_names = ['vgg_a/conv1/conv1_1',
  87. 'vgg_a/pool1',
  88. 'vgg_a/conv2/conv2_1',
  89. 'vgg_a/pool2',
  90. 'vgg_a/conv3/conv3_1',
  91. 'vgg_a/conv3/conv3_2',
  92. 'vgg_a/pool3',
  93. 'vgg_a/conv4/conv4_1',
  94. 'vgg_a/conv4/conv4_2',
  95. 'vgg_a/pool4',
  96. 'vgg_a/conv5/conv5_1',
  97. 'vgg_a/conv5/conv5_2',
  98. 'vgg_a/pool5',
  99. 'vgg_a/fc6',
  100. 'vgg_a/fc7',
  101. ]
  102. self.assertSetEqual(set(end_points.keys()), set(expected_names))
  103. self.assertTrue(net.op.name.startswith('vgg_a/fc7'))
  104. def testModelVariables(self):
  105. batch_size = 5
  106. height, width = 224, 224
  107. num_classes = 1000
  108. with self.test_session():
  109. inputs = tf.random_uniform((batch_size, height, width, 3))
  110. vgg.vgg_a(inputs, num_classes)
  111. expected_names = ['vgg_a/conv1/conv1_1/weights',
  112. 'vgg_a/conv1/conv1_1/biases',
  113. 'vgg_a/conv2/conv2_1/weights',
  114. 'vgg_a/conv2/conv2_1/biases',
  115. 'vgg_a/conv3/conv3_1/weights',
  116. 'vgg_a/conv3/conv3_1/biases',
  117. 'vgg_a/conv3/conv3_2/weights',
  118. 'vgg_a/conv3/conv3_2/biases',
  119. 'vgg_a/conv4/conv4_1/weights',
  120. 'vgg_a/conv4/conv4_1/biases',
  121. 'vgg_a/conv4/conv4_2/weights',
  122. 'vgg_a/conv4/conv4_2/biases',
  123. 'vgg_a/conv5/conv5_1/weights',
  124. 'vgg_a/conv5/conv5_1/biases',
  125. 'vgg_a/conv5/conv5_2/weights',
  126. 'vgg_a/conv5/conv5_2/biases',
  127. 'vgg_a/fc6/weights',
  128. 'vgg_a/fc6/biases',
  129. 'vgg_a/fc7/weights',
  130. 'vgg_a/fc7/biases',
  131. 'vgg_a/fc8/weights',
  132. 'vgg_a/fc8/biases',
  133. ]
  134. model_variables = [v.op.name for v in slim.get_model_variables()]
  135. self.assertSetEqual(set(model_variables), set(expected_names))
  136. def testEvaluation(self):
  137. batch_size = 2
  138. height, width = 224, 224
  139. num_classes = 1000
  140. with self.test_session():
  141. eval_inputs = tf.random_uniform((batch_size, height, width, 3))
  142. logits, _ = vgg.vgg_a(eval_inputs, is_training=False)
  143. self.assertListEqual(logits.get_shape().as_list(),
  144. [batch_size, num_classes])
  145. predictions = tf.argmax(logits, 1)
  146. self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
  147. def testTrainEvalWithReuse(self):
  148. train_batch_size = 2
  149. eval_batch_size = 1
  150. train_height, train_width = 224, 224
  151. eval_height, eval_width = 256, 256
  152. num_classes = 1000
  153. with self.test_session():
  154. train_inputs = tf.random_uniform(
  155. (train_batch_size, train_height, train_width, 3))
  156. logits, _ = vgg.vgg_a(train_inputs)
  157. self.assertListEqual(logits.get_shape().as_list(),
  158. [train_batch_size, num_classes])
  159. tf.get_variable_scope().reuse_variables()
  160. eval_inputs = tf.random_uniform(
  161. (eval_batch_size, eval_height, eval_width, 3))
  162. logits, _ = vgg.vgg_a(eval_inputs, is_training=False,
  163. spatial_squeeze=False)
  164. self.assertListEqual(logits.get_shape().as_list(),
  165. [eval_batch_size, 2, 2, num_classes])
  166. logits = tf.reduce_mean(logits, [1, 2])
  167. predictions = tf.argmax(logits, 1)
  168. self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
  169. def testForward(self):
  170. batch_size = 1
  171. height, width = 224, 224
  172. with self.test_session() as sess:
  173. inputs = tf.random_uniform((batch_size, height, width, 3))
  174. logits, _ = vgg.vgg_a(inputs)
  175. sess.run(tf.global_variables_initializer())
  176. output = sess.run(logits)
  177. self.assertTrue(output.any())
  178. class VGG16Test(tf.test.TestCase):
  179. def testBuild(self):
  180. batch_size = 5
  181. height, width = 224, 224
  182. num_classes = 1000
  183. with self.test_session():
  184. inputs = tf.random_uniform((batch_size, height, width, 3))
  185. logits, _ = vgg.vgg_16(inputs, num_classes)
  186. self.assertEquals(logits.op.name, 'vgg_16/fc8/squeezed')
  187. self.assertListEqual(logits.get_shape().as_list(),
  188. [batch_size, num_classes])
  189. def testFullyConvolutional(self):
  190. batch_size = 1
  191. height, width = 256, 256
  192. num_classes = 1000
  193. with self.test_session():
  194. inputs = tf.random_uniform((batch_size, height, width, 3))
  195. logits, _ = vgg.vgg_16(inputs, num_classes, spatial_squeeze=False)
  196. self.assertEquals(logits.op.name, 'vgg_16/fc8/BiasAdd')
  197. self.assertListEqual(logits.get_shape().as_list(),
  198. [batch_size, 2, 2, num_classes])
  199. def testGlobalPool(self):
  200. batch_size = 1
  201. height, width = 256, 256
  202. num_classes = 1000
  203. with self.test_session():
  204. inputs = tf.random_uniform((batch_size, height, width, 3))
  205. logits, _ = vgg.vgg_16(inputs, num_classes, spatial_squeeze=False,
  206. global_pool=True)
  207. self.assertEquals(logits.op.name, 'vgg_16/fc8/BiasAdd')
  208. self.assertListEqual(logits.get_shape().as_list(),
  209. [batch_size, 1, 1, num_classes])
  210. def testEndPoints(self):
  211. batch_size = 5
  212. height, width = 224, 224
  213. num_classes = 1000
  214. with self.test_session():
  215. inputs = tf.random_uniform((batch_size, height, width, 3))
  216. _, end_points = vgg.vgg_16(inputs, num_classes)
  217. expected_names = ['vgg_16/conv1/conv1_1',
  218. 'vgg_16/conv1/conv1_2',
  219. 'vgg_16/pool1',
  220. 'vgg_16/conv2/conv2_1',
  221. 'vgg_16/conv2/conv2_2',
  222. 'vgg_16/pool2',
  223. 'vgg_16/conv3/conv3_1',
  224. 'vgg_16/conv3/conv3_2',
  225. 'vgg_16/conv3/conv3_3',
  226. 'vgg_16/pool3',
  227. 'vgg_16/conv4/conv4_1',
  228. 'vgg_16/conv4/conv4_2',
  229. 'vgg_16/conv4/conv4_3',
  230. 'vgg_16/pool4',
  231. 'vgg_16/conv5/conv5_1',
  232. 'vgg_16/conv5/conv5_2',
  233. 'vgg_16/conv5/conv5_3',
  234. 'vgg_16/pool5',
  235. 'vgg_16/fc6',
  236. 'vgg_16/fc7',
  237. 'vgg_16/fc8'
  238. ]
  239. self.assertSetEqual(set(end_points.keys()), set(expected_names))
  240. def testNoClasses(self):
  241. batch_size = 5
  242. height, width = 224, 224
  243. num_classes = None
  244. with self.test_session():
  245. inputs = tf.random_uniform((batch_size, height, width, 3))
  246. net, end_points = vgg.vgg_16(inputs, num_classes)
  247. expected_names = ['vgg_16/conv1/conv1_1',
  248. 'vgg_16/conv1/conv1_2',
  249. 'vgg_16/pool1',
  250. 'vgg_16/conv2/conv2_1',
  251. 'vgg_16/conv2/conv2_2',
  252. 'vgg_16/pool2',
  253. 'vgg_16/conv3/conv3_1',
  254. 'vgg_16/conv3/conv3_2',
  255. 'vgg_16/conv3/conv3_3',
  256. 'vgg_16/pool3',
  257. 'vgg_16/conv4/conv4_1',
  258. 'vgg_16/conv4/conv4_2',
  259. 'vgg_16/conv4/conv4_3',
  260. 'vgg_16/pool4',
  261. 'vgg_16/conv5/conv5_1',
  262. 'vgg_16/conv5/conv5_2',
  263. 'vgg_16/conv5/conv5_3',
  264. 'vgg_16/pool5',
  265. 'vgg_16/fc6',
  266. 'vgg_16/fc7',
  267. ]
  268. self.assertSetEqual(set(end_points.keys()), set(expected_names))
  269. self.assertTrue(net.op.name.startswith('vgg_16/fc7'))
  270. def testModelVariables(self):
  271. batch_size = 5
  272. height, width = 224, 224
  273. num_classes = 1000
  274. with self.test_session():
  275. inputs = tf.random_uniform((batch_size, height, width, 3))
  276. vgg.vgg_16(inputs, num_classes)
  277. expected_names = ['vgg_16/conv1/conv1_1/weights',
  278. 'vgg_16/conv1/conv1_1/biases',
  279. 'vgg_16/conv1/conv1_2/weights',
  280. 'vgg_16/conv1/conv1_2/biases',
  281. 'vgg_16/conv2/conv2_1/weights',
  282. 'vgg_16/conv2/conv2_1/biases',
  283. 'vgg_16/conv2/conv2_2/weights',
  284. 'vgg_16/conv2/conv2_2/biases',
  285. 'vgg_16/conv3/conv3_1/weights',
  286. 'vgg_16/conv3/conv3_1/biases',
  287. 'vgg_16/conv3/conv3_2/weights',
  288. 'vgg_16/conv3/conv3_2/biases',
  289. 'vgg_16/conv3/conv3_3/weights',
  290. 'vgg_16/conv3/conv3_3/biases',
  291. 'vgg_16/conv4/conv4_1/weights',
  292. 'vgg_16/conv4/conv4_1/biases',
  293. 'vgg_16/conv4/conv4_2/weights',
  294. 'vgg_16/conv4/conv4_2/biases',
  295. 'vgg_16/conv4/conv4_3/weights',
  296. 'vgg_16/conv4/conv4_3/biases',
  297. 'vgg_16/conv5/conv5_1/weights',
  298. 'vgg_16/conv5/conv5_1/biases',
  299. 'vgg_16/conv5/conv5_2/weights',
  300. 'vgg_16/conv5/conv5_2/biases',
  301. 'vgg_16/conv5/conv5_3/weights',
  302. 'vgg_16/conv5/conv5_3/biases',
  303. 'vgg_16/fc6/weights',
  304. 'vgg_16/fc6/biases',
  305. 'vgg_16/fc7/weights',
  306. 'vgg_16/fc7/biases',
  307. 'vgg_16/fc8/weights',
  308. 'vgg_16/fc8/biases',
  309. ]
  310. model_variables = [v.op.name for v in slim.get_model_variables()]
  311. self.assertSetEqual(set(model_variables), set(expected_names))
  312. def testEvaluation(self):
  313. batch_size = 2
  314. height, width = 224, 224
  315. num_classes = 1000
  316. with self.test_session():
  317. eval_inputs = tf.random_uniform((batch_size, height, width, 3))
  318. logits, _ = vgg.vgg_16(eval_inputs, is_training=False)
  319. self.assertListEqual(logits.get_shape().as_list(),
  320. [batch_size, num_classes])
  321. predictions = tf.argmax(logits, 1)
  322. self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
  323. def testTrainEvalWithReuse(self):
  324. train_batch_size = 2
  325. eval_batch_size = 1
  326. train_height, train_width = 224, 224
  327. eval_height, eval_width = 256, 256
  328. num_classes = 1000
  329. with self.test_session():
  330. train_inputs = tf.random_uniform(
  331. (train_batch_size, train_height, train_width, 3))
  332. logits, _ = vgg.vgg_16(train_inputs)
  333. self.assertListEqual(logits.get_shape().as_list(),
  334. [train_batch_size, num_classes])
  335. tf.get_variable_scope().reuse_variables()
  336. eval_inputs = tf.random_uniform(
  337. (eval_batch_size, eval_height, eval_width, 3))
  338. logits, _ = vgg.vgg_16(eval_inputs, is_training=False,
  339. spatial_squeeze=False)
  340. self.assertListEqual(logits.get_shape().as_list(),
  341. [eval_batch_size, 2, 2, num_classes])
  342. logits = tf.reduce_mean(logits, [1, 2])
  343. predictions = tf.argmax(logits, 1)
  344. self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
  345. def testForward(self):
  346. batch_size = 1
  347. height, width = 224, 224
  348. with self.test_session() as sess:
  349. inputs = tf.random_uniform((batch_size, height, width, 3))
  350. logits, _ = vgg.vgg_16(inputs)
  351. sess.run(tf.global_variables_initializer())
  352. output = sess.run(logits)
  353. self.assertTrue(output.any())
  354. class VGG19Test(tf.test.TestCase):
  355. def testBuild(self):
  356. batch_size = 5
  357. height, width = 224, 224
  358. num_classes = 1000
  359. with self.test_session():
  360. inputs = tf.random_uniform((batch_size, height, width, 3))
  361. logits, _ = vgg.vgg_19(inputs, num_classes)
  362. self.assertEquals(logits.op.name, 'vgg_19/fc8/squeezed')
  363. self.assertListEqual(logits.get_shape().as_list(),
  364. [batch_size, num_classes])
  365. def testFullyConvolutional(self):
  366. batch_size = 1
  367. height, width = 256, 256
  368. num_classes = 1000
  369. with self.test_session():
  370. inputs = tf.random_uniform((batch_size, height, width, 3))
  371. logits, _ = vgg.vgg_19(inputs, num_classes, spatial_squeeze=False)
  372. self.assertEquals(logits.op.name, 'vgg_19/fc8/BiasAdd')
  373. self.assertListEqual(logits.get_shape().as_list(),
  374. [batch_size, 2, 2, num_classes])
  375. def testGlobalPool(self):
  376. batch_size = 1
  377. height, width = 256, 256
  378. num_classes = 1000
  379. with self.test_session():
  380. inputs = tf.random_uniform((batch_size, height, width, 3))
  381. logits, _ = vgg.vgg_19(inputs, num_classes, spatial_squeeze=False,
  382. global_pool=True)
  383. self.assertEquals(logits.op.name, 'vgg_19/fc8/BiasAdd')
  384. self.assertListEqual(logits.get_shape().as_list(),
  385. [batch_size, 1, 1, num_classes])
  386. def testEndPoints(self):
  387. batch_size = 5
  388. height, width = 224, 224
  389. num_classes = 1000
  390. with self.test_session():
  391. inputs = tf.random_uniform((batch_size, height, width, 3))
  392. _, end_points = vgg.vgg_19(inputs, num_classes)
  393. expected_names = [
  394. 'vgg_19/conv1/conv1_1',
  395. 'vgg_19/conv1/conv1_2',
  396. 'vgg_19/pool1',
  397. 'vgg_19/conv2/conv2_1',
  398. 'vgg_19/conv2/conv2_2',
  399. 'vgg_19/pool2',
  400. 'vgg_19/conv3/conv3_1',
  401. 'vgg_19/conv3/conv3_2',
  402. 'vgg_19/conv3/conv3_3',
  403. 'vgg_19/conv3/conv3_4',
  404. 'vgg_19/pool3',
  405. 'vgg_19/conv4/conv4_1',
  406. 'vgg_19/conv4/conv4_2',
  407. 'vgg_19/conv4/conv4_3',
  408. 'vgg_19/conv4/conv4_4',
  409. 'vgg_19/pool4',
  410. 'vgg_19/conv5/conv5_1',
  411. 'vgg_19/conv5/conv5_2',
  412. 'vgg_19/conv5/conv5_3',
  413. 'vgg_19/conv5/conv5_4',
  414. 'vgg_19/pool5',
  415. 'vgg_19/fc6',
  416. 'vgg_19/fc7',
  417. 'vgg_19/fc8'
  418. ]
  419. self.assertSetEqual(set(end_points.keys()), set(expected_names))
  420. def testNoClasses(self):
  421. batch_size = 5
  422. height, width = 224, 224
  423. num_classes = None
  424. with self.test_session():
  425. inputs = tf.random_uniform((batch_size, height, width, 3))
  426. net, end_points = vgg.vgg_19(inputs, num_classes)
  427. expected_names = [
  428. 'vgg_19/conv1/conv1_1',
  429. 'vgg_19/conv1/conv1_2',
  430. 'vgg_19/pool1',
  431. 'vgg_19/conv2/conv2_1',
  432. 'vgg_19/conv2/conv2_2',
  433. 'vgg_19/pool2',
  434. 'vgg_19/conv3/conv3_1',
  435. 'vgg_19/conv3/conv3_2',
  436. 'vgg_19/conv3/conv3_3',
  437. 'vgg_19/conv3/conv3_4',
  438. 'vgg_19/pool3',
  439. 'vgg_19/conv4/conv4_1',
  440. 'vgg_19/conv4/conv4_2',
  441. 'vgg_19/conv4/conv4_3',
  442. 'vgg_19/conv4/conv4_4',
  443. 'vgg_19/pool4',
  444. 'vgg_19/conv5/conv5_1',
  445. 'vgg_19/conv5/conv5_2',
  446. 'vgg_19/conv5/conv5_3',
  447. 'vgg_19/conv5/conv5_4',
  448. 'vgg_19/pool5',
  449. 'vgg_19/fc6',
  450. 'vgg_19/fc7',
  451. ]
  452. self.assertSetEqual(set(end_points.keys()), set(expected_names))
  453. self.assertTrue(net.op.name.startswith('vgg_19/fc7'))
  454. def testModelVariables(self):
  455. batch_size = 5
  456. height, width = 224, 224
  457. num_classes = 1000
  458. with self.test_session():
  459. inputs = tf.random_uniform((batch_size, height, width, 3))
  460. vgg.vgg_19(inputs, num_classes)
  461. expected_names = [
  462. 'vgg_19/conv1/conv1_1/weights',
  463. 'vgg_19/conv1/conv1_1/biases',
  464. 'vgg_19/conv1/conv1_2/weights',
  465. 'vgg_19/conv1/conv1_2/biases',
  466. 'vgg_19/conv2/conv2_1/weights',
  467. 'vgg_19/conv2/conv2_1/biases',
  468. 'vgg_19/conv2/conv2_2/weights',
  469. 'vgg_19/conv2/conv2_2/biases',
  470. 'vgg_19/conv3/conv3_1/weights',
  471. 'vgg_19/conv3/conv3_1/biases',
  472. 'vgg_19/conv3/conv3_2/weights',
  473. 'vgg_19/conv3/conv3_2/biases',
  474. 'vgg_19/conv3/conv3_3/weights',
  475. 'vgg_19/conv3/conv3_3/biases',
  476. 'vgg_19/conv3/conv3_4/weights',
  477. 'vgg_19/conv3/conv3_4/biases',
  478. 'vgg_19/conv4/conv4_1/weights',
  479. 'vgg_19/conv4/conv4_1/biases',
  480. 'vgg_19/conv4/conv4_2/weights',
  481. 'vgg_19/conv4/conv4_2/biases',
  482. 'vgg_19/conv4/conv4_3/weights',
  483. 'vgg_19/conv4/conv4_3/biases',
  484. 'vgg_19/conv4/conv4_4/weights',
  485. 'vgg_19/conv4/conv4_4/biases',
  486. 'vgg_19/conv5/conv5_1/weights',
  487. 'vgg_19/conv5/conv5_1/biases',
  488. 'vgg_19/conv5/conv5_2/weights',
  489. 'vgg_19/conv5/conv5_2/biases',
  490. 'vgg_19/conv5/conv5_3/weights',
  491. 'vgg_19/conv5/conv5_3/biases',
  492. 'vgg_19/conv5/conv5_4/weights',
  493. 'vgg_19/conv5/conv5_4/biases',
  494. 'vgg_19/fc6/weights',
  495. 'vgg_19/fc6/biases',
  496. 'vgg_19/fc7/weights',
  497. 'vgg_19/fc7/biases',
  498. 'vgg_19/fc8/weights',
  499. 'vgg_19/fc8/biases',
  500. ]
  501. model_variables = [v.op.name for v in slim.get_model_variables()]
  502. self.assertSetEqual(set(model_variables), set(expected_names))
  503. def testEvaluation(self):
  504. batch_size = 2
  505. height, width = 224, 224
  506. num_classes = 1000
  507. with self.test_session():
  508. eval_inputs = tf.random_uniform((batch_size, height, width, 3))
  509. logits, _ = vgg.vgg_19(eval_inputs, is_training=False)
  510. self.assertListEqual(logits.get_shape().as_list(),
  511. [batch_size, num_classes])
  512. predictions = tf.argmax(logits, 1)
  513. self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
  514. def testTrainEvalWithReuse(self):
  515. train_batch_size = 2
  516. eval_batch_size = 1
  517. train_height, train_width = 224, 224
  518. eval_height, eval_width = 256, 256
  519. num_classes = 1000
  520. with self.test_session():
  521. train_inputs = tf.random_uniform(
  522. (train_batch_size, train_height, train_width, 3))
  523. logits, _ = vgg.vgg_19(train_inputs)
  524. self.assertListEqual(logits.get_shape().as_list(),
  525. [train_batch_size, num_classes])
  526. tf.get_variable_scope().reuse_variables()
  527. eval_inputs = tf.random_uniform(
  528. (eval_batch_size, eval_height, eval_width, 3))
  529. logits, _ = vgg.vgg_19(eval_inputs, is_training=False,
  530. spatial_squeeze=False)
  531. self.assertListEqual(logits.get_shape().as_list(),
  532. [eval_batch_size, 2, 2, num_classes])
  533. logits = tf.reduce_mean(logits, [1, 2])
  534. predictions = tf.argmax(logits, 1)
  535. self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
  536. def testForward(self):
  537. batch_size = 1
  538. height, width = 224, 224
  539. with self.test_session() as sess:
  540. inputs = tf.random_uniform((batch_size, height, width, 3))
  541. logits, _ = vgg.vgg_19(inputs)
  542. sess.run(tf.global_variables_initializer())
  543. output = sess.run(logits)
  544. self.assertTrue(output.any())
  545. if __name__ == '__main__':
  546. tf.test.main()