compute_features.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. '''
  2. File that computes features for a set of images
  3. ex. python compute_features.py --data_dir=/mnt/images/ --model=vgg19 --model_path=./vgg_19.ckpt
  4. '''
  5. import scipy.misc as misc
  6. # import cPickle as pickle
  7. import _pickle as cPickle
  8. import tensorflow as tf
  9. from tqdm import tqdm
  10. import numpy as np
  11. import argparse
  12. import fnmatch
  13. import sys
  14. import os
  15. sys.path.insert(0, 'nets/')
  16. slim = tf.contrib.slim
  17. '''
  18. Recursively obtains all images in the directory specified
  19. '''
  20. def getPaths(data_dir):
  21. image_paths = []
  22. # add more extensions if need be
  23. ps = ['jpg', 'jpeg', 'JPG', 'JPEG', 'bmp', 'BMP', 'png', 'PNG']
  24. for p in ps:
  25. pattern = '*.'+p
  26. for d, s, fList in os.walk(data_dir):
  27. for filename in fList:
  28. if fnmatch.fnmatch(filename, pattern):
  29. fname_ = os.path.join(d,filename)
  30. image_paths.append(fname_)
  31. return image_paths
  32. if __name__ == '__main__':
  33. parser = argparse.ArgumentParser()
  34. parser.add_argument('--data_dir', required=True, type=str, help='Directory images are in. Searches recursively.')
  35. parser.add_argument('--model', required=True, type=str, help='Model to use')
  36. parser.add_argument('--checkpoint_file', required=True, type=str, help='Model file')
  37. a = parser.parse_args()
  38. data_dir = a.data_dir
  39. model = a.model
  40. checkpoint_file = a.checkpoint_file
  41. print( data_dir, model, checkpoint_file )
  42. # I only have these because I thought some take in size of (299,299), but maybe not
  43. if 'inception' in model: height, width, channels = 224, 224, 3
  44. if 'resnet' in model: height, width, channels = 224, 224, 3
  45. if 'vgg' in model: height, width, channels = 224, 224, 3
  46. if model == 'inception_resnet_v2': height, width, channels = 299, 299, 3
  47. x = tf.placeholder(tf.float32, shape=(1, height, width, channels))
  48. # load up model specific stuff
  49. if model == 'inception_v1':
  50. from inception_v1 import *
  51. arg_scope = inception_v1_arg_scope()
  52. with slim.arg_scope(arg_scope):
  53. logits, end_points = inception_v1(x, is_training=False, num_classes=1001)
  54. features = end_points['AvgPool_0a_7x7']
  55. elif model == 'inception_v2':
  56. from inception_v2 import *
  57. arg_scope = inception_v2_arg_scope()
  58. with slim.arg_scope(arg_scope):
  59. logits, end_points = inception_v2(x, is_training=False, num_classes=1001)
  60. features = end_points['AvgPool_1a']
  61. elif model == 'inception_v3':
  62. from inception_v3 import *
  63. arg_scope = inception_v3_arg_scope()
  64. with slim.arg_scope(arg_scope):
  65. logits, end_points = inception_v3(x, is_training=False, num_classes=1001)
  66. features = end_points['AvgPool_1a']
  67. elif model == 'inception_resnet_v2':
  68. from inception_resnet_v2 import *
  69. arg_scope = inception_resnet_v2_arg_scope()
  70. with slim.arg_scope(arg_scope):
  71. logits, end_points = inception_resnet_v2(x, is_training=False, num_classes=1001)
  72. features = end_points['PreLogitsFlatten']
  73. elif model == 'resnet_v1_50':
  74. from resnet_v1 import *
  75. arg_scope = resnet_arg_scope()
  76. with slim.arg_scope(arg_scope):
  77. logits, end_points = resnet_v1_50(x, is_training=False, num_classes=1000)
  78. features = end_points['global_pool']
  79. elif model == 'resnet_v1_101':
  80. from resnet_v1 import *
  81. arg_scope = resnet_arg_scope()
  82. with slim.arg_scope(arg_scope):
  83. logits, end_points = resnet_v1_101(x, is_training=False, num_classes=1000)
  84. features = end_points['global_pool']
  85. elif model == 'vgg_16':
  86. from vgg import *
  87. arg_scope = vgg_arg_scope()
  88. with slim.arg_scope(arg_scope):
  89. logits, end_points = vgg_16(x, is_training=False)
  90. features = end_points['vgg_16/fc8']
  91. elif model == 'vgg_19':
  92. from vgg import *
  93. arg_scope = vgg_arg_scope()
  94. with slim.arg_scope(arg_scope):
  95. logits, end_points = vgg_19(x, is_training=False)
  96. features = end_points['vgg_19/fc8']
  97. print('init features...')
  98. sess = tf.Session()
  99. saver = tf.train.Saver()
  100. saver.restore(sess, checkpoint_file)
  101. feat_dict = {}
  102. paths = getPaths(data_dir)
  103. print('Computing features...')
  104. for path in tqdm(paths):
  105. image = misc.imread(path)
  106. image = misc.imresize(image, (height, width))
  107. image = np.expand_dims(image, 0)
  108. feat = np.squeeze(sess.run(features, feed_dict={x:image}))
  109. feat_dict[path] = feat
  110. try: os.makedirs('features/')
  111. except: pass
  112. exp_pkl = open('features/'+model+'_features.pkl', 'wb')
  113. data = pickle.dumps(feat_dict)
  114. exp_pkl.write(data)
  115. exp_pkl.close()