testing other losses

This commit is contained in:
Joseph Redmon 2015-05-15 10:25:05 -07:00
parent 7399dd1af5
commit 46e1b263e1
3 changed files with 23 additions and 22 deletions

View File

@ -165,7 +165,7 @@ void fill_truth_detection(char *path, float *truth, int classes, int num_boxes,
w = constrain(0, 1, w); w = constrain(0, 1, w);
h = constrain(0, 1, h); h = constrain(0, 1, h);
if (w == 0 || h == 0) continue; if (w < .01 || h < .01) continue;
if(1){ if(1){
//w = sqrt(w); //w = sqrt(w);
//h = sqrt(h); //h = sqrt(h);

View File

@ -309,8 +309,8 @@ void predict_detections(network net, data d, float threshold, int offset, int cl
float y = (pred.vals[j][ci + 1] + row)/num_boxes; float y = (pred.vals[j][ci + 1] + row)/num_boxes;
float w = pred.vals[j][ci + 2]; //* distance_from_edge(row, num_boxes); float w = pred.vals[j][ci + 2]; //* distance_from_edge(row, num_boxes);
float h = pred.vals[j][ci + 3]; //* distance_from_edge(col, num_boxes); float h = pred.vals[j][ci + 3]; //* distance_from_edge(col, num_boxes);
w = pow(w, 2); w = pow(w, 1);
h = pow(h, 2); h = pow(h, 1);
float prob = scale*pred.vals[j][k+class+background+nuisance]; float prob = scale*pred.vals[j][k+class+background+nuisance];
if(prob < threshold) continue; if(prob < threshold) continue;
printf("%d %d %f %f %f %f %f\n", offset + j, class, prob, x, y, w, h); printf("%d %d %f %f %f %f %f\n", offset + j, class, prob, x, y, w, h);

View File

@ -330,9 +330,8 @@ void forward_detection_layer(const detection_layer l, network_state state)
l.output[out_i++] = mask*state.input[in_i++]; l.output[out_i++] = mask*state.input[in_i++];
} }
} }
if(l.does_cost && state.train && 0){ if(l.does_cost && state.train){
int count = 0; int count = 0;
float avg = 0;
*(l.cost) = 0; *(l.cost) = 0;
int size = get_detection_layer_output_size(l) * l.batch; int size = get_detection_layer_output_size(l) * l.batch;
memset(l.delta, 0, size * sizeof(float)); memset(l.delta, 0, size * sizeof(float));
@ -354,26 +353,28 @@ void forward_detection_layer(const detection_layer l, network_state state)
out.w = l.output[j+2]; out.w = l.output[j+2];
out.h = l.output[j+3]; out.h = l.output[j+3];
if(!(truth.w*truth.h)) continue; if(!(truth.w*truth.h)) continue;
//printf("iou: %f\n", iou); l.delta[j+0] = (truth.x - out.x);
dbox d = diou(out, truth); l.delta[j+1] = (truth.y - out.y);
l.delta[j+0] = d.dx; l.delta[j+2] = (truth.w - out.w);
l.delta[j+1] = d.dy; l.delta[j+3] = (truth.h - out.h);
l.delta[j+2] = d.dw; *(l.cost) += pow((out.x - truth.x), 2);
l.delta[j+3] = d.dh; *(l.cost) += pow((out.y - truth.y), 2);
*(l.cost) += pow((out.w - truth.w), 2);
*(l.cost) += pow((out.h - truth.h), 2);
int sqr = 1; /*
if(sqr){ l.delta[j+0] = .1 * (truth.x - out.x) / (49 * truth.w * truth.w);
truth.w *= truth.w; l.delta[j+1] = .1 * (truth.y - out.y) / (49 * truth.h * truth.h);
truth.h *= truth.h; l.delta[j+2] = .1 * (truth.w - out.w) / ( truth.w * truth.w);
out.w *= out.w; l.delta[j+3] = .1 * (truth.h - out.h) / ( truth.h * truth.h);
out.h *= out.h;
} *(l.cost) += pow((out.x - truth.x)/truth.w/7., 2);
float iou = box_iou(truth, out); *(l.cost) += pow((out.y - truth.y)/truth.h/7., 2);
*(l.cost) += pow((1-iou), 2); *(l.cost) += pow((out.w - truth.w)/truth.w, 2);
avg += iou; *(l.cost) += pow((out.h - truth.h)/truth.h, 2);
*/
++count; ++count;
} }
fprintf(stderr, "Avg IOU: %f\n", avg/count);
} }
/* /*
int count = 0; int count = 0;