利用Pytorch torchvision完成Faster-rcnn目标检测demo及源码详解

利用Pytorch torchvision完成Faster-rcnn目标检测demo及源码详解Torchvision更新到0.3.0后支持了更多的功能,其中新增模块detection中实现了整个faster-rcnn的功能。本博客主要讲述如何通过torchvision和pytorch使用faster-rcnn,并提供一个demo和对应代码及解析注释。目录如果你不想深入了解原理和训练,只想用Faster-rcnn做目标检测任务的demo,请看这里torchvision中Faste…

大家好,又见面了,我是你们的朋友全栈君。

Torchvision更新到0.3.0后支持了更多的功能,其中新增模块detection中实现了整个faster-rcnn的功能。本博客主要讲述如何通过torchvision和pytorch使用faster-rcnn,并提供一个demo和对应代码及解析注释。

目录

如果你不想深入了解原理和训练,只想用Faster-rcnn做目标检测,请看这里

torchvision中Faster-rcnn接口

一个demo

使用方法

如果你想深入了解原理,并训练自己的模型

环境搭建

准备训练数据

模型训练

单张图片检测

效果



如果你不想深入了解原理和训练,只想用Faster-rcnn做目标检测,请看这里

torchvision中Faster-rcnn接口

torchvision内部集成了Faster-rcnn的模型,其接口和调用方式野非常简洁,目前官方提供resnet50+rpn在coco上训练的模型,调用该模型只需要几行代码:

>>> import torch
>>> import torchvision

// 创建模型,pretrained=True将下载官方提供的coco2017模型
>>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)



注意网络的输入x是一个Tensor构成的list,而输出prediction则是一个由dict构成list。prediction的长度和网络输入的list中Tensor个数相同。prediction中的每个dict包含输出的结果:

其中boxes是检测框坐标,labels是类别,scores则是置信度。

>>> predictions[0]

{'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward>), 'labels': tensor([], dtype=torch.int64), 'scores': tensor([], grad_fn=<IndexBackward>)}

一个demo

如果你不想自己写读取图片/预处理/后处理,我这里有个写好的demo.py,可以跑在任何安装了pytorch1.1+和torchvision0.3+的环境下,不需要其他依赖,可以用来完成目标检测的任务。

为了能够显示类别标签,我们将coco的所有类别写入coco_names.py

names = {'0': 'background', '1': 'person', '2': 'bicycle', '3': 'car', '4': 'motorcycle', '5': 'airplane', '6': 'bus', '7': 'train', '8': 'truck', '9': 'boat', '10': 'traffic light', '11': 'fire hydrant', '13': 'stop sign', '14': 'parking meter', '15': 'bench', '16': 'bird', '17': 'cat', '18': 'dog', '19': 'horse', '20': 'sheep', '21': 'cow', '22': 'elephant', '23': 'bear', '24': 'zebra', '25': 'giraffe', '27': 'backpack', '28': 'umbrella', '31': 'handbag', '32': 'tie', '33': 'suitcase', '34': 'frisbee', '35': 'skis', '36': 'snowboard', '37': 'sports ball', '38': 'kite', '39': 'baseball bat', '40': 'baseball glove', '41': 'skateboard', '42': 'surfboard', '43': 'tennis racket', '44': 'bottle', '46': 'wine glass', '47': 'cup', '48': 'fork', '49': 'knife', '50': 'spoon', '51': 'bowl', '52': 'banana', '53': 'apple', '54': 'sandwich', '55': 'orange', '56': 'broccoli', '57': 'carrot', '58': 'hot dog', '59': 'pizza', '60': 'donut', '61': 'cake', '62': 'chair', '63': 'couch', '64': 'potted plant', '65': 'bed', '67': 'dining table', '70': 'toilet', '72': 'tv', '73': 'laptop', '74': 'mouse', '75': 'remote', '76': 'keyboard', '77': 'cell phone', '78': 'microwave', '79': 'oven', '80': 'toaster', '81': 'sink', '82': 'refrigerator', '84': 'book', '85': 'clock', '86': 'vase', '87': 'scissors', '88': 'teddybear', '89': 'hair drier', '90': 'toothbrush'}

然后构建一个可以读取图片并检测的demo.py

import torch
import torchvision
import argparse
import cv2
import numpy as np
import sys
sys.path.append('./')
import coco_names
import random

def get_args():
    parser = argparse.ArgumentParser(description='Pytorch Faster-rcnn Detection')

    parser.add_argument('image_path', type=str, help='image path')
    parser.add_argument('--model', default='fasterrcnn_resnet50_fpn', help='model')
    parser.add_argument('--dataset', default='coco', help='model')
    parser.add_argument('--score', type=float, default=0.8, help='objectness score threshold')
    args = parser.parse_args()

    return args

def random_color():
    b = random.randint(0,255)
    g = random.randint(0,255)
    r = random.randint(0,255)

    return (b,g,r)

def main():
    args = get_args()
    input = []
    num_classes = 91
    names = coco_names.names
        
    # Model creating
    print("Creating model")
    model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=True)  
    model = model.cuda()

    model.eval()

    src_img = cv2.imread(args.image_path)
    img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
    img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().cuda()
    input.append(img_tensor)
    out = model(input)
    boxes = out[0]['boxes']
    labels = out[0]['labels']
    scores = out[0]['scores']

    for idx in range(boxes.shape[0]):
        if scores[idx] >= args.score:
            x1, y1, x2, y2 = boxes[idx][0], boxes[idx][1], boxes[idx][2], boxes[idx][3]
            name = names.get(str(labels[idx].item()))
            cv2.rectangle(src_img,(x1,y1),(x2,y2),random_color(),thickness=2)
            cv2.putText(src_img, text=name, org=(x1, y1+10), fontFace=cv2.FONT_HERSHEY_SIMPLEX, 
                fontScale=0.5, thickness=1, lineType=cv2.LINE_AA, color=(0, 0, 255))

    cv2.imshow('result',src_img)
    cv2.waitKey()
    cv2.destroyAllWindows()

    

if __name__ == "__main__":
    main()

运行命令

$ python demo.py [image path]

就能完成检测,并且不需要任何其他依赖,只需要Pytorch1.1+和torchvision0.3+。看下效果:

利用Pytorch torchvision完成Faster-rcnn目标检测demo及源码详解

使用方法

我发现好像很多人对上面这个demo怎么用不太清楚,照着下面的流程做就好了:

  1. 下载代码:https://github.com/supernotman/Faster-RCNN-with-torchvision
  2. 下载模型:Baidu Cloud
  3. 运行命令:
    $ python detect.py --model_path [模型路径] --image_path [图片路径]

其实非常简单。

如果你想深入了解原理,并训练自己的模型

这里提供一份我重构过的代码,把torchvision中的faster-rcnn部分提取出来,可以训练自己的模型(目前只支持coco),并有对应博客讲解。

Pytorch torchvision构建Faster-rcnn(一)—-coco数据读取

Pytorch torchvision构建Faster-rcnn(二)—-基础网络

Pytorch torchvision构建Faster-rcnn(三)—-RPN

Pytorch torchvision构建Faster-rcnn(四)—-ROIHead

环境搭建

下载代码:

$ git clone https://github.com/supernotman/Faster-RCNN-with-torchvision.git

安装依赖:

$ pip install -r requirements.txt

注意:

代码要求Pytorch版本大于1.1.0,torchvision版本大于0.3.0。

如果某个依赖项通过pip安装过慢,推荐替换清华源:

$ pip install -i https://pypi.tuna.tsinghua.edu.cn/simple some-package

如果pytorch安装过慢,可参考conda安装Pytorch下载过慢解决办法(7月23日更新ubuntu下pytorch1.1安装方法)

准备训练数据

下载coco2017数据集,下载地址:

http://images.cocodataset.org/zips/train2017.zip
http://images.cocodataset.org/annotations/annotations_trainval2017.zip

http://images.cocodataset.org/zips/val2017.zip
http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip

http://images.cocodataset.org/zips/test2017.zip
http://images.cocodataset.org/annotations/image_info_test2017.zip 

如果下载速度过慢,可参考博客COCO2017数据集国内下载地址

数据下载后按照如下结构放置:

  coco/
    2017/
      annotations/
      test2017/
      train2017/
      val2017/

模型训练

$ python -m torch.distributed.launch --nproc_per_node=$gpus --use_env train.py --world-size $gpus --b 4

训练采用了Pytorch的distributedparallel方式,支持多gpu。

注意其中$gpus为指定使用的gpu数量,b为每个gpu上的batch_size,因此实际batch_size大小为$gpus × b。

实测当b=4,1080ti下大概每张卡会占用11G显存,请根据情况自行设定。

训练过程中每个epoch会给出一次评估结果,形式如下:

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.352
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.573
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.375
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.207
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.387
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.448
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.296
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.474
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.498
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.312
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.538
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.631

其中AP为准确率,AR为召回率,第一行为训练结果的mAP,第四、五、六行分别为小/中/大物体对应的mAP

单张图片检测

$ python detect.py --model_path result/model_13.pth --image_path imgs/1.jpg

model_path为模型路径,image_path为测试图片路径。

代码文件夹中assets给出了从coco2017测试集中挑选的11张图片测试结果。

效果

利用Pytorch torchvision完成Faster-rcnn目标检测demo及源码详解

Good Luck!

如果对你有用的话,请给代码一个star,谢谢!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/152602.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号