码迷,mamicode.com
首页 > 其他好文 > 详细

【pytorch-ssd目标检测】可视化检测结果

时间:2020-03-23 16:45:56      阅读:368      评论:0      收藏:0      [点我收藏+]

标签:ade   rds   ase   测试   path   mod   var   使用   blog   

制作类似pascal voc格式的目标检测数据集:https://www.cnblogs.com/xiximayou/p/12546061.html

训练自己创建的数据集:https://www.cnblogs.com/xiximayou/p/12546556.html

验证自己创建的数据集:https://www.cnblogs.com/xiximayou/p/12550471.html

测试自己创建的数据集:https://www.cnblogs.com/xiximayou/p/12550566.html

 

还是以在谷歌colab上为例:

cd /content/drive/My Drive/pytorch_ssd

导入相应的包:

import os
import sys
module_path = os.path.abspath(os.path.join(..))
if module_path not in sys.path:
    sys.path.append(module_path)

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import numpy as np
import cv2
if torch.cuda.is_available():
    torch.set_default_tensor_type(torch.cuda.FloatTensor)

from ssd import build_ssd

加载谷歌网盘:

from google.colab import drive
drive.mount(/content/drive)

加载模型:

net = build_ssd(test, 300, 3)    # initialize SSD
net.load_weights(weights/ssd300_MASK_5000.pth)

可视化要检测的图像:

# image = cv2.imread(‘./data/example.jpg‘, cv2.IMREAD_COLOR)  # uncomment if dataset not downloaded
%matplotlib inline
from matplotlib import pyplot as plt
from data import MASKDetection, MASK_ROOT, MASKAnnotationTransform
# here we specify year (07 or 12) and dataset (‘test‘, ‘val‘, ‘train‘) 
mask_root="/content/drive/My Drive/pytorch_ssd"
testset = MASKDetection(mask_root, "val", None, MASKAnnotationTransform())
img_id = 2
image = testset.pull_image(img_id)
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# View the sampled input image before transform
plt.figure(figsize=(10,10))
plt.imshow(rgb_image)
plt.show()

技术图片

调整图片的格式:

x = cv2.resize(image, (300, 300)).astype(np.float32)
x -= (104.0, 117.0, 123.0)
x = x.astype(np.float32)
x = x[:, :, ::-1].copy()
plt.imshow(x)
x = torch.from_numpy(x).permute(2, 0, 1)

技术图片

使用模型进行预测:

xx = Variable(x.unsqueeze(0))     # wrap tensor in Variable
if torch.cuda.is_available():
    xx = xx.cuda()
y = net(xx)

输出结果:

from data import MASK_CLASSES as labels
top_k=3

plt.figure(figsize=(10,10))
colors = plt.cm.hsv(np.linspace(0, 1, 3)).tolist()
plt.imshow(rgb_image)  # plot the image for matplotlib
currentAxis = plt.gca()

detections = y.data
# scale each detection back up to the image
scale = torch.Tensor(rgb_image.shape[1::-1]).repeat(2)
for i in range(detections.size(1)):
    j = 0
    while detections[0,i,j,0] >= 0.6:
        score = detections[0,i,j,0]
        label_name = labels[i-1]
        display_txt = %s: %.2f%(label_name, score)
        pt = (detections[0,i,j,1:]*scale).cpu().numpy()
        coords = (pt[0], pt[1]), pt[2]-pt[0]+1, pt[3]-pt[1]+1
        color = colors[i]
        currentAxis.add_patch(plt.Rectangle(*coords, fill=False, edgecolor=color, linewidth=2))
        currentAxis.text(pt[0], pt[1], display_txt, bbox={facecolor:color, alpha:0.5})
        j+=1

技术图片

由于我的数据集中很少没有戴口罩的样本,因此没有戴口罩的AP较低。

至此,使用pytorch-ssd训练测试自己数据集就全部完成啦。 

 

【pytorch-ssd目标检测】可视化检测结果

标签:ade   rds   ase   测试   path   mod   var   使用   blog   

原文地址:https://www.cnblogs.com/xiximayou/p/12552854.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!