Google Open Images Dataset V4 图片数据集详解2-分类快速下载

上节我们介绍了open image v4数据集的结构信息,这节里我们来尝试来真正下载相应的图片,整个数据集很大有561GB,这么大的数据量,对于学习者,传输和存储都是个问题。其实我最常用的方式是下载某些(某个)对象的图片就够了,根据上节我们讲的关系,以对象检测为例,我们可以写一个脚本,单独的获取某些对象图片。这节我们讲述如何快速下载一个乌龟的图像集,我们先在v4的官网上浏览Tortoise,差不多是这样:
Google Open Images Dataset V4 图片数据集详解2-分类快速下载

一、安装tensorflow object detect Api

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#在根目录下创建一个output目录
mkdir /output
cd /output/
 
#下载旧版本的tensorflow model(object api 包含在model里),最新版本的api存在问题(当前2018.4.20)
wget  https://github.com/tensorflow/models/archive/dcfe009a024854207c9067d785c105f5ebf5a01b.zip
unzip dcfe009a024854207c9067d785c105f5ebf5a01b.zip 
mv models-dcfe009a024854207c9067d785c105f5ebf5a01b models
rm dcfe009a024854207c9067d785c105f5ebf5a01b.zip 
 
#安装依赖项
pip install Cython
pip install pillow
pip install lxml
pip install jupyter
pip install matplotlib
pip install opencv-python
pip install pycocotools
 
#安装object detection api 并验证
cd /output/models/research/
protoc object_detection/protos/*.proto --python_out=.
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
python object_detection/builders/model_builder_test.py


下载代码github


二、根据关键字生成tfrecord

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import pandas as pd
import numpy as np
import os
import tensorflow as tf
import io
import logging
import random
import sys
import PIL.Image
import hashlib
from urllib import request
  
sys.path.append("/output/models/research/")
from object_detection.utils import dataset_util
  
  
class open_image_dataset:
  
      
      
  
      
  
    def download_test(self):
        print("start download test info")
        folder="test"
        if os.path.exists(folder,) is False:
            os.makedirs(folder)
        image_csv_path=folder+"/image.csv"
        box_csv_path=folder+"/box.csv"
        classname_csv_path=folder+"/classname.csv"
        if os.path.exists(image_csv_path) is False:
            request.urlretrieve(self.test_image_csv,image_csv_path)
        if os.path.exists(box_csv_path) is False:
            request.urlretrieve(self.test_box_csv,box_csv_path )
        if os.path.exists(classname_csv_path) is False:
            request.urlretrieve(self.classname_csv,classname_csv_path )
        print("download test complete")
    def download_val(self):
        folder="val"
        if os.path.exists(folder,) is False:
            os.makedirs(folder)
        image_csv_path=folder+"/image.csv"
        box_csv_path=folder+"/box.csv"
        classname_csv_path=folder+"/classname.csv"
        if os.path.exists(image_csv_path) is False:
            request.urlretrieve(self.val_image_csv,image_csv_path)
        if os.path.exists(box_csv_path) is False:
            request.urlretrieve(self.val_box_csv,box_csv_path )
        if os.path.exists(classname_csv_path) is False:
            request.urlretrieve(self.classname_csv,classname_csv_path)
        print("download val complete")
      
    def download_train(self):
        folder="train"
        if os.path.exists(folder,) is False:
            os.makedirs(folder)
        image_csv_path=folder+"/image.csv"
        box_csv_path=folder+"/box.csv"
        classname_csv_path=folder+"/classname.csv"
        if os.path.exists(image_csv_path) is False:
            request.urlretrieve(self.train_image_csv,image_csv_path)
        if os.path.exists(box_csv_path) is False:
            request.urlretrieve(self.train_box_csv,box_csv_path )
        if os.path.exists(classname_csv_path) is False:
            request.urlretrieve(self.classname_csv,classname_csv_path )
        print("download train complete")
              
    def create_tfrecord(self,folder,keywords):  
        image_csv_path=folder+"/image.csv"
        box_csv_path=folder+"/box.csv"
        classname_csv_path=folder+"/classname.csv"    
          
        df_image = pd.read_csv(image_csv_path)
        df_box = pd.read_csv(box_csv_path)
        df_classname = pd.read_csv(classname_csv_path,names=['labelID','LabelName'])
  
        data= df_classname[df_classname['LabelName']==keywords]
        data=pd.merge(data, df_box, left_on = 'labelID', right_on = 'LabelName', how='right')
        data=pd.merge(data, df_image, left_on = 'ImageID', right_on = 'ImageID', how='right')
        data=data[data['labelID'].notna() & data['ImageID'].notna()]
          
        folder_path=keywords+"/"+folder+"/"
        if os.path.exists(folder_path) is False:
            os.makedirs(folder_path)
              
        tfrecord_file=folder_path+keywords+".tfrecord"
        writer = tf.python_io.TFRecordWriter(tfrecord_file)
  
        for  index,row in data.iterrows():
            file_name=row['ImageID']+".jpg"
            file_path=folder_path+file_name
            if os.path.exists(file_path) is False:
                request.urlretrieve(row['OriginalURL'],file_path)        
            with tf.gfile.GFile(file_path, 'rb') as fid:
                encoded_jpg = fid.read()
            encoded_jpg_io = io.BytesIO(encoded_jpg)
            image = PIL.Image.open(encoded_jpg_io)
            if image.format != 'JPEG':
                print("file format error "+file_path)
                os.remove(file_path)
                continue
            image.close()  
            key = hashlib.sha256(encoded_jpg).hexdigest()    
  
            xmin = []
            ymin = []
            xmax = []
            ymax = []
            classes = []
            classes_text = []
            width=image.width
            height=image.height
            xmin.append(float(row['XMin']))
            xmax.append(float(row['XMax']))
            ymin.append(float(row['YMin']))
            ymax.append(float(row['YMax']))
            classes.append(int(1))
            classes_text.append(keywords.encode('utf8'))
              
            example = tf.train.Example(features=tf.train.Features(feature={
                'image/height': dataset_util.int64_feature(int(height)),
              'image/width': dataset_util.int64_feature(int(width)),
              'image/filename': dataset_util.bytes_feature(file_name.encode('utf8')),
              'image/source_id': dataset_util.bytes_feature(file_name.encode('utf8')),
              'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
              'image/encoded': dataset_util.bytes_feature(encoded_jpg),
              'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
              'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
              'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
              'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
              'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
              'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
              'image/object/class/label': dataset_util.int64_list_feature(classes),
            }))
            writer.write(example.SerializeToString())
            os.remove(file_path)
            print("file "+file_path)
        writer.close() 
        print("create "+tfrecord_file+" success!")
          
    def create_train_tfrecord(self,keywords):  
         self.download_train()
         self.create_tfrecord("train",keywords)
    def create_val_tfrecord(self,keywords):  
         self.download_val()
         self.create_tfrecord("val",keywords) 
    def create_test_tfrecord(self,keywords):  
         self.download_test()
         self.create_tfrecord("test",keywords)
    def create_all_tfrecord(self,keywords):
        self.create_train_tfrecord(keywords)
        self.create_val_tfrecord(keywords)
          
dataset=open_image_dataset()
dataset.download_test()
dataset.create_tfrecord("test","Tortoise")#下载关键字为"Tortoise"的测试数据集
#dataset.download_val()
#dataset.create_tfrecord("val","Tortoise")#下载关键字为"Tortoise"的验证数据集
#dataset.download_train()
#dataset.create_tfrecord("train","Tortoise")#下载关键字为"Tortoise"的训练数据集
  
# dataset.create_all_tfrecord("Tortoise") #下载所有关键字为"Tortoise"的数据集


三、对生成的tfrecord进行验证

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import tensorflow as tf
import numpy as np
import os
import skimage.io as io
import cv2
tfrecords_filename = "Tortoise/test/Tortoise.tfrecord"
 
filename_queue = tf.train.string_input_producer([tfrecords_filename]) 
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) 
     
features = tf.parse_single_example(serialized_example,
                                   features={
                                        'image/width':tf.FixedLenFeature([], tf.int64),
                                        'image/height': tf.FixedLenFeature([], tf.int64),
                                        'image/filename':  tf.FixedLenFeature([], tf.string),
                                        'image/source_id': tf.FixedLenFeature([], tf.string),
                                        'image/key/sha256':  tf.FixedLenFeature([], tf.string),
                                        'image/encoded': tf.FixedLenFeature([], tf.string),
                                        'image/format':  tf.FixedLenFeature([], tf.string),
                                        'image/object/bbox/xmin': tf.FixedLenFeature([], tf.float32),
                                        'image/object/bbox/xmax': tf.FixedLenFeature([], tf.float32),
                                        'image/object/bbox/ymin':tf.FixedLenFeature([], tf.float32),
                                        'image/object/bbox/ymax':tf.FixedLenFeature([], tf.float32),
                                        'image/object/class/text':tf.FixedLenFeature([], tf.string),
                                        'image/object/class/label': tf.FixedLenFeature([], tf.int64),
                                   })  
 
width= tf.cast(features['image/width'], tf.int32)
height = tf.cast(features['image/height'], tf.int32)
filename = tf.cast(features['image/filename'], tf.string)
format = tf.cast(features['image/format'], tf.string)
xmin = tf.cast(features['image/object/bbox/xmin'], tf.float32)
xmax = tf.cast(features['image/object/bbox/xmax'], tf.float32)
ymin = tf.cast(features['image/object/bbox/ymin'], tf.float32)
ymax = tf.cast(features['image/object/bbox/ymax'], tf.float32)
text = tf.cast(features['image/object/class/text'], tf.string)
label = tf.cast(features['image/object/class/label'], tf.int64)
 
image =tf.image.decode_jpeg(features['image/encoded']);
image = tf.reshape(image,tf.stack([height,width,3]))
 
 
 
 
with tf.Session() as sess: 
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(coord=coord)
    for in range(20):
        width1,height1,filename1,format1,xmin1,xmax1,ymin1,ymax1,text1,label1,image1=sess.run([width,height,filename,format,xmin,xmax,ymin,ymax,text,label,image])
        print(width1,height1,filename1,format1,xmin1,xmax1,ymin1,ymax1,text1,label1)
        x1,y1=int(xmin1*width1),int(ymin1*height1)
        x2,y2=int(xmax1*width1),int(ymax1*height1)
        io.imshow(cv2.rectangle(np.array(image1),(x1,y1),(x2,y2),(0,255,0),3), cmap = 'gray', interpolation = 'bicubic')
        io.show()
         
    coord.request_stop()
    coord.join(threads)

下载代码github

最终的结果如下:

Google Open Images Dataset V4 图片数据集详解2-分类快速下载