yolov3_darknet 训练模型

github代码:https://github.com/pjreddie/darknet

1. darknet安装

参考官网:https://pjreddie.com/darknet/

2.数据集准备

a. 制作voc格式的数据集

b. 使用darknet/scripts内的voc_label.py脚本将数据集转换为yolo格式(txt)

3.配置文件

a. cfg/voc.data文件

classes= 20 #修改为自己数据集的类别
train  = /home/pjreddie/data/voc/train.txt #对应txt文件位置
valid  = /home/pjreddie/data/voc/2007_test.txt
names = data/voc.names #对应.names文件
backup = backup #权重文件保存位置

b. data/voc.names文件

修改为自己的类别

c. cfg/yolo-voc.cfg文件

[net]
# Testing 测试时去除注释
# batch=1 
# subdivisions=1
# Training 训练时去除注释
batch=64
subdivisions=16
width=608
height=608
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1

learning_rate=0.001
burn_in=1000
max_batches = 500200
policy=steps
steps=400000,450000
scales=.1,.1

------------------------------
[convolutional]
size=1
stride=1
pad=1
filters=255 # 3x(classes数目+5)
activation=linear


[yolo]
mask = 6,7,8
anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
classes=80 #class类数目
num=9
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1 #开启关闭多尺度训练
------------------------------

4.训练过程参数

Region xx: cfg文件中yolo-layer的索引;

Avg IOU: 当前迭代中,预测的box与标注的box的平均交并比,越大越好,期望数值为1;

Class: 标注物体的分类准确率,越大越好,期望数值为1;

obj: 越大越好,期望数值为1;

No obj: 越小越好;

.5R: 以IOU=0.5为阈值时候的recall; recall = 检出的正样本/实际的正样本

0.75R: 以IOU=0.75为阈值时候的recall;

count: 正样本数目

5.模型保存与再训练

a. 迭代次数保存
修改examples/darknet.c文件138行

        if(i%10000==0 || (i < 1000 && i%100 == 0)){
#ifdef GPU
            if(ngpus != 1) sync_nets(nets, ngpus, 0);
#endif
            char buff[256];
            sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
            save_weights(net, buff);
        }

需重新编译,先 make clean 再 make。

b. 模型再训练
接着上回迭代次数的模型继续训练

./darknet detector train cfg/voc.data cfg/yolov3-voc.cfg xxx.weights

6.训练过程可视化

训练时保存训练时的log文件

./darknet detector train cfg/voc.data cfg/yolov3-voc.cfg darknet53.conv.74 2>1 | tee visualization/train_yolov3.log 

extract_log.py脚本生成txt数据供可视化工具绘图

# coding=utf-8
# 该文件用来提取训练log,去除不可解析的log后使log文件格式化,生成新的log文件供可视化工具绘图

import inspect
import os
import random
import sys
def extract_log(log_file,new_log_file,key_word):
    with open(log_file, 'r') as f:
      with open(new_log_file, 'w') as train_log:
  #f = open(log_file)
    #train_log = open(new_log_file, 'w')
        for line in f:
    # 去除多gpu的同步log
          if 'Syncing' in line:
            continue
    # 去除除零错误的log
          if 'nan' in line:
            continue
          if key_word in line:
            train_log.write(line)
    f.close()
    train_log.close()

extract_log('train_yolov3.log','train_log_loss.txt','images')
extract_log('train_yolov3.log','train_log_iou.txt','IOU')

train_loss_visualization.pyj脚本用于绘制loss曲线

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
#%matplotlib inline

lines =5124    #改为自己生成的train_log_loss.txt中的行数
result = pd.read_csv('train_log_loss.txt', skiprows=[x for x in range(lines) if ((x%10!=9) |(x<1000))] ,error_bad_lines=False, names=['loss', 'avg', 'rate', 'seconds', 'images'])
result.head()

result['loss']=result['loss'].str.split(' ').str.get(1)
result['avg']=result['avg'].str.split(' ').str.get(1)
result['rate']=result['rate'].str.split(' ').str.get(1)
result['seconds']=result['seconds'].str.split(' ').str.get(1)
result['images']=result['images'].str.split(' ').str.get(1)
result.head()
result.tail()

# print(result.head())
# print(result.tail())
# print(result.dtypes)

print(result['loss'])
print(result['avg'])
print(result['rate'])
print(result['seconds'])
print(result['images'])

result['loss']=pd.to_numeric(result['loss'])
result['avg']=pd.to_numeric(result['avg'])
result['rate']=pd.to_numeric(result['rate'])
result['seconds']=pd.to_numeric(result['seconds'])
result['images']=pd.to_numeric(result['images'])
result.dtypes


fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(result['avg'].values,label='avg_loss')
# ax.plot(result['loss'].values,label='loss')
ax.legend(loc='best')  #图列自适应位置
ax.set_title('The loss curves')
ax.set_xlabel('batches')
fig.savefig('avg_loss')
# fig.savefig('loss')

参考

[1]https://blog.csdn.net/lilai619/article/details/79695109
[2]https://blog.csdn.net/qq_34806812/article/details/81459982