mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
batch inference refactoring
This commit is contained in:
17
darknet.py
17
darknet.py
@ -212,10 +212,10 @@ predict_image_letterbox = lib.network_predict_image_letterbox
|
|||||||
predict_image_letterbox.argtypes = [c_void_p, IMAGE]
|
predict_image_letterbox.argtypes = [c_void_p, IMAGE]
|
||||||
predict_image_letterbox.restype = POINTER(c_float)
|
predict_image_letterbox.restype = POINTER(c_float)
|
||||||
|
|
||||||
network_predict_custom = lib.network_predict_custom
|
network_predict_batch = lib.network_predict_batch
|
||||||
network_predict_custom.argtypes = [c_void_p, IMAGE, c_int, c_int, c_int,
|
network_predict_batch.argtypes = [c_void_p, IMAGE, c_int, c_int, c_int,
|
||||||
c_float, c_float, POINTER(c_int), c_int, c_int]
|
c_float, c_float, POINTER(c_int), c_int, c_int]
|
||||||
network_predict_custom.restype = POINTER(DETNUMPAIR)
|
network_predict_batch.restype = POINTER(DETNUMPAIR)
|
||||||
|
|
||||||
def array_to_image(arr):
|
def array_to_image(arr):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -460,9 +460,6 @@ def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath
|
|||||||
img_samples = ['data/person.jpg', 'data/person.jpg', 'data/person.jpg']
|
img_samples = ['data/person.jpg', 'data/person.jpg', 'data/person.jpg']
|
||||||
image_list = [cv2.imread(k) for k in img_samples]
|
image_list = [cv2.imread(k) for k in img_samples]
|
||||||
|
|
||||||
if len(image_list) > batch_size:
|
|
||||||
raise ValueError(
|
|
||||||
"Please check if batch size is equal to the number of images passed to the function")
|
|
||||||
net = load_net_custom(configPath.encode('utf-8'), weightPath.encode('utf-8'), 0, batch_size)
|
net = load_net_custom(configPath.encode('utf-8'), weightPath.encode('utf-8'), 0, batch_size)
|
||||||
meta = load_meta(metaPath.encode('utf-8'))
|
meta = load_meta(metaPath.encode('utf-8'))
|
||||||
pred_height, pred_width, c = image_list[0].shape
|
pred_height, pred_width, c = image_list[0].shape
|
||||||
@ -480,7 +477,7 @@ def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath
|
|||||||
data = arr.ctypes.data_as(POINTER(c_float))
|
data = arr.ctypes.data_as(POINTER(c_float))
|
||||||
im = IMAGE(net_width, net_height, c, data)
|
im = IMAGE(net_width, net_height, c, data)
|
||||||
|
|
||||||
batch_dets = network_predict_custom(net, im, batch_size, pred_width,
|
batch_dets = network_predict_batch(net, im, batch_size, pred_width,
|
||||||
pred_height, thresh, hier_thresh, None, 0, 0)
|
pred_height, thresh, hier_thresh, None, 0, 0)
|
||||||
batch_boxes = []
|
batch_boxes = []
|
||||||
batch_scores = []
|
batch_scores = []
|
||||||
@ -521,6 +518,6 @@ def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath
|
|||||||
return batch_boxes, batch_scores, batch_classes
|
return batch_boxes, batch_scores, batch_classes
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(performDetect())
|
#print(performDetect())
|
||||||
# Uncomment the following line to see batch inference working
|
#Uncomment the following line to see batch inference working
|
||||||
#print(performBatchDetect())
|
print(performBatchDetect())
|
@ -694,14 +694,14 @@ int num_detections(network *net, float thresh)
|
|||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
int num_detections_custom(network *net, float thresh, int b)
|
int num_detections_batch(network *net, float thresh, int batch)
|
||||||
{
|
{
|
||||||
int i;
|
int i;
|
||||||
int s = 0;
|
int s = 0;
|
||||||
for (i = 0; i < net->n; ++i) {
|
for (i = 0; i < net->n; ++i) {
|
||||||
layer l = net->layers[i];
|
layer l = net->layers[i];
|
||||||
if (l.type == YOLO) {
|
if (l.type == YOLO) {
|
||||||
s += yolo_num_detections_custom(l, thresh, b);
|
s += yolo_num_detections_batch(l, thresh, batch);
|
||||||
}
|
}
|
||||||
if (l.type == DETECTION || l.type == REGION) {
|
if (l.type == DETECTION || l.type == REGION) {
|
||||||
s += l.w*l.h*l.n;
|
s += l.w*l.h*l.n;
|
||||||
@ -726,12 +726,13 @@ detection *make_network_boxes(network *net, float thresh, int *num)
|
|||||||
return dets;
|
return dets;
|
||||||
}
|
}
|
||||||
|
|
||||||
detection *make_network_boxes_custom(network *net, float thresh, int *num, int batch)
|
detection *make_network_boxes_batch(network *net, float thresh, int *num, int batch)
|
||||||
{
|
{
|
||||||
int i;
|
int i;
|
||||||
layer l = net->layers[net->n - 1];
|
layer l = net->layers[net->n - 1];
|
||||||
int nboxes = num_detections_custom(net, thresh, batch);
|
int nboxes = num_detections_batch(net, thresh, batch);
|
||||||
if (num) *num = nboxes;
|
assert(num != NULL);
|
||||||
|
*num = nboxes;
|
||||||
detection* dets = (detection*)calloc(nboxes, sizeof(detection));
|
detection* dets = (detection*)calloc(nboxes, sizeof(detection));
|
||||||
for (i = 0; i < nboxes; ++i) {
|
for (i = 0; i < nboxes; ++i) {
|
||||||
dets[i].prob = (float*)calloc(l.classes, sizeof(float));
|
dets[i].prob = (float*)calloc(l.classes, sizeof(float));
|
||||||
@ -792,14 +793,14 @@ void fill_network_boxes(network *net, int w, int h, float thresh, float hier, in
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void fill_network_boxes_custom(network *net, int w, int h, float thresh, float hier, int *map, int relative, detection *dets, int letter, int batch)
|
void fill_network_boxes_batch(network *net, int w, int h, float thresh, float hier, int *map, int relative, detection *dets, int letter, int batch)
|
||||||
{
|
{
|
||||||
int prev_classes = -1;
|
int prev_classes = -1;
|
||||||
int j;
|
int j;
|
||||||
for (j = 0; j < net->n; ++j) {
|
for (j = 0; j < net->n; ++j) {
|
||||||
layer l = net->layers[j];
|
layer l = net->layers[j];
|
||||||
if (l.type == YOLO) {
|
if (l.type == YOLO) {
|
||||||
int count = get_yolo_detections_custom(l, w, h, net->w, net->h, thresh, map, relative, dets, letter, batch);
|
int count = get_yolo_detections_batch(l, w, h, net->w, net->h, thresh, map, relative, dets, letter, batch);
|
||||||
dets += count;
|
dets += count;
|
||||||
if (prev_classes < 0) prev_classes = l.classes;
|
if (prev_classes < 0) prev_classes = l.classes;
|
||||||
else if (prev_classes != l.classes) {
|
else if (prev_classes != l.classes) {
|
||||||
@ -840,7 +841,7 @@ void free_batch_detections(detNumPair *detNumPairs, int n)
|
|||||||
{
|
{
|
||||||
int i;
|
int i;
|
||||||
for(i=0; i<n; ++i)
|
for(i=0; i<n; ++i)
|
||||||
free_detections(detNumPairs[i].dets,detNumPairs[i].num);
|
free_detections(detNumPairs[i].dets, detNumPairs[i].num);
|
||||||
free(detNumPairs);
|
free(detNumPairs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -915,17 +916,16 @@ float *network_predict_image(network *net, image im)
|
|||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
detNumPair* network_predict_custom(network *net, image im, int batch, int w, int h, float thresh, float hier, int *map, int relative, int letter)
|
detNumPair* network_predict_batch(network *net, image im, int batch_size, int w, int h, float thresh, float hier, int *map, int relative, int letter)
|
||||||
{
|
{
|
||||||
set_batch_network(net, batch);
|
|
||||||
network_predict(*net, im.data);
|
network_predict(*net, im.data);
|
||||||
detNumPair *pdets = ( struct detNumPair * )malloc(batch*sizeof(detNumPair));
|
detNumPair *pdets = (struct detNumPair *)calloc(batch_size, sizeof(detNumPair));
|
||||||
int num;
|
int num;
|
||||||
for(int b=0;b<batch;b++){
|
for(int batch=0; batch<batch_size; batch++){
|
||||||
detection *dets = make_network_boxes_custom(net, thresh, &num, b);
|
detection *dets = make_network_boxes_batch(net, thresh, &num, batch);
|
||||||
fill_network_boxes_custom(net, w, h, thresh, hier, map, relative, dets, letter,b);
|
fill_network_boxes_batch(net, w, h, thresh, hier, map, relative, dets, letter, batch);
|
||||||
pdets[b].num = num;
|
pdets[batch].num = num;
|
||||||
pdets[b].dets = dets;
|
pdets[batch].dets = dets;
|
||||||
}
|
}
|
||||||
return pdets;
|
return pdets;
|
||||||
}
|
}
|
||||||
|
@ -461,7 +461,7 @@ int yolo_num_detections(layer l, float thresh)
|
|||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
|
|
||||||
int yolo_num_detections_custom(layer l, float thresh, int batch)
|
int yolo_num_detections_batch(layer l, float thresh, int batch)
|
||||||
{
|
{
|
||||||
int i, n;
|
int i, n;
|
||||||
int count = 0;
|
int count = 0;
|
||||||
@ -537,9 +537,8 @@ int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh,
|
|||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
|
|
||||||
int get_yolo_detections_custom(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter,int batch)
|
int get_yolo_detections_batch(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter, int batch)
|
||||||
{
|
{
|
||||||
//printf("\n l.batch = %d, l.w = %d, l.h = %d, l.n = %d \n", l.batch, l.w, l.h, l.n);
|
|
||||||
int i,j,n;
|
int i,j,n;
|
||||||
float *predictions = l.output;
|
float *predictions = l.output;
|
||||||
//if (l.batch == 2) avg_flipped_yolo(l);
|
//if (l.batch == 2) avg_flipped_yolo(l);
|
||||||
|
@ -13,9 +13,9 @@ void forward_yolo_layer(const layer l, network_state state);
|
|||||||
void backward_yolo_layer(const layer l, network_state state);
|
void backward_yolo_layer(const layer l, network_state state);
|
||||||
void resize_yolo_layer(layer *l, int w, int h);
|
void resize_yolo_layer(layer *l, int w, int h);
|
||||||
int yolo_num_detections(layer l, float thresh);
|
int yolo_num_detections(layer l, float thresh);
|
||||||
int yolo_num_detections_custom(layer l, float thresh, int batch);
|
int yolo_num_detections_batch(layer l, float thresh, int batch);
|
||||||
int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter);
|
int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter);
|
||||||
int get_yolo_detections_custom(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter, int batch);
|
int get_yolo_detections_batch(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter, int batch);
|
||||||
void correct_yolo_boxes(detection *dets, int n, int w, int h, int netw, int neth, int relative, int letter);
|
void correct_yolo_boxes(detection *dets, int n, int w, int h, int netw, int neth, int relative, int letter);
|
||||||
|
|
||||||
#ifdef GPU
|
#ifdef GPU
|
||||||
|
Reference in New Issue
Block a user