mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
80
darknet.py
80
darknet.py
@ -63,6 +63,9 @@ class DETECTION(Structure):
|
||||
("uc", POINTER(c_float)),
|
||||
("points", c_int)]
|
||||
|
||||
class DETNUMPAIR(Structure):
|
||||
_fields_ = [("num", c_int),
|
||||
("dets", POINTER(DETECTION))]
|
||||
|
||||
class IMAGE(Structure):
|
||||
_fields_ = [("w", c_int),
|
||||
@ -161,6 +164,9 @@ make_network_boxes.restype = POINTER(DETECTION)
|
||||
free_detections = lib.free_detections
|
||||
free_detections.argtypes = [POINTER(DETECTION), c_int]
|
||||
|
||||
free_batch_detections = lib.free_batch_detections
|
||||
free_batch_detections.argtypes = [POINTER(DETNUMPAIR), c_int]
|
||||
|
||||
free_ptrs = lib.free_ptrs
|
||||
free_ptrs.argtypes = [POINTER(c_void_p), c_int]
|
||||
|
||||
@ -210,6 +216,11 @@ predict_image_letterbox = lib.network_predict_image_letterbox
|
||||
predict_image_letterbox.argtypes = [c_void_p, IMAGE]
|
||||
predict_image_letterbox.restype = POINTER(c_float)
|
||||
|
||||
network_predict_batch = lib.network_predict_batch
|
||||
network_predict_batch.argtypes = [c_void_p, IMAGE, c_int, c_int, c_int,
|
||||
c_float, c_float, POINTER(c_int), c_int, c_int]
|
||||
network_predict_batch.restype = POINTER(DETNUMPAIR)
|
||||
|
||||
def array_to_image(arr):
|
||||
import numpy as np
|
||||
# need to return old values to avoid python freeing memory
|
||||
@ -445,5 +456,72 @@ def performDetect(imagePath="data/dog.jpg", thresh= 0.25, configPath = "./cfg/yo
|
||||
print("Unable to show image: "+str(e))
|
||||
return detections
|
||||
|
||||
def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath = "yolov3.weights", metaPath= "./cfg/coco.data", hier_thresh=.5, nms=.45, batch_size=3):
|
||||
import cv2
|
||||
import numpy as np
|
||||
# NB! Image sizes should be the same
|
||||
# You can change the images, yet, be sure that they have the same width and height
|
||||
img_samples = ['data/person.jpg', 'data/person.jpg', 'data/person.jpg']
|
||||
image_list = [cv2.imread(k) for k in img_samples]
|
||||
|
||||
net = load_net_custom(configPath.encode('utf-8'), weightPath.encode('utf-8'), 0, batch_size)
|
||||
meta = load_meta(metaPath.encode('utf-8'))
|
||||
pred_height, pred_width, c = image_list[0].shape
|
||||
net_width, net_height = (network_width(net), network_height(net))
|
||||
img_list = []
|
||||
for custom_image_bgr in image_list:
|
||||
custom_image = cv2.cvtColor(custom_image_bgr, cv2.COLOR_BGR2RGB)
|
||||
custom_image = cv2.resize(
|
||||
custom_image, (net_width, net_height), interpolation=cv2.INTER_NEAREST)
|
||||
custom_image = custom_image.transpose(2, 0, 1)
|
||||
img_list.append(custom_image)
|
||||
|
||||
arr = np.concatenate(img_list, axis=0)
|
||||
arr = np.ascontiguousarray(arr.flat, dtype=np.float32) / 255.0
|
||||
data = arr.ctypes.data_as(POINTER(c_float))
|
||||
im = IMAGE(net_width, net_height, c, data)
|
||||
|
||||
batch_dets = network_predict_batch(net, im, batch_size, pred_width,
|
||||
pred_height, thresh, hier_thresh, None, 0, 0)
|
||||
batch_boxes = []
|
||||
batch_scores = []
|
||||
batch_classes = []
|
||||
for b in range(batch_size):
|
||||
num = batch_dets[b].num
|
||||
dets = batch_dets[b].dets
|
||||
if nms:
|
||||
do_nms_obj(dets, num, meta.classes, nms)
|
||||
boxes = []
|
||||
scores = []
|
||||
classes = []
|
||||
for i in range(num):
|
||||
det = dets[i]
|
||||
score = -1
|
||||
label = None
|
||||
for c in range(det.classes):
|
||||
p = det.prob[c]
|
||||
if p > score:
|
||||
score = p
|
||||
label = c
|
||||
if score > thresh:
|
||||
box = det.bbox
|
||||
left, top, right, bottom = map(int,(box.x - box.w / 2, box.y - box.h / 2,
|
||||
box.x + box.w / 2, box.y + box.h / 2))
|
||||
boxes.append((top, left, bottom, right))
|
||||
scores.append(score)
|
||||
classes.append(label)
|
||||
boxColor = (int(255 * (1 - (score ** 2))), int(255 * (score ** 2)), 0)
|
||||
cv2.rectangle(image_list[b], (left, top),
|
||||
(right, bottom), boxColor, 2)
|
||||
cv2.imwrite(os.path.basename(img_samples[b]),image_list[b])
|
||||
|
||||
batch_boxes.append(boxes)
|
||||
batch_scores.append(scores)
|
||||
batch_classes.append(classes)
|
||||
free_batch_detections(batch_dets, batch_size)
|
||||
return batch_boxes, batch_scores, batch_classes
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(performDetect())
|
||||
#print(performDetect())
|
||||
#Uncomment the following line to see batch inference working
|
||||
print(performBatchDetect())
|
@ -830,6 +830,12 @@ typedef struct detection{
|
||||
int points; // bit-0 - center, bit-1 - top-left-corner, bit-2 - bottom-right-corner
|
||||
} detection;
|
||||
|
||||
// network.c -batch inference
|
||||
typedef struct detNumPair {
|
||||
int num;
|
||||
detection *dets;
|
||||
} detNumPair, *pdetNumPair;
|
||||
|
||||
// matrix.h
|
||||
typedef struct matrix {
|
||||
int rows, cols;
|
||||
|
@ -745,6 +745,22 @@ int num_detections(network *net, float thresh)
|
||||
return s;
|
||||
}
|
||||
|
||||
int num_detections_batch(network *net, float thresh, int batch)
|
||||
{
|
||||
int i;
|
||||
int s = 0;
|
||||
for (i = 0; i < net->n; ++i) {
|
||||
layer l = net->layers[i];
|
||||
if (l.type == YOLO) {
|
||||
s += yolo_num_detections_batch(l, thresh, batch);
|
||||
}
|
||||
if (l.type == DETECTION || l.type == REGION) {
|
||||
s += l.w*l.h*l.n;
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
detection *make_network_boxes(network *net, float thresh, int *num)
|
||||
{
|
||||
layer l = net->layers[net->n - 1];
|
||||
@ -763,6 +779,22 @@ detection *make_network_boxes(network *net, float thresh, int *num)
|
||||
return dets;
|
||||
}
|
||||
|
||||
detection *make_network_boxes_batch(network *net, float thresh, int *num, int batch)
|
||||
{
|
||||
int i;
|
||||
layer l = net->layers[net->n - 1];
|
||||
int nboxes = num_detections_batch(net, thresh, batch);
|
||||
assert(num != NULL);
|
||||
*num = nboxes;
|
||||
detection* dets = (detection*)calloc(nboxes, sizeof(detection));
|
||||
for (i = 0; i < nboxes; ++i) {
|
||||
dets[i].prob = (float*)calloc(l.classes, sizeof(float));
|
||||
if (l.coords > 4) {
|
||||
dets[i].mask = (float*)calloc(l.coords - 4, sizeof(float));
|
||||
}
|
||||
}
|
||||
return dets;
|
||||
}
|
||||
|
||||
void custom_get_region_detections(layer l, int w, int h, int net_w, int net_h, float thresh, int *map, float hier, int relative, detection *dets, int letter)
|
||||
{
|
||||
@ -818,6 +850,33 @@ void fill_network_boxes(network *net, int w, int h, float thresh, float hier, in
|
||||
}
|
||||
}
|
||||
|
||||
void fill_network_boxes_batch(network *net, int w, int h, float thresh, float hier, int *map, int relative, detection *dets, int letter, int batch)
|
||||
{
|
||||
int prev_classes = -1;
|
||||
int j;
|
||||
for (j = 0; j < net->n; ++j) {
|
||||
layer l = net->layers[j];
|
||||
if (l.type == YOLO) {
|
||||
int count = get_yolo_detections_batch(l, w, h, net->w, net->h, thresh, map, relative, dets, letter, batch);
|
||||
dets += count;
|
||||
if (prev_classes < 0) prev_classes = l.classes;
|
||||
else if (prev_classes != l.classes) {
|
||||
printf(" Error: Different [yolo] layers have different number of classes = %d and %d - check your cfg-file! \n",
|
||||
prev_classes, l.classes);
|
||||
}
|
||||
}
|
||||
if (l.type == REGION) {
|
||||
custom_get_region_detections(l, w, h, net->w, net->h, thresh, map, hier, relative, dets, letter);
|
||||
//get_region_detections(l, w, h, net->w, net->h, thresh, map, hier, relative, dets);
|
||||
dets += l.w*l.h*l.n;
|
||||
}
|
||||
if (l.type == DETECTION) {
|
||||
get_detection_detections(l, w, h, thresh, dets);
|
||||
dets += l.w*l.h*l.n;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
detection *get_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, int *num, int letter)
|
||||
{
|
||||
detection *dets = make_network_boxes(net, thresh, num);
|
||||
@ -836,6 +895,14 @@ void free_detections(detection *dets, int n)
|
||||
free(dets);
|
||||
}
|
||||
|
||||
void free_batch_detections(detNumPair *detNumPairs, int n)
|
||||
{
|
||||
int i;
|
||||
for(i=0; i<n; ++i)
|
||||
free_detections(detNumPairs[i].dets, detNumPairs[i].num);
|
||||
free(detNumPairs);
|
||||
}
|
||||
|
||||
// JSON format:
|
||||
//{
|
||||
// "frame_id":8990,
|
||||
@ -911,6 +978,20 @@ float *network_predict_image(network *net, image im)
|
||||
return p;
|
||||
}
|
||||
|
||||
detNumPair* network_predict_batch(network *net, image im, int batch_size, int w, int h, float thresh, float hier, int *map, int relative, int letter)
|
||||
{
|
||||
network_predict(*net, im.data);
|
||||
detNumPair *pdets = (struct detNumPair *)calloc(batch_size, sizeof(detNumPair));
|
||||
int num;
|
||||
for(int batch=0; batch<batch_size; batch++){
|
||||
detection *dets = make_network_boxes_batch(net, thresh, &num, batch);
|
||||
fill_network_boxes_batch(net, w, h, thresh, hier, map, relative, dets, letter, batch);
|
||||
pdets[batch].num = num;
|
||||
pdets[batch].dets = dets;
|
||||
}
|
||||
return pdets;
|
||||
}
|
||||
|
||||
float *network_predict_image_letterbox(network *net, image im)
|
||||
{
|
||||
//image imr = letterbox_image(im, net->w, net->h);
|
||||
|
@ -729,6 +729,21 @@ int yolo_num_detections(layer l, float thresh)
|
||||
return count;
|
||||
}
|
||||
|
||||
int yolo_num_detections_batch(layer l, float thresh, int batch)
|
||||
{
|
||||
int i, n;
|
||||
int count = 0;
|
||||
for (i = 0; i < l.w*l.h; ++i){
|
||||
for(n = 0; n < l.n; ++n){
|
||||
int obj_index = entry_index(l, batch, n*l.w*l.h + i, 4);
|
||||
if(l.output[obj_index] > thresh){
|
||||
++count;
|
||||
}
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
void avg_flipped_yolo(layer l)
|
||||
{
|
||||
int i,j,n,z;
|
||||
@ -790,6 +805,38 @@ int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh,
|
||||
return count;
|
||||
}
|
||||
|
||||
int get_yolo_detections_batch(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter, int batch)
|
||||
{
|
||||
int i,j,n;
|
||||
float *predictions = l.output;
|
||||
//if (l.batch == 2) avg_flipped_yolo(l);
|
||||
int count = 0;
|
||||
for (i = 0; i < l.w*l.h; ++i){
|
||||
int row = i / l.w;
|
||||
int col = i % l.w;
|
||||
for(n = 0; n < l.n; ++n){
|
||||
int obj_index = entry_index(l, batch, n*l.w*l.h + i, 4);
|
||||
float objectness = predictions[obj_index];
|
||||
//if(objectness <= thresh) continue; // incorrect behavior for Nan values
|
||||
if (objectness > thresh) {
|
||||
//printf("\n objectness = %f, thresh = %f, i = %d, n = %d \n", objectness, thresh, i, n);
|
||||
int box_index = entry_index(l, batch, n*l.w*l.h + i, 0);
|
||||
dets[count].bbox = get_yolo_box(predictions, l.biases, l.mask[n], box_index, col, row, l.w, l.h, netw, neth, l.w*l.h);
|
||||
dets[count].objectness = objectness;
|
||||
dets[count].classes = l.classes;
|
||||
for (j = 0; j < l.classes; ++j) {
|
||||
int class_index = entry_index(l, batch, n*l.w*l.h + i, 4 + 1 + j);
|
||||
float prob = objectness*predictions[class_index];
|
||||
dets[count].prob[j] = (prob > thresh) ? prob : 0;
|
||||
}
|
||||
++count;
|
||||
}
|
||||
}
|
||||
}
|
||||
correct_yolo_boxes(dets, count, w, h, netw, neth, relative, letter);
|
||||
return count;
|
||||
}
|
||||
|
||||
#ifdef GPU
|
||||
|
||||
void forward_yolo_layer_gpu(const layer l, network_state state)
|
||||
|
@ -13,7 +13,9 @@ void forward_yolo_layer(const layer l, network_state state);
|
||||
void backward_yolo_layer(const layer l, network_state state);
|
||||
void resize_yolo_layer(layer *l, int w, int h);
|
||||
int yolo_num_detections(layer l, float thresh);
|
||||
int yolo_num_detections_batch(layer l, float thresh, int batch);
|
||||
int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter);
|
||||
int get_yolo_detections_batch(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter, int batch);
|
||||
void correct_yolo_boxes(detection *dets, int n, int w, int h, int netw, int neth, int relative, int letter);
|
||||
|
||||
#ifdef GPU
|
||||
|
Reference in New Issue
Block a user