From 1d53b6414e0cd81043d7c76aa89f4f97da5e479f Mon Sep 17 00:00:00 2001 From: Joseph Redmon Date: Thu, 23 Jan 2014 11:24:37 -0800 Subject: [PATCH] Stable on MNIST, about to change a lot --- full.cfg | 4 ---- nist.cfg | 4 ++-- src/network.c | 18 ++++++++++++++++++ src/network.h | 1 + src/tests.c | 17 ++++++++++++----- 5 files changed, 33 insertions(+), 11 deletions(-) diff --git a/full.cfg b/full.cfg index a18da176..78e938fb 100644 --- a/full.cfg +++ b/full.cfg @@ -10,10 +10,6 @@ activation=ramp [maxpool] stride=2 -[conn] -output = 100 -activation=ramp - [conn] output = 2 activation=ramp diff --git a/nist.cfg b/nist.cfg index 5b0541cf..46e32233 100644 --- a/nist.cfg +++ b/nist.cfg @@ -2,7 +2,7 @@ width=28 height=28 channels=1 -filters=5 +filters=20 size=5 stride=1 activation=ramp @@ -20,7 +20,7 @@ activation=ramp stride=2 [conn] -output = 100 +output = 500 activation=ramp [conn] diff --git a/src/network.c b/src/network.c index 10ad110e..07ac6213 100644 --- a/src/network.c +++ b/src/network.c @@ -187,6 +187,24 @@ double train_network_sgd(network net, data d, int n, double step, double momentu } return (double)correct/n; } +double train_network_batch(network net, data d, int n, double step, double momentum,double decay) +{ + int i; + int correct = 0; + for(i = 0; i < n; ++i){ + int index = rand()%d.X.rows; + double *x = d.X.vals[index]; + double *y = d.y.vals[index]; + forward_network(net, x); + int class = get_predicted_class_network(net); + backward_network(net, x, y); + correct += (y[class]?1:0); + } + update_network(net, step, momentum, decay); + return (double)correct/n; + +} + void train_network(network net, data d, double step, double momentum, double decay) { diff --git a/src/network.h b/src/network.h index 2ffc76bf..975c3ddc 100644 --- a/src/network.h +++ b/src/network.h @@ -25,6 +25,7 @@ void forward_network(network net, double *input); void backward_network(network net, double *input, double *truth); void update_network(network net, double step, double momentum, double decay); double train_network_sgd(network net, data d, int n, double step, double momentum,double decay); +double train_network_batch(network net, data d, int n, double step, double momentum,double decay); void train_network(network net, data d, double step, double momentum, double decay); matrix network_predict_data(network net, data test); double network_accuracy(network net, data d); diff --git a/src/tests.c b/src/tests.c index 4638645e..2a50bacf 100644 --- a/src/tests.c +++ b/src/tests.c @@ -184,9 +184,12 @@ void test_full() srand(0); int i = 0; char *labels[] = {"cat","dog"}; + double lr = .00001; + double momentum = .9; + double decay = 0.01; while(i++ < 1000 || 1){ data train = load_data_image_pathfile_random("train_paths.txt", 1000, labels, 2); - train_network(net, train, .0005, 0, 0); + train_network(net, train, lr, momentum, decay); free_data(train); printf("Round %d\n", i); } @@ -206,9 +209,13 @@ void test_nist() double lr = .0005; double momentum = .9; double decay = 0.01; + clock_t start = clock(), end; while(++count <= 1000){ - double acc = train_network_sgd(net, train, 1000, lr, momentum, decay); - printf("Training Accuracy: %lf, Params: %f %f %f\n", acc, lr, momentum, decay); + double acc = train_network_sgd(net, train, 6400, lr, momentum, decay); + printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*100, 1.-acc, lr, momentum, decay); + end = clock(); + printf("Time: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC); + start=end; visualize_network(net); cvWaitKey(100); //lr /= 2; @@ -334,8 +341,8 @@ int main() { //test_kernel_update(); //test_split(); - test_ensemble(); - //test_nist(); + //test_ensemble(); + test_nist(); //test_full(); //test_random_preprocess(); //test_random_classify();