EfficientDet 训练自己的数据集_efficientdet训练自己的数据-程序员宅基地

技术标签: python  深度学习  efficientDet  pytorch  json  

EfficientDet训练自己的数据集

项目安装

参考代码:https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch
安装及环境配置可参考作者介绍或者其他博客

数据准备

训练时需要将数据集转换为coco格式的数据集,本人使用的数据集为visdrone数据集,转换过程如下:txt->XML->coco.json

txt->XML

import os
from PIL import Image

# 把下面的路径改成你自己的路径即可
root_dir = "./VisDrone2019-DET-train/"
annotations_dir = root_dir+"annotations/"
image_dir = root_dir + "images/"
xml_dir = root_dir+"Annotations_XML/"
# 下面的类别也换成你自己数据类别,也可适用于其他的数据集转换
class_name = ['ignored regions','pedestrian','people','bicycle','car','van','truck','tricycle','awning-tricycle','bus','motor','others']

for filename in os.listdir(annotations_dir):
    fin = open(annotations_dir+filename, 'r')
    image_name = filename.split('.')[0]
    img = Image.open(image_dir+image_name+".jpg") # 若图像数据是“png”转换成“.png”即可
    xml_name = xml_dir+image_name+'.xml'
    with open(xml_name, 'w') as fout:
        fout.write('<annotation>'+'\n')
        
        fout.write('\t'+'<folder>VOC2007</folder>'+'\n')
        fout.write('\t'+'<filename>'+image_name+'.jpg'+'</filename>'+'\n')
        
        fout.write('\t'+'<source>'+'\n')
        fout.write('\t\t'+'<database>'+'VisDrone2018 Database'+'</database>'+'\n')
        fout.write('\t\t'+'<annotation>'+'VisDrone2018'+'</annotation>'+'\n')
        fout.write('\t\t'+'<image>'+'flickr'+'</image>'+'\n')
        fout.write('\t\t'+'<flickrid>'+'Unspecified'+'</flickrid>'+'\n')
        fout.write('\t'+'</source>'+'\n')
        
        fout.write('\t'+'<owner>'+'\n')
        fout.write('\t\t'+'<flickrid>'+'Haipeng Zhang'+'</flickrid>'+'\n')
        fout.write('\t\t'+'<name>'+'Haipeng Zhang'+'</name>'+'\n')
        fout.write('\t'+'</owner>'+'\n')
        
        fout.write('\t'+'<size>'+'\n')
        fout.write('\t\t'+'<width>'+str(img.size[0])+'</width>'+'\n')
        fout.write('\t\t'+'<height>'+str(img.size[1])+'</height>'+'\n')
        fout.write('\t\t'+'<depth>'+'3'+'</depth>'+'\n')
        fout.write('\t'+'</size>'+'\n')
        
        fout.write('\t'+'<segmented>'+'0'+'</segmented>'+'\n')

        for line in fin.readlines():
            line = line.split(',')
            fout.write('\t'+'<object>'+'\n')
            fout.write('\t\t'+'<name>'+class_name[int(line[5])]+'</name>'+'\n')
            fout.write('\t\t'+'<pose>'+'Unspecified'+'</pose>'+'\n')
            fout.write('\t\t'+'<truncated>'+line[6]+'</truncated>'+'\n')
            fout.write('\t\t'+'<difficult>'+str(int(line[7]))+'</difficult>'+'\n')
            fout.write('\t\t'+'<bndbox>'+'\n')
            fout.write('\t\t\t'+'<xmin>'+line[0]+'</xmin>'+'\n')
            fout.write('\t\t\t'+'<ymin>'+line[1]+'</ymin>'+'\n')
            # pay attention to this point!(0-based)
            fout.write('\t\t\t'+'<xmax>'+str(int(line[0])+int(line[2])-1)+'</xmax>'+'\n')
            fout.write('\t\t\t'+'<ymax>'+str(int(line[1])+int(line[3])-1)+'</ymax>'+'\n')
            fout.write('\t\t'+'</bndbox>'+'\n')
            fout.write('\t'+'</object>'+'\n')
             
        fin.close()
        fout.write('</annotation>')

XML->coco.json

    # coding=utf-8
import xml.etree.ElementTree as ET
import os
import json


voc_clses = ['aeroplane', 'bicycle', 'bird', 'boat',
    'bottle', 'bus', 'car', 'cat', 'chair',
    'cow', 'diningtable', 'dog', 'horse',
    'motorbike', 'person', 'pottedplant',
    'sheep', 'sofa', 'train', 'tvmonitor']


categories = []
for iind, cat in enumerate(voc_clses):
    cate = {
    }
    cate['supercategory'] = cat
    cate['name'] = cat
    cate['id'] = iind
    categories.append(cate)

def getimages(xmlname, id):
    sig_xml_box = []
    tree = ET.parse(xmlname)
    root = tree.getroot()
    images = {
    }
    for i in root:  # 遍历一级节点
        if i.tag == 'filename':
            file_name = i.text  # 0001.jpg
            # print('image name: ', file_name)
            images['file_name'] = file_name
        if i.tag == 'size':
            for j in i:
                if j.tag == 'width':
                    width = j.text
                    images['width'] = width
                if j.tag == 'height':
                    height = j.text
                    images['height'] = height
        if i.tag == 'object':
            for j in i:
                if j.tag == 'name':
                    cls_name = j.text
                cat_id = voc_clses.index(cls_name) + 1
                if j.tag == 'bndbox':
                    bbox = []
                    xmin = 0
                    ymin = 0
                    xmax = 0
                    ymax = 0
                    for r in j:
                        if r.tag == 'xmin':
                            xmin = eval(r.text)
                        if r.tag == 'ymin':
                            ymin = eval(r.text)
                        if r.tag == 'xmax':
                            xmax = eval(r.text)
                        if r.tag == 'ymax':
                            ymax = eval(r.text)
                    bbox.append(xmin)
                    bbox.append(ymin)
                    bbox.append(xmax - xmin)
                    bbox.append(ymax - ymin)
                    bbox.append(id)   # 保存当前box对应的image_id
                    bbox.append(cat_id)
                    # anno area
                    bbox.append((xmax - xmin) * (ymax - ymin) - 10.0)   # bbox的ares
                    # coco中的ares数值是 < w*h 的, 因为它其实是按segmentation的面积算的,所以我-10.0一下...
                    sig_xml_box.append(bbox)
                    # print('bbox', xmin, ymin, xmax - xmin, ymax - ymin, 'id', id, 'cls_id', cat_id)
    images['id'] = id
    # print ('sig_img_box', sig_xml_box)
    return images, sig_xml_box



def txt2list(txtfile):
    f = open(txtfile)
    l = []
    for line in f:
        l.append(line[:-1])
    return l


# voc2007xmls = 'anns'
voc2007xmls = '/data2/chenjia/data/VOCdevkit/VOC2007/Annotations'
# test_txt = 'voc2007/test.txt'
test_txt = '/data2/chenjia/data/VOCdevkit/VOC2007/ImageSets/Main/test.txt'
xml_names = txt2list(test_txt)
xmls = []
bboxes = []
ann_js = {
    }
for ind, xml_name in enumerate(xml_names):
    xmls.append(os.path.join(voc2007xmls, xml_name + '.xml'))
json_name = 'annotations/instances_voc2007val.json'
images = []
for i_index, xml_file in enumerate(xmls):
    image, sig_xml_bbox = getimages(xml_file, i_index)
    images.append(image)
    bboxes.extend(sig_xml_bbox)
ann_js['images'] = images
ann_js['categories'] = categories
annotations = []
for box_ind, box in enumerate(bboxes):
    anno = {
    }
    anno['image_id'] =  box[-3]
    anno['category_id'] = box[-2]
    anno['bbox'] = box[:-3]
    anno['id'] = box_ind
    anno['area'] = box[-1]
    anno['iscrowd'] = 0
    annotations.append(anno)
ann_js['annotations'] = annotations
json.dump(ann_js, open(json_name, 'w'), indent=4)  # indent=4 更加美观显示               

将生成的json及图片按照一下结构放置,注意修改json文件名称:

  • dadasets
    • visdrone2019
      • train2019
      • val2019
      • annotations
        • instances_train2019.json
        • instances_val2019.json

修改projects下coco.yml内容,按照自己的数据库情况修改

project_name: visdrone2019  # also the folder name of the dataset that under data_path folder
train_set: train2019
val_set: val2019
num_gpus: 1

# mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco.
mean: [0.373, 0.378, 0.364]
std: [0.191, 0.182, 0.194]

# this is coco anchors, change it if necessary
anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]'
anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]'

# must match your dataset's category_id.
# category_id is one_indexed,
# for example, index of 'car' here is 2, while category_id of is 3
obj_list: ["pedestrian","people","bicycle","car","van","truck","tricycle","awning-tricycle","bus","motor"]

训练

python train.py -c 2 --batch_size 8 --lr 1e-5 --num_epochs 10
–load_weights /path/to/your/weights/efficientdet-d2.pth

提前下载model文件,放置在文件夹中,建议d0,d1,d2(大了显存会溢出),如出现显存溢出情况,调整batch_size大小。
在这里插入图片描述

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/gu891221/article/details/105783719

智能推荐

全卷积网络 Fully Convolutional Networks-程序员宅基地

文章浏览阅读1.1w次,点赞9次,收藏45次。CNN能够对图片进行分类,可是怎么样才能识别图片中特定部分的物体,在2015年之前还是一个世界难题。神经网络大神Jonathan Long发表了《Fully Convolutional Networks for Semantic Segmentation》在图像语义分割挖了一个坑,于是无穷无尽的人往坑里面跳。全卷积网络 Fully Convolutional Networks_fully convolutional networks

签证上的mult是什么意思_申根签证中mult是什么意思-程序员宅基地

文章浏览阅读789次。展开全部申根签证中mult是是多次的意思,指可以在有效期内多次往返申根国家。类型申根签证分62616964757a686964616fe78988e69d8331333431373939为入境和过境两类。1.入境签证有一次入境和多次入境两种。签证持有者分别可一次连续停留90天或每半年多次累计不超过3个月。如需长期停留,可向某一成员国申请只在该国使用的国别签证;2.过境签证指过境前往协定国以外国家的..._mult是什么意思?

webpack 配置_webpack设置 require-程序员宅基地

文章浏览阅读602次。corejs处理,在项目根目录下的 babel.config.js 文件配置。webpack.config.js文件。babel.config.js文件。记录学习 webpack 的过程。.eslintrc.js 文件。_webpack设置 require

Vue组件详解-程序员宅基地

文章浏览阅读454次。文章目录什么是组件?模块化与组件化组件定义命名规则创建组件的方式方式一方式二方式三组件的唯一性什么是组件?什么是组件:组件的出现,就是为了拆分vue实例的代码里的,能够让我们以不同的组件,来划分不同的功能模块,将来我们需要什么样的功能,就可以去调用对应的组件即可模块化与组件化名称概念模块化是从代码逻辑角度进行划分的;方便代码的分层开发,保证每个功能模块的职能单一组件化是从UI界面的角度进行划分的;前端的组件化,方便UI组件重用组件定义命名规则推荐全小写,然后

图像分类篇——使用pytorch搭建ResNet网络_resnet实战:使用resnet实现图像分类(pytorch)-程序员宅基地

文章浏览阅读2.7k次,点赞6次,收藏40次。目录1. ResNet网络详解1.1 ResNet网络概述1.2 Batch Normalization1.3 residual结构1.4 ResNet结构和详细参数1.5 迁移学习2. Pytorch搭建本文为学习记录和备忘录,对代码进行了详细注释,以供学习。内容来源:★github:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing★b站:https://space.bilibili.com/18161609/chan_resnet实战:使用resnet实现图像分类(pytorch)

Beautiful Soup之find()和find_all()的基本使用_soup.find_all-程序员宅基地

文章浏览阅读1w次,点赞17次,收藏66次。1.HTML文本这里以官方文档提供的html代码来演示Beautiful Soup中find_all()和find()的基本使用。<html><head><title>The Dormouse's story</title></head><body><p class="title"><b>The Dormouse's story</b></p><p class="stor_soup.find_all

随便推点

for input String "id"错误解决_for input string: "id-程序员宅基地

文章浏览阅读1.3w次。出现这个问题,Idea报错说明是 类型转换错误,可是之前这么做事没错的,类似于我就是这么展示在jsp页面上的,这次居然报类型转换错误,Google了一圈之后,发现不应该这么写,应该用数组下标去获取值。类似下面这种。这里面的下标顺序就是数据库当中对应的显示顺序。比如id是第一位,那么要想取到id,数组下标就是0,依次类推。..._for input string: "id

python heapq 优先队列数组内部不单调递增,但是单独一个个出队列单调递增_python heapq 维护单调-程序员宅基地

文章浏览阅读65次。在使用heapq的优先队列时,发现了标题中的现象,一度怀疑优先队列出错了。总之,使用heapq访问q[0]后面的内容时要注意。具体原因没有深究,有知道的大佬可以指点指点。_python heapq 维护单调

phpStudy环境安装SSL证书教程(apache)-程序员宅基地

文章浏览阅读85次。https://cloud.tencent.com/product/ssl此链接是检测域名 证书的可以检测一下下面是证书配置 小白呢亲测作为PHP程序员,我们一定要学会使用phpStudy环境集成包,PHPstudy用起来方便,快捷,对于刚入门的PHP初学者来说phpStudy是个好东西,我本文我们就和大家分享一下phpStudy环境如何安装SSL证书。第一步:修改apache目录..._d:/phpstudy/apache/conf/ssl/ca.key do not match

Python之小词典应用-程序员宅基地

文章浏览阅读5.9k次,点赞5次,收藏41次。Python之小词典应用这个学期专业开了python课,最后老师布置了一个作业:用python制作一个英语小词典的应用,遂做了一下。题目要求:制作英文学习词典。编写程序制作英文学习词典,词典有三个基本功能:添加、查询、和退出。程序读取源文件路径下的txt格式词典文件,若没有就创建一个。词典文件存储格式为“英文单词 中文单词”,每行仅有一对中英释义。程序会根据用户的选择进入相应的功能...

【轮式平衡机器人】——TMS320F28069片内外设之Timer_IT(补:CCS程序烧录方法)_机器人人烧录程序调试过程-程序员宅基地

文章浏览阅读1.1k次,点赞27次,收藏24次。TMS320F28069 的定时器中断功能。在微控制器或数字信号控制器中,定时器是一个非常重要的外设,它可以用来产生固定时间间隔的中断,或者用来精确计算时间。_机器人人烧录程序调试过程

计算机输入法无法启动,Win7系统开机后输入法总是消失如何解决-程序员宅基地

文章浏览阅读1.2k次。输入法是我们在使用电脑的时候经常会用来输入文字的工具,一般在开机的时候都会自动启动并在任务栏右下角显示,可是有不少win7系统用户却遇到开机后输入法总是消失的情况,要怎么解决呢?现在为大家带来Win7系统开机后输入法总是消失的详细解决步骤。1、在win7系统中点击“开始”--“运行”输入--regedit,打开“注册表编辑器”,找到“HKEY_USERS\.DEFAULT\ControlPanel..._开机输入法消失

推荐文章

热门文章

相关标签