LRNorm layer, Viz, better cfg

This commit is contained in:
Joseph Redmon 2014-04-17 09:51:38 -07:00
parent 738cd4c2d7
commit 5e468d1c13

View File

@ -548,7 +548,9 @@ void visualize_imagenet_topk(char *filename)
score[i] = calloc(topk, sizeof(float));
}
int count = 0;
while(n){
++count;
char *image_path = (char *)n->val;
image im = load_image(image_path, 0, 0);
n = n->next;
@ -560,37 +562,46 @@ void visualize_imagenet_topk(char *filename)
forward_network(net, im.data);
image out = get_network_image(net);
int dh = (im.h - h)/h;
int dw = (im.w - w)/w;
for(i = 0; i < out.h; ++i){
for(j = 0; j < out.w; ++j){
image sub = get_sub_image(im, dh*i, dw*j, h, w);
for(k = 0; k < out.c; ++k){
int dh = (im.h - h)/(out.h-1);
int dw = (im.w - w)/(out.w-1);
//printf("%d %d\n", dh, dw);
for(k = 0; k < out.c; ++k){
float topv = 0;
int topi = -1;
int topj = -1;
for(i = 0; i < out.h; ++i){
for(j = 0; j < out.w; ++j){
float val = get_pixel(out, i, j, k);
//printf("%f, ", val);
image sub_c = copy_image(sub);
for(l = 0; l < topk; ++l){
if(val > score[k][l]){
float swap = score[k][l];
score[k][l] = val;
val = swap;
image swapi = vizs[k][l];
vizs[k][l] = sub_c;
sub_c = swapi;
}
if(val > topv){
topv = val;
topi = i;
topj = j;
}
}
}
if(topv){
image sub = get_sub_image(im, dh*topi, dw*topj, h, w);
for(l = 0; l < topk; ++l){
if(topv > score[k][l]){
float swap = score[k][l];
score[k][l] = topv;
topv = swap;
image swapi = vizs[k][l];
vizs[k][l] = sub;
sub = swapi;
}
free_image(sub_c);
}
free_image(sub);
}
}
free_image(im);
//printf("\n");
image grid = grid_images(vizs, num, topk);
show_image(grid, "IMAGENET Visualization");
save_image(grid, "IMAGENET Grid");
free_image(grid);
if(count%50 == 0){
image grid = grid_images(vizs, num, topk);
//show_image(grid, "IMAGENET Visualization");
save_image(grid, "IMAGENET Grid Single Nonorm");
free_image(grid);
}
}
//cvWaitKey(0);
}
@ -644,7 +655,7 @@ void visualize_cat()
printf("Processing %dx%d image\n", im.h, im.w);
resize_network(net, im.h, im.w, im.c);
forward_network(net, im.data);
image out = get_network_image(net);
visualize_network(net);
cvWaitKey(1000);
@ -778,7 +789,7 @@ int main(int argc, char *argv[])
//features_VOC_image(argv[1], argv[2], argv[3]);
//features_VOC_image_size(argv[1], atoi(argv[2]), atoi(argv[3]));
//visualize_imagenet_features("data/assira/train.list");
visualize_imagenet_topk("data/VOC2011.list");
visualize_imagenet_topk("data/VOC2012.list");
//visualize_cat();
//flip_network();
//test_visualize();