Added Yolo v3

This commit is contained in:
AlexeyAB
2018-03-28 02:59:03 +03:00
parent 47c7af1cea
commit d9ae3dd681
37 changed files with 3666 additions and 23 deletions

View File

@ -27,6 +27,7 @@
#include "dropout_layer.h"
#include "route_layer.h"
#include "shortcut_layer.h"
#include "yolo_layer.h"
int get_current_batch(network net)
{
@ -499,6 +500,107 @@ float *network_predict(network net, float *input)
return out;
}
int num_detections(network *net, float thresh)
{
int i;
int s = 0;
for (i = 0; i < net->n; ++i) {
layer l = net->layers[i];
if (l.type == YOLO) {
s += yolo_num_detections(l, thresh);
}
if (l.type == DETECTION || l.type == REGION) {
s += l.w*l.h*l.n;
}
}
return s;
}
detection *make_network_boxes(network *net, float thresh, int *num)
{
layer l = net->layers[net->n - 1];
int i;
int nboxes = num_detections(net, thresh);
if (num) *num = nboxes;
detection *dets = calloc(nboxes, sizeof(detection));
for (i = 0; i < nboxes; ++i) {
dets[i].prob = calloc(l.classes, sizeof(float));
if (l.coords > 4) {
dets[i].mask = calloc(l.coords - 4, sizeof(float));
}
}
return dets;
}
void custom_get_region_detections(layer l, int w, int h, int net_w, int net_h, float thresh, int *map, float hier, int relative, detection *dets, int letter)
{
box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
float **probs = calloc(l.w*l.h*l.n, sizeof(float *));
int i, j;
for (j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(l.classes, sizeof(float *));
get_region_boxes(l, 1, 1, thresh, probs, boxes, 0, map);
for (j = 0; j < l.w*l.h*l.n; ++j) {
dets[j].classes = l.classes;
dets[j].bbox = boxes[j];
dets[j].objectness = 1;
for (i = 0; i < l.classes; ++i) dets[j].prob[i] = probs[j][i];
}
free(boxes);
free_ptrs((void **)probs, l.w*l.h*l.n);
}
void fill_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, detection *dets, int letter)
{
int j;
for (j = 0; j < net->n; ++j) {
layer l = net->layers[j];
if (l.type == YOLO) {
int count = get_yolo_detections(l, w, h, net->w, net->h, thresh, map, relative, dets, letter);
dets += count;
}
if (l.type == REGION) {
custom_get_region_detections(l, w, h, net->w, net->h, thresh, map, hier, relative, dets, letter);
//get_region_detections(l, w, h, net->w, net->h, thresh, map, hier, relative, dets);
dets += l.w*l.h*l.n;
}
if (l.type == DETECTION) {
get_detection_detections(l, w, h, thresh, dets);
dets += l.w*l.h*l.n;
}
}
}
detection *get_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, int *num, int letter)
{
detection *dets = make_network_boxes(net, thresh, num);
fill_network_boxes(net, w, h, thresh, hier, map, relative, dets, letter);
return dets;
}
void free_detections(detection *dets, int n)
{
int i;
for (i = 0; i < n; ++i) {
free(dets[i].prob);
if (dets[i].mask) free(dets[i].mask);
}
free(dets);
}
float *network_predict_image(network *net, image im)
{
image imr = letterbox_image(im, net->w, net->h);
set_batch_network(net, 1);
float *p = network_predict(*net, imr.data);
free_image(imr);
return p;
}
int network_width(network *net) { return net->w; }
int network_height(network *net) { return net->h; }
matrix network_predict_data_multi(network net, data test, int n)
{
int i,j,b,m;