batch inference refactoring

This commit is contained in:
enes
2019-10-19 16:18:44 +03:00
parent c435fca738
commit b392621e2e
4 changed files with 27 additions and 31 deletions

View File

@ -212,10 +212,10 @@ predict_image_letterbox = lib.network_predict_image_letterbox
predict_image_letterbox.argtypes = [c_void_p, IMAGE] predict_image_letterbox.argtypes = [c_void_p, IMAGE]
predict_image_letterbox.restype = POINTER(c_float) predict_image_letterbox.restype = POINTER(c_float)
network_predict_custom = lib.network_predict_custom network_predict_batch = lib.network_predict_batch
network_predict_custom.argtypes = [c_void_p, IMAGE, c_int, c_int, c_int, network_predict_batch.argtypes = [c_void_p, IMAGE, c_int, c_int, c_int,
c_float, c_float, POINTER(c_int), c_int, c_int] c_float, c_float, POINTER(c_int), c_int, c_int]
network_predict_custom.restype = POINTER(DETNUMPAIR) network_predict_batch.restype = POINTER(DETNUMPAIR)
def array_to_image(arr): def array_to_image(arr):
import numpy as np import numpy as np
@ -460,9 +460,6 @@ def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath
img_samples = ['data/person.jpg', 'data/person.jpg', 'data/person.jpg'] img_samples = ['data/person.jpg', 'data/person.jpg', 'data/person.jpg']
image_list = [cv2.imread(k) for k in img_samples] image_list = [cv2.imread(k) for k in img_samples]
if len(image_list) > batch_size:
raise ValueError(
"Please check if batch size is equal to the number of images passed to the function")
net = load_net_custom(configPath.encode('utf-8'), weightPath.encode('utf-8'), 0, batch_size) net = load_net_custom(configPath.encode('utf-8'), weightPath.encode('utf-8'), 0, batch_size)
meta = load_meta(metaPath.encode('utf-8')) meta = load_meta(metaPath.encode('utf-8'))
pred_height, pred_width, c = image_list[0].shape pred_height, pred_width, c = image_list[0].shape
@ -480,7 +477,7 @@ def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath
data = arr.ctypes.data_as(POINTER(c_float)) data = arr.ctypes.data_as(POINTER(c_float))
im = IMAGE(net_width, net_height, c, data) im = IMAGE(net_width, net_height, c, data)
batch_dets = network_predict_custom(net, im, batch_size, pred_width, batch_dets = network_predict_batch(net, im, batch_size, pred_width,
pred_height, thresh, hier_thresh, None, 0, 0) pred_height, thresh, hier_thresh, None, 0, 0)
batch_boxes = [] batch_boxes = []
batch_scores = [] batch_scores = []
@ -521,6 +518,6 @@ def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath
return batch_boxes, batch_scores, batch_classes return batch_boxes, batch_scores, batch_classes
if __name__ == "__main__": if __name__ == "__main__":
print(performDetect()) #print(performDetect())
# Uncomment the following line to see batch inference working #Uncomment the following line to see batch inference working
#print(performBatchDetect()) print(performBatchDetect())

View File

@ -694,14 +694,14 @@ int num_detections(network *net, float thresh)
return s; return s;
} }
int num_detections_custom(network *net, float thresh, int b) int num_detections_batch(network *net, float thresh, int batch)
{ {
int i; int i;
int s = 0; int s = 0;
for (i = 0; i < net->n; ++i) { for (i = 0; i < net->n; ++i) {
layer l = net->layers[i]; layer l = net->layers[i];
if (l.type == YOLO) { if (l.type == YOLO) {
s += yolo_num_detections_custom(l, thresh, b); s += yolo_num_detections_batch(l, thresh, batch);
} }
if (l.type == DETECTION || l.type == REGION) { if (l.type == DETECTION || l.type == REGION) {
s += l.w*l.h*l.n; s += l.w*l.h*l.n;
@ -726,12 +726,13 @@ detection *make_network_boxes(network *net, float thresh, int *num)
return dets; return dets;
} }
detection *make_network_boxes_custom(network *net, float thresh, int *num, int batch) detection *make_network_boxes_batch(network *net, float thresh, int *num, int batch)
{ {
int i; int i;
layer l = net->layers[net->n - 1]; layer l = net->layers[net->n - 1];
int nboxes = num_detections_custom(net, thresh, batch); int nboxes = num_detections_batch(net, thresh, batch);
if (num) *num = nboxes; assert(num != NULL);
*num = nboxes;
detection* dets = (detection*)calloc(nboxes, sizeof(detection)); detection* dets = (detection*)calloc(nboxes, sizeof(detection));
for (i = 0; i < nboxes; ++i) { for (i = 0; i < nboxes; ++i) {
dets[i].prob = (float*)calloc(l.classes, sizeof(float)); dets[i].prob = (float*)calloc(l.classes, sizeof(float));
@ -792,14 +793,14 @@ void fill_network_boxes(network *net, int w, int h, float thresh, float hier, in
} }
} }
void fill_network_boxes_custom(network *net, int w, int h, float thresh, float hier, int *map, int relative, detection *dets, int letter, int batch) void fill_network_boxes_batch(network *net, int w, int h, float thresh, float hier, int *map, int relative, detection *dets, int letter, int batch)
{ {
int prev_classes = -1; int prev_classes = -1;
int j; int j;
for (j = 0; j < net->n; ++j) { for (j = 0; j < net->n; ++j) {
layer l = net->layers[j]; layer l = net->layers[j];
if (l.type == YOLO) { if (l.type == YOLO) {
int count = get_yolo_detections_custom(l, w, h, net->w, net->h, thresh, map, relative, dets, letter, batch); int count = get_yolo_detections_batch(l, w, h, net->w, net->h, thresh, map, relative, dets, letter, batch);
dets += count; dets += count;
if (prev_classes < 0) prev_classes = l.classes; if (prev_classes < 0) prev_classes = l.classes;
else if (prev_classes != l.classes) { else if (prev_classes != l.classes) {
@ -840,7 +841,7 @@ void free_batch_detections(detNumPair *detNumPairs, int n)
{ {
int i; int i;
for(i=0; i<n; ++i) for(i=0; i<n; ++i)
free_detections(detNumPairs[i].dets,detNumPairs[i].num); free_detections(detNumPairs[i].dets, detNumPairs[i].num);
free(detNumPairs); free(detNumPairs);
} }
@ -915,17 +916,16 @@ float *network_predict_image(network *net, image im)
return p; return p;
} }
detNumPair* network_predict_custom(network *net, image im, int batch, int w, int h, float thresh, float hier, int *map, int relative, int letter) detNumPair* network_predict_batch(network *net, image im, int batch_size, int w, int h, float thresh, float hier, int *map, int relative, int letter)
{ {
set_batch_network(net, batch);
network_predict(*net, im.data); network_predict(*net, im.data);
detNumPair *pdets = ( struct detNumPair * )malloc(batch*sizeof(detNumPair)); detNumPair *pdets = (struct detNumPair *)calloc(batch_size, sizeof(detNumPair));
int num; int num;
for(int b=0;b<batch;b++){ for(int batch=0; batch<batch_size; batch++){
detection *dets = make_network_boxes_custom(net, thresh, &num, b); detection *dets = make_network_boxes_batch(net, thresh, &num, batch);
fill_network_boxes_custom(net, w, h, thresh, hier, map, relative, dets, letter,b); fill_network_boxes_batch(net, w, h, thresh, hier, map, relative, dets, letter, batch);
pdets[b].num = num; pdets[batch].num = num;
pdets[b].dets = dets; pdets[batch].dets = dets;
} }
return pdets; return pdets;
} }

View File

@ -461,7 +461,7 @@ int yolo_num_detections(layer l, float thresh)
return count; return count;
} }
int yolo_num_detections_custom(layer l, float thresh, int batch) int yolo_num_detections_batch(layer l, float thresh, int batch)
{ {
int i, n; int i, n;
int count = 0; int count = 0;
@ -537,9 +537,8 @@ int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh,
return count; return count;
} }
int get_yolo_detections_custom(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter,int batch) int get_yolo_detections_batch(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter, int batch)
{ {
//printf("\n l.batch = %d, l.w = %d, l.h = %d, l.n = %d \n", l.batch, l.w, l.h, l.n);
int i,j,n; int i,j,n;
float *predictions = l.output; float *predictions = l.output;
//if (l.batch == 2) avg_flipped_yolo(l); //if (l.batch == 2) avg_flipped_yolo(l);

View File

@ -13,9 +13,9 @@ void forward_yolo_layer(const layer l, network_state state);
void backward_yolo_layer(const layer l, network_state state); void backward_yolo_layer(const layer l, network_state state);
void resize_yolo_layer(layer *l, int w, int h); void resize_yolo_layer(layer *l, int w, int h);
int yolo_num_detections(layer l, float thresh); int yolo_num_detections(layer l, float thresh);
int yolo_num_detections_custom(layer l, float thresh, int batch); int yolo_num_detections_batch(layer l, float thresh, int batch);
int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter); int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter);
int get_yolo_detections_custom(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter, int batch); int get_yolo_detections_batch(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter, int batch);
void correct_yolo_boxes(detection *dets, int n, int w, int h, int netw, int neth, int relative, int letter); void correct_yolo_boxes(detection *dets, int n, int w, int h, int netw, int neth, int relative, int letter);
#ifdef GPU #ifdef GPU