diff --git a/Makefile b/Makefile index fda7d888..4c1bb148 100644 --- a/Makefile +++ b/Makefile @@ -1,12 +1,12 @@ CC=gcc COMMON=-Wall `pkg-config --cflags opencv` -CFLAGS= $(COMMON) -O3 -ffast-math -flto UNAME = $(shell uname) ifeq ($(UNAME), Darwin) COMMON += -isystem /usr/local/Cellar/opencv/2.4.6.1/include/opencv -isystem /usr/local/Cellar/opencv/2.4.6.1/include else -CFLAGS += -march=native +COMMON += -march=native endif +CFLAGS= $(COMMON) -Ofast -flto #CFLAGS= $(COMMON) -O0 -g LDFLAGS=`pkg-config --libs opencv` -lm VPATH=./src/ diff --git a/connected.cfg b/connected.cfg deleted file mode 100644 index dc2c073c..00000000 --- a/connected.cfg +++ /dev/null @@ -1,8 +0,0 @@ -[conn] -input=1690 -output = 10 -activation=relu - -[conn] -output = 1 -activation=relu diff --git a/convolutional.cfg b/convolutional.cfg deleted file mode 100644 index 1612c9cb..00000000 --- a/convolutional.cfg +++ /dev/null @@ -1,9 +0,0 @@ -[conv] -width=200 -height=200 -channels=3 -filters=10 -size=15 -stride=16 -activation=relu - diff --git a/full.cfg b/full.cfg deleted file mode 100644 index 78e938fb..00000000 --- a/full.cfg +++ /dev/null @@ -1,17 +0,0 @@ -[conv] -width=64 -height=64 -channels=3 -filters=10 -size=11 -stride=2 -activation=ramp - -[maxpool] -stride=2 - -[conn] -output = 2 -activation=ramp - -[softmax] diff --git a/nist.cfg b/nist.cfg deleted file mode 100644 index 46e32233..00000000 --- a/nist.cfg +++ /dev/null @@ -1,30 +0,0 @@ -[conv] -width=28 -height=28 -channels=1 -filters=20 -size=5 -stride=1 -activation=ramp - -[maxpool] -stride=2 - -[conv] -filters=50 -size=5 -stride=1 -activation=ramp - -[maxpool] -stride=2 - -[conn] -output = 500 -activation=ramp - -[conn] -output = 10 -activation=ramp - -[softmax] diff --git a/nist_basic.cfg b/nist_basic.cfg deleted file mode 100644 index 71427358..00000000 --- a/nist_basic.cfg +++ /dev/null @@ -1,14 +0,0 @@ -[conv] -width=28 -height=28 -channels=1 -filters=20 -size=11 -stride=1 -activation=linear - -[conn] -output = 10 -activation=ramp - -[softmax] diff --git a/src/activations.c b/src/activations.c index cc923d0e..c81d6aa5 100644 --- a/src/activations.c +++ b/src/activations.c @@ -4,6 +4,25 @@ #include #include +char *get_activation_string(ACTIVATION a) +{ + switch(a){ + case SIGMOID: + return "sigmoid"; + case RELU: + return "relu"; + case RAMP: + return "ramp"; + case LINEAR: + return "linear"; + case TANH: + return "tanh"; + default: + break; + } + return "relu"; +} + ACTIVATION get_activation(char *s) { if (strcmp(s, "sigmoid")==0) return SIGMOID; diff --git a/src/activations.h b/src/activations.h index fb2c54f4..94741215 100644 --- a/src/activations.h +++ b/src/activations.h @@ -7,6 +7,7 @@ typedef enum{ ACTIVATION get_activation(char *s); +char *get_activation_string(ACTIVATION a); float activate(float x, ACTIVATION a); float gradient(float x, ACTIVATION a); diff --git a/src/connected_layer.c b/src/connected_layer.c index 5f6631cb..07fad695 100644 --- a/src/connected_layer.c +++ b/src/connected_layer.c @@ -19,23 +19,46 @@ connected_layer *make_connected_layer(int inputs, int outputs, ACTIVATION activa layer->delta = calloc(outputs, sizeof(float*)); layer->weight_updates = calloc(inputs*outputs, sizeof(float)); + layer->weight_adapt = calloc(inputs*outputs, sizeof(float)); layer->weight_momentum = calloc(inputs*outputs, sizeof(float)); layer->weights = calloc(inputs*outputs, sizeof(float)); - float scale = 2./inputs; + float scale = 1./inputs; for(i = 0; i < inputs*outputs; ++i) - layer->weights[i] = rand_normal()*scale; + layer->weights[i] = scale*(rand_uniform()); layer->bias_updates = calloc(outputs, sizeof(float)); + layer->bias_adapt = calloc(outputs, sizeof(float)); layer->bias_momentum = calloc(outputs, sizeof(float)); layer->biases = calloc(outputs, sizeof(float)); for(i = 0; i < outputs; ++i) //layer->biases[i] = rand_normal()*scale + scale; - layer->biases[i] = 0; + layer->biases[i] = 1; layer->activation = activation; return layer; } +/* +void update_connected_layer(connected_layer layer, float step, float momentum, float decay) +{ + int i; + for(i = 0; i < layer.outputs; ++i){ + float delta = layer.bias_updates[i]; + layer.bias_adapt[i] += delta*delta; + layer.bias_momentum[i] = step/sqrt(layer.bias_adapt[i])*(layer.bias_updates[i]) + momentum*layer.bias_momentum[i]; + layer.biases[i] += layer.bias_momentum[i]; + } + for(i = 0; i < layer.outputs*layer.inputs; ++i){ + float delta = layer.weight_updates[i]; + layer.weight_adapt[i] += delta*delta; + layer.weight_momentum[i] = step/sqrt(layer.weight_adapt[i])*(layer.weight_updates[i] - decay*layer.weights[i]) + momentum*layer.weight_momentum[i]; + layer.weights[i] += layer.weight_momentum[i]; + } + memset(layer.bias_updates, 0, layer.outputs*sizeof(float)); + memset(layer.weight_updates, 0, layer.outputs*layer.inputs*sizeof(float)); +} +*/ + void update_connected_layer(connected_layer layer, float step, float momentum, float decay) { int i; @@ -65,6 +88,7 @@ void forward_connected_layer(connected_layer layer, float *input) for(i = 0; i < layer.outputs; ++i){ layer.output[i] = activate(layer.output[i], layer.activation); } + //for(i = 0; i < layer.outputs; ++i) if(i%(layer.outputs/10+1)==0) printf("%f, ", layer.output[i]); printf("\n"); } void learn_connected_layer(connected_layer layer, float *input) diff --git a/src/connected_layer.h b/src/connected_layer.h index ce0181d4..4b17c59b 100644 --- a/src/connected_layer.h +++ b/src/connected_layer.h @@ -12,6 +12,9 @@ typedef struct{ float *weight_updates; float *bias_updates; + float *weight_adapt; + float *bias_adapt; + float *weight_momentum; float *bias_momentum; diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index cdfe9e1a..6a103f6e 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -41,8 +41,8 @@ convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int si layer->biases = calloc(n, sizeof(float)); layer->bias_updates = calloc(n, sizeof(float)); layer->bias_momentum = calloc(n, sizeof(float)); - float scale = 2./(size*size); - for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = rand_normal()*scale; + float scale = 1./(size*size*c); + for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*(rand_uniform()); for(i = 0; i < n; ++i){ //layer->biases[i] = rand_normal()*scale + scale; layer->biases[i] = 0; @@ -65,6 +65,7 @@ convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int si void forward_convolutional_layer(const convolutional_layer layer, float *in) { + int i; int m = layer.n; int k = layer.size*layer.size*layer.c; int n = ((layer.h-layer.size)/layer.stride + 1)* @@ -79,6 +80,11 @@ void forward_convolutional_layer(const convolutional_layer layer, float *in) im2col_cpu(in, layer.c, layer.h, layer.w, layer.size, layer.stride, b); gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); + for(i = 0; i < m*n; ++i){ + layer.output[i] = activate(layer.output[i], layer.activation); + } + //for(i = 0; i < m*n; ++i) if(i%(m*n/10+1)==0) printf("%f, ", layer.output[i]); printf("\n"); + } void gradient_delta_convolutional_layer(convolutional_layer layer) diff --git a/src/data.c b/src/data.c index 2c5932b0..035efa18 100644 --- a/src/data.c +++ b/src/data.c @@ -30,7 +30,7 @@ void fill_truth(char *path, char **labels, int k, float *truth) } } -data load_data_image_paths(char **paths, int n, char **labels, int k) +data load_data_image_paths(char **paths, int n, char **labels, int k, int h, int w) { int i; data d; @@ -40,7 +40,7 @@ data load_data_image_paths(char **paths, int n, char **labels, int k) d.y = make_matrix(n, k); for(i = 0; i < n; ++i){ - image im = load_image(paths[i]); + image im = load_image(paths[i], h, w); d.X.vals[i] = im.data; d.X.cols = im.h*im.w*im.c; fill_truth(paths[i], labels, k, d.y.vals[i]); @@ -48,11 +48,11 @@ data load_data_image_paths(char **paths, int n, char **labels, int k) return d; } -data load_data_image_pathfile(char *filename, char **labels, int k) +data load_data_image_pathfile(char *filename, char **labels, int k, int h, int w) { list *plist = get_paths(filename); char **paths = (char **)list_to_array(plist); - data d = load_data_image_paths(paths, plist->size, labels, k); + data d = load_data_image_paths(paths, plist->size, labels, k, h, w); free_list_contents(plist); free_list(plist); free(paths); @@ -70,20 +70,20 @@ void free_data(data d) } } -data load_data_image_pathfile_part(char *filename, int part, int total, char **labels, int k) +data load_data_image_pathfile_part(char *filename, int part, int total, char **labels, int k, int h, int w) { list *plist = get_paths(filename); char **paths = (char **)list_to_array(plist); int start = part*plist->size/total; int end = (part+1)*plist->size/total; - data d = load_data_image_paths(paths+start, end-start, labels, k); + data d = load_data_image_paths(paths+start, end-start, labels, k, h, w); free_list_contents(plist); free_list(plist); free(paths); return d; } -data load_data_image_pathfile_random(char *filename, int n, char **labels, int k) +data load_data_image_pathfile_random(char *filename, int n, char **labels, int k, int h, int w) { int i; list *plist = get_paths(filename); @@ -92,8 +92,9 @@ data load_data_image_pathfile_random(char *filename, int n, char **labels, int k for(i = 0; i < n; ++i){ int index = rand()%plist->size; random_paths[i] = paths[index]; + if(i == 0) printf("%s\n", paths[index]); } - data d = load_data_image_paths(random_paths, n, labels, k); + data d = load_data_image_paths(random_paths, n, labels, k, h, w); free_list_contents(plist); free_list(plist); free(paths); @@ -133,6 +134,14 @@ void randomize_data(data d) } } +void scale_data_rows(data d, float s) +{ + int i; + for(i = 0; i < d.X.rows; ++i){ + scale_array(d.X.vals[i], d.X.cols, s); + } +} + void normalize_data_rows(data d) { int i; diff --git a/src/data.h b/src/data.h index e887d0b1..e1709741 100644 --- a/src/data.h +++ b/src/data.h @@ -10,14 +10,15 @@ typedef struct{ } data; -data load_data_image_pathfile(char *filename, char **labels, int k); void free_data(data d); -data load_data_image_pathfile(char *filename, char **labels, int k); +data load_data_image_pathfile(char *filename, char **labels, int k, int h, int w); data load_data_image_pathfile_part(char *filename, int part, int total, - char **labels, int k); -data load_data_image_pathfile_random(char *filename, int n, char **labels, int k); + char **labels, int k, int h, int w); +data load_data_image_pathfile_random(char *filename, int n, char **labels, + int k, int h, int w); data load_categorical_data_csv(char *filename, int target, int k); void normalize_data_rows(data d); +void scale_data_rows(data d, float s); void randomize_data(data d); data *split_data(data d, int part, int total); diff --git a/src/image.c b/src/image.c index 62ee5f7e..460df3d8 100644 --- a/src/image.c +++ b/src/image.c @@ -242,8 +242,107 @@ image make_random_kernel(int size, int c, float scale) return out; } +// Returns a new image that is a cropped version (rectangular cut-out) +// of the original image. +IplImage* cropImage(const IplImage *img, const CvRect region) +{ + IplImage *imageCropped; + CvSize size; -image load_image(char *filename) + if (img->width <= 0 || img->height <= 0 + || region.width <= 0 || region.height <= 0) { + //cerr << "ERROR in cropImage(): invalid dimensions." << endl; + exit(1); + } + + if (img->depth != IPL_DEPTH_8U) { + //cerr << "ERROR in cropImage(): image depth is not 8." << endl; + exit(1); + } + + // Set the desired region of interest. + cvSetImageROI((IplImage*)img, region); + // Copy region of interest into a new iplImage and return it. + size.width = region.width; + size.height = region.height; + imageCropped = cvCreateImage(size, IPL_DEPTH_8U, img->nChannels); + cvCopy(img, imageCropped,NULL); // Copy just the region. + + return imageCropped; +} + +// Creates a new image copy that is of a desired size. The aspect ratio will +// be kept constant if 'keepAspectRatio' is true, by cropping undesired parts +// so that only pixels of the original image are shown, instead of adding +// extra blank space. +// Remember to free the new image later. +IplImage* resizeImage(const IplImage *origImg, int newHeight, int newWidth, + int keepAspectRatio) +{ + IplImage *outImg = 0; + int origWidth = 0; + int origHeight = 0; + if (origImg) { + origWidth = origImg->width; + origHeight = origImg->height; + } + if (newWidth <= 0 || newHeight <= 0 || origImg == 0 + || origWidth <= 0 || origHeight <= 0) { + //cerr << "ERROR: Bad desired image size of " << newWidth + // << "x" << newHeight << " in resizeImage().\n"; + exit(1); + } + + if (keepAspectRatio) { + // Resize the image without changing its aspect ratio, + // by cropping off the edges and enlarging the middle section. + CvRect r; + // input aspect ratio + float origAspect = (origWidth / (float)origHeight); + // output aspect ratio + float newAspect = (newWidth / (float)newHeight); + // crop width to be origHeight * newAspect + if (origAspect > newAspect) { + int tw = (origHeight * newWidth) / newHeight; + r = cvRect((origWidth - tw)/2, 0, tw, origHeight); + } + else { // crop height to be origWidth / newAspect + int th = (origWidth * newHeight) / newWidth; + r = cvRect(0, (origHeight - th)/2, origWidth, th); + } + IplImage *croppedImg = cropImage(origImg, r); + + // Call this function again, with the new aspect ratio image. + // Will do a scaled image resize with the correct aspect ratio. + outImg = resizeImage(croppedImg, newHeight, newWidth, 0); + cvReleaseImage( &croppedImg ); + + } + else { + + // Scale the image to the new dimensions, + // even if the aspect ratio will be changed. + outImg = cvCreateImage(cvSize(newWidth, newHeight), + origImg->depth, origImg->nChannels); + if (newWidth > origImg->width && newHeight > origImg->height) { + // Make the image larger + cvResetImageROI((IplImage*)origImg); + // CV_INTER_LINEAR: good at enlarging. + // CV_INTER_CUBIC: good at enlarging. + cvResize(origImg, outImg, CV_INTER_LINEAR); + } + else { + // Make the image smaller + cvResetImageROI((IplImage*)origImg); + // CV_INTER_AREA: good at shrinking (decimation) only. + cvResize(origImg, outImg, CV_INTER_AREA); + } + + } + return outImg; +} + +image load_image(char *filename, int h, int w) { IplImage* src = 0; if( (src = cvLoadImage(filename,-1)) == 0 ) @@ -251,10 +350,14 @@ image load_image(char *filename) printf("Cannot load file image %s\n", filename); exit(0); } + cvShowImage("Orig", src); + IplImage *resized = resizeImage(src, h, w, 1); + cvShowImage("Sized", resized); + cvWaitKey(0); + cvReleaseImage(&src); + src = resized; unsigned char *data = (unsigned char *)src->imageData; int c = src->nChannels; - int h = src->height; - int w = src->width; int step = src->widthStep; image out = make_image(h,w,c); int i, j, k, count=0;; @@ -363,14 +466,14 @@ void convolve(image m, image kernel, int stride, int channel, image out, int edg two_d_convolve(m, i, kernel, i, stride, out, channel, edge); } /* - int j; - for(i = 0; i < m.h; i += stride){ - for(j = 0; j < m.w; j += stride){ - float val = single_convolve(m, kernel, i, j); - set_pixel(out, i/stride, j/stride, channel, val); - } - } - */ + int j; + for(i = 0; i < m.h; i += stride){ + for(j = 0; j < m.w; j += stride){ + float val = single_convolve(m, kernel, i, j); + set_pixel(out, i/stride, j/stride, channel, val); + } + } + */ } void upsample_image(image m, int stride, image out) @@ -422,10 +525,10 @@ void kernel_update(image m, image update, int stride, int channel, image out, in } } /* - for(i = 0; i < update.h*update.w*update.c; ++i){ - update.data[i] /= (m.h/stride)*(m.w/stride); - } - */ + for(i = 0; i < update.h*update.w*update.c; ++i){ + update.data[i] /= (m.h/stride)*(m.w/stride); + } + */ } void single_back_convolve(image m, image kernel, int x, int y, float val) diff --git a/src/image.h b/src/image.h index 72c4b2c0..2c5d38ac 100644 --- a/src/image.h +++ b/src/image.h @@ -33,13 +33,12 @@ image make_random_image(int h, int w, int c); image make_random_kernel(int size, int c, float scale); image float_to_image(int h, int w, int c, float *data); image copy_image(image p); -image load_image(char *filename); +image load_image(char *filename, int h, int w); float get_pixel(image m, int x, int y, int c); float get_pixel_extend(image m, int x, int y, int c); void set_pixel(image m, int x, int y, int c, float val); - image get_image_layer(image m, int l); void two_d_convolve(image m, int mc, image kernel, int kc, int stride, image out, int oc, int edge); diff --git a/src/mini_blas.c b/src/mini_blas.c index b9a43049..262798bc 100644 --- a/src/mini_blas.c +++ b/src/mini_blas.c @@ -159,7 +159,7 @@ void time_random_matrix(int TA, int TB, int m, int k, int n) gemm(TA,TB,m,n,k,1,a,k,b,n,1,c,n); } end = clock(); - printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf ms\n",m,k,k,n, TA, TB, (double)(end-start)/CLOCKS_PER_SEC); + printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf ms\n",m,k,k,n, TA, TB, (float)(end-start)/CLOCKS_PER_SEC); } void test_blas() diff --git a/src/network.c b/src/network.c index 29e22e4e..f7abf580 100644 --- a/src/network.c +++ b/src/network.c @@ -21,6 +21,77 @@ network make_network(int n) return net; } +void print_convolutional_cfg(FILE *fp, convolutional_layer *l) +{ + int i; + fprintf(fp, "[convolutional]\n" + "height=%d\n" + "width=%d\n" + "channels=%d\n" + "filters=%d\n" + "size=%d\n" + "stride=%d\n" + "activation=%s\n", + l->h, l->w, l->c, + l->n, l->size, l->stride, + get_activation_string(l->activation)); + fprintf(fp, "data="); + for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]); + for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]); + fprintf(fp, "\n\n"); +} +void print_connected_cfg(FILE *fp, connected_layer *l) +{ + int i; + fprintf(fp, "[connected]\n" + "input=%d\n" + "output=%d\n" + "activation=%s\n", + l->inputs, l->outputs, + get_activation_string(l->activation)); + fprintf(fp, "data="); + for(i = 0; i < l->outputs; ++i) fprintf(fp, "%g,", l->biases[i]); + for(i = 0; i < l->inputs*l->outputs; ++i) fprintf(fp, "%g,", l->weights[i]); + fprintf(fp, "\n\n"); +} + +void print_maxpool_cfg(FILE *fp, maxpool_layer *l) +{ + fprintf(fp, "[maxpool]\n" + "height=%d\n" + "width=%d\n" + "channels=%d\n" + "stride=%d\n\n", + l->h, l->w, l->c, + l->stride); +} + +void print_softmax_cfg(FILE *fp, softmax_layer *l) +{ + fprintf(fp, "[softmax]\n" + "input=%d\n\n", + l->inputs); +} + +void save_network(network net, char *filename) +{ + FILE *fp = fopen(filename, "w"); + if(!fp) file_error(filename); + int i; + for(i = 0; i < net.n; ++i) + { + if(net.types[i] == CONVOLUTIONAL) + print_convolutional_cfg(fp, (convolutional_layer *)net.layers[i]); + else if(net.types[i] == CONNECTED) + print_connected_cfg(fp, (connected_layer *)net.layers[i]); + else if(net.types[i] == MAXPOOL) + print_maxpool_cfg(fp, (maxpool_layer *)net.layers[i]); + else if(net.types[i] == SOFTMAX) + print_softmax_cfg(fp, (softmax_layer *)net.layers[i]); + } + fclose(fp); +} + void forward_network(network net, float *input) { int i; @@ -64,7 +135,7 @@ void update_network(network net, float step, float momentum, float decay) } else if(net.types[i] == CONNECTED){ connected_layer layer = *(connected_layer *)net.layers[i]; - update_connected_layer(layer, step, momentum, 0); + update_connected_layer(layer, step, momentum, decay); } } } @@ -121,9 +192,11 @@ float calculate_error_network(network net, float *truth) float *out = get_network_output(net); int i, k = get_network_output_size(net); for(i = 0; i < k; ++i){ + printf("%f, ", out[i]); delta[i] = truth[i] - out[i]; sum += delta[i]*delta[i]; } + printf("\n"); return sum; } @@ -173,25 +246,31 @@ float backward_network(network net, float *input, float *truth) float train_network_datum(network net, float *x, float *y, float step, float momentum, float decay) { - forward_network(net, x); - int class = get_predicted_class_network(net); - float error = backward_network(net, x, y); - update_network(net, step, momentum, decay); - //return (y[class]?1:0); - return error; + forward_network(net, x); + //int class = get_predicted_class_network(net); + float error = backward_network(net, x, y); + update_network(net, step, momentum, decay); + //return (y[class]?1:0); + return error; } float train_network_sgd(network net, data d, int n, float step, float momentum,float decay) { int i; float error = 0; + int correct = 0; for(i = 0; i < n; ++i){ int index = rand()%d.X.rows; error += train_network_datum(net, d.X.vals[index], d.y.vals[index], step, momentum, decay); + float *y = d.y.vals[index]; + int class = get_predicted_class_network(net); + correct += (y[class]?1:0); + //printf("%d %f %f\n", i,net.output[0], d.y.vals[index][0]); //if((i+1)%10 == 0){ // printf("%d: %f\n", (i+1), (float)correct/(i+1)); //} } + printf("Accuracy: %f\n",(float) correct/n); return error/n; } float train_network_batch(network net, data d, int n, float step, float momentum,float decay) diff --git a/src/network.h b/src/network.h index 17cc10bb..a8b2860f 100644 --- a/src/network.h +++ b/src/network.h @@ -40,6 +40,7 @@ image get_network_image_layer(network net, int i); int get_predicted_class_network(network net); void print_network(network net); void visualize_network(network net); +void save_network(network net, char *filename); #endif diff --git a/src/option_list.c b/src/option_list.c index 7902cd9c..bb8b7101 100644 --- a/src/option_list.c +++ b/src/option_list.c @@ -3,12 +3,6 @@ #include #include "option_list.h" -typedef struct{ - char *key; - char *val; - int used; -} kvp; - void option_insert(list *l, char *key, char *val) { kvp *p = malloc(sizeof(kvp)); @@ -47,7 +41,7 @@ char *option_find_str(list *l, char *key, char *def) { char *v = option_find(l, key); if(v) return v; - fprintf(stderr, "%s: Using default '%s'\n", key, def); + if(def) fprintf(stderr, "%s: Using default '%s'\n", key, def); return def; } diff --git a/src/option_list.h b/src/option_list.h index 60e37fec..26cd36fc 100644 --- a/src/option_list.h +++ b/src/option_list.h @@ -2,6 +2,13 @@ #define OPTION_LIST_H #include "list.h" +typedef struct{ + char *key; + char *val; + int used; +} kvp; + + void option_insert(list *l, char *key, char *val); char *option_find(list *l, char *key); char *option_find_str(list *l, char *key, char *def); diff --git a/src/parser.c b/src/parser.c index eeb6f930..cf35a94a 100644 --- a/src/parser.c +++ b/src/parser.c @@ -23,6 +23,130 @@ int is_maxpool(section *s); int is_softmax(section *s); list *read_cfg(char *filename); +void free_section(section *s) +{ + free(s->type); + node *n = s->options->front; + while(n){ + kvp *pair = (kvp *)n->val; + free(pair->key); + free(pair); + node *next = n->next; + free(n); + n = next; + } + free(s->options); + free(s); +} + +convolutional_layer *parse_convolutional(list *options, network net, int count) +{ + int i; + int h,w,c; + int n = option_find_int(options, "filters",1); + int size = option_find_int(options, "size",1); + int stride = option_find_int(options, "stride",1); + char *activation_s = option_find_str(options, "activation", "sigmoid"); + ACTIVATION activation = get_activation(activation_s); + if(count == 0){ + h = option_find_int(options, "height",1); + w = option_find_int(options, "width",1); + c = option_find_int(options, "channels",1); + }else{ + image m = get_network_image_layer(net, count-1); + h = m.h; + w = m.w; + c = m.c; + if(h == 0) error("Layer before convolutional layer must output image."); + } + convolutional_layer *layer = make_convolutional_layer(h,w,c,n,size,stride, activation); + char *data = option_find_str(options, "data", 0); + if(data){ + char *curr = data; + char *next = data; + for(i = 0; i < n; ++i){ + while(*++next !='\0' && *next != ','); + *next = '\0'; + sscanf(curr, "%g", &layer->biases[i]); + curr = next+1; + } + for(i = 0; i < c*n*size*size; ++i){ + while(*++next !='\0' && *next != ','); + *next = '\0'; + sscanf(curr, "%g", &layer->filters[i]); + curr = next+1; + } + } + option_unused(options); + return layer; +} + +connected_layer *parse_connected(list *options, network net, int count) +{ + int i; + int input; + int output = option_find_int(options, "output",1); + char *activation_s = option_find_str(options, "activation", "sigmoid"); + ACTIVATION activation = get_activation(activation_s); + if(count == 0){ + input = option_find_int(options, "input",1); + }else{ + input = get_network_output_size_layer(net, count-1); + } + connected_layer *layer = make_connected_layer(input, output, activation); + char *data = option_find_str(options, "data", 0); + if(data){ + char *curr = data; + char *next = data; + for(i = 0; i < output; ++i){ + while(*++next !='\0' && *next != ','); + *next = '\0'; + sscanf(curr, "%g", &layer->biases[i]); + curr = next+1; + } + for(i = 0; i < input*output; ++i){ + while(*++next !='\0' && *next != ','); + *next = '\0'; + sscanf(curr, "%g", &layer->weights[i]); + curr = next+1; + } + } + option_unused(options); + return layer; +} + +softmax_layer *parse_softmax(list *options, network net, int count) +{ + int input; + if(count == 0){ + input = option_find_int(options, "input",1); + }else{ + input = get_network_output_size_layer(net, count-1); + } + softmax_layer *layer = make_softmax_layer(input); + option_unused(options); + return layer; +} + +maxpool_layer *parse_maxpool(list *options, network net, int count) +{ + int h,w,c; + int stride = option_find_int(options, "stride",1); + if(count == 0){ + h = option_find_int(options, "height",1); + w = option_find_int(options, "width",1); + c = option_find_int(options, "channels",1); + }else{ + image m = get_network_image_layer(net, count-1); + h = m.h; + w = m.w; + c = m.c; + if(h == 0) error("Layer before convolutional layer must output image."); + } + maxpool_layer *layer = make_maxpool_layer(h,w,c,stride); + option_unused(options); + return layer; +} network parse_network_cfg(char *filename) { @@ -35,78 +159,29 @@ network parse_network_cfg(char *filename) section *s = (section *)n->val; list *options = s->options; if(is_convolutional(s)){ - int h,w,c; - int n = option_find_int(options, "filters",1); - int size = option_find_int(options, "size",1); - int stride = option_find_int(options, "stride",1); - char *activation_s = option_find_str(options, "activation", "sigmoid"); - ACTIVATION activation = get_activation(activation_s); - if(count == 0){ - h = option_find_int(options, "height",1); - w = option_find_int(options, "width",1); - c = option_find_int(options, "channels",1); - }else{ - image m = get_network_image_layer(net, count-1); - h = m.h; - w = m.w; - c = m.c; - if(h == 0) error("Layer before convolutional layer must output image."); - } - convolutional_layer *layer = make_convolutional_layer(h,w,c,n,size,stride, activation); + convolutional_layer *layer = parse_convolutional(options, net, count); net.types[count] = CONVOLUTIONAL; net.layers[count] = layer; - option_unused(options); - } - else if(is_connected(s)){ - int input; - int output = option_find_int(options, "output",1); - char *activation_s = option_find_str(options, "activation", "sigmoid"); - ACTIVATION activation = get_activation(activation_s); - if(count == 0){ - input = option_find_int(options, "input",1); - }else{ - input = get_network_output_size_layer(net, count-1); - } - connected_layer *layer = make_connected_layer(input, output, activation); + }else if(is_connected(s)){ + connected_layer *layer = parse_connected(options, net, count); net.types[count] = CONNECTED; net.layers[count] = layer; - option_unused(options); }else if(is_softmax(s)){ - int input; - if(count == 0){ - input = option_find_int(options, "input",1); - }else{ - input = get_network_output_size_layer(net, count-1); - } - softmax_layer *layer = make_softmax_layer(input); + softmax_layer *layer = parse_softmax(options, net, count); net.types[count] = SOFTMAX; net.layers[count] = layer; - option_unused(options); }else if(is_maxpool(s)){ - int h,w,c; - int stride = option_find_int(options, "stride",1); - //char *activation_s = option_find_str(options, "activation", "sigmoid"); - if(count == 0){ - h = option_find_int(options, "height",1); - w = option_find_int(options, "width",1); - c = option_find_int(options, "channels",1); - }else{ - image m = get_network_image_layer(net, count-1); - h = m.h; - w = m.w; - c = m.c; - if(h == 0) error("Layer before convolutional layer must output image."); - } - maxpool_layer *layer = make_maxpool_layer(h,w,c,stride); + maxpool_layer *layer = parse_maxpool(options, net, count); net.types[count] = MAXPOOL; net.layers[count] = layer; - option_unused(options); }else{ fprintf(stderr, "Type not recognized: %s\n", s->type); } + free_section(s); ++count; n = n->next; } + free_list(sections); net.outputs = get_network_output_size(net); net.output = get_network_output(net); return net; diff --git a/src/softmax_layer.c b/src/softmax_layer.c index 1e01bd20..79375de5 100644 --- a/src/softmax_layer.c +++ b/src/softmax_layer.c @@ -36,8 +36,11 @@ void forward_softmax_layer(const softmax_layer layer, float *input) } for(i = 0; i < layer.inputs; ++i){ sum += exp(input[i]-largest); + printf("%f, ", input[i]); } - sum = largest+log(sum); + printf("\n"); + if(sum) sum = largest+log(sum); + else sum = largest-100; for(i = 0; i < layer.inputs; ++i){ layer.output[i] = exp(input[i]-sum); } diff --git a/src/tests.c b/src/tests.c index 00cd1a12..09ec7b23 100644 --- a/src/tests.c +++ b/src/tests.c @@ -19,7 +19,7 @@ void test_convolve() { - image dog = load_image("dog.jpg"); + image dog = load_image("dog.jpg",300,400); printf("dog channels %d\n", dog.c); image kernel = make_random_image(3,3,dog.c); image edge = make_image(dog.h, dog.w, 1); @@ -35,7 +35,7 @@ void test_convolve() void test_convolve_matrix() { - image dog = load_image("dog.jpg"); + image dog = load_image("dog.jpg",300,400); printf("dog channels %d\n", dog.c); int size = 11; @@ -64,7 +64,7 @@ void test_convolve_matrix() void test_color() { - image dog = load_image("test_color.png"); + image dog = load_image("test_color.png", 300, 400); show_image_layers(dog, "Test Color"); } @@ -124,13 +124,13 @@ void verify_convolutional_layer() void test_load() { - image dog = load_image("dog.jpg"); + image dog = load_image("dog.jpg", 300, 400); show_image(dog, "Test Load"); show_image_layers(dog, "Test Load"); } void test_upsample() { - image dog = load_image("dog.jpg"); + image dog = load_image("dog.jpg", 300, 400); int n = 3; image up = make_image(n*dog.h, n*dog.w, dog.c); upsample_image(dog, n, up); @@ -141,7 +141,7 @@ void test_upsample() void test_rotate() { int i; - image dog = load_image("dog.jpg"); + image dog = load_image("dog.jpg",300,400); clock_t start = clock(), end; for(i = 0; i < 1001; ++i){ rotate_image(dog); @@ -184,24 +184,39 @@ void test_parser() void test_data() { char *labels[] = {"cat","dog"}; - data train = load_data_image_pathfile_random("train_paths.txt", 101,labels, 2); + data train = load_data_image_pathfile_random("train_paths.txt", 101,labels, 2, 300, 400); free_data(train); } void test_full() { network net = parse_network_cfg("full.cfg"); - srand(0); - int i = 0; + srand(2222222); + int i = 800; char *labels[] = {"cat","dog"}; float lr = .00001; float momentum = .9; float decay = 0.01; while(i++ < 1000 || 1){ - data train = load_data_image_pathfile_random("train_paths.txt", 1000, labels, 2); - train_network(net, train, lr, momentum, decay); + visualize_network(net); + cvWaitKey(100); + data train = load_data_image_pathfile_random("train_paths.txt", 1000, labels, 2, 256, 256); + image im = float_to_image(256, 256, 3,train.X.vals[0]); + show_image(im, "input"); + cvWaitKey(100); + //scale_data_rows(train, 1./255.); + normalize_data_rows(train); + clock_t start = clock(), end; + float loss = train_network_sgd(net, train, 100, lr, momentum, decay); + end = clock(); + printf("%d: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", i, loss, (float)(end-start)/CLOCKS_PER_SEC, lr, momentum, decay); free_data(train); - printf("Round %d\n", i); + if(i%100==0){ + char buff[256]; + sprintf(buff, "backup_%d.cfg", i); + //save_network(net, buff); + } + //lr *= .99; } } @@ -218,7 +233,7 @@ void test_nist() int count = 0; float lr = .0005; float momentum = .9; - float decay = 0.01; + float decay = 0.001; clock_t start = clock(), end; while(++count <= 100){ //visualize_network(net); @@ -227,7 +242,7 @@ void test_nist() end = clock(); printf("Time: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC); start=end; - cvWaitKey(100); + //cvWaitKey(100); //lr /= 2; if(count%5 == 0){ float train_acc = network_accuracy(net, train); @@ -235,7 +250,7 @@ void test_nist() float test_acc = network_accuracy(net, test); fprintf(stderr, "TEST: %f\n\n", test_acc); printf("%d, %f, %f\n", count, train_acc, test_acc); - lr *= .5; + //lr *= .5; } } } @@ -345,7 +360,38 @@ void test_im2row() int i; for(i = 0; i < 1000; ++i){ im2col_cpu(test.data, c, h, w, size, stride, matrix); - image render = float_to_image(mh, mw, mc, matrix); + //image render = float_to_image(mh, mw, mc, matrix); + } +} + +void train_VOC() +{ + network net = parse_network_cfg("cfg/voc_backup_ramp_80.cfg"); + srand(2222222); + int i = 0; + char *labels[] = {"aeroplane","bicycle","bird","boat","bottle","bus","car","cat","chair","cow","diningtable","dog","horse","motorbike","person","pottedplant","sheep","sofa","train","tvmonitor"}; + float lr = .00001; + float momentum = .9; + float decay = 0.01; + while(i++ < 1000 || 1){ + visualize_network(net); + cvWaitKey(100); + data train = load_data_image_pathfile_random("images/VOC2012/train_paths.txt", 1000, labels, 20, 300, 400); + image im = float_to_image(300, 400, 3,train.X.vals[0]); + show_image(im, "input"); + cvWaitKey(100); + normalize_data_rows(train); + clock_t start = clock(), end; + float loss = train_network_sgd(net, train, 1000, lr, momentum, decay); + end = clock(); + printf("%d: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", i, loss, (float)(end-start)/CLOCKS_PER_SEC, lr, momentum, decay); + free_data(train); + if(i%10==0){ + char buff[256]; + sprintf(buff, "cfg/voc_backup_ramp_%d.cfg", i); + save_network(net, buff); + } + //lr *= .99; } } @@ -358,8 +404,9 @@ int main() // test_im2row(); //test_split(); //test_ensemble(); - test_nist(); + //test_nist(); //test_full(); + train_VOC(); //test_random_preprocess(); //test_random_classify(); //test_parser(); diff --git a/src/utils.c b/src/utils.c index 41ee7681..67a9ba11 100644 --- a/src/utils.c +++ b/src/utils.c @@ -216,6 +216,10 @@ float rand_normal() for(i = 0; i < 12; ++i) sum += (float)rand()/RAND_MAX; return sum-6.; } +float rand_uniform() +{ + return (float)rand()/RAND_MAX; +} float **one_hot_encode(float *a, int n, int k) { diff --git a/src/utils.h b/src/utils.h index 8185107e..6fe0343a 100644 --- a/src/utils.h +++ b/src/utils.h @@ -20,6 +20,7 @@ void translate_array(float *a, int n, float s); int max_index(float *a, int n); float constrain(float a, float max); float rand_normal(); +float rand_uniform(); float mean_array(float *a, int n); float variance_array(float *a, int n); float **one_hot_encode(float *a, int n, int k); diff --git a/test.cfg b/test.cfg deleted file mode 100644 index fdbcc107..00000000 --- a/test.cfg +++ /dev/null @@ -1,37 +0,0 @@ -[conv] -width=200 -height=200 -channels=3 -filters=10 -size=15 -stride=16 -activation=relu - -#[maxpool] -#stride=2 - -#[conv] -#filters=10 -#size=10 -#stride=4 -#activation=relu - -#[maxpool] -#stride=2 - -#[conv] -#filters=10 -#size=10 -#stride=4 -#activation=relu - -#[maxpool] -#stride=2 - -[conn] -output = 10 -activation=relu - -[conn] -output = 1 -activation=relu