mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Stable on MNIST, about to change a lot
This commit is contained in:
parent
ad9dbfe164
commit
1d53b6414e
4
full.cfg
4
full.cfg
@ -10,10 +10,6 @@ activation=ramp
|
||||
[maxpool]
|
||||
stride=2
|
||||
|
||||
[conn]
|
||||
output = 100
|
||||
activation=ramp
|
||||
|
||||
[conn]
|
||||
output = 2
|
||||
activation=ramp
|
||||
|
4
nist.cfg
4
nist.cfg
@ -2,7 +2,7 @@
|
||||
width=28
|
||||
height=28
|
||||
channels=1
|
||||
filters=5
|
||||
filters=20
|
||||
size=5
|
||||
stride=1
|
||||
activation=ramp
|
||||
@ -20,7 +20,7 @@ activation=ramp
|
||||
stride=2
|
||||
|
||||
[conn]
|
||||
output = 100
|
||||
output = 500
|
||||
activation=ramp
|
||||
|
||||
[conn]
|
||||
|
@ -187,6 +187,24 @@ double train_network_sgd(network net, data d, int n, double step, double momentu
|
||||
}
|
||||
return (double)correct/n;
|
||||
}
|
||||
double train_network_batch(network net, data d, int n, double step, double momentum,double decay)
|
||||
{
|
||||
int i;
|
||||
int correct = 0;
|
||||
for(i = 0; i < n; ++i){
|
||||
int index = rand()%d.X.rows;
|
||||
double *x = d.X.vals[index];
|
||||
double *y = d.y.vals[index];
|
||||
forward_network(net, x);
|
||||
int class = get_predicted_class_network(net);
|
||||
backward_network(net, x, y);
|
||||
correct += (y[class]?1:0);
|
||||
}
|
||||
update_network(net, step, momentum, decay);
|
||||
return (double)correct/n;
|
||||
|
||||
}
|
||||
|
||||
|
||||
void train_network(network net, data d, double step, double momentum, double decay)
|
||||
{
|
||||
|
@ -25,6 +25,7 @@ void forward_network(network net, double *input);
|
||||
void backward_network(network net, double *input, double *truth);
|
||||
void update_network(network net, double step, double momentum, double decay);
|
||||
double train_network_sgd(network net, data d, int n, double step, double momentum,double decay);
|
||||
double train_network_batch(network net, data d, int n, double step, double momentum,double decay);
|
||||
void train_network(network net, data d, double step, double momentum, double decay);
|
||||
matrix network_predict_data(network net, data test);
|
||||
double network_accuracy(network net, data d);
|
||||
|
17
src/tests.c
17
src/tests.c
@ -184,9 +184,12 @@ void test_full()
|
||||
srand(0);
|
||||
int i = 0;
|
||||
char *labels[] = {"cat","dog"};
|
||||
double lr = .00001;
|
||||
double momentum = .9;
|
||||
double decay = 0.01;
|
||||
while(i++ < 1000 || 1){
|
||||
data train = load_data_image_pathfile_random("train_paths.txt", 1000, labels, 2);
|
||||
train_network(net, train, .0005, 0, 0);
|
||||
train_network(net, train, lr, momentum, decay);
|
||||
free_data(train);
|
||||
printf("Round %d\n", i);
|
||||
}
|
||||
@ -206,9 +209,13 @@ void test_nist()
|
||||
double lr = .0005;
|
||||
double momentum = .9;
|
||||
double decay = 0.01;
|
||||
clock_t start = clock(), end;
|
||||
while(++count <= 1000){
|
||||
double acc = train_network_sgd(net, train, 1000, lr, momentum, decay);
|
||||
printf("Training Accuracy: %lf, Params: %f %f %f\n", acc, lr, momentum, decay);
|
||||
double acc = train_network_sgd(net, train, 6400, lr, momentum, decay);
|
||||
printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*100, 1.-acc, lr, momentum, decay);
|
||||
end = clock();
|
||||
printf("Time: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
|
||||
start=end;
|
||||
visualize_network(net);
|
||||
cvWaitKey(100);
|
||||
//lr /= 2;
|
||||
@ -334,8 +341,8 @@ int main()
|
||||
{
|
||||
//test_kernel_update();
|
||||
//test_split();
|
||||
test_ensemble();
|
||||
//test_nist();
|
||||
//test_ensemble();
|
||||
test_nist();
|
||||
//test_full();
|
||||
//test_random_preprocess();
|
||||
//test_random_classify();
|
||||
|
Loading…
Reference in New Issue
Block a user