mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
lots of comparator stuff
This commit is contained in:
parent
40cc104639
commit
c40cdeb402
2
Makefile
2
Makefile
@ -34,7 +34,7 @@ CFLAGS+= -DGPU
|
||||
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
|
||||
endif
|
||||
|
||||
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o layer.o compare.o swag.o
|
||||
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o layer.o compare.o swag.o classifier.o
|
||||
ifeq ($(GPU), 1)
|
||||
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
|
||||
endif
|
||||
|
@ -104,6 +104,7 @@ output=1000
|
||||
activation=leaky
|
||||
|
||||
[softmax]
|
||||
groups=1
|
||||
|
||||
[cost]
|
||||
type=sse
|
||||
|
@ -135,6 +135,7 @@ void get_probs(float *predictions, int total, int classes, int inc, float **prob
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void get_boxes(float *predictions, int n, int num_boxes, int per_box, box *boxes)
|
||||
{
|
||||
int i,j;
|
||||
|
110
src/compare.c
110
src/compare.c
@ -150,17 +150,20 @@ typedef struct {
|
||||
network net;
|
||||
char *filename;
|
||||
int class;
|
||||
int classes;
|
||||
float elo;
|
||||
float *elos;
|
||||
} sortable_bbox;
|
||||
|
||||
int total_compares = 0;
|
||||
int current_class = 0;
|
||||
|
||||
int elo_comparator(const void*a, const void *b)
|
||||
{
|
||||
sortable_bbox box1 = *(sortable_bbox*)a;
|
||||
sortable_bbox box2 = *(sortable_bbox*)b;
|
||||
if(box1.elo == box2.elo) return 0;
|
||||
if(box1.elo > box2.elo) return -1;
|
||||
if(box1.elos[current_class] == box2.elos[current_class]) return 0;
|
||||
if(box1.elos[current_class] > box2.elos[current_class]) return -1;
|
||||
return 1;
|
||||
}
|
||||
|
||||
@ -188,16 +191,38 @@ int bbox_comparator(const void *a, const void *b)
|
||||
return -1;
|
||||
}
|
||||
|
||||
void bbox_fight(sortable_bbox *a, sortable_bbox *b)
|
||||
void bbox_update(sortable_bbox *a, sortable_bbox *b, int class, int result)
|
||||
{
|
||||
int k = 32;
|
||||
int result = bbox_comparator(a,b);
|
||||
float EA = 1./(1+pow(10, (b->elo - a->elo)/400.));
|
||||
float EB = 1./(1+pow(10, (a->elo - b->elo)/400.));
|
||||
float SA = 1.*(result > 0);
|
||||
float SB = 1.*(result < 0);
|
||||
a->elo = a->elo + k*(SA - EA);
|
||||
b->elo = b->elo + k*(SB - EB);
|
||||
float EA = 1./(1+pow(10, (b->elos[class] - a->elos[class])/400.));
|
||||
float EB = 1./(1+pow(10, (a->elos[class] - b->elos[class])/400.));
|
||||
float SA = result ? 1 : 0;
|
||||
float SB = result ? 0 : 1;
|
||||
a->elos[class] += k*(SA - EA);
|
||||
b->elos[class] += k*(SB - EB);
|
||||
}
|
||||
|
||||
void bbox_fight(network net, sortable_bbox *a, sortable_bbox *b, int classes, int class)
|
||||
{
|
||||
image im1 = load_image_color(a->filename, net.w, net.h);
|
||||
image im2 = load_image_color(b->filename, net.w, net.h);
|
||||
float *X = calloc(net.w*net.h*net.c, sizeof(float));
|
||||
memcpy(X, im1.data, im1.w*im1.h*im1.c*sizeof(float));
|
||||
memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c*sizeof(float));
|
||||
float *predictions = network_predict(net, X);
|
||||
++total_compares;
|
||||
|
||||
int i;
|
||||
for(i = 0; i < classes; ++i){
|
||||
if(class < 0 || class == i){
|
||||
int result = predictions[i*2] > predictions[i*2+1];
|
||||
bbox_update(a, b, i, result);
|
||||
}
|
||||
}
|
||||
|
||||
free_image(im1);
|
||||
free_image(im2);
|
||||
free(X);
|
||||
}
|
||||
|
||||
void SortMaster3000(char *filename, char *weightfile)
|
||||
@ -233,7 +258,8 @@ void SortMaster3000(char *filename, char *weightfile)
|
||||
|
||||
void BattleRoyaleWithCheese(char *filename, char *weightfile)
|
||||
{
|
||||
int i = 0;
|
||||
int classes = 20;
|
||||
int i,j;
|
||||
network net = parse_network_cfg(filename);
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
@ -241,47 +267,67 @@ void BattleRoyaleWithCheese(char *filename, char *weightfile)
|
||||
srand(time(0));
|
||||
set_batch_network(&net, 1);
|
||||
|
||||
//list *plist = get_paths("data/compare.sort.list");
|
||||
list *plist = get_paths("data/compare.cat.list");
|
||||
list *plist = get_paths("data/compare.sort.list");
|
||||
//list *plist = get_paths("data/compare.small.list");
|
||||
//list *plist = get_paths("data/compare.cat.list");
|
||||
//list *plist = get_paths("data/compare.val.old");
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
int N = plist->size;
|
||||
int total = N;
|
||||
free_list(plist);
|
||||
sortable_bbox *boxes = calloc(N, sizeof(sortable_bbox));
|
||||
printf("Battling %d boxes...\n", N);
|
||||
for(i = 0; i < N; ++i){
|
||||
boxes[i].filename = paths[i];
|
||||
boxes[i].net = net;
|
||||
boxes[i].class = 7;
|
||||
boxes[i].elo = 1500;
|
||||
boxes[i].classes = classes;
|
||||
boxes[i].elos = calloc(classes, sizeof(float));;
|
||||
for(j = 0; j < classes; ++j){
|
||||
boxes[i].elos[j] = 1500;
|
||||
}
|
||||
}
|
||||
int round;
|
||||
clock_t time=clock();
|
||||
for(round = 1; round <= 500; ++round){
|
||||
for(round = 1; round <= 4; ++round){
|
||||
clock_t round_time=clock();
|
||||
printf("Round: %d\n", round);
|
||||
qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
|
||||
sorta_shuffle(boxes, N, sizeof(sortable_bbox), 10);
|
||||
shuffle(boxes, N, sizeof(sortable_bbox));
|
||||
for(i = 0; i < N/2; ++i){
|
||||
bbox_fight(boxes+i*2, boxes+i*2+1);
|
||||
}
|
||||
if(round >= 4 && 0){
|
||||
qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
|
||||
if(round == 4){
|
||||
N = N/2;
|
||||
}else{
|
||||
N = (N*9/10)/2*2;
|
||||
}
|
||||
bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, -1);
|
||||
}
|
||||
printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N);
|
||||
}
|
||||
qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
|
||||
FILE *outfp = fopen("results/battle.log", "w");
|
||||
for(i = 0; i < N; ++i){
|
||||
fprintf(outfp, "%s %f\n", boxes[i].filename, boxes[i].elo);
|
||||
|
||||
int class;
|
||||
|
||||
for (class = 0; class < classes; ++class){
|
||||
|
||||
N = total;
|
||||
current_class = class;
|
||||
qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
|
||||
N /= 2;
|
||||
|
||||
for(round = 1; round <= 20; ++round){
|
||||
clock_t round_time=clock();
|
||||
printf("Round: %d\n", round);
|
||||
|
||||
sorta_shuffle(boxes, N, sizeof(sortable_bbox), 10);
|
||||
for(i = 0; i < N/2; ++i){
|
||||
bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, class);
|
||||
}
|
||||
qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
|
||||
N = (N*9/10)/2*2;
|
||||
|
||||
printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N);
|
||||
}
|
||||
char buff[256];
|
||||
sprintf(buff, "results/battle_%d.log", class);
|
||||
FILE *outfp = fopen(buff, "w");
|
||||
for(i = 0; i < N; ++i){
|
||||
fprintf(outfp, "%s %f\n", boxes[i].filename, boxes[i].elos[class]);
|
||||
}
|
||||
fclose(outfp);
|
||||
}
|
||||
fclose(outfp);
|
||||
printf("Tournament in %d compares, %f secs\n", total_compares, sec(clock()-time));
|
||||
}
|
||||
|
||||
|
@ -61,7 +61,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
||||
|
||||
l.biases = calloc(n, sizeof(float));
|
||||
l.bias_updates = calloc(n, sizeof(float));
|
||||
//float scale = 1./sqrt(size*size*c);
|
||||
// float scale = 1./sqrt(size*size*c);
|
||||
float scale = sqrt(2./(size*size*c));
|
||||
for(i = 0; i < c*n*size*size; ++i) l.filters[i] = 2*scale*rand_uniform() - scale;
|
||||
for(i = 0; i < n; ++i){
|
||||
|
@ -20,6 +20,7 @@ extern void run_captcha(int argc, char **argv);
|
||||
extern void run_nightmare(int argc, char **argv);
|
||||
extern void run_dice(int argc, char **argv);
|
||||
extern void run_compare(int argc, char **argv);
|
||||
extern void run_classifier(int argc, char **argv);
|
||||
|
||||
void change_rate(char *filename, float scale, float add)
|
||||
{
|
||||
@ -183,6 +184,8 @@ int main(int argc, char **argv)
|
||||
run_swag(argc, argv);
|
||||
} else if (0 == strcmp(argv[1], "coco")){
|
||||
run_coco(argc, argv);
|
||||
} else if (0 == strcmp(argv[1], "classifier")){
|
||||
run_classifier(argc, argv);
|
||||
} else if (0 == strcmp(argv[1], "compare")){
|
||||
run_compare(argc, argv);
|
||||
} else if (0 == strcmp(argv[1], "dice")){
|
||||
|
@ -366,7 +366,7 @@ void free_data(data d)
|
||||
}
|
||||
}
|
||||
|
||||
data load_data_region(int n, char **paths, int m, int w, int h, int size, int classes)
|
||||
data load_data_region(int n, char **paths, int m, int w, int h, int size, int classes, float jitter)
|
||||
{
|
||||
char **random_paths = get_random_paths(paths, n, m);
|
||||
int i;
|
||||
@ -385,8 +385,8 @@ data load_data_region(int n, char **paths, int m, int w, int h, int size, int cl
|
||||
int oh = orig.h;
|
||||
int ow = orig.w;
|
||||
|
||||
int dw = ow/10;
|
||||
int dh = oh/10;
|
||||
int dw = (ow*jitter);
|
||||
int dh = (oh*jitter);
|
||||
|
||||
int pleft = (rand_uniform() * 2*dw - dw);
|
||||
int pright = (rand_uniform() * 2*dw - dw);
|
||||
@ -556,7 +556,7 @@ void *load_thread(void *ptr)
|
||||
} else if (a.type == WRITING_DATA){
|
||||
*a.d = load_data_writing(a.paths, a.n, a.m, a.w, a.h, a.out_w, a.out_h);
|
||||
} else if (a.type == REGION_DATA){
|
||||
*a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes);
|
||||
*a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter);
|
||||
} else if (a.type == COMPARE_DATA){
|
||||
*a.d = load_data_compare(a.n, a.paths, a.m, a.classes, a.w, a.h);
|
||||
} else if (a.type == IMAGE_DATA){
|
||||
|
@ -44,6 +44,7 @@ typedef struct load_args{
|
||||
int num_boxes;
|
||||
int classes;
|
||||
int background;
|
||||
float jitter;
|
||||
data *d;
|
||||
image *im;
|
||||
image *resized;
|
||||
|
@ -61,7 +61,7 @@ void validate_dice(char *filename, char *weightfile)
|
||||
free_list(plist);
|
||||
|
||||
data val = load_data(paths, m, 0, labels, 6, net.w, net.h);
|
||||
float *acc = network_accuracies(net, val);
|
||||
float *acc = network_accuracies(net, val, 2);
|
||||
printf("Validation Accuracy: %f, %d images\n", acc[0], m);
|
||||
free_data(val);
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ void validate_imagenet(char *filename, char *weightfile)
|
||||
printf("Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time));
|
||||
|
||||
time=clock();
|
||||
float *acc = network_accuracies(net, val);
|
||||
float *acc = network_accuracies(net, val, 5);
|
||||
avg_acc += acc[0];
|
||||
avg_top5 += acc[1];
|
||||
printf("%d: top1: %f, top5: %f, %lf seconds, %d images\n", i, avg_acc/i, avg_top5/i, sec(clock()-time), val.X.rows);
|
||||
|
@ -29,6 +29,9 @@ typedef struct {
|
||||
COST_TYPE cost_type;
|
||||
int batch;
|
||||
int forced;
|
||||
int object_logistic;
|
||||
int class_logistic;
|
||||
int coord_logistic;
|
||||
int inputs;
|
||||
int outputs;
|
||||
int truths;
|
||||
@ -45,6 +48,7 @@ typedef struct {
|
||||
int sqrt;
|
||||
int flip;
|
||||
float angle;
|
||||
float jitter;
|
||||
float saturation;
|
||||
float exposure;
|
||||
int softmax;
|
||||
|
@ -540,12 +540,12 @@ float network_accuracy(network net, data d)
|
||||
return acc;
|
||||
}
|
||||
|
||||
float *network_accuracies(network net, data d)
|
||||
float *network_accuracies(network net, data d, int n)
|
||||
{
|
||||
static float acc[2];
|
||||
matrix guess = network_predict_data(net, d);
|
||||
acc[0] = matrix_topk_accuracy(d.y, guess,1);
|
||||
acc[1] = matrix_topk_accuracy(d.y, guess,5);
|
||||
acc[0] = matrix_topk_accuracy(d.y, guess, 1);
|
||||
acc[1] = matrix_topk_accuracy(d.y, guess, n);
|
||||
free_matrix(guess);
|
||||
return acc;
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ float train_network_sgd(network net, data d, int n);
|
||||
matrix network_predict_data(network net, data test);
|
||||
float *network_predict(network net, float *input);
|
||||
float network_accuracy(network net, data d);
|
||||
float *network_accuracies(network net, data d);
|
||||
float *network_accuracies(network net, data d, int n);
|
||||
float network_accuracy_multi(network net, data d, int n);
|
||||
void top_predictions(network net, int n, int *index);
|
||||
float *get_network_output(network net);
|
||||
|
@ -3,6 +3,24 @@
|
||||
#include <string.h>
|
||||
#include "option_list.h"
|
||||
|
||||
int read_option(char *s, list *options)
|
||||
{
|
||||
size_t i;
|
||||
size_t len = strlen(s);
|
||||
char *val = 0;
|
||||
for(i = 0; i < len; ++i){
|
||||
if(s[i] == '='){
|
||||
s[i] = '\0';
|
||||
val = s+i+1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(i == len-1) return 0;
|
||||
char *key = s;
|
||||
option_insert(options, key, val);
|
||||
return 1;
|
||||
}
|
||||
|
||||
void option_insert(list *l, char *key, char *val)
|
||||
{
|
||||
kvp *p = malloc(sizeof(kvp));
|
||||
|
@ -9,6 +9,7 @@ typedef struct{
|
||||
} kvp;
|
||||
|
||||
|
||||
int read_option(char *s, list *options);
|
||||
void option_insert(list *l, char *key, char *val);
|
||||
char *option_find(list *l, char *key);
|
||||
char *option_find_str(list *l, char *key, char *def);
|
||||
|
23
src/parser.c
23
src/parser.c
@ -186,11 +186,16 @@ region_layer parse_region(list *options, size_params params)
|
||||
layer.softmax = option_find_int(options, "softmax", 0);
|
||||
layer.sqrt = option_find_int(options, "sqrt", 0);
|
||||
|
||||
layer.object_logistic = option_find_int(options, "object_logistic", 0);
|
||||
layer.class_logistic = option_find_int(options, "class_logistic", 0);
|
||||
layer.coord_logistic = option_find_int(options, "coord_logistic", 0);
|
||||
|
||||
layer.coord_scale = option_find_float(options, "coord_scale", 1);
|
||||
layer.forced = option_find_int(options, "forced", 0);
|
||||
layer.object_scale = option_find_float(options, "object_scale", 1);
|
||||
layer.noobject_scale = option_find_float(options, "noobject_scale", 1);
|
||||
layer.class_scale = option_find_float(options, "class_scale", 1);
|
||||
layer.jitter = option_find_float(options, "jitter", .1);
|
||||
return layer;
|
||||
}
|
||||
|
||||
@ -532,24 +537,6 @@ int is_route(section *s)
|
||||
return (strcmp(s->type, "[route]")==0);
|
||||
}
|
||||
|
||||
int read_option(char *s, list *options)
|
||||
{
|
||||
size_t i;
|
||||
size_t len = strlen(s);
|
||||
char *val = 0;
|
||||
for(i = 0; i < len; ++i){
|
||||
if(s[i] == '='){
|
||||
s[i] = '\0';
|
||||
val = s+i+1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(i == len-1) return 0;
|
||||
char *key = s;
|
||||
option_insert(options, key, val);
|
||||
return 1;
|
||||
}
|
||||
|
||||
list *read_cfg(char *filename)
|
||||
{
|
||||
FILE *file = fopen(filename, "r");
|
||||
|
@ -57,6 +57,28 @@ void forward_region_layer(const region_layer l, network_state state)
|
||||
activate_array(l.output + index + offset, locations*l.n*(1+l.coords), LOGISTIC);
|
||||
}
|
||||
}
|
||||
if (l.object_logistic) {
|
||||
for(b = 0; b < l.batch; ++b){
|
||||
int index = b*l.inputs;
|
||||
int p_index = index + locations*l.classes;
|
||||
activate_array(l.output + p_index, locations*l.n, LOGISTIC);
|
||||
}
|
||||
}
|
||||
|
||||
if (l.coord_logistic) {
|
||||
for(b = 0; b < l.batch; ++b){
|
||||
int index = b*l.inputs;
|
||||
int coord_index = index + locations*(l.classes + l.n);
|
||||
activate_array(l.output + coord_index, locations*l.n*l.coords, LOGISTIC);
|
||||
}
|
||||
}
|
||||
|
||||
if (l.class_logistic) {
|
||||
for(b = 0; b < l.batch; ++b){
|
||||
int class_index = b*l.inputs;
|
||||
activate_array(l.output + class_index, locations*l.classes, LOGISTIC);
|
||||
}
|
||||
}
|
||||
|
||||
if(state.train){
|
||||
float avg_iou = 0;
|
||||
@ -85,7 +107,6 @@ void forward_region_layer(const region_layer l, network_state state)
|
||||
float best_rmse = 20;
|
||||
|
||||
if (!is_obj){
|
||||
//printf(".");
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -113,6 +134,7 @@ void forward_region_layer(const region_layer l, network_state state)
|
||||
}
|
||||
|
||||
float iou = box_iou(out, truth);
|
||||
//iou = 0;
|
||||
float rmse = box_rmse(out, truth);
|
||||
if(best_iou > 0 || iou > 0){
|
||||
if(iou > best_iou){
|
||||
@ -175,6 +197,20 @@ void forward_region_layer(const region_layer l, network_state state)
|
||||
gradient_array(l.output + index + locations*l.classes, locations*l.n*(1+l.coords),
|
||||
LOGISTIC, l.delta + index + locations*l.classes);
|
||||
}
|
||||
if (l.object_logistic) {
|
||||
int p_index = index + locations*l.classes;
|
||||
gradient_array(l.output + p_index, locations*l.n, LOGISTIC, l.delta + p_index);
|
||||
}
|
||||
|
||||
if (l.class_logistic) {
|
||||
int class_index = index;
|
||||
gradient_array(l.output + class_index, locations*l.classes, LOGISTIC, l.delta + class_index);
|
||||
}
|
||||
|
||||
if (l.coord_logistic) {
|
||||
int coord_index = index + locations*(l.classes + l.n);
|
||||
gradient_array(l.output + coord_index, locations*l.n*l.coords, LOGISTIC, l.delta + coord_index);
|
||||
}
|
||||
//printf("\n");
|
||||
}
|
||||
printf("Region Avg IOU: %f, Pos Cat: %f, All Cat: %f, Pos Obj: %f, Any Obj: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_allcat/(count*l.classes), avg_obj/count, avg_anyobj/(l.batch*locations*l.n), count);
|
||||
|
99
src/swag.c
99
src/swag.c
@ -73,6 +73,7 @@ void train_swag(char *cfgfile, char *weightfile)
|
||||
|
||||
int side = l.side;
|
||||
int classes = l.classes;
|
||||
float jitter = l.jitter;
|
||||
|
||||
list *plist = get_paths(train_images);
|
||||
//int N = plist->size;
|
||||
@ -85,6 +86,7 @@ void train_swag(char *cfgfile, char *weightfile)
|
||||
args.n = imgs;
|
||||
args.m = plist->size;
|
||||
args.classes = classes;
|
||||
args.jitter = jitter;
|
||||
args.num_boxes = side;
|
||||
args.d = &buffer;
|
||||
args.type = REGION_DATA;
|
||||
@ -127,7 +129,7 @@ void train_swag(char *cfgfile, char *weightfile)
|
||||
save_weights(net, buff);
|
||||
}
|
||||
|
||||
void convert_swag_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes)
|
||||
void convert_swag_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness)
|
||||
{
|
||||
int i,j,n;
|
||||
//int per_cell = 5*num+classes;
|
||||
@ -148,6 +150,9 @@ void convert_swag_detections(float *predictions, int classes, int num, int squar
|
||||
float prob = scale*predictions[class_index+j];
|
||||
probs[index][j] = (prob > thresh) ? prob : 0;
|
||||
}
|
||||
if(only_objectness){
|
||||
probs[index][0] = scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -250,7 +255,7 @@ void validate_swag(char *cfgfile, char *weightfile)
|
||||
float *predictions = network_predict(net, X);
|
||||
int w = val[t].w;
|
||||
int h = val[t].h;
|
||||
convert_swag_detections(predictions, classes, l.n, square, side, w, h, thresh, probs, boxes);
|
||||
convert_swag_detections(predictions, classes, l.n, square, side, w, h, thresh, probs, boxes, 0);
|
||||
if (nms) do_nms(boxes, probs, side*side*l.n, classes, iou_thresh);
|
||||
print_swag_detections(fps, id, boxes, probs, side*side*l.n, classes, w, h);
|
||||
free(id);
|
||||
@ -261,6 +266,95 @@ void validate_swag(char *cfgfile, char *weightfile)
|
||||
fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
|
||||
}
|
||||
|
||||
void validate_swag_recall(char *cfgfile, char *weightfile)
|
||||
{
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
set_batch_network(&net, 1);
|
||||
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
||||
srand(time(0));
|
||||
|
||||
char *base = "results/comp4_det_test_";
|
||||
list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt");
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
|
||||
layer l = net.layers[net.n-1];
|
||||
int classes = l.classes;
|
||||
int square = l.sqrt;
|
||||
int side = l.side;
|
||||
|
||||
int j, k;
|
||||
FILE **fps = calloc(classes, sizeof(FILE *));
|
||||
for(j = 0; j < classes; ++j){
|
||||
char buff[1024];
|
||||
snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
|
||||
fps[j] = fopen(buff, "w");
|
||||
}
|
||||
box *boxes = calloc(side*side*l.n, sizeof(box));
|
||||
float **probs = calloc(side*side*l.n, sizeof(float *));
|
||||
for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
|
||||
|
||||
int m = plist->size;
|
||||
int i=0;
|
||||
|
||||
float thresh = .001;
|
||||
int nms = 0;
|
||||
float iou_thresh = .5;
|
||||
float nms_thresh = .5;
|
||||
|
||||
int total = 0;
|
||||
int correct = 0;
|
||||
int proposals = 0;
|
||||
float avg_iou = 0;
|
||||
|
||||
for(i = 0; i < m; ++i){
|
||||
char *path = paths[i];
|
||||
image orig = load_image_color(path, 0, 0);
|
||||
image sized = resize_image(orig, net.w, net.h);
|
||||
char *id = basecfg(path);
|
||||
float *predictions = network_predict(net, sized.data);
|
||||
int w = orig.w;
|
||||
int h = orig.h;
|
||||
convert_swag_detections(predictions, classes, l.n, square, side, 1, 1, thresh, probs, boxes, 1);
|
||||
if (nms) do_nms(boxes, probs, side*side*l.n, 1, nms_thresh);
|
||||
|
||||
char *labelpath = find_replace(path, "images", "labels");
|
||||
labelpath = find_replace(labelpath, "JPEGImages", "labels");
|
||||
labelpath = find_replace(labelpath, ".jpg", ".txt");
|
||||
labelpath = find_replace(labelpath, ".JPEG", ".txt");
|
||||
|
||||
int num_labels = 0;
|
||||
box_label *truth = read_boxes(labelpath, &num_labels);
|
||||
for(k = 0; k < side*side*l.n; ++k){
|
||||
if(probs[k][0] > thresh){
|
||||
++proposals;
|
||||
}
|
||||
}
|
||||
for (j = 0; j < num_labels; ++j) {
|
||||
++total;
|
||||
box t = {truth[j].x, truth[j].y, truth[j].w, truth[j].h};
|
||||
float best_iou = 0;
|
||||
for(k = 0; k < side*side*l.n; ++k){
|
||||
float iou = box_iou(boxes[k], t);
|
||||
if(probs[k][0] > thresh && iou > best_iou){
|
||||
best_iou = iou;
|
||||
}
|
||||
}
|
||||
avg_iou += best_iou;
|
||||
if(best_iou > iou_thresh){
|
||||
++correct;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "%5d %5d %5d\tRPs/Img: %.2f\tIOU: %.2f%%\tRecall:%.2f%%\n", i, correct, total, (float)proposals/(i+1), avg_iou*100/total, 100.*correct/total);
|
||||
free(id);
|
||||
free_image(orig);
|
||||
free_image(sized);
|
||||
}
|
||||
}
|
||||
|
||||
void test_swag(char *cfgfile, char *weightfile, char *filename, float thresh)
|
||||
{
|
||||
|
||||
@ -316,4 +410,5 @@ void run_swag(int argc, char **argv)
|
||||
if(0==strcmp(argv[2], "test")) test_swag(cfg, weights, filename, thresh);
|
||||
else if(0==strcmp(argv[2], "train")) train_swag(cfg, weights);
|
||||
else if(0==strcmp(argv[2], "valid")) validate_swag(cfg, weights);
|
||||
else if(0==strcmp(argv[2], "recall")) validate_swag_recall(cfg, weights);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user