From f26da0ad5c679936274917c3d1e53821250414f6 Mon Sep 17 00:00:00 2001 From: Joseph Redmon Date: Sun, 28 Dec 2014 09:42:35 -0800 Subject: [PATCH] Need to fix line reads --- src/cnn.c | 12 ++++++++---- src/data.c | 53 +++++++++++++++++++++++++++++++++++++++------------ src/data.h | 2 ++ src/network.c | 4 ++-- src/parser.c | 1 + src/utils.c | 9 +++++---- 6 files changed, 59 insertions(+), 22 deletions(-) diff --git a/src/cnn.c b/src/cnn.c index 59948aae..1c74e5c0 100644 --- a/src/cnn.c +++ b/src/cnn.c @@ -84,11 +84,15 @@ void train_detection_net() list *plist = get_paths("/home/pjreddie/data/imagenet/horse.txt"); char **paths = (char **)list_to_array(plist); printf("%d\n", plist->size); + data train, buffer; + pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, 256, 256, 7, 7, 256, &buffer); clock_t time; while(1){ i += 1; time=clock(); - data train = load_data_detection_jitter_random(imgs, paths, plist->size, 256, 256, 7, 7, 256); + pthread_join(load_thread, 0); + train = buffer; + load_thread = load_data_detection_thread(imgs, paths, plist->size, 256, 256, 7, 7, 256, &buffer); //data train = load_data_detection_random(imgs, paths, plist->size, 224, 224, 7, 7, 256); /* @@ -102,7 +106,7 @@ void train_detection_net() float loss = train_network(net, train); avg_loss = avg_loss*.9 + loss*.1; printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs*net.batch); - if(i%10==0){ + if(i%100==0){ char buff[256]; sprintf(buff, "/home/pjreddie/imagenet_backup/detnet_%d.cfg", i); save_network(net, buff); @@ -155,10 +159,10 @@ void train_imagenet(char *cfgfile) //network net = parse_network_cfg("/home/pjreddie/imagenet_backup/alexnet_1270.cfg"); srand(time(0)); network net = parse_network_cfg(cfgfile); - set_learning_network(&net, net.learning_rate/10., .5, .0005); + //set_learning_network(&net, net.learning_rate, 0, .0005); printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); int imgs = 1024; - int i = 44700; + int i = 47900; char **labels = get_labels("/home/pjreddie/data/imagenet/cls.labels.list"); list *plist = get_paths("/data/imagenet/cls.train.list"); char **paths = (char **)list_to_array(plist); diff --git a/src/data.c b/src/data.c index 3f74f6bd..12dc1017 100644 --- a/src/data.c +++ b/src/data.c @@ -6,6 +6,20 @@ #include #include +struct load_args{ + char **paths; + int n; + int m; + char **labels; + int k; + int h; + int w; + int nh; + int nw; + float scale; + data *d; +}; + list *get_paths(char *filename) { char *path; @@ -165,11 +179,36 @@ data load_data_detection_jitter_random(int n, char **paths, int m, int h, int w, jitter_image(a,224,224,dy,dx); } d.X.cols = 224*224*3; - // print_matrix(d.y); free(random_paths); return d; } +void *load_detection_thread(void *ptr) +{ + struct load_args a = *(struct load_args*)ptr; + *a.d = load_data_detection_jitter_random(a.n, a.paths, a.m, a.h, a.w, a.nh, a.nw, a.scale); + free(ptr); + return 0; +} + +pthread_t load_data_detection_thread(int n, char **paths, int m, int h, int w, int nh, int nw, float scale, data *d) +{ + pthread_t thread; + struct load_args *args = calloc(1, sizeof(struct load_args)); + args->n = n; + args->paths = paths; + args->m = m; + args->h = h; + args->w = w; + args->nh = nh; + args->nw = nw; + args->scale = scale; + args->d = d; + if(pthread_create(&thread, 0, load_detection_thread, args)) { + error("Thread creation failed"); + } + return thread; +} data load_data_detection_random(int n, char **paths, int m, int h, int w, int nh, int nw, float scale) { @@ -193,21 +232,11 @@ data load_data(char **paths, int n, int m, char **labels, int k, int h, int w) return d; } -struct load_args{ - char **paths; - int n; - int m; - char **labels; - int k; - int h; - int w; - data *d; -}; - void *load_in_thread(void *ptr) { struct load_args a = *(struct load_args*)ptr; *a.d = load_data(a.paths, a.n, a.m, a.labels, a.k, a.h, a.w); + free(ptr); return 0; } diff --git a/src/data.h b/src/data.h index 1c0b732a..367416eb 100644 --- a/src/data.h +++ b/src/data.h @@ -17,6 +17,8 @@ void free_data(data d); data load_data(char **paths, int n, int m, char **labels, int k, int h, int w); pthread_t load_data_thread(char **paths, int n, int m, char **labels, int k, int h, int w, data *d); +pthread_t load_data_detection_thread(int n, char **paths, int m, int h, int w, int nh, int nw, float scale, data *d); + data load_data_detection_random(int n, char **paths, int m, int h, int w, int nh, int nw, float scale); data load_data_detection_jitter_random(int n, char **paths, int m, int h, int w, int nh, int nw, float scale); data load_data_image_pathfile(char *filename, char **labels, int k, int h, int w); diff --git a/src/network.c b/src/network.c index 42253dc9..e4831efe 100644 --- a/src/network.c +++ b/src/network.c @@ -103,8 +103,8 @@ void update_network(network net) } else if(net.types[i] == CONNECTED){ connected_layer layer = *(connected_layer *)net.layers[i]; - secret_update_connected_layer((connected_layer *)net.layers[i]); - //update_connected_layer(layer); + //secret_update_connected_layer((connected_layer *)net.layers[i]); + update_connected_layer(layer); } } } diff --git a/src/parser.c b/src/parser.c index d53e87cb..37ceb085 100644 --- a/src/parser.c +++ b/src/parser.c @@ -416,6 +416,7 @@ list *read_cfg(char *filename) strip(line); switch(line[0]){ case '[': + printf("%s\n", line); current = malloc(sizeof(section)); list_insert(sections, current); current->options = make_list(); diff --git a/src/utils.c b/src/utils.c index 682d3043..0878b746 100644 --- a/src/utils.c +++ b/src/utils.c @@ -106,16 +106,17 @@ void strip_char(char *s, char bad) char *fgetl(FILE *fp) { if(feof(fp)) return 0; - unsigned long size = 512; + size_t size = 512; char *line = malloc(size*sizeof(char)); if(!fgets(line, size, fp)){ free(line); return 0; } - int curr = strlen(line); + size_t curr = strlen(line); - while(line[curr-1]!='\n'){ + while((line[curr-1] != '\n') && !feof(fp)){ + printf("%ld %ld\n", curr, size); size *= 2; line = realloc(line, size*sizeof(char)); if(!line) { @@ -125,7 +126,7 @@ char *fgetl(FILE *fp) fgets(&line[curr], size-curr, fp); curr = strlen(line); } - line[curr-1] = '\0'; + if(line[curr-1] == '\n') line[curr-1] = '\0'; return line; }