码迷,mamicode.com
首页 > 其他好文 > 详细

CA code1.1 image处理

时间:2021-03-29 12:41:51      阅读:0      评论:0      收藏:0      [点我收藏+]

标签:str   sample   处理   closed   sha   rgba   tuple   spatial   ati   

# config

技术图片
 1 # data parameters
 2 dataset_name: paris
 3 data_with_subfolder: False # 是否有子文件夹
 4 train_data_path: F:\\pycharm\\Dataset\\paris\\paris_eval_gt
 5 val_data_path:
 6 resume:
 7 batch_size: 5
 8 image_shape: [256, 256, 3] # resize之后输入network的size
 9 mask_shape: [128, 128]
10 mask_batch_same: True
11 max_delta_shape: [32, 32]
12 margin: [0, 0]
13 discounted_mask: True
14 spatial_discounting_gamma: 0.9
15 random_crop: True
16 mask_type: hole     # hole | mosaic
17 mosaic_unit_size: 12
View Code

# 加载处理数据

技术图片
 1         train_dataset = Dataset(data_path=config[train_data_path],
 2                                 with_subfolder=config[data_with_subfolder],
 3                                 image_shape=config[image_shape],
 4                                 random_crop=config[random_crop])
 5         # val_dataset = Dataset(data_path=config[‘val_data_path‘],
 6         #                       with_subfolder=config[‘data_with_subfolder‘],
 7         #                       image_size=config[‘image_size‘],
 8         #                       random_crop=config[‘random_crop‘])
 9         train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
10                                                    batch_size=config[batch_size],
11                                                    shuffle=True,
12                                                    num_workers=config[num_workers])
13         # val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
14         #                                           batch_size=config[‘batch_size‘],
15         #                                           shuffle=False,
16         #                                           num_workers=config[‘num_workers‘])
View Code

# Dataset

技术图片
 1 class Dataset(data.Dataset):
 2     def __init__(self, data_path, image_shape, with_subfolder=False, random_crop=True, return_name=False):
 3         super(Dataset, self).__init__()
 4         if with_subfolder:  # 检查是否有子目录,如果有加载进去
 5             self.samples = self._find_samples_in_subfolders(data_path)
 6         else:
 7             self.samples = [x for x in listdir(data_path) if is_image_file(x)]
 8         self.data_path = data_path
 9         self.image_shape = image_shape[:-1]
10         self.random_crop = random_crop
11         self.return_name = return_name
12 
13     def __getitem__(self, index):
14         # 加载训练数据集目录 F:\\pycharm\\Dataset\\paris\\paris_eval_gt\001_im.png
15         path = os.path.join(self.data_path, self.samples[index])
16         # 加载图片 img_size [227, 227]
17         img = default_loader(path)
18         # 随机裁剪 resize之后输入network的size,resize为[image_shape[0], image_shape[1]]
19         if self.random_crop:
20             imgw, imgh = img.size
21             if imgh < self.image_shape[0] or imgw < self.image_shape[1]:
22                 img = transforms.Resize(min(self.image_shape))(img)
23             img = transforms.RandomCrop(self.image_shape)(img)
24         else:
25             img = transforms.Resize(self.image_shape)(img)
26             img = transforms.RandomCrop(self.image_shape)(img)
27 
28         # 转变为torch.Size([3, image_shape[0], image_shape[1]])
29         img = transforms.ToTensor()(img)  # turn the image to a tensor
30         img = normalize(img)
31 
32         if self.return_name:
33             return self.samples[index], img
34         else:
35             return img
36 
37     def _find_samples_in_subfolders(self, dir):
38         """
39         Finds the class folders in a dataset.
40         Args:
41             dir (string): Root directory path.
42         Returns:
43             tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
44         Ensures:
45             No class is a subdirectory of another.
46         """
47         if sys.version_info >= (3, 5):
48             # Faster and available in Python 3.5 and above
49             classes = [d.name for d in os.scandir(dir) if d.is_dir()]
50         else:
51             classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
52         classes.sort()
53         class_to_idx = {classes[i]: i for i in range(len(classes))}
54         samples = []
55         for target in sorted(class_to_idx.keys()):
56             d = os.path.join(dir, target)
57             if not os.path.isdir(d):
58                 continue
59             for root, _, fnames in sorted(os.walk(d)):
60                 for fname in sorted(fnames):
61                     if is_image_file(fname):
62                         path = os.path.join(root, fname)
63                         # item = (path, class_to_idx[target])
64                         # samples.append(item)
65                         samples.append(path)
66         return samples
67 
68     def __len__(self):
69         return len(self.samples)
View Code

CA code1.1 image处理

标签:str   sample   处理   closed   sha   rgba   tuple   spatial   ati   

原文地址:https://www.cnblogs.com/Overture/p/14587127.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!