Visualizations?

This commit is contained in:
Joseph Redmon 2014-04-16 17:05:29 -07:00
parent cc06817efa
commit 738cd4c2d7
10 changed files with 438 additions and 55 deletions

View File

@ -1,20 +1,21 @@
CC=gcc
COMMON=-Wall `pkg-config --cflags opencv`
UNAME = $(shell uname)
OPTS=-O3
ifeq ($(UNAME), Darwin)
COMMON+= -isystem /usr/local/Cellar/opencv/2.4.6.1/include/opencv -isystem /usr/local/Cellar/opencv/2.4.6.1/include
LDFLAGS= -framework OpenCL
else
COMMON+= -march=native -flto
OPTS+= -march=native -flto
LDFLAGS= -lOpenCL
endif
CFLAGS= $(COMMON) -Ofast
CFLAGS= $(COMMON) $(OPTS)
#CFLAGS= $(COMMON) -O0 -g
LDFLAGS+=`pkg-config --libs opencv` -lm
VPATH=./src/
EXEC=cnn
OBJ=network.o image.o tests.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o mini_blas.o convolutional_layer.o opencl.o gpu_gemm.o cpu_gemm.o
OBJ=network.o image.o tests.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o mini_blas.o convolutional_layer.o opencl.o gpu_gemm.o cpu_gemm.o normalization_layer.o
all: $(EXEC)

View File

@ -320,11 +320,12 @@ image *visualize_convolutional_layer(convolutional_layer layer, char *window, im
image *single_filters = weighted_sum_filters(layer, 0);
show_images(single_filters, layer.n, window);
image delta = get_convolutional_delta(layer);
image delta = get_convolutional_image(layer);
image dc = collapse_image_layers(delta, 1);
char buff[256];
sprintf(buff, "%s: Delta", window);
//show_image(dc, buff);
sprintf(buff, "%s: Output", window);
show_image(dc, buff);
save_image(dc, buff);
free_image(dc);
return single_filters;
}

View File

@ -264,7 +264,7 @@ void add_into_image(image src, image dest, int h, int w)
}
}
void add_scalar_image(image m, float s)
void translate_image(image m, float s)
{
int i;
for(i = 0; i < m.h*m.w*m.c; ++i) m.data[i] += s;
@ -645,15 +645,49 @@ void print_image(image m)
for(i =0 ; i < m.h*m.w*m.c; ++i) printf("%lf, ", m.data[i]);
printf("\n");
}
image collapse_images_vert(image *ims, int n)
{
int color = 1;
int border = 1;
int h,w,c;
w = ims[0].w;
h = (ims[0].h + border) * n - border;
c = ims[0].c;
if(c != 3 || !color){
w = (w+border)*c - border;
c = 1;
}
image collapse_images(image *ims, int n)
image filters = make_image(h,w,c);
int i,j;
for(i = 0; i < n; ++i){
int h_offset = i*(ims[0].h+border);
image copy = copy_image(ims[i]);
//normalize_image(copy);
if(c == 3 && color){
embed_image(copy, filters, h_offset, 0);
}
else{
for(j = 0; j < copy.c; ++j){
int w_offset = j*(ims[0].w+border);
image layer = get_image_layer(copy, j);
embed_image(layer, filters, h_offset, w_offset);
free_image(layer);
}
}
free_image(copy);
}
return filters;
}
image collapse_images_horz(image *ims, int n)
{
int color = 1;
int border = 1;
int h,w,c;
int size = ims[0].h;
h = size;
w = (size + border) * n - border;
w = (ims[0].w + border) * n - border;
c = ims[0].c;
if(c != 3 || !color){
h = (h+border)*c - border;
@ -665,7 +699,7 @@ image collapse_images(image *ims, int n)
for(i = 0; i < n; ++i){
int w_offset = i*(size+border);
image copy = copy_image(ims[i]);
normalize_image(copy);
//normalize_image(copy);
if(c == 3 && color){
embed_image(copy, filters, 0, w_offset);
}
@ -684,11 +718,49 @@ image collapse_images(image *ims, int n)
void show_images(image *ims, int n, char *window)
{
image m = collapse_images(ims, n);
image m = collapse_images_vert(ims, n);
save_image(m, window);
show_image(m, window);
free_image(m);
}
image grid_images(image **ims, int h, int w)
{
int i;
image *rows = calloc(h, sizeof(image));
for(i = 0; i < h; ++i){
rows[i] = collapse_images_horz(ims[i], w);
}
image out = collapse_images_vert(rows, h);
for(i = 0; i < h; ++i){
free_image(rows[i]);
}
free(rows);
return out;
}
void test_grid()
{
int i,j;
int num = 3;
int topk = 3;
image **vizs = calloc(num, sizeof(image*));
for(i = 0; i < num; ++i){
vizs[i] = calloc(topk, sizeof(image));
for(j = 0; j < topk; ++j) vizs[i][j] = make_image(3,3,3);
}
image grid = grid_images(vizs, num, topk);
save_image(grid, "Test Grid");
free_image(grid);
}
void show_images_grid(image **ims, int h, int w, char *window)
{
image out = grid_images(ims, h, w);
show_image(out, window);
free_image(out);
}
void free_image(image m)
{
free(m.data);

View File

@ -1,6 +1,7 @@
#ifndef IMAGE_H
#define IMAGE_H
#include "opencv2/highgui/highgui_c.h"
#include "opencv2/imgproc/imgproc_c.h"
typedef struct {
@ -12,7 +13,7 @@ typedef struct {
image image_distance(image a, image b);
void scale_image(image m, float s);
void add_scalar_image(image m, float s);
void translate_image(image m, float s);
void normalize_image(image p);
void z_normalize_image(image p);
void threshold_image(image p, float t);
@ -23,6 +24,8 @@ float avg_image_layer(image m, int l);
void embed_image(image source, image dest, int h, int w);
void add_into_image(image src, image dest, int h, int w);
image collapse_image_layers(image source, int border);
image collapse_images_horz(image *ims, int n);
image collapse_images_vert(image *ims, int n);
image get_sub_image(image m, int h, int w, int dh, int dw);
void show_image(image p, char *name);
@ -30,6 +33,9 @@ void save_image(image p, char *name);
void show_images(image *ims, int n, char *window);
void show_image_layers(image p, char *name);
void show_image_collapsed(image p, char *name);
void show_images_grid(image **ims, int h, int w, char *window);
void test_grid();
image grid_images(image **ims, int h, int w);
void print_image(image m);
image make_image(int h, int w, int c);

View File

@ -8,6 +8,7 @@
#include "convolutional_layer.h"
//#include "old_conv.h"
#include "maxpool_layer.h"
#include "normalization_layer.h"
#include "softmax_layer.h"
network make_network(int n, int batch)
@ -40,6 +41,17 @@ void print_convolutional_cfg(FILE *fp, convolutional_layer *l, int first)
fprintf(fp, "data=");
for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]);
for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]);
/*
int j,k;
for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]);
for(i = 0; i < l->n; ++i){
for(j = l->c-1; j >= 0; --j){
for(k = 0; k < l->size*l->size; ++k){
fprintf(fp, "%g,", l->filters[i*(l->c*l->size*l->size)+j*l->size*l->size+k]);
}
}
}
*/
fprintf(fp, "\n\n");
}
void print_connected_cfg(FILE *fp, connected_layer *l, int first)
@ -48,9 +60,9 @@ void print_connected_cfg(FILE *fp, connected_layer *l, int first)
fprintf(fp, "[connected]\n");
if(first) fprintf(fp, "batch=%d\ninput=%d\n", l->batch, l->inputs);
fprintf(fp, "output=%d\n"
"activation=%s\n",
l->outputs,
get_activation_string(l->activation));
"activation=%s\n",
l->outputs,
get_activation_string(l->activation));
fprintf(fp, "data=");
for(i = 0; i < l->outputs; ++i) fprintf(fp, "%g,", l->biases[i]);
for(i = 0; i < l->inputs*l->outputs; ++i) fprintf(fp, "%g,", l->weights[i]);
@ -61,13 +73,27 @@ void print_maxpool_cfg(FILE *fp, maxpool_layer *l, int first)
{
fprintf(fp, "[maxpool]\n");
if(first) fprintf(fp, "batch=%d\n"
"height=%d\n"
"width=%d\n"
"channels=%d\n",
l->batch,l->h, l->w, l->c);
"height=%d\n"
"width=%d\n"
"channels=%d\n",
l->batch,l->h, l->w, l->c);
fprintf(fp, "stride=%d\n\n", l->stride);
}
void print_normalization_cfg(FILE *fp, normalization_layer *l, int first)
{
fprintf(fp, "[localresponsenormalization]\n");
if(first) fprintf(fp, "batch=%d\n"
"height=%d\n"
"width=%d\n"
"channels=%d\n",
l->batch,l->h, l->w, l->c);
fprintf(fp, "size=%d\n"
"alpha=%g\n"
"beta=%g\n"
"kappa=%g\n\n", l->size, l->alpha, l->beta, l->kappa);
}
void print_softmax_cfg(FILE *fp, softmax_layer *l, int first)
{
fprintf(fp, "[softmax]\n");
@ -88,6 +114,8 @@ void save_network(network net, char *filename)
print_connected_cfg(fp, (connected_layer *)net.layers[i], i==0);
else if(net.types[i] == MAXPOOL)
print_maxpool_cfg(fp, (maxpool_layer *)net.layers[i], i==0);
else if(net.types[i] == NORMALIZATION)
print_normalization_cfg(fp, (normalization_layer *)net.layers[i], i==0);
else if(net.types[i] == SOFTMAX)
print_softmax_cfg(fp, (softmax_layer *)net.layers[i], i==0);
}
@ -118,6 +146,11 @@ void forward_network(network net, float *input)
forward_maxpool_layer(layer, input);
input = layer.output;
}
else if(net.types[i] == NORMALIZATION){
normalization_layer layer = *(normalization_layer *)net.layers[i];
forward_normalization_layer(layer, input);
input = layer.output;
}
}
}
@ -135,6 +168,9 @@ void update_network(network net, float step, float momentum, float decay)
else if(net.types[i] == SOFTMAX){
//maxpool_layer layer = *(maxpool_layer *)net.layers[i];
}
else if(net.types[i] == NORMALIZATION){
//maxpool_layer layer = *(maxpool_layer *)net.layers[i];
}
else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
update_connected_layer(layer, step, momentum, decay);
@ -156,6 +192,9 @@ float *get_network_output_layer(network net, int i)
} else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
return layer.output;
} else if(net.types[i] == NORMALIZATION){
normalization_layer layer = *(normalization_layer *)net.layers[i];
return layer.output;
}
return 0;
}
@ -233,6 +272,10 @@ float backward_network(network net, float *input, float *truth)
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
if(i != 0) backward_maxpool_layer(layer, prev_input, prev_delta);
}
else if(net.types[i] == NORMALIZATION){
normalization_layer layer = *(normalization_layer *)net.layers[i];
if(i != 0) backward_normalization_layer(layer, prev_input, prev_delta);
}
else if(net.types[i] == SOFTMAX){
softmax_layer layer = *(softmax_layer *)net.layers[i];
if(i != 0) backward_softmax_layer(layer, prev_input, prev_delta);
@ -272,7 +315,7 @@ float train_network_sgd(network net, data d, int n, float step, float momentum,f
error += err;
++pos;
}
//printf("%d %f %f\n", i,net.output[0], d.y.vals[index][0]);
//if((i+1)%10 == 0){
@ -341,34 +384,34 @@ int get_network_output_size_layer(network net, int i)
}
/*
int resize_network(network net, int h, int w, int c)
{
int i;
for (i = 0; i < net.n; ++i){
if(net.types[i] == CONVOLUTIONAL){
convolutional_layer *layer = (convolutional_layer *)net.layers[i];
layer->h = h;
layer->w = w;
layer->c = c;
image output = get_convolutional_image(*layer);
h = output.h;
w = output.w;
c = output.c;
}
else if(net.types[i] == MAXPOOL){
maxpool_layer *layer = (maxpool_layer *)net.layers[i];
layer->h = h;
layer->w = w;
layer->c = c;
image output = get_maxpool_image(*layer);
h = output.h;
w = output.w;
c = output.c;
}
}
return 0;
}
*/
int resize_network(network net, int h, int w, int c)
{
int i;
for (i = 0; i < net.n; ++i){
if(net.types[i] == CONVOLUTIONAL){
convolutional_layer *layer = (convolutional_layer *)net.layers[i];
layer->h = h;
layer->w = w;
layer->c = c;
image output = get_convolutional_image(*layer);
h = output.h;
w = output.w;
c = output.c;
}
else if(net.types[i] == MAXPOOL){
maxpool_layer *layer = (maxpool_layer *)net.layers[i];
layer->h = h;
layer->w = w;
layer->c = c;
image output = get_maxpool_image(*layer);
h = output.h;
w = output.w;
c = output.c;
}
}
return 0;
}
*/
int resize_network(network net, int h, int w, int c)
{
@ -381,16 +424,21 @@ int resize_network(network net, int h, int w, int c)
h = output.h;
w = output.w;
c = output.c;
}
else if(net.types[i] == MAXPOOL){
}else if(net.types[i] == MAXPOOL){
maxpool_layer *layer = (maxpool_layer *)net.layers[i];
resize_maxpool_layer(layer, h, w, c);
image output = get_maxpool_image(*layer);
h = output.h;
w = output.w;
c = output.c;
}
else{
}else if(net.types[i] == NORMALIZATION){
normalization_layer *layer = (normalization_layer *)net.layers[i];
resize_normalization_layer(layer, h, w, c);
image output = get_normalization_image(*layer);
h = output.h;
w = output.w;
c = output.c;
}else{
error("Cannot resize this type of layer");
}
}
@ -413,6 +461,10 @@ image get_network_image_layer(network net, int i)
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
return get_maxpool_image(layer);
}
else if(net.types[i] == NORMALIZATION){
normalization_layer layer = *(normalization_layer *)net.layers[i];
return get_normalization_image(layer);
}
return make_empty_image(0,0,0);
}
@ -437,6 +489,10 @@ void visualize_network(network net)
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
prev = visualize_convolutional_layer(layer, buff, prev);
}
if(net.types[i] == NORMALIZATION){
normalization_layer layer = *(normalization_layer *)net.layers[i];
visualize_normalization_layer(layer, buff);
}
}
}

View File

@ -9,7 +9,8 @@ typedef enum {
CONVOLUTIONAL,
CONNECTED,
MAXPOOL,
SOFTMAX
SOFTMAX,
NORMALIZATION
} LAYER_TYPE;
typedef struct {

96
src/normalization_layer.c Normal file
View File

@ -0,0 +1,96 @@
#include "normalization_layer.h"
#include <stdio.h>
image get_normalization_image(normalization_layer layer)
{
int h = layer.h;
int w = layer.w;
int c = layer.c;
return float_to_image(h,w,c,layer.output);
}
image get_normalization_delta(normalization_layer layer)
{
int h = layer.h;
int w = layer.w;
int c = layer.c;
return float_to_image(h,w,c,layer.delta);
}
normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa)
{
fprintf(stderr, "Local Response Normalization Layer: %d x %d x %d image, %d size\n", h,w,c,size);
normalization_layer *layer = calloc(1, sizeof(normalization_layer));
layer->batch = batch;
layer->h = h;
layer->w = w;
layer->c = c;
layer->kappa = kappa;
layer->size = size;
layer->alpha = alpha;
layer->beta = beta;
layer->output = calloc(h * w * c * batch, sizeof(float));
layer->delta = calloc(h * w * c * batch, sizeof(float));
layer->sums = calloc(h*w, sizeof(float));
return layer;
}
void resize_normalization_layer(normalization_layer *layer, int h, int w, int c)
{
layer->h = h;
layer->w = w;
layer->c = c;
layer->output = realloc(layer->output, h * w * c * layer->batch * sizeof(float));
layer->delta = realloc(layer->delta, h * w * c * layer->batch * sizeof(float));
layer->sums = realloc(layer->sums, h*w * sizeof(float));
}
void add_square_array(float *src, float *dest, int n)
{
int i;
for(i = 0; i < n; ++i){
dest[i] += src[i]*src[i];
}
}
void sub_square_array(float *src, float *dest, int n)
{
int i;
for(i = 0; i < n; ++i){
dest[i] -= src[i]*src[i];
}
}
void forward_normalization_layer(const normalization_layer layer, float *in)
{
int i,j,k;
memset(layer.sums, 0, layer.h*layer.w*sizeof(float));
int imsize = layer.h*layer.w;
for(j = 0; j < layer.size/2; ++j){
if(j < layer.c) add_square_array(in+j*imsize, layer.sums, imsize);
}
for(k = 0; k < layer.c; ++k){
int next = k+layer.size/2;
int prev = k-layer.size/2-1;
if(next < layer.c) add_square_array(in+next*imsize, layer.sums, imsize);
if(prev > 0) sub_square_array(in+prev*imsize, layer.sums, imsize);
for(i = 0; i < imsize; ++i){
layer.output[k*imsize + i] = in[k*imsize+i] / pow(layer.kappa + layer.alpha * layer.sums[i], layer.beta);
}
}
}
void backward_normalization_layer(const normalization_layer layer, float *in, float *delta)
{
//TODO!
}
void visualize_normalization_layer(normalization_layer layer, char *window)
{
image delta = get_normalization_image(layer);
image dc = collapse_image_layers(delta, 1);
char buff[256];
sprintf(buff, "%s: Output", window);
show_image(dc, buff);
save_image(dc, buff);
free_image(dc);
}

26
src/normalization_layer.h Normal file
View File

@ -0,0 +1,26 @@
#ifndef NORMALIZATION_LAYER_H
#define NORMALIZATION_LAYER_H
#include "image.h"
typedef struct {
int batch;
int h,w,c;
int size;
float alpha;
float beta;
float kappa;
float *delta;
float *output;
float *sums;
} normalization_layer;
image get_normalization_image(normalization_layer layer);
normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa);
void resize_normalization_layer(normalization_layer *layer, int h, int w, int c);
void forward_normalization_layer(const normalization_layer layer, float *in);
void backward_normalization_layer(const normalization_layer layer, float *in, float *delta);
void visualize_normalization_layer(normalization_layer layer, char *window);
#endif

View File

@ -7,6 +7,7 @@
#include "convolutional_layer.h"
#include "connected_layer.h"
#include "maxpool_layer.h"
#include "normalization_layer.h"
#include "softmax_layer.h"
#include "list.h"
#include "option_list.h"
@ -21,6 +22,7 @@ int is_convolutional(section *s);
int is_connected(section *s);
int is_maxpool(section *s);
int is_softmax(section *s);
int is_normalization(section *s);
list *read_cfg(char *filename);
void free_section(section *s)
@ -152,6 +154,30 @@ maxpool_layer *parse_maxpool(list *options, network net, int count)
return layer;
}
normalization_layer *parse_normalization(list *options, network net, int count)
{
int h,w,c;
int size = option_find_int(options, "size",1);
float alpha = option_find_float(options, "alpha", 0.);
float beta = option_find_float(options, "beta", 1.);
float kappa = option_find_float(options, "kappa", 1.);
if(count == 0){
h = option_find_int(options, "height",1);
w = option_find_int(options, "width",1);
c = option_find_int(options, "channels",1);
net.batch = option_find_int(options, "batch",1);
}else{
image m = get_network_image_layer(net, count-1);
h = m.h;
w = m.w;
c = m.c;
if(h == 0) error("Layer before convolutional layer must output image.");
}
normalization_layer *layer = make_normalization_layer(net.batch,h,w,c,size, alpha, beta, kappa);
option_unused(options);
return layer;
}
network parse_network_cfg(char *filename)
{
list *sections = read_cfg(filename);
@ -182,6 +208,11 @@ network parse_network_cfg(char *filename)
net.types[count] = MAXPOOL;
net.layers[count] = layer;
net.batch = layer->batch;
}else if(is_normalization(s)){
normalization_layer *layer = parse_normalization(options, net, count);
net.types[count] = NORMALIZATION;
net.layers[count] = layer;
net.batch = layer->batch;
}else{
fprintf(stderr, "Type not recognized: %s\n", s->type);
}
@ -216,6 +247,11 @@ int is_softmax(section *s)
return (strcmp(s->type, "[soft]")==0
|| strcmp(s->type, "[softmax]")==0);
}
int is_normalization(section *s)
{
return (strcmp(s->type, "[lrnorm]")==0
|| strcmp(s->type, "[localresponsenormalization]")==0);
}
int read_option(char *s, list *options)
{

View File

@ -1,4 +1,5 @@
#include "connected_layer.h"
//#include "old_conv.h"
#include "convolutional_layer.h"
#include "maxpool_layer.h"
@ -223,7 +224,7 @@ void train_full()
void test_visualize()
{
network net = parse_network_cfg("cfg/imagenet.cfg");
network net = parse_network_cfg("cfg/voc_imagenet.cfg");
srand(2222222);
visualize_network(net);
cvWaitKey(0);
@ -445,6 +446,12 @@ void test_im2row()
}
}
void flip_network()
{
network net = parse_network_cfg("cfg/voc_imagenet_orig.cfg");
save_network(net, "cfg/voc_imagenet_rev.cfg");
}
void train_VOC()
{
network net = parse_network_cfg("cfg/voc_start.cfg");
@ -498,6 +505,7 @@ image features_output_size(network net, IplImage *src, int outh, int outw)
IplImage *sized = cvCreateImage(cvSize(w,h), src->depth, src->nChannels);
cvResize(src, sized, CV_INTER_LINEAR);
image im = ipl_to_image(sized);
normalize_array(im.data, im.h*im.w*im.c);
resize_network(net, im.h, im.w, im.c);
forward_network(net, im.data);
image out = get_network_image_layer(net, 6);
@ -523,6 +531,69 @@ void features_VOC_image_size(char *image_path, int h, int w)
free_image(out);
cvReleaseImage(&src);
}
void visualize_imagenet_topk(char *filename)
{
int i,j,k,l;
int topk = 10;
network net = parse_network_cfg("cfg/voc_imagenet.cfg");
list *plist = get_paths(filename);
node *n = plist->front;
int h = voc_size(1), w = voc_size(1);
int num = get_network_image(net).c;
image **vizs = calloc(num, sizeof(image*));
float **score = calloc(num, sizeof(float *));
for(i = 0; i < num; ++i){
vizs[i] = calloc(topk, sizeof(image));
for(j = 0; j < topk; ++j) vizs[i][j] = make_image(h,w,3);
score[i] = calloc(topk, sizeof(float));
}
while(n){
char *image_path = (char *)n->val;
image im = load_image(image_path, 0, 0);
n = n->next;
if(im.h < 200 || im.w < 200) continue;
printf("Processing %dx%d image\n", im.h, im.w);
resize_network(net, im.h, im.w, im.c);
//scale_image(im, 1./255);
translate_image(im, -144);
forward_network(net, im.data);
image out = get_network_image(net);
int dh = (im.h - h)/h;
int dw = (im.w - w)/w;
for(i = 0; i < out.h; ++i){
for(j = 0; j < out.w; ++j){
image sub = get_sub_image(im, dh*i, dw*j, h, w);
for(k = 0; k < out.c; ++k){
float val = get_pixel(out, i, j, k);
//printf("%f, ", val);
image sub_c = copy_image(sub);
for(l = 0; l < topk; ++l){
if(val > score[k][l]){
float swap = score[k][l];
score[k][l] = val;
val = swap;
image swapi = vizs[k][l];
vizs[k][l] = sub_c;
sub_c = swapi;
}
}
free_image(sub_c);
}
free_image(sub);
}
}
free_image(im);
//printf("\n");
image grid = grid_images(vizs, num, topk);
show_image(grid, "IMAGENET Visualization");
save_image(grid, "IMAGENET Grid");
free_image(grid);
}
//cvWaitKey(0);
}
void visualize_imagenet_features(char *filename)
{
@ -566,6 +637,20 @@ void visualize_imagenet_features(char *filename)
cvWaitKey(0);
}
void visualize_cat()
{
network net = parse_network_cfg("cfg/voc_imagenet.cfg");
image im = load_image("data/cat.png", 0, 0);
printf("Processing %dx%d image\n", im.h, im.w);
resize_network(net, im.h, im.w, im.c);
forward_network(net, im.data);
image out = get_network_image(net);
visualize_network(net);
cvWaitKey(1000);
cvWaitKey(0);
}
void features_VOC_image(char *image_file, char *image_dir, char *out_dir)
{
int i,j;
@ -693,7 +778,10 @@ int main(int argc, char *argv[])
//features_VOC_image(argv[1], argv[2], argv[3]);
//features_VOC_image_size(argv[1], atoi(argv[2]), atoi(argv[3]));
//visualize_imagenet_features("data/assira/train.list");
visualize_imagenet_features("data/VOC2011.list");
visualize_imagenet_topk("data/VOC2011.list");
//visualize_cat();
//flip_network();
//test_visualize();
fprintf(stderr, "Success!\n");
//test_random_preprocess();
//test_random_classify();