码迷,mamicode.com
首页 > 其他好文 > 详细

数据增强(每10度进行旋转,进行一次增强,然后对每张图片进行扩充10张patch,最后得到原始图片数*37*10数量的图片)

时间:2018-05-23 22:07:11      阅读:368      评论:0      收藏:0      [点我收藏+]

标签:assert   return   +=   join   import   rmi   dal   image   proc   

# -*- coding: utf-8 -*-
"""
Fourmi

This is a temporary script file.
"""
import cv2
import os
import numpy as np
import random
import math



def extract_random(full_imgs,full_masks,patch_h,patch_w,N_patches):
    if(N_patches%(len(full_imgs))!=0):
        print("N_patches: please enter a multiple of 115")
        exit()
    patches=np.empty((N_patches,patch_h,patch_w))
    patches_masks = np.empty((N_patches,patch_h,patch_w))
    img_h=full_imgs[0].shape[0]
    img_w=full_imgs[0].shape[1]
    patch_per_img=int(N_patches/(full_imgs.shape[0]))
    print("patches per full image: "+str(patch_per_img))
    iter_tot=0
    for i in range(full_imgs.shape[0]):
        k=0
        while k<patch_per_img:
            x_center = random.randint(0+int(patch_w/2),img_w-int(patch_w/2))
            y_center = random.randint(0+int(patch_h/2),img_h-int(patch_h/2))
            patch=full_imgs[i][y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)]
            patch_mask=full_masks[i][y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)]
            #print(patch_mask.shape)
            patches[iter_tot]=patch
            patches_masks[iter_tot]=patch_mask
            iter_tot+=1
            k+=1        
    return patches,patches_masks
    
    
def imagePadding(img):
    img_h=img.shape[0]
    img_w=img.shape[1]
    scale=int(math.sqrt(img_h*img_h+img_w*img_w))
    scale=scale*2
    size=(int(scale),int(scale))
    out=cv2.resize(img,size,interpolation=cv2.INTER_AREA)
    return out

def get_data(data_imgs_org,
             data_groundTruth,
             patch_height,
             patch_width,
             N_subimgs):
    imgs_org,imgs_groundTruth=ReadandProcessImage(data_imgs_org,data_groundTruth)
    print(imgs.shape,imgs_org.shape)
    print(imgs_groundTruth,imgs_groundTruth.shape)
    patches_imgs_train,patches_masks_train=extract_random(imgs_org,
                        imgs_groundTruth,patch_height,patch_width,N_subimgs)
    return patches_imgs_train,patches_masks_train


def ReadandProcessImage(orgImgPath,groundTruthPath):
    images=[]
    labels=[]
    for root, dirs, files in os.walk(orgImgPath, topdown=False):
        for file in files:
            temp=file[:-4]
            ImgPath=os.path.join(root,file)
            LabelPath=os.path.join(groundTruthPath,temp+.png)
            myimg=cv2.imread(ImgPath,0)
            mylabel=cv2.imread(LabelPath,0)
            print(ImgPath:,ImgPath)
            print(LabelPath:,LabelPath)
            #img=cv2.cvtColor(myimg,cv2.COLOR_BGR2GRAY)
            #mylabel=cv2.cvtColor(mylabel,cv2.COLOR_BGR2GRAY)
            assert(len(myimg.shape)==len(mylabel.shape))
            assert(myimg.shape[0]==mylabel.shape[0])
            assert(myimg.shape[1]==mylabel.shape[1])
            img=myimg
            #org_h=img.shape[0]
            #org_w=img.shape[1]
            img=cv2.equalizeHist(img)
            img=imagePadding(img)
            mylabel=imagePadding(mylabel) 
            images.append(img)
            labels.append(mylabel)
        return np.array(images),np.array(labels)


def roatate_img_label_to_file(imgPath,labelPath):
    global Iter
    Iter=1
    def rotateImg(img,label,orgHeight,orgWidth,imgPath,labelPath):
        global Iter
        (h,w)=img.shape
        center=(h/2,w/2)
        for i in range(360):
            if (i%10!=0):
                continue
            M = cv2.getRotationMatrix2D(center, i, 1)
            imgRotated = cv2.warpAffine(img, M, (h, w))
            img0=imgRotated[int(center[0])-int(orgHeight/2):int(center[0])+int(orgHeight/2),
                int(center[1])-int(orgWidth/2):int(center[1])+int(orgWidth/2)]
            labelRotated = cv2.warpAffine(label, M, (h, w))
            label0=labelRotated[int(center[0])-int(orgHeight/2):int(center[0])+int(orgHeight/2),
                int(center[1])-int(orgWidth/2):int(center[1])+int(orgWidth/2)]
            path0=os.path.join(imgPath,str(Iter+115)+.jpg)
            cv2.imwrite(path0,img0)
            path=os.path.join(labelPath,str(Iter+115)+.png)
            cv2.imwrite(path,label0)
            Iter=Iter+1
                      
        print("ROTATW DONE!!!!")
    for root,dirs,files in os.walk(imgPath,topdown=False):
        for file in files:
            imgpath=os.path.join(root,file)
            temp=file[:-4]
            labelpath=os.path.join(labelPath,temp+.png)
            img=cv2.imread(imgpath,0)
            label=cv2.imread(labelpath,0)
            print(imgpath:,imgpath)
            print(labelpath:,labelpath)
            print(imgshape:,img.shape)
            print(labelshape:,label.shape)
            assert(len(img.shape)==len(label.shape))
            assert(img.shape[0]==label.shape[0])
            assert(img.shape[1]==label.shape[1])
            org_h=img.shape[0]
            org_w=img.shape[1]
            img=imagePadding(img)
            label=imagePadding(label)
            print(imgPadding:,img.shape)
            print(labelPadding:,label.shape)
            rotateImg(img,label,org_h,org_w,imgPath,labelPath)
         

data_train_imgs_org="/home/chendali1/Gsj/JX/Image/train/"
data_test_imgs_org="/home/chendali1/Gsj/JX/Image/test/"
data_train_grountTruth="/home/chendali1/Gsj/JX/GT/train/"
data_test_grountTruth="/home/chendali1/Gsj/JX/GT/test/"

patches_path_train=/home/chendali1/Gsj/JX/Patches/Org/train/
patches_path_test=/home/chendali1/Gsj/JX/Patches/Org/test/
patches_path_label_train=/home/chendali1/Gsj/JX/Patches/Label/train/
patches_path_label_test=/home/chendali1/Gsj/JX/Patches/Label/test/

#rotate_train_imgs_path="/home/chendali1/Gsj/JX/Image/train/"
#rotate_test_imgs_path="/home/chendali1/Gsj/JX/Image/test/"
#rotate_train_label_path="/home/chendali1/Gsj/JX/GT/train/"
#rotate_test_label_path="/home/chendali1/Gsj/JX/GT/test/"

if not os.path.exists(patches_path_train):
    os.makedirs(patches_path_train)
if not os.path.exists(patches_path_test):
    os.makedirs(patches_path_test)
if not os.path.exists(patches_path_label_train):
    os.makedirs(patches_path_label_train)
if not os.path.exists(patches_path_label_test):
    os.makedirs(patches_path_label_test)
roatate_img_label_to_file(data_train_imgs_org,data_train_grountTruth)
train_patches,train_groundTruth=get_data(data_train_imgs_org,data_train_grountTruth,224,224,37*115*10)
for i in range(train_patches.shape[0]):
    b=np.zeros([train_patches.shape[1],train_patches.shape[2],3])
    b[:,:,0]=train_patches[i,:,:]
    b[:,:,1]=train_patches[i,:,:]
    b[:,:,2]=train_patches[i,:,:]
    cv2.imwrite(patches_path_train+str(i)+.jpg,train_patches[i,:,:])
    cv2.imwrite(patches_path_label_train+str(i)+.png,train_groundTruth[i,:,:])

 

数据增强(每10度进行旋转,进行一次增强,然后对每张图片进行扩充10张patch,最后得到原始图片数*37*10数量的图片)

标签:assert   return   +=   join   import   rmi   dal   image   proc   

原文地址:https://www.cnblogs.com/fourmi/p/9079368.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!