mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Stable on MNIST, about to change a lot
This commit is contained in:
parent
ad9dbfe164
commit
1d53b6414e
4
full.cfg
4
full.cfg
@ -10,10 +10,6 @@ activation=ramp
|
|||||||
[maxpool]
|
[maxpool]
|
||||||
stride=2
|
stride=2
|
||||||
|
|
||||||
[conn]
|
|
||||||
output = 100
|
|
||||||
activation=ramp
|
|
||||||
|
|
||||||
[conn]
|
[conn]
|
||||||
output = 2
|
output = 2
|
||||||
activation=ramp
|
activation=ramp
|
||||||
|
4
nist.cfg
4
nist.cfg
@ -2,7 +2,7 @@
|
|||||||
width=28
|
width=28
|
||||||
height=28
|
height=28
|
||||||
channels=1
|
channels=1
|
||||||
filters=5
|
filters=20
|
||||||
size=5
|
size=5
|
||||||
stride=1
|
stride=1
|
||||||
activation=ramp
|
activation=ramp
|
||||||
@ -20,7 +20,7 @@ activation=ramp
|
|||||||
stride=2
|
stride=2
|
||||||
|
|
||||||
[conn]
|
[conn]
|
||||||
output = 100
|
output = 500
|
||||||
activation=ramp
|
activation=ramp
|
||||||
|
|
||||||
[conn]
|
[conn]
|
||||||
|
@ -187,6 +187,24 @@ double train_network_sgd(network net, data d, int n, double step, double momentu
|
|||||||
}
|
}
|
||||||
return (double)correct/n;
|
return (double)correct/n;
|
||||||
}
|
}
|
||||||
|
double train_network_batch(network net, data d, int n, double step, double momentum,double decay)
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
int correct = 0;
|
||||||
|
for(i = 0; i < n; ++i){
|
||||||
|
int index = rand()%d.X.rows;
|
||||||
|
double *x = d.X.vals[index];
|
||||||
|
double *y = d.y.vals[index];
|
||||||
|
forward_network(net, x);
|
||||||
|
int class = get_predicted_class_network(net);
|
||||||
|
backward_network(net, x, y);
|
||||||
|
correct += (y[class]?1:0);
|
||||||
|
}
|
||||||
|
update_network(net, step, momentum, decay);
|
||||||
|
return (double)correct/n;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void train_network(network net, data d, double step, double momentum, double decay)
|
void train_network(network net, data d, double step, double momentum, double decay)
|
||||||
{
|
{
|
||||||
|
@ -25,6 +25,7 @@ void forward_network(network net, double *input);
|
|||||||
void backward_network(network net, double *input, double *truth);
|
void backward_network(network net, double *input, double *truth);
|
||||||
void update_network(network net, double step, double momentum, double decay);
|
void update_network(network net, double step, double momentum, double decay);
|
||||||
double train_network_sgd(network net, data d, int n, double step, double momentum,double decay);
|
double train_network_sgd(network net, data d, int n, double step, double momentum,double decay);
|
||||||
|
double train_network_batch(network net, data d, int n, double step, double momentum,double decay);
|
||||||
void train_network(network net, data d, double step, double momentum, double decay);
|
void train_network(network net, data d, double step, double momentum, double decay);
|
||||||
matrix network_predict_data(network net, data test);
|
matrix network_predict_data(network net, data test);
|
||||||
double network_accuracy(network net, data d);
|
double network_accuracy(network net, data d);
|
||||||
|
17
src/tests.c
17
src/tests.c
@ -184,9 +184,12 @@ void test_full()
|
|||||||
srand(0);
|
srand(0);
|
||||||
int i = 0;
|
int i = 0;
|
||||||
char *labels[] = {"cat","dog"};
|
char *labels[] = {"cat","dog"};
|
||||||
|
double lr = .00001;
|
||||||
|
double momentum = .9;
|
||||||
|
double decay = 0.01;
|
||||||
while(i++ < 1000 || 1){
|
while(i++ < 1000 || 1){
|
||||||
data train = load_data_image_pathfile_random("train_paths.txt", 1000, labels, 2);
|
data train = load_data_image_pathfile_random("train_paths.txt", 1000, labels, 2);
|
||||||
train_network(net, train, .0005, 0, 0);
|
train_network(net, train, lr, momentum, decay);
|
||||||
free_data(train);
|
free_data(train);
|
||||||
printf("Round %d\n", i);
|
printf("Round %d\n", i);
|
||||||
}
|
}
|
||||||
@ -206,9 +209,13 @@ void test_nist()
|
|||||||
double lr = .0005;
|
double lr = .0005;
|
||||||
double momentum = .9;
|
double momentum = .9;
|
||||||
double decay = 0.01;
|
double decay = 0.01;
|
||||||
|
clock_t start = clock(), end;
|
||||||
while(++count <= 1000){
|
while(++count <= 1000){
|
||||||
double acc = train_network_sgd(net, train, 1000, lr, momentum, decay);
|
double acc = train_network_sgd(net, train, 6400, lr, momentum, decay);
|
||||||
printf("Training Accuracy: %lf, Params: %f %f %f\n", acc, lr, momentum, decay);
|
printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*100, 1.-acc, lr, momentum, decay);
|
||||||
|
end = clock();
|
||||||
|
printf("Time: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
|
||||||
|
start=end;
|
||||||
visualize_network(net);
|
visualize_network(net);
|
||||||
cvWaitKey(100);
|
cvWaitKey(100);
|
||||||
//lr /= 2;
|
//lr /= 2;
|
||||||
@ -334,8 +341,8 @@ int main()
|
|||||||
{
|
{
|
||||||
//test_kernel_update();
|
//test_kernel_update();
|
||||||
//test_split();
|
//test_split();
|
||||||
test_ensemble();
|
//test_ensemble();
|
||||||
//test_nist();
|
test_nist();
|
||||||
//test_full();
|
//test_full();
|
||||||
//test_random_preprocess();
|
//test_random_preprocess();
|
||||||
//test_random_classify();
|
//test_random_classify();
|
||||||
|
Loading…
Reference in New Issue
Block a user