1 Imagenet验证集介绍

 Imagenet验证集数据大小为6.5G,共有1000类的50000张图片。本文主要是对这1000类的50000张图片的标签信息进行处理分类汇总成一个csv表格,便于实验读入信息需要。Imagenet验证集标签整理的文件和代码链接如下所示:https://download.csdn.net/download/qq_38406029/86030944

2 Imagenet验证集处理

 待处理的文件有两个,一个是imagenet_img_info.txt文件,它包含了50000张图片与100个类别的对应关系。 另一个文件是imagenet_img_info.txt文件,它包含了Imagenet数据集中1000个类别详细信息。 最终输出的用于实验的csv文件如下所示,它详细的概括了图片与其类别的关系。

3 标签分类信息代码

 用imagenet_img_info.txt文件和imagenet_img_info.txt文件最终生成的csv文件的程序如下所示:

import os

import re

import json

import csv

path1 = 'imagenet_label_info.txt'

path2 = 'imagenet_img_info.txt'

f1 = open(path1,'r')

str_all = f1.read()

json_all = json.loads(str_all)

f1.close()

img_dict = {}

img_list = []

f2 = open(path2,'r')

image_list = []

for item in f2.readlines():

ditem = item.strip().split('\t')

img_dict[ditem[0]] = ditem[1]

img_list.append(ditem[0])

f2.close()

data_list = []

for img_name in img_list:

temp_list = []

temp_list.append(int(img_dict[img_name]))

temp_list.append(json_all[img_dict[img_name]][0])

temp_list.append(img_name.strip())

temp_list.append(json_all[img_dict[img_name]][1])

data_list.append(tuple(temp_list))

header = ['class_index', 'class', 'image_name', 'class_name']

with open('selected_imagenet.csv', 'w+', newline='') as file_obj:

writer = csv.writer(file_obj)

writer.writerow(header)

for data in data_list:

writer.writerow(data)

reader = csv.reader(open('selected_imagenet.csv', 'r'))

for row in reader:

print(row)

4 验证集归类文件夹

import csv

import os

import shutil

path_val = 'ILSVRC2012_img_val'

img_list = os.listdir(path_val)

path_csv = 'selected_imagenet.csv'

csvFile = open(path_csv, "r")

reader = csv.reader(csvFile)

path_root = 'img_deal'

for item in reader:

item_path = os.path.join(path_root,item[1])

if reader.line_num == 1:

continue

if not os.path.exists(item_path):

os.makedirs(item_path)

path_img = os.path.join(path_val,item[2])

shutil.copy(path_img, item_path)

5 Imagenet预训练模型分类

import os

import torch

import torchvision.transforms as T

import torch.nn as nn

import torchvision

from torch.utils.data import Dataset

import csv

import numpy as np

import pretrainedmodels

import PIL.Image as Image

import ssl

ssl._create_default_https_context = ssl._create_unverified_context

os.environ['TORCH_HOME']=r'F:\code\Imagenet\model'

class SelectedImagenet(Dataset):

def __init__(self, imagenet_val_dir, selected_images_csv, transform=None):

super(SelectedImagenet, self).__init__()

self.imagenet_val_dir = imagenet_val_dir

self.selected_images_csv = selected_images_csv

self.transform = transform

self._load_csv()

def _load_csv(self):

reader = csv.reader(open(self.selected_images_csv, 'r'))

next(reader)

self.selected_list = list(reader)[0:1000]

def __getitem__(self, item):

target, target_name, image_name, _ = self.selected_list[item]

image = Image.open(os.path.join(self.imagenet_val_dir, image_name))

if image.mode != 'RGB':

image = image.convert('RGB')

if self.transform is not None:

image = self.transform(image)

return image, int(target)

def __len__(self):

return len(self.selected_list)

model_name = 'senet'

if model_name == 'resnet':

model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

elif model_name == 'senet':

model = pretrainedmodels.__dict__['senet154'](num_classes=1000, pretrained='imagenet')

else:

print('No implemation')

batch_size = 4

model.eval()

device = 'cpu'

model.to(device)

if model_name in ['inception']:

input_size = [3, 299, 299]

else:

input_size = [3, 224, 224]

mean = (0.485, 0.456, 0.406)

std = (0.229, 0.224, 0.225)

norm = T.Normalize(tuple(mean), tuple(std))

resize = T.Resize(tuple((input_size[1:])))

trans = T.Compose([

T.Resize((256,256)),

T.CenterCrop((224,224)),

resize,

T.ToTensor(),

norm

])

dataset = SelectedImagenet(imagenet_val_dir='data/imagenet/ILSVRC2012_img_val',

selected_images_csv='data/imagenet/selected_imagenet.csv',

transform=trans

)

ori_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers = 0, pin_memory = False)

correct = 0

for ind, (ori_img , label) in enumerate(ori_loader):

ori_img = ori_img.to(device)

label = label.to(device)

predict = model(ori_img)

predicted = torch.max(predict.data, 1)[1]

correct += (predicted == label).sum()

print(correct)

相关链接

评论可见,请评论后查看内容,谢谢!!!评论后请刷新页面。