diff --git a/src/blas.h b/src/blas.h index 90f1a9b7..99099253 100644 --- a/src/blas.h +++ b/src/blas.h @@ -22,7 +22,7 @@ void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY); void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY); void scal_ongpu(int N, float ALPHA, float * X, int INCX); -void mask_ongpu(int N, float * X, float * mask); +void mask_ongpu(int N, float * X, float mask_num, float * mask); void const_ongpu(int N, float ALPHA, float *X, int INCX); void pow_ongpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY); void mul_ongpu(int N, float *X, int INCX, float *Y, int INCY); diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu index 2155801a..0c89c475 100644 --- a/src/blas_kernels.cu +++ b/src/blas_kernels.cu @@ -1,6 +1,7 @@ extern "C" { #include "blas.h" #include "cuda.h" +#include "utils.h" } __global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX, float *Y, int OFFY, int INCY) @@ -27,10 +28,10 @@ __global__ void scal_kernel(int N, float ALPHA, float *X, int INCX) if(i < N) X[i*INCX] *= ALPHA; } -__global__ void mask_kernel(int n, float *x, float *mask) +__global__ void mask_kernel(int n, float *x, float mask_num, float *mask) { int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; - if(i < n && mask[i] == 0) x[i] = 0; + if(i < n && mask[i] == mask_num) x[i] = mask_num; } __global__ void copy_kernel(int N, float *X, int OFFX, int INCX, float *Y, int OFFY, int INCY) @@ -79,9 +80,9 @@ extern "C" void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * check_error(cudaPeekAtLastError()); } -extern "C" void mask_ongpu(int N, float * X, float * mask) +extern "C" void mask_ongpu(int N, float * X, float mask_num, float * mask) { - mask_kernel<<>>(N, X, mask); + mask_kernel<<>>(N, X, mask_num, mask); check_error(cudaPeekAtLastError()); } diff --git a/src/captcha.c b/src/captcha.c index e369ffbc..ccefa450 100644 --- a/src/captcha.c +++ b/src/captcha.c @@ -2,11 +2,35 @@ #include "utils.h" #include "parser.h" - -void train_captcha(char *cfgfile, char *weightfile) +void fix_data_captcha(data d, int mask) { - float avg_loss = -1; + matrix labels = d.y; + int i, j; + for(i = 0; i < d.y.rows; ++i){ + for(j = 0; j < d.y.cols; j += 2){ + if (mask){ + if(!labels.vals[i][j]){ + labels.vals[i][j] = SECRET_NUM; + labels.vals[i][j+1] = SECRET_NUM; + }else if(labels.vals[i][j+1]){ + labels.vals[i][j] = 0; + } + } else{ + if (labels.vals[i][j]) { + labels.vals[i][j+1] = 0; + } else { + labels.vals[i][j+1] = 1; + } + } + } + } +} + +void train_captcha2(char *cfgfile, char *weightfile) +{ + data_seed = time(0); srand(time(0)); + float avg_loss = -1; char *base = basecfg(cfgfile); printf("%s\n", base); network net = parse_network_cfg(cfgfile); @@ -14,18 +38,38 @@ void train_captcha(char *cfgfile, char *weightfile) load_weights(&net, weightfile); } printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + //net.seen=0; int imgs = 1024; int i = net.seen/imgs; - list *plist = get_paths("/data/captcha/train.auto5"); + int solved = 1; + list *plist; + char **labels = get_labels("/data/captcha/reimgs.labels2.list"); + if (solved){ + plist = get_paths("/data/captcha/reimgs.solved.list"); + }else{ + plist = get_paths("/data/captcha/reimgs.train.list"); + } char **paths = (char **)list_to_array(plist); printf("%d\n", plist->size); clock_t time; + pthread_t load_thread; + data train; + data buffer; + load_thread = load_data_thread(paths, imgs, plist->size, labels, 26, net.w, net.h, &buffer); while(1){ ++i; time=clock(); - data train = load_data_captcha(paths, imgs, plist->size, 10, 200, 60); - translate_data_rows(train, -128); - scale_data_rows(train, 1./128); + pthread_join(load_thread, 0); + train = buffer; + fix_data_captcha(train, solved); + + /* + image im = float_to_image(256, 256, 3, train.X.vals[114]); + show_image(im, "training"); + cvWaitKey(0); + */ + + load_thread = load_data_thread(paths, imgs, plist->size, labels, 26, net.w, net.h, &buffer); printf("Loaded: %lf seconds\n", sec(clock()-time)); time=clock(); float loss = train_network(net, train); @@ -34,7 +78,7 @@ void train_captcha(char *cfgfile, char *weightfile) avg_loss = avg_loss*.9 + loss*.1; printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen); free_data(train); - if(i%10==0){ + if(i%100==0){ char buff[256]; sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i); save_weights(net, buff); @@ -42,37 +86,51 @@ void train_captcha(char *cfgfile, char *weightfile) } } -void decode_captcha(char *cfgfile, char *weightfile) +void test_captcha2(char *cfgfile, char *weightfile, char *filename) { - setbuf(stdout, NULL); - srand(time(0)); network net = parse_network_cfg(cfgfile); - set_batch_network(&net, 1); if(weightfile){ load_weights(&net, weightfile); } - char filename[256]; + set_batch_network(&net, 1); + srand(2222222); + int i = 0; + char **names = get_labels("/data/captcha/reimgs.labels2.list"); + clock_t time; + char input[256]; + int indexes[26]; while(1){ - printf("Enter filename: "); - fgets(filename, 256, stdin); - strtok(filename, "\n"); - image im = load_image_color(filename, 300, 57); - scale_image(im, 1./255.); + if(filename){ + strncpy(input, filename, 256); + }else{ + //printf("Enter Image Path: "); + //fflush(stdout); + fgets(input, 256, stdin); + strtok(input, "\n"); + } + image im = load_image_color(input, net.w, net.h); float *X = im.data; + time=clock(); float *predictions = network_predict(net, X); - image out = float_to_image(300, 57, 1, predictions); - show_image(out, "decoded"); - #ifdef OPENCV - cvWaitKey(0); - #endif + top_predictions(net, 26, indexes); + //printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); + for(i = 0; i < 26; ++i){ + int index = indexes[i]; + if(i != 0) printf(", "); + printf("%s %f", names[index], predictions[index]); + } + printf("\n"); + fflush(stdout); free_image(im); + if (filename) break; } } -void encode_captcha(char *cfgfile, char *weightfile) +void train_captcha(char *cfgfile, char *weightfile) { - float avg_loss = -1; + data_seed = time(0); srand(time(0)); + float avg_loss = -1; char *base = basecfg(cfgfile); printf("%s\n", base); network net = parse_network_cfg(cfgfile); @@ -80,17 +138,31 @@ void encode_captcha(char *cfgfile, char *weightfile) load_weights(&net, weightfile); } printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + //net.seen=0; int imgs = 1024; int i = net.seen/imgs; - list *plist = get_paths("/data/captcha/encode.list"); + char **labels = get_labels("/data/captcha/reimgs.labels.list"); + list *plist = get_paths("/data/captcha/reimgs.train.list"); char **paths = (char **)list_to_array(plist); printf("%d\n", plist->size); clock_t time; + pthread_t load_thread; + data train; + data buffer; + load_thread = load_data_thread(paths, imgs, plist->size, labels, 13, net.w, net.h, &buffer); while(1){ ++i; time=clock(); - data train = load_data_captcha_encode(paths, imgs, plist->size, 300, 57); - scale_data_rows(train, 1./255); + pthread_join(load_thread, 0); + train = buffer; + + /* + image im = float_to_image(256, 256, 3, train.X.vals[114]); + show_image(im, "training"); + cvWaitKey(0); + */ + + load_thread = load_data_thread(paths, imgs, plist->size, labels, 13, net.w, net.h, &buffer); printf("Loaded: %lf seconds\n", sec(clock()-time)); time=clock(); float loss = train_network(net, train); @@ -98,7 +170,7 @@ void encode_captcha(char *cfgfile, char *weightfile) if(avg_loss == -1) avg_loss = loss; avg_loss = avg_loss*.9 + loss*.1; printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen); - free_matrix(train.X); + free_data(train); if(i%100==0){ char buff[256]; sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i); @@ -107,6 +179,153 @@ void encode_captcha(char *cfgfile, char *weightfile) } } +void test_captcha(char *cfgfile, char *weightfile, char *filename) +{ + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + set_batch_network(&net, 1); + srand(2222222); + int i = 0; + char **names = get_labels("/data/captcha/reimgs.labels.list"); + clock_t time; + char input[256]; + int indexes[13]; + while(1){ + if(filename){ + strncpy(input, filename, 256); + }else{ + //printf("Enter Image Path: "); + //fflush(stdout); + fgets(input, 256, stdin); + strtok(input, "\n"); + } + image im = load_image_color(input, net.w, net.h); + float *X = im.data; + time=clock(); + float *predictions = network_predict(net, X); + top_predictions(net, 13, indexes); + //printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); + for(i = 0; i < 13; ++i){ + int index = indexes[i]; + if(i != 0) printf(", "); + printf("%s %f", names[index], predictions[index]); + } + printf("\n"); + fflush(stdout); + free_image(im); + if (filename) break; + } +} + + + +/* + void train_captcha(char *cfgfile, char *weightfile) + { + float avg_loss = -1; + srand(time(0)); + char *base = basecfg(cfgfile); + printf("%s\n", base); + network net = parse_network_cfg(cfgfile); + if(weightfile){ + load_weights(&net, weightfile); + } + printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); + int imgs = 1024; + int i = net.seen/imgs; + list *plist = get_paths("/data/captcha/train.auto5"); + char **paths = (char **)list_to_array(plist); + printf("%d\n", plist->size); + clock_t time; + while(1){ + ++i; + time=clock(); + data train = load_data_captcha(paths, imgs, plist->size, 10, 200, 60); + translate_data_rows(train, -128); + scale_data_rows(train, 1./128); + printf("Loaded: %lf seconds\n", sec(clock()-time)); + time=clock(); + float loss = train_network(net, train); + net.seen += imgs; + if(avg_loss == -1) avg_loss = loss; + avg_loss = avg_loss*.9 + loss*.1; + printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen); + free_data(train); + if(i%10==0){ + char buff[256]; + sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i); + save_weights(net, buff); + } + } + } + + void decode_captcha(char *cfgfile, char *weightfile) + { + setbuf(stdout, NULL); + srand(time(0)); + network net = parse_network_cfg(cfgfile); + set_batch_network(&net, 1); + if(weightfile){ + load_weights(&net, weightfile); + } + char filename[256]; + while(1){ + printf("Enter filename: "); + fgets(filename, 256, stdin); + strtok(filename, "\n"); + image im = load_image_color(filename, 300, 57); + scale_image(im, 1./255.); + float *X = im.data; + float *predictions = network_predict(net, X); + image out = float_to_image(300, 57, 1, predictions); + show_image(out, "decoded"); +#ifdef OPENCV +cvWaitKey(0); +#endif +free_image(im); +} +} + +void encode_captcha(char *cfgfile, char *weightfile) +{ +float avg_loss = -1; +srand(time(0)); +char *base = basecfg(cfgfile); +printf("%s\n", base); +network net = parse_network_cfg(cfgfile); +if(weightfile){ + load_weights(&net, weightfile); +} +printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); +int imgs = 1024; +int i = net.seen/imgs; +list *plist = get_paths("/data/captcha/encode.list"); +char **paths = (char **)list_to_array(plist); +printf("%d\n", plist->size); +clock_t time; +while(1){ + ++i; + time=clock(); + data train = load_data_captcha_encode(paths, imgs, plist->size, 300, 57); + scale_data_rows(train, 1./255); + printf("Loaded: %lf seconds\n", sec(clock()-time)); + time=clock(); + float loss = train_network(net, train); + net.seen += imgs; + if(avg_loss == -1) avg_loss = loss; + avg_loss = avg_loss*.9 + loss*.1; + printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen); + free_matrix(train.X); + if(i%100==0){ + char buff[256]; + sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i); + save_weights(net, buff); + } +} +} + void validate_captcha(char *cfgfile, char *weightfile) { srand(time(0)); @@ -168,6 +387,7 @@ void test_captcha(char *cfgfile, char *weightfile) free_image(im); } } + */ void run_captcha(int argc, char **argv) { if(argc < 4){ @@ -177,10 +397,12 @@ void run_captcha(int argc, char **argv) char *cfg = argv[3]; char *weights = (argc > 4) ? argv[4] : 0; - if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights); - else if(0==strcmp(argv[2], "train")) train_captcha(cfg, weights); - else if(0==strcmp(argv[2], "encode")) encode_captcha(cfg, weights); - else if(0==strcmp(argv[2], "decode")) decode_captcha(cfg, weights); - else if(0==strcmp(argv[2], "valid")) validate_captcha(cfg, weights); + char *filename = (argc > 5) ? argv[5]: 0; + if(0==strcmp(argv[2], "train")) train_captcha2(cfg, weights); + else if(0==strcmp(argv[2], "test")) test_captcha2(cfg, weights, filename); + //if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights); + //else if(0==strcmp(argv[2], "encode")) encode_captcha(cfg, weights); + //else if(0==strcmp(argv[2], "decode")) decode_captcha(cfg, weights); + //else if(0==strcmp(argv[2], "valid")) validate_captcha(cfg, weights); } diff --git a/src/cost_layer.c b/src/cost_layer.c index 24f6ffa3..76aa17e1 100644 --- a/src/cost_layer.c +++ b/src/cost_layer.c @@ -50,7 +50,7 @@ void forward_cost_layer(cost_layer l, network_state state) if(l.cost_type == MASKED){ int i; for(i = 0; i < l.batch*l.inputs; ++i){ - if(state.truth[i] == 0) state.input[i] = 0; + if(state.truth[i] == SECRET_NUM) state.input[i] = SECRET_NUM; } } copy_cpu(l.batch*l.inputs, state.truth, 1, l.delta, 1); @@ -80,7 +80,7 @@ void forward_cost_layer_gpu(cost_layer l, network_state state) { if (!state.truth) return; if (l.cost_type == MASKED) { - mask_ongpu(l.batch*l.inputs, state.input, state.truth); + mask_ongpu(l.batch*l.inputs, state.input, SECRET_NUM, state.truth); } copy_ongpu(l.batch*l.inputs, state.truth, 1, l.delta_gpu, 1); diff --git a/src/data.c b/src/data.c index dafcc98f..f6df50f5 100644 --- a/src/data.c +++ b/src/data.c @@ -332,7 +332,7 @@ void fill_truth(char *path, char **labels, int k, float *truth) ++count; } } - if(count != 1) printf("%d, %s\n", count, path); + if(count != 1) printf("Too many or too few labels: %d, %s\n", count, path); } matrix load_labels_paths(char **paths, int n, char **labels, int k) diff --git a/src/utils.h b/src/utils.h index 5e6c5071..e93cdd01 100644 --- a/src/utils.h +++ b/src/utils.h @@ -4,6 +4,8 @@ #include #include "list.h" +#define SECRET_NUM -1234 + char *basecfg(char *cfgfile); int alphanum_to_int(char c); char int_to_alphanum(int i);