testing other losses

This commit is contained in:
Joseph Redmon
2015-05-15 10:25:05 -07:00
parent 7399dd1af5
commit 46e1b263e1
3 changed files with 23 additions and 22 deletions

View File

@ -330,9 +330,8 @@ void forward_detection_layer(const detection_layer l, network_state state)
l.output[out_i++] = mask*state.input[in_i++];
}
}
if(l.does_cost && state.train && 0){
if(l.does_cost && state.train){
int count = 0;
float avg = 0;
*(l.cost) = 0;
int size = get_detection_layer_output_size(l) * l.batch;
memset(l.delta, 0, size * sizeof(float));
@ -354,26 +353,28 @@ void forward_detection_layer(const detection_layer l, network_state state)
out.w = l.output[j+2];
out.h = l.output[j+3];
if(!(truth.w*truth.h)) continue;
//printf("iou: %f\n", iou);
dbox d = diou(out, truth);
l.delta[j+0] = d.dx;
l.delta[j+1] = d.dy;
l.delta[j+2] = d.dw;
l.delta[j+3] = d.dh;
l.delta[j+0] = (truth.x - out.x);
l.delta[j+1] = (truth.y - out.y);
l.delta[j+2] = (truth.w - out.w);
l.delta[j+3] = (truth.h - out.h);
*(l.cost) += pow((out.x - truth.x), 2);
*(l.cost) += pow((out.y - truth.y), 2);
*(l.cost) += pow((out.w - truth.w), 2);
*(l.cost) += pow((out.h - truth.h), 2);
int sqr = 1;
if(sqr){
truth.w *= truth.w;
truth.h *= truth.h;
out.w *= out.w;
out.h *= out.h;
}
float iou = box_iou(truth, out);
*(l.cost) += pow((1-iou), 2);
avg += iou;
/*
l.delta[j+0] = .1 * (truth.x - out.x) / (49 * truth.w * truth.w);
l.delta[j+1] = .1 * (truth.y - out.y) / (49 * truth.h * truth.h);
l.delta[j+2] = .1 * (truth.w - out.w) / ( truth.w * truth.w);
l.delta[j+3] = .1 * (truth.h - out.h) / ( truth.h * truth.h);
*(l.cost) += pow((out.x - truth.x)/truth.w/7., 2);
*(l.cost) += pow((out.y - truth.y)/truth.h/7., 2);
*(l.cost) += pow((out.w - truth.w)/truth.w, 2);
*(l.cost) += pow((out.h - truth.h)/truth.h, 2);
*/
++count;
}
fprintf(stderr, "Avg IOU: %f\n", avg/count);
}
/*
int count = 0;