代码如下:
import cv2 import os import numpy as np def main():
# 加载图片a,b,label img_a = cv2.imread("a.jpg") img_b = cv2.imread("b.jpg") img_label = cv2.imread("label.jpg")
# 设置裁剪大小 crop_size = 256
# 输入图片大小是20000*20000的,裁剪大小是256*256的,步长是128,计算裁剪个数 num_crop = (20000 - 256) // 128 + 1
# 计算每次裁剪的步长 step = 128
# 循环裁剪 crops_a = [] crops_b = [] crops_label = [] for i in range(num_crop): x = i * step y = i * step crop_a = img_a[x:x + crop_size, y:y + crop_size, :] if crop_a.shape != (256, 256, 3): continue else: crops_a.append(crop_a)
crop_b = img_b[x:x + crop_size, y:y + crop_size, :] if crop_b.shape != (256, 256, 3): continue else: crops_b.append(crop_b)
crop_label = img_label[x:x + crop_size, y:y + crop_size, :] if crop_label .shape != (256, 256, 3): continue else: crops_label.append(crop_label)
# 对crops_a, crops_b, crops_label列表里面的图片进行数据增强来扩充数据集
# 水平翻转 crops_a.extend([cv2.flip(img, 1) for img in crops_a]) crops_b.extend([cv2.flip(img, 1) for img in crops_b]) crops_label.extend([cv2.flip(img, 1) for img in crops_label])
# 垂直翻转 crops_a.extend([cv2.flip(img, 0) for img in crops_a]) crops_b.extend([cv2.flip(img, 0) for img in crops_b]) crops_label.extend([cv2.flip(img, 0) for img in crops_label])
# 水平垂直翻转 crops_a.extend([cv2.flip(img, -1) for img in crops_a]) crops_b.extend([cv2.flip(img, -1) for img in crops_b]) crops_label.extend([cv2.flip(img, -1) for img in crops_label])
# 旋转90度 crops_a.extend([cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) for img in crops_a]) crops_b.extend([cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) for img in crops_b]) crops_label.extend([cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) for img in crops_label])
# 旋转180度 crops_a.extend([cv2.rotate(img, cv2.ROTATE_180) for img in crops_a]) crops_b.extend([cv2.rotate(img, cv2.ROTATE_180) for img in crops_b]) crops_label.extend([cv2.rotate(img, cv2.ROTATE_180) for img in crops_label])
# 旋转270度 crops_a.extend([cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) for img in crops_a]) crops_b.extend([cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) for img in crops_b]) crops_label.extend([cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) for img in crops_label])
# 调整亮度 crops_a.extend([cv2.addWeighted(img, 1.5, img, 0, 0) for img in crops_a]) crops_b.extend([cv2.addWeighted(img, 1.5, img, 0, 0) for img in crops_b]) crops_label.extend([cv2.addWeighted(img, 1.5, img, 0, 0) for img in crops_label])
# 调整对比度 crops_a.extend([cv2.addWeighted(img, 1, img, 0, -50) for img in crops_a]) crops_b.extend([cv2.addWeighted(img, 1, img, 0, -50) for img in crops_b]) crops_label.extend([cv2.addWeighted(img, 1, img, 0, -50) for img in crops_label])
# 调整饱和度 crops_a.extend([cv2.cvtColor(img, cv2.COLOR_BGR2HSV) for img in crops_a]) crops_b.extend([cv2.cvtColor(img, cv2.COLOR_BGR2HSV) for img in crops_b]) crops_label.extend([cv2.cvtColor(img, cv2.COLOR_BGR2HSV) for img in crops_label])
# 调整色相 crops_a.extend([cv2.cvtColor(img, cv2.COLOR_HSV2BGR) for img in crops_a]) crops_b.extend([cv2.cvtColor(img, cv2.COLOR_HSV2BGR) for img in crops_b]) crops_label.extend([cv2.cvtColor(img, cv2.COLOR_HSV2BGR) for img in crops_label])
# 分别统计三个列表的元素个数,并赋值给变量count_a, count_b, count_label count_a = len(crops_a) count_b = len(crops_b) count_label = len(crops_label)
# 创建train, val, test文件夹 for dirname in ["train", "val", "test"]: if not os.path.exists(dirname): os.makedirs(dirname)
# 创建train文件夹下的A,B,Label目录 for dirname in ["train/A", "train/B", "train/Label"]: if not os.path.exists(dirname): os.makedirs(dirname)
# 创建val文件夹下的A,B,Label目录 for dirname in ["val/A", "val/B", "val/Label"]: if not os.path.exists(dirname): os.makedirs(dirname)
# 创建test文件夹下的A,B,Label目录 for dirname in ["test/A", "test/B", "test/Label"]: if not os.path.exists(dirname): os.makedirs(dirname)
# 保存图片 def crop_and_save_images(img_list, prefix): for idx, img in enumerate(img_list): cv2.imwrite(f"{prefix}/{str(idx).zfill(4)}.jpg", img)
crop_and_save_images(crops_a[:count_a//10*8], "train/A") crop_and_save_images(crops_b[:count_b//10*8], "train/B") crop_and_save_images(crops_label[count_label//10*8], "train/Label") crop_and_save_images(crops_a[count_a//10*8:count_a//10*9], "val/A") crop_and_save_images(crops_b[count_b//10*8:count_b//10*9], "val/B") crop_and_save_images(crops_label[count_label//10*8:count_label//10*9], "val/Label") crop_and_save_images(crops_a[count_a//10*9:], "test/A") crop_and_save_images(crops_b[count_b//10*9:], "test/B") crop_and_save_images(crops_label[count_label//10*9:], "test/Label")
if __name__ == "__main__": main()
精彩文章
发表评论