diff --git a/Makefile b/Makefile index 581b6d77..40bfcecb 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ CFLAGS+= -DGPU LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand endif -OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o layer.o compare.o yoloplus.o +OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o layer.o compare.o swag.o ifeq ($(GPU), 1) OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o endif diff --git a/cfg/darknet.cfg b/cfg/darknet.cfg index eb1310a7..64aab1ed 100644 --- a/cfg/darknet.cfg +++ b/cfg/darknet.cfg @@ -1,15 +1,17 @@ [net] -batch=128 +batch=256 subdivisions=1 height=256 width=256 channels=3 momentum=0.9 decay=0.0005 + learning_rate=0.01 -policy=poly -power=.5 -max_batches=600000 +policy=step +scale=.1 +step=100000 +max_batches=400000 [crop] crop_height=224 diff --git a/src/coco.c b/src/coco.c index 87f3dcaa..234f3426 100644 --- a/src/coco.c +++ b/src/coco.c @@ -111,20 +111,6 @@ void train_coco(char *cfgfile, char *weightfile) avg_loss = avg_loss*.9 + loss*.1; printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs); - if((i-1)*imgs <= N && i*imgs > N){ - fprintf(stderr, "First stage done\n"); - net.learning_rate *= 10; - char buff[256]; - sprintf(buff, "%s/%s_first_stage.weights", backup_directory, base); - save_weights(net, buff); - } - - if((i-1)*imgs <= 80*N && i*imgs > N*80){ - fprintf(stderr, "Second stage done.\n"); - char buff[256]; - sprintf(buff, "%s/%s_second_stage.weights", backup_directory, base); - save_weights(net, buff); - } if(i%1000==0){ char buff[256]; sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); diff --git a/src/compare.c b/src/compare.c index 9b6d6bf2..0408f800 100644 --- a/src/compare.c +++ b/src/compare.c @@ -175,8 +175,8 @@ int bbox_comparator(const void *a, const void *b) image im1 = load_image_color(box1.filename, net.w, net.h); image im2 = load_image_color(box2.filename, net.w, net.h); float *X = calloc(net.w*net.h*net.c, sizeof(float)); - memcpy(X, im1.data, im1.w*im1.h*im1.c); - memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c); + memcpy(X, im1.data, im1.w*im1.h*im1.c*sizeof(float)); + memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c*sizeof(float)); float *predictions = network_predict(net, X); free_image(im1); diff --git a/src/darknet.c b/src/darknet.c index 833f89ec..9632f914 100644 --- a/src/darknet.c +++ b/src/darknet.c @@ -13,7 +13,7 @@ extern void run_imagenet(int argc, char **argv); extern void run_yolo(int argc, char **argv); -extern void run_yoloplus(int argc, char **argv); +extern void run_swag(int argc, char **argv); extern void run_coco(int argc, char **argv); extern void run_writing(int argc, char **argv); extern void run_captcha(int argc, char **argv); @@ -179,8 +179,8 @@ int main(int argc, char **argv) average(argc, argv); } else if (0 == strcmp(argv[1], "yolo")){ run_yolo(argc, argv); - } else if (0 == strcmp(argv[1], "yoloplus")){ - run_yoloplus(argc, argv); + } else if (0 == strcmp(argv[1], "swag")){ + run_swag(argc, argv); } else if (0 == strcmp(argv[1], "coco")){ run_coco(argc, argv); } else if (0 == strcmp(argv[1], "compare")){ diff --git a/src/data.c b/src/data.c index 003338e6..17772c1c 100644 --- a/src/data.c +++ b/src/data.c @@ -176,8 +176,10 @@ void fill_truth_region(char *path, float *truth, int classes, int num_boxes, int int index = (col+row*num_boxes)*(5+classes); if (truth[index]) continue; truth[index++] = 1; - if (classes) truth[index+id] = 1; + + if (id < classes) truth[index+id] = 1; index += classes; + truth[index++] = x; truth[index++] = y; truth[index++] = w; diff --git a/src/layer.h b/src/layer.h index 77d7f089..1eb73512 100644 --- a/src/layer.h +++ b/src/layer.h @@ -30,6 +30,7 @@ typedef struct { int batch; int inputs; int outputs; + int truths; int h,w,c; int out_h, out_w, out_c; int n; @@ -40,10 +41,12 @@ typedef struct { int pad; int crop_width; int crop_height; + int sqrt; int flip; float angle; float saturation; float exposure; + int softmax; int classes; int coords; int background; diff --git a/src/network.c b/src/network.c index af4861a3..80ee2915 100644 --- a/src/network.c +++ b/src/network.c @@ -48,7 +48,7 @@ float get_current_rate(network net) case POLY: return net.learning_rate * pow(1 - (float)batch_num / net.max_batches, net.power); case SIG: - return net.learning_rate * (1/(1+exp(net.gamma*(batch_num - net.step)))); + return net.learning_rate * (1./(1.+exp(net.gamma*(batch_num - net.step)))); default: fprintf(stderr, "Policy is weird!\n"); return net.learning_rate; diff --git a/src/network_kernels.cu b/src/network_kernels.cu index 1f0a6546..cfc6e83a 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -134,6 +134,7 @@ float train_network_datum_gpu(network net, float *x, float *y) network_state state; int x_size = get_network_input_size(net)*net.batch; int y_size = get_network_output_size(net)*net.batch; + if(net.layers[net.n-1].type == REGION) y_size = net.layers[net.n-1].truths*net.batch; if(!*net.input_gpu){ *net.input_gpu = cuda_make_array(x, x_size); *net.truth_gpu = cuda_make_array(y, y_size); diff --git a/src/parser.c b/src/parser.c index 94dc0fad..53e84615 100644 --- a/src/parser.c +++ b/src/parser.c @@ -182,6 +182,10 @@ region_layer parse_region(list *options, size_params params) int num = option_find_int(options, "num", 1); int side = option_find_int(options, "side", 7); region_layer layer = make_region_layer(params.batch, params.inputs, num, side, classes, coords, rescore); + int softmax = option_find_int(options, "softmax", 0); + int sqrt = option_find_int(options, "sqrt", 0); + layer.softmax = softmax; + layer.sqrt = sqrt; return layer; } diff --git a/src/region_layer.c b/src/region_layer.c index dcdcfadc..d65c1a87 100644 --- a/src/region_layer.c +++ b/src/region_layer.c @@ -14,7 +14,7 @@ region_layer make_region_layer(int batch, int inputs, int n, int side, int class { region_layer l = {0}; l.type = REGION; - + l.n = n; l.batch = batch; l.inputs = inputs; @@ -22,15 +22,15 @@ region_layer make_region_layer(int batch, int inputs, int n, int side, int class l.coords = coords; l.rescore = rescore; l.side = side; - assert(side*side*l.coords*l.n == inputs); + assert(side*side*((1 + l.coords)*l.n + l.classes) == inputs); l.cost = calloc(1, sizeof(float)); - int outputs = l.n*5*side*side; - l.outputs = outputs; - l.output = calloc(batch*outputs, sizeof(float)); - l.delta = calloc(batch*inputs, sizeof(float)); - #ifdef GPU - l.output_gpu = cuda_make_array(l.output, batch*outputs); - l.delta_gpu = cuda_make_array(l.delta, batch*inputs); + l.outputs = l.inputs; + l.truths = l.side*l.side*(1+l.coords+l.classes); + l.output = calloc(batch*l.outputs, sizeof(float)); + l.delta = calloc(batch*l.outputs, sizeof(float)); +#ifdef GPU + l.output_gpu = cuda_make_array(l.output, batch*l.outputs); + l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs); #endif fprintf(stderr, "Region Layer\n"); @@ -43,64 +43,69 @@ void forward_region_layer(const region_layer l, network_state state) { int locations = l.side*l.side; int i,j; + memcpy(l.output, state.input, l.outputs*l.batch*sizeof(float)); for(i = 0; i < l.batch*locations; ++i){ - for(j = 0; j < l.n; ++j){ - int in_index = i*l.n*l.coords + j*l.coords; - int out_index = i*l.n*5 + j*5; - - float prob = state.input[in_index+0]; - float x = state.input[in_index+1]; - float y = state.input[in_index+2]; - float w = state.input[in_index+3]; - float h = state.input[in_index+4]; - /* - float min_w = state.input[in_index+5]; - float max_w = state.input[in_index+6]; - float min_h = state.input[in_index+7]; - float max_h = state.input[in_index+8]; - */ - - l.output[out_index+0] = prob; - l.output[out_index+1] = x; - l.output[out_index+2] = y; - l.output[out_index+3] = w; - l.output[out_index+4] = h; - + int index = i*((1+l.coords)*l.n + l.classes); + if(l.softmax){ + activate_array(l.output + index, l.n*(1+l.coords), LOGISTIC); + int offset = l.n*(1+l.coords); + softmax_array(l.output + index + offset, l.classes, + l.output + index + offset); } } if(state.train){ float avg_iou = 0; + float avg_cat = 0; + float avg_obj = 0; + float avg_anyobj = 0; int count = 0; *(l.cost) = 0; int size = l.inputs * l.batch; memset(l.delta, 0, size * sizeof(float)); for (i = 0; i < l.batch*locations; ++i) { - + int index = i*((1+l.coords)*l.n + l.classes); for(j = 0; j < l.n; ++j){ - int in_index = i*l.n*l.coords + j*l.coords; - l.delta[in_index+0] = .1*(0-state.input[in_index+0]); + int prob_index = index + j*(1 + l.coords); + l.delta[prob_index] = (1./l.n)*(0-l.output[prob_index]); + if(l.softmax){ + l.delta[prob_index] = 1./(l.n*l.side)*(0-l.output[prob_index]); + } + *(l.cost) += (1./l.n)*pow(l.output[prob_index], 2); + //printf("%f\n", l.output[prob_index]); + avg_anyobj += l.output[prob_index]; } - int truth_index = i*5; + int truth_index = i*(1 + l.coords + l.classes); int best_index = -1; float best_iou = 0; float best_rmse = 4; int bg = !state.truth[truth_index]; - if(bg) continue; + if(bg) { + continue; + } - box truth = {state.truth[truth_index+1], state.truth[truth_index+2], state.truth[truth_index+3], state.truth[truth_index+4]}; + int class_index = index + l.n*(1+l.coords); + for(j = 0; j < l.classes; ++j) { + l.delta[class_index+j] = state.truth[truth_index+1+j] - l.output[class_index+j]; + *(l.cost) += pow(state.truth[truth_index+1+j] - l.output[class_index+j], 2); + if(state.truth[truth_index + 1 + j]) avg_cat += l.output[class_index+j]; + } + truth_index += l.classes + 1; + box truth = {state.truth[truth_index+0], state.truth[truth_index+1], state.truth[truth_index+2], state.truth[truth_index+3]}; truth.x /= l.side; truth.y /= l.side; for(j = 0; j < l.n; ++j){ - int out_index = i*l.n*5 + j*5; + int out_index = index + j*(1+l.coords); box out = {l.output[out_index+1], l.output[out_index+2], l.output[out_index+3], l.output[out_index+4]}; - //printf("\n%f %f %f %f %f\n", l.output[out_index+0], out.x, out.y, out.w, out.h); - out.x /= l.side; out.y /= l.side; + if (l.sqrt){ + out.w = out.w*out.w; + out.h = out.h*out.h; + } float iou = box_iou(out, truth); float rmse = box_rmse(out, truth); @@ -116,46 +121,41 @@ void forward_region_layer(const region_layer l, network_state state) } } } - printf("%d", best_index); - //int out_index = i*l.n*5 + best_index*5; - //box out = {l.output[out_index+1], l.output[out_index+2], l.output[out_index+3], l.output[out_index+4]}; - int in_index = i*l.n*l.coords + best_index*l.coords; - - l.delta[in_index+0] = (1-state.input[in_index+0]); - l.delta[in_index+1] = state.truth[truth_index+1] - state.input[in_index+1]; - l.delta[in_index+2] = state.truth[truth_index+2] - state.input[in_index+2]; - l.delta[in_index+3] = state.truth[truth_index+3] - state.input[in_index+3]; - l.delta[in_index+4] = state.truth[truth_index+4] - state.input[in_index+4]; - /* - l.delta[in_index+5] = 0 - state.input[in_index+5]; - l.delta[in_index+6] = 1 - state.input[in_index+6]; - l.delta[in_index+7] = 0 - state.input[in_index+7]; - l.delta[in_index+8] = 1 - state.input[in_index+8]; - */ - - /* - float x = state.input[in_index+1]; - float y = state.input[in_index+2]; - float w = state.input[in_index+3]; - float h = state.input[in_index+4]; - float min_w = state.input[in_index+5]; - float max_w = state.input[in_index+6]; - float min_h = state.input[in_index+7]; - float max_h = state.input[in_index+8]; - */ + //printf("%d", best_index); + int in_index = index + best_index*(1+l.coords); + *(l.cost) -= pow(l.output[in_index], 2); + *(l.cost) += pow(1-l.output[in_index], 2); + avg_obj += l.output[in_index]; + l.delta[in_index+0] = (1.-l.output[in_index]); + if(l.softmax){ + l.delta[in_index+0] = 5*(1.-l.output[in_index]); + } + //printf("%f\n", l.output[in_index]); + l.delta[in_index+1] = 5*(state.truth[truth_index+0] - l.output[in_index+1]); + l.delta[in_index+2] = 5*(state.truth[truth_index+1] - l.output[in_index+2]); + if(l.sqrt){ + l.delta[in_index+3] = 5*(sqrt(state.truth[truth_index+2]) - l.output[in_index+3]); + l.delta[in_index+4] = 5*(sqrt(state.truth[truth_index+3]) - l.output[in_index+4]); + }else{ + l.delta[in_index+3] = 5*(state.truth[truth_index+2] - l.output[in_index+3]); + l.delta[in_index+4] = 5*(state.truth[truth_index+3] - l.output[in_index+4]); + } + *(l.cost) += pow(1-best_iou, 2); avg_iou += best_iou; ++count; + if(l.softmax){ + gradient_array(l.output + index, l.n*(1+l.coords), LOGISTIC, l.delta + index); + } } - printf("\nAvg IOU: %f %d\n", avg_iou/count, count); + printf("Avg IOU: %f, Avg Cat Pred: %f, Avg Obj: %f, Avg Any: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_obj/count, avg_anyobj/(l.batch*locations*l.n), count); } } void backward_region_layer(const region_layer l, network_state state) { axpy_cpu(l.batch*l.inputs, 1, l.delta, 1, state.delta, 1); - //copy_cpu(l.batch*l.inputs, l.delta, 1, state.delta, 1); } #ifdef GPU @@ -165,8 +165,9 @@ void forward_region_layer_gpu(const region_layer l, network_state state) float *in_cpu = calloc(l.batch*l.inputs, sizeof(float)); float *truth_cpu = 0; if(state.truth){ - truth_cpu = calloc(l.batch*l.outputs, sizeof(float)); - cuda_pull_array(state.truth, truth_cpu, l.batch*l.outputs); + int num_truth = l.batch*l.side*l.side*(1+l.coords+l.classes); + truth_cpu = calloc(num_truth, sizeof(float)); + cuda_pull_array(state.truth, truth_cpu, num_truth); } cuda_pull_array(state.input, in_cpu, l.batch*l.inputs); network_state cpu_state; diff --git a/src/yoloplus.c b/src/swag.c similarity index 65% rename from src/yoloplus.c rename to src/swag.c index dcae7bce..4dcf36bf 100644 --- a/src/yoloplus.c +++ b/src/swag.c @@ -11,7 +11,7 @@ char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"}; -void draw_yoloplus(image im, float *box, int side, int objectness, char *label, float thresh) +void draw_swag(image im, float *box, int side, int objectness, char *label, float thresh) { int classes = 20; int elems = 4+classes+objectness; @@ -52,7 +52,7 @@ void draw_yoloplus(image im, float *box, int side, int objectness, char *label, show_image(im, label); } -void train_yoloplus(char *cfgfile, char *weightfile) +void train_swag(char *cfgfile, char *weightfile) { char *train_images = "/home/pjreddie/data/voc/test/train.txt"; char *backup_directory = "/home/pjreddie/backup/"; @@ -65,23 +65,20 @@ void train_yoloplus(char *cfgfile, char *weightfile) if(weightfile){ load_weights(&net, weightfile); } - detection_layer layer = get_network_detection_layer(net); - int imgs = 128; + printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + int imgs = net.batch*net.subdivisions; int i = *net.seen/imgs; - - char **paths; - list *plist = get_paths(train_images); - int N = plist->size; - paths = (char **)list_to_array(plist); - - if(i*imgs > N*120){ - net.layers[net.n-1].rescore = 1; - } data train, buffer; - int classes = layer.classes; - int background = layer.objectness; - int side = sqrt(get_detection_layer_locations(layer)); + + layer l = net.layers[net.n - 1]; + + int side = l.side; + int classes = l.classes; + + list *plist = get_paths(train_images); + int N = plist->size; + char **paths = (char **)list_to_array(plist); load_args args = {0}; args.w = net.w; @@ -91,12 +88,12 @@ void train_yoloplus(char *cfgfile, char *weightfile) args.m = plist->size; args.classes = classes; args.num_boxes = side; - args.background = background; args.d = &buffer; - args.type = DETECTION_DATA; + args.type = REGION_DATA; pthread_t load_thread = load_data_in_thread(args); clock_t time; + //while(i*imgs < N*120){ while(get_current_batch(net) < net.max_batches){ i += 1; time=clock(); @@ -105,36 +102,21 @@ void train_yoloplus(char *cfgfile, char *weightfile) load_thread = load_data_in_thread(args); printf("Loaded: %lf seconds\n", sec(clock()-time)); + +/* + image im = float_to_image(net.w, net.h, 3, train.X.vals[113]); + image copy = copy_image(im); + draw_swag(copy, train.y.vals[113], 7, "truth"); + cvWaitKey(0); + free_image(copy); + */ + time=clock(); float loss = train_network(net, train); if (avg_loss < 0) avg_loss = loss; avg_loss = avg_loss*.9 + loss*.1; - printf("%d: %f, %f avg, %lf seconds, %f rate, %d images, epoch: %f\n", get_current_batch(net), loss, avg_loss, sec(clock()-time), get_current_rate(net), *net.seen, (float)*net.seen/N); - - if((i-1)*imgs <= 80*N && i*imgs > N*80){ - fprintf(stderr, "Second stage done.\n"); - char buff[256]; - sprintf(buff, "%s/%s_second_stage.weights", backup_directory, base); - save_weights(net, buff); - net.layers[net.n-1].joint = 1; - net.layers[net.n-1].objectness = 0; - background = 0; - - pthread_join(load_thread, 0); - free_data(buffer); - args.background = background; - load_thread = load_data_in_thread(args); - } - - if((i-1)*imgs <= 120*N && i*imgs > N*120){ - fprintf(stderr, "Third stage done.\n"); - char buff[256]; - sprintf(buff, "%s/%s_final.weights", backup_directory, base); - net.layers[net.n-1].rescore = 1; - save_weights(net, buff); - } - + printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs); if(i%1000==0){ char buff[256]; sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); @@ -143,36 +125,38 @@ void train_yoloplus(char *cfgfile, char *weightfile) free_data(train); } char buff[256]; - sprintf(buff, "%s/%s_rescore.weights", backup_directory, base); + sprintf(buff, "%s/%s_final.weights", backup_directory, base); save_weights(net, buff); } -void convert_yoloplus_detections(float *predictions, int classes, int objectness, int background, int num_boxes, int w, int h, float thresh, float **probs, box *boxes) +void convert_swag_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes) { - int i,j; - int per_box = 4+classes+(background || objectness); - for (i = 0; i < num_boxes*num_boxes; ++i){ - float scale = 1; - if(objectness) scale = 1-predictions[i*per_box]; - int offset = i*per_box+(background||objectness); - for(j = 0; j < classes; ++j){ - float prob = scale*predictions[offset+j]; - probs[i][j] = (prob > thresh) ? prob : 0; + int i,j,n; + int per_cell = 5*num+classes; + for (i = 0; i < side*side; ++i){ + int row = i / side; + int col = i % side; + for(n = 0; n < num; ++n){ + int offset = i*per_cell + 5*n; + float scale = predictions[offset]; + int index = i*num + n; + boxes[index].x = (predictions[offset + 1] + col) / side * w; + boxes[index].y = (predictions[offset + 2] + row) / side * h; + boxes[index].w = pow(predictions[offset + 3], (square?2:1)) * w; + boxes[index].h = pow(predictions[offset + 4], (square?2:1)) * h; + for(j = 0; j < classes; ++j){ + offset = i*per_cell + 5*num; + float prob = scale*predictions[offset+j]; + probs[index][j] = (prob > thresh) ? prob : 0; + } } - int row = i / num_boxes; - int col = i % num_boxes; - offset += classes; - boxes[i].x = (predictions[offset + 0] + col) / num_boxes * w; - boxes[i].y = (predictions[offset + 1] + row) / num_boxes * h; - boxes[i].w = pow(predictions[offset + 2], 2) * w; - boxes[i].h = pow(predictions[offset + 3], 2) * h; } } -void print_yoloplus_detections(FILE **fps, char *id, box *boxes, float **probs, int num_boxes, int classes, int w, int h) +void print_swag_detections(FILE **fps, char *id, box *boxes, float **probs, int total, int classes, int w, int h) { int i, j; - for(i = 0; i < num_boxes*num_boxes; ++i){ + for(i = 0; i < total; ++i){ float xmin = boxes[i].x - boxes[i].w/2.; float xmax = boxes[i].x + boxes[i].w/2.; float ymin = boxes[i].y - boxes[i].h/2.; @@ -190,14 +174,13 @@ void print_yoloplus_detections(FILE **fps, char *id, box *boxes, float **probs, } } -void validate_yoloplus(char *cfgfile, char *weightfile) +void validate_swag(char *cfgfile, char *weightfile) { network net = parse_network_cfg(cfgfile); if(weightfile){ load_weights(&net, weightfile); } set_batch_network(&net, 1); - detection_layer layer = get_network_detection_layer(net); fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); srand(time(0)); @@ -205,10 +188,10 @@ void validate_yoloplus(char *cfgfile, char *weightfile) list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt"); char **paths = (char **)list_to_array(plist); - int classes = layer.classes; - int objectness = layer.objectness; - int background = layer.background; - int num_boxes = sqrt(get_detection_layer_locations(layer)); + layer l = net.layers[net.n-1]; + int classes = l.classes; + int square = l.sqrt; + int side = l.side; int j; FILE **fps = calloc(classes, sizeof(FILE *)); @@ -217,9 +200,9 @@ void validate_yoloplus(char *cfgfile, char *weightfile) snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]); fps[j] = fopen(buff, "w"); } - box *boxes = calloc(num_boxes*num_boxes, sizeof(box)); - float **probs = calloc(num_boxes*num_boxes, sizeof(float *)); - for(j = 0; j < num_boxes*num_boxes; ++j) probs[j] = calloc(classes, sizeof(float *)); + box *boxes = calloc(side*side*l.n, sizeof(box)); + float **probs = calloc(side*side*l.n, sizeof(float *)); + for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *)); int m = plist->size; int i=0; @@ -268,9 +251,9 @@ void validate_yoloplus(char *cfgfile, char *weightfile) float *predictions = network_predict(net, X); int w = val[t].w; int h = val[t].h; - convert_yoloplus_detections(predictions, classes, objectness, background, num_boxes, w, h, thresh, probs, boxes); - if (nms) do_nms(boxes, probs, num_boxes*num_boxes, classes, iou_thresh); - print_yoloplus_detections(fps, id, boxes, probs, num_boxes, classes, w, h); + convert_swag_detections(predictions, classes, l.n, square, side, w, h, thresh, probs, boxes); + if (nms) do_nms(boxes, probs, side*side*l.n, classes, iou_thresh); + print_swag_detections(fps, id, boxes, probs, side*side*l.n, classes, w, h); free(id); free_image(val[t]); free_image(val_resized[t]); @@ -279,7 +262,7 @@ void validate_yoloplus(char *cfgfile, char *weightfile) fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start)); } -void test_yoloplus(char *cfgfile, char *weightfile, char *filename, float thresh) +void test_swag(char *cfgfile, char *weightfile, char *filename, float thresh) { network net = parse_network_cfg(cfgfile); @@ -306,7 +289,7 @@ void test_yoloplus(char *cfgfile, char *weightfile, char *filename, float thresh time=clock(); float *predictions = network_predict(net, X); printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); - draw_yoloplus(im, predictions, 7, layer.objectness, "predictions", thresh); + draw_swag(im, predictions, 7, layer.objectness, "predictions", thresh); free_image(im); free_image(sized); #ifdef OPENCV @@ -317,7 +300,7 @@ void test_yoloplus(char *cfgfile, char *weightfile, char *filename, float thresh } } -void run_yoloplus(int argc, char **argv) +void run_swag(int argc, char **argv) { float thresh = find_float_arg(argc, argv, "-thresh", .2); if(argc < 4){ @@ -328,7 +311,7 @@ void run_yoloplus(int argc, char **argv) char *cfg = argv[3]; char *weights = (argc > 4) ? argv[4] : 0; char *filename = (argc > 5) ? argv[5]: 0; - if(0==strcmp(argv[2], "test")) test_yoloplus(cfg, weights, filename, thresh); - else if(0==strcmp(argv[2], "train")) train_yoloplus(cfg, weights); - else if(0==strcmp(argv[2], "valid")) validate_yoloplus(cfg, weights); + if(0==strcmp(argv[2], "test")) test_swag(cfg, weights, filename, thresh); + else if(0==strcmp(argv[2], "train")) train_swag(cfg, weights); + else if(0==strcmp(argv[2], "valid")) validate_swag(cfg, weights); }