mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
stable, dropout on gpu
This commit is contained in:
parent
d407bffde9
commit
7c120aef23
@ -308,8 +308,8 @@ void train_asirra()
|
|||||||
void train_imagenet()
|
void train_imagenet()
|
||||||
{
|
{
|
||||||
float avg_loss = 1;
|
float avg_loss = 1;
|
||||||
network net = parse_network_cfg("/home/pjreddie/imagenet_backup/imagenet_2280.cfg");
|
//network net = parse_network_cfg("/home/pjreddie/imagenet_backup/alexnet_1270.cfg");
|
||||||
//network net = parse_network_cfg("cfg/imagenet2.cfg");
|
network net = parse_network_cfg("cfg/imagenet.cfg");
|
||||||
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
||||||
int imgs = 1000/net.batch+1;
|
int imgs = 1000/net.batch+1;
|
||||||
srand(time(0));
|
srand(time(0));
|
||||||
@ -1042,6 +1042,7 @@ int main(int argc, char *argv[])
|
|||||||
#ifdef GPU
|
#ifdef GPU
|
||||||
else if(0==strcmp(argv[1], "test_gpu")) test_gpu_blas();
|
else if(0==strcmp(argv[1], "test_gpu")) test_gpu_blas();
|
||||||
#endif
|
#endif
|
||||||
|
test_parser();
|
||||||
fprintf(stderr, "Success!\n");
|
fprintf(stderr, "Success!\n");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
5
src/dropout_layer.cl
Normal file
5
src/dropout_layer.cl
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
__kernel void forward(__global float *input, __global float *rand, float prob)
|
||||||
|
{
|
||||||
|
int id = get_global_id(0);
|
||||||
|
input[id] = (rand[id] < prob) ? 0 : input[id]/(1.-prob);
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user