mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Added [net] dynamic_minibatch=1 for increasing mini_batch_size when random=1 is used
This commit is contained in:
@ -218,6 +218,7 @@ struct layer {
|
||||
int batch_normalize;
|
||||
int shortcut;
|
||||
int batch;
|
||||
int dynamic_minibatch;
|
||||
int forced;
|
||||
int flipped;
|
||||
int inputs;
|
||||
@ -640,6 +641,7 @@ typedef struct network {
|
||||
int n;
|
||||
int batch;
|
||||
uint64_t *seen;
|
||||
int *cur_iteration;
|
||||
int *t;
|
||||
float epoch;
|
||||
int subdivisions;
|
||||
@ -739,6 +741,7 @@ typedef struct network {
|
||||
size_t max_delta_gpu_size;
|
||||
//#endif // GPU
|
||||
int optimized_memory;
|
||||
int dynamic_minibatch;
|
||||
size_t workspace_size_limit;
|
||||
} network;
|
||||
|
||||
|
@ -419,7 +419,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
|
||||
//#ifdef CUDNN_HALF
|
||||
//if (state.use_mixed_precision) {
|
||||
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
|
||||
int iteration_num = get_current_iteration(state.net); // (*state.net.seen) / (state.net.batch*state.net.subdivisions);
|
||||
if (state.index != 0 && state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > 3*state.net.burn_in) &&
|
||||
(l.c / l.groups) % 8 == 0 && l.n % 8 == 0 && !state.train && l.groups <= 1 && l.size > 1)
|
||||
{
|
||||
@ -671,7 +671,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
||||
float alpha = 1, beta = 0;
|
||||
|
||||
//#ifdef CUDNN_HALF
|
||||
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
|
||||
int iteration_num = get_current_iteration(state.net); //(*state.net.seen) / (state.net.batch*state.net.subdivisions);
|
||||
if (state.index != 0 && state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > 3*state.net.burn_in) &&
|
||||
(l.c / l.groups) % 8 == 0 && l.n % 8 == 0 && !state.train && l.groups <= 1 && l.size > 1)
|
||||
{
|
||||
@ -978,7 +978,7 @@ void assisted_activation2_gpu(float alpha, float *output, float *gt_gpu, float *
|
||||
|
||||
void assisted_excitation_forward_gpu(convolutional_layer l, network_state state)
|
||||
{
|
||||
const int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
|
||||
const int iteration_num = get_current_iteration(state.net); //(*state.net.seen) / (state.net.batch*state.net.subdivisions);
|
||||
|
||||
// epoch
|
||||
//const float epoch = (float)(*state.net.seen) / state.net.train_images_num;
|
||||
|
@ -786,7 +786,7 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h)
|
||||
|
||||
if (l->activation == SWISH || l->activation == MISH) l->activation_input = (float*)realloc(l->activation_input, total_batch*l->outputs * sizeof(float));
|
||||
#ifdef GPU
|
||||
if (old_w < w || old_h < h) {
|
||||
if (old_w < w || old_h < h || l->dynamic_minibatch) {
|
||||
if (l->train) {
|
||||
cuda_free(l->delta_gpu);
|
||||
l->delta_gpu = cuda_make_array(l->delta, total_batch*l->outputs);
|
||||
|
108
src/detector.c
108
src/detector.c
@ -66,19 +66,19 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
||||
|
||||
srand(time(0));
|
||||
int seed = rand();
|
||||
int i;
|
||||
for (i = 0; i < ngpus; ++i) {
|
||||
int k;
|
||||
for (k = 0; k < ngpus; ++k) {
|
||||
srand(seed);
|
||||
#ifdef GPU
|
||||
cuda_set_device(gpus[i]);
|
||||
cuda_set_device(gpus[k]);
|
||||
#endif
|
||||
nets[i] = parse_network_cfg(cfgfile);
|
||||
nets[i].benchmark_layers = benchmark_layers;
|
||||
nets[k] = parse_network_cfg(cfgfile);
|
||||
nets[k].benchmark_layers = benchmark_layers;
|
||||
if (weightfile) {
|
||||
load_weights(&nets[i], weightfile);
|
||||
load_weights(&nets[k], weightfile);
|
||||
}
|
||||
if (clear) *nets[i].seen = 0;
|
||||
nets[i].learning_rate *= ngpus;
|
||||
if (clear) *nets[k].seen = 0;
|
||||
nets[k].learning_rate *= ngpus;
|
||||
}
|
||||
srand(time(0));
|
||||
network net = nets[0];
|
||||
@ -105,12 +105,13 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
||||
int train_images_num = plist->size;
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
|
||||
int init_w = net.w;
|
||||
int init_h = net.h;
|
||||
const int init_w = net.w;
|
||||
const int init_h = net.h;
|
||||
const int init_b = net.batch;
|
||||
int iter_save, iter_save_last, iter_map;
|
||||
iter_save = get_current_batch(net);
|
||||
iter_save_last = get_current_batch(net);
|
||||
iter_map = get_current_batch(net);
|
||||
iter_save = get_current_iteration(net);
|
||||
iter_save_last = get_current_iteration(net);
|
||||
iter_map = get_current_iteration(net);
|
||||
float mean_average_precision = -1;
|
||||
float best_map = mean_average_precision;
|
||||
|
||||
@ -165,7 +166,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
||||
pthread_t load_thread = load_data(args);
|
||||
int count = 0;
|
||||
//while(i*imgs < N*120){
|
||||
while (get_current_batch(net) < net.max_batches) {
|
||||
while (get_current_iteration(net) < net.max_batches) {
|
||||
if (l.random && count++ % 10 == 0) {
|
||||
float rand_coef = 1.4;
|
||||
if (l.random != 1.0) rand_coef = l.random;
|
||||
@ -175,26 +176,48 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
||||
int dim_h = roundl(random_val*init_h / net.resize_step + 1) * net.resize_step;
|
||||
if (random_val < 1 && (dim_w > init_w || dim_h > init_h)) dim_w = init_w, dim_h = init_h;
|
||||
|
||||
// at the beginning
|
||||
if (avg_loss < 0) {
|
||||
dim_w = roundl(rand_coef*init_w / net.resize_step + 1) * net.resize_step;
|
||||
dim_h = roundl(rand_coef*init_h / net.resize_step + 1) * net.resize_step;
|
||||
int max_dim_w = roundl(rand_coef*init_w / net.resize_step + 1) * net.resize_step;
|
||||
int max_dim_h = roundl(rand_coef*init_h / net.resize_step + 1) * net.resize_step;
|
||||
|
||||
// at the beginning (check if enough memory) and at the end (calc rolling mean/variance)
|
||||
if (avg_loss < 0 || get_current_iteration(net) > net.max_batches - 100) {
|
||||
dim_w = max_dim_w;
|
||||
dim_h = max_dim_h;
|
||||
}
|
||||
|
||||
if (dim_w < net.resize_step) dim_w = net.resize_step;
|
||||
if (dim_h < net.resize_step) dim_h = net.resize_step;
|
||||
int dim_b = (init_b * max_dim_w * max_dim_h) / (dim_w * dim_h);
|
||||
int new_dim_b = (int)(dim_b * 0.8);
|
||||
if (new_dim_b > init_b) dim_b = new_dim_b;
|
||||
|
||||
printf("%d x %d \n", dim_w, dim_h);
|
||||
args.w = dim_w;
|
||||
args.h = dim_h;
|
||||
|
||||
int k;
|
||||
if (net.dynamic_minibatch) {
|
||||
for (k = 0; k < ngpus; ++k) {
|
||||
(*nets[k].seen) = init_b * net.subdivisions * get_current_iteration(net); // remove this line, when you will save to weights-file both: seen & cur_iteration
|
||||
nets[k].batch = dim_b;
|
||||
int j;
|
||||
for (j = 0; j < nets[k].n; ++j)
|
||||
nets[k].layers[j].batch = dim_b;
|
||||
}
|
||||
net.batch = dim_b;
|
||||
imgs = net.batch * net.subdivisions * ngpus;
|
||||
args.n = imgs;
|
||||
printf("\n %d x %d (batch = %d) \n", dim_w, dim_h, net.batch);
|
||||
}
|
||||
else
|
||||
printf("\n %d x %d \n", dim_w, dim_h);
|
||||
|
||||
pthread_join(load_thread, 0);
|
||||
train = buffer;
|
||||
free_data(train);
|
||||
load_thread = load_data(args);
|
||||
|
||||
for (i = 0; i < ngpus; ++i) {
|
||||
resize_network(nets + i, dim_w, dim_h);
|
||||
for (k = 0; k < ngpus; ++k) {
|
||||
resize_network(nets + k, dim_w, dim_h);
|
||||
}
|
||||
net = nets[0];
|
||||
}
|
||||
@ -246,7 +269,8 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
||||
if (avg_loss < 0 || avg_loss != avg_loss) avg_loss = loss; // if(-inf or nan)
|
||||
avg_loss = avg_loss*.9 + loss*.1;
|
||||
|
||||
i = get_current_batch(net);
|
||||
const int iteration = get_current_iteration(net);
|
||||
//i = get_current_batch(net);
|
||||
|
||||
int calc_map_for_each = 4 * train_images_num / (net.batch * net.subdivisions); // calculate mAP for each 4 Epochs
|
||||
calc_map_for_each = fmax(calc_map_for_each, 100);
|
||||
@ -259,22 +283,36 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
||||
}
|
||||
|
||||
if (net.cudnn_half) {
|
||||
if (i < net.burn_in * 3) fprintf(stderr, "\n Tensor Cores are disabled until the first %d iterations are reached.", 3 * net.burn_in);
|
||||
if (iteration < net.burn_in * 3) fprintf(stderr, "\n Tensor Cores are disabled until the first %d iterations are reached.", 3 * net.burn_in);
|
||||
else fprintf(stderr, "\n Tensor Cores are used.");
|
||||
}
|
||||
printf("\n %d: %f, %f avg loss, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), (what_time_is_it_now() - time), i*imgs);
|
||||
printf("\n %d: %f, %f avg loss, %f rate, %lf seconds, %d images\n", iteration, loss, avg_loss, get_current_rate(net), (what_time_is_it_now() - time), iteration*imgs);
|
||||
|
||||
int draw_precision = 0;
|
||||
if (calc_map && (i >= next_map_calc || i == net.max_batches)) {
|
||||
if (calc_map && (iteration >= next_map_calc || iteration == net.max_batches)) {
|
||||
if (l.random) {
|
||||
printf("Resizing to initial size: %d x %d \n", init_w, init_h);
|
||||
printf("Resizing to initial size: %d x %d ", init_w, init_h);
|
||||
args.w = init_w;
|
||||
args.h = init_h;
|
||||
int k;
|
||||
if (net.dynamic_minibatch) {
|
||||
for (k = 0; k < ngpus; ++k) {
|
||||
for (k = 0; k < ngpus; ++k) {
|
||||
nets[k].batch = init_b;
|
||||
int j;
|
||||
for (j = 0; j < nets[k].n; ++j)
|
||||
nets[k].layers[j].batch = init_b;
|
||||
}
|
||||
}
|
||||
net.batch = init_b;
|
||||
imgs = init_b * net.subdivisions * ngpus;
|
||||
args.n = imgs;
|
||||
printf("\n %d x %d (batch = %d) \n", init_w, init_h, init_b);
|
||||
}
|
||||
pthread_join(load_thread, 0);
|
||||
free_data(train);
|
||||
train = buffer;
|
||||
load_thread = load_data(args);
|
||||
int k;
|
||||
for (k = 0; k < ngpus; ++k) {
|
||||
resize_network(nets + k, init_w, init_h);
|
||||
}
|
||||
@ -286,7 +324,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
||||
// combine Training and Validation networks
|
||||
//network net_combined = combine_train_valid_networks(net, net_map);
|
||||
|
||||
iter_map = i;
|
||||
iter_map = iteration;
|
||||
mean_average_precision = validate_detector_map(datacfg, cfgfile, weightfile, 0.25, 0.5, 0, net.letter_box, &net_map);// &net_combined);
|
||||
printf("\n mean_average_precision (mAP@0.5) = %f \n", mean_average_precision);
|
||||
if (mean_average_precision > best_map) {
|
||||
@ -300,23 +338,23 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
||||
draw_precision = 1;
|
||||
}
|
||||
#ifdef OPENCV
|
||||
draw_train_loss(windows_name, img, img_size, avg_loss, max_img_loss, i, net.max_batches, mean_average_precision, draw_precision, "mAP%", dont_show, mjpeg_port);
|
||||
draw_train_loss(windows_name, img, img_size, avg_loss, max_img_loss, iteration, net.max_batches, mean_average_precision, draw_precision, "mAP%", dont_show, mjpeg_port);
|
||||
#endif // OPENCV
|
||||
|
||||
//if (i % 1000 == 0 || (i < 1000 && i % 100 == 0)) {
|
||||
//if (i % 100 == 0) {
|
||||
if (i >= (iter_save + 1000) || i % 1000 == 0) {
|
||||
iter_save = i;
|
||||
if (iteration >= (iter_save + 1000) || iteration % 1000 == 0) {
|
||||
iter_save = iteration;
|
||||
#ifdef GPU
|
||||
if (ngpus != 1) sync_nets(nets, ngpus, 0);
|
||||
#endif
|
||||
char buff[256];
|
||||
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
||||
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, iteration);
|
||||
save_weights(net, buff);
|
||||
}
|
||||
|
||||
if (i >= (iter_save_last + 100) || i % 100 == 0) {
|
||||
iter_save_last = i;
|
||||
if (iteration >= (iter_save_last + 100) || iteration % 100 == 0) {
|
||||
iter_save_last = iteration;
|
||||
#ifdef GPU
|
||||
if (ngpus != 1) sync_nets(nets, ngpus, 0);
|
||||
#endif
|
||||
@ -350,7 +388,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
||||
free_list_contents_kvp(options);
|
||||
free_list(options);
|
||||
|
||||
for (i = 0; i < ngpus; ++i) free_network(nets[i]);
|
||||
for (k = 0; k < ngpus; ++k) free_network(nets[k]);
|
||||
free(nets);
|
||||
//free_network(net);
|
||||
|
||||
|
@ -33,8 +33,10 @@ dropout_layer make_dropout_layer(int batch, int inputs, float probability, int d
|
||||
l.forward_gpu = forward_dropout_layer_gpu;
|
||||
l.backward_gpu = backward_dropout_layer_gpu;
|
||||
l.rand_gpu = cuda_make_array(l.rand, inputs*batch);
|
||||
l.drop_blocks_scale = cuda_make_array_pinned(l.rand, l.batch);
|
||||
l.drop_blocks_scale_gpu = cuda_make_array(l.rand, l.batch);
|
||||
if (l.dropblock) {
|
||||
l.drop_blocks_scale = cuda_make_array_pinned(l.rand, l.batch);
|
||||
l.drop_blocks_scale_gpu = cuda_make_array(l.rand, l.batch);
|
||||
}
|
||||
#endif
|
||||
if (l.dropblock) {
|
||||
if(l.dropblock_size_abs) fprintf(stderr, "dropblock p = %.3f l.dropblock_size_abs = %d %4d -> %4d\n", probability, l.dropblock_size_abs, inputs, inputs);
|
||||
@ -48,11 +50,18 @@ void resize_dropout_layer(dropout_layer *l, int inputs)
|
||||
{
|
||||
l->inputs = l->outputs = inputs;
|
||||
l->rand = (float*)xrealloc(l->rand, l->inputs * l->batch * sizeof(float));
|
||||
#ifdef GPU
|
||||
#ifdef GPU
|
||||
cuda_free(l->rand_gpu);
|
||||
|
||||
l->rand_gpu = cuda_make_array(l->rand, l->inputs*l->batch);
|
||||
#endif
|
||||
|
||||
if (l->dropblock) {
|
||||
cudaFreeHost(l->drop_blocks_scale);
|
||||
l->drop_blocks_scale = cuda_make_array_pinned(l->rand, l->batch);
|
||||
|
||||
cuda_free(l->drop_blocks_scale_gpu);
|
||||
l->drop_blocks_scale_gpu = cuda_make_array(l->rand, l->batch);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void forward_dropout_layer(dropout_layer l, network_state state)
|
||||
|
@ -95,7 +95,7 @@ __global__ void yoloswag420blazeit360noscope(float *input, int size, float *rand
|
||||
void forward_dropout_layer_gpu(dropout_layer l, network_state state)
|
||||
{
|
||||
if (!state.train) return;
|
||||
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
|
||||
int iteration_num = get_current_iteration(state.net); // (*state.net.seen) / (state.net.batch*state.net.subdivisions);
|
||||
//if (iteration_num < state.net.burn_in) return;
|
||||
|
||||
// We gradually increase the block size and the probability of dropout - during the first half of the training
|
||||
@ -141,9 +141,9 @@ void forward_dropout_layer_gpu(dropout_layer l, network_state state)
|
||||
for (int b = 0; b < l.batch; ++b) {
|
||||
const float prob = l.drop_blocks_scale[b] * block_size * block_size / (float)l.outputs;
|
||||
const float scale = 1.0f / (1.0f - prob);
|
||||
printf(" %d x %d - block_size = %d, block_size*block_size = %d , ", l.w, l.h, block_size, block_size*block_size);
|
||||
printf(" , l.drop_blocks_scale[b] = %f, prob = %f, calc scale = %f \t cur_prob = %f, cur_scale = %f \n",
|
||||
l.drop_blocks_scale[b], prob, scale, cur_prob, cur_scale);
|
||||
//printf(" %d x %d - block_size = %d, block_size*block_size = %d , ", l.w, l.h, block_size, block_size*block_size);
|
||||
//printf(" , l.drop_blocks_scale[b] = %f, prob = %f, calc scale = %f \t cur_prob = %f, cur_scale = %f \n",
|
||||
// l.drop_blocks_scale[b], prob, scale, cur_prob, cur_scale);
|
||||
l.drop_blocks_scale[b] = scale;
|
||||
}
|
||||
|
||||
@ -176,14 +176,14 @@ void forward_dropout_layer_gpu(dropout_layer l, network_state state)
|
||||
void backward_dropout_layer_gpu(dropout_layer l, network_state state)
|
||||
{
|
||||
if(!state.delta) return;
|
||||
//int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
|
||||
//int iteration_num = get_current_iteration(state.net); //(*state.net.seen) / (state.net.batch*state.net.subdivisions);
|
||||
//if (iteration_num < state.net.burn_in) return;
|
||||
|
||||
int size = l.inputs*l.batch;
|
||||
|
||||
// dropblock
|
||||
if (l.dropblock) {
|
||||
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
|
||||
int iteration_num = get_current_iteration(state.net); //(*state.net.seen) / (state.net.batch*state.net.subdivisions);
|
||||
float multiplier = 1.0;
|
||||
if (iteration_num < (state.net.max_batches*0.85))
|
||||
multiplier = (iteration_num / (float)(state.net.max_batches*0.85));
|
||||
|
@ -57,6 +57,11 @@ load_args get_base_args(network *net)
|
||||
return args;
|
||||
}
|
||||
|
||||
int64_t get_current_iteration(network net)
|
||||
{
|
||||
return *net.cur_iteration;
|
||||
}
|
||||
|
||||
int get_current_batch(network net)
|
||||
{
|
||||
int batch_num = (*net.seen)/(net.batch*net.subdivisions);
|
||||
@ -240,6 +245,7 @@ network make_network(int n)
|
||||
net.n = n;
|
||||
net.layers = (layer*)xcalloc(net.n, sizeof(layer));
|
||||
net.seen = (uint64_t*)xcalloc(1, sizeof(uint64_t));
|
||||
net.cur_iteration = (int*)xcalloc(1, sizeof(int));
|
||||
#ifdef GPU
|
||||
net.input_gpu = (float**)xcalloc(1, sizeof(float*));
|
||||
net.truth_gpu = (float**)xcalloc(1, sizeof(float*));
|
||||
@ -359,7 +365,7 @@ float train_network_datum(network net, float *x, float *y)
|
||||
forward_network(net, state);
|
||||
backward_network(net, state);
|
||||
float error = get_network_cost(net);
|
||||
if(((*net.seen)/net.batch)%net.subdivisions == 0) update_network(net);
|
||||
//if(((*net.seen)/net.batch)%net.subdivisions == 0) update_network(net);
|
||||
return error;
|
||||
}
|
||||
|
||||
@ -404,6 +410,12 @@ float train_network_waitkey(network net, data d, int wait_key)
|
||||
sum += err;
|
||||
if(wait_key) wait_key_cv(5);
|
||||
}
|
||||
(*net.cur_iteration) += 1;
|
||||
#ifdef GPU
|
||||
update_network_gpu(net);
|
||||
#else // GPU
|
||||
update_network(net);
|
||||
#endif // GPU
|
||||
free(X);
|
||||
free(y);
|
||||
return (float)sum/(n*batch);
|
||||
@ -523,7 +535,7 @@ int resize_network(network *net, int w, int h)
|
||||
//fflush(stderr);
|
||||
for (i = 0; i < net->n; ++i){
|
||||
layer l = net->layers[i];
|
||||
//printf(" %d: layer = %d,", i, l.type);
|
||||
//printf(" (resize %d: layer = %d) , ", i, l.type);
|
||||
if(l.type == CONVOLUTIONAL){
|
||||
resize_convolutional_layer(&l, w, h);
|
||||
}
|
||||
@ -1048,6 +1060,7 @@ void free_network(network net)
|
||||
free(net.scales);
|
||||
free(net.steps);
|
||||
free(net.seen);
|
||||
free(net.cur_iteration);
|
||||
|
||||
#ifdef GPU
|
||||
if (gpu_index >= 0) cuda_free(net.workspace);
|
||||
|
@ -108,6 +108,7 @@ float get_current_seq_subdivisions(network net);
|
||||
int get_sequence_value(network net);
|
||||
float get_current_rate(network net);
|
||||
int get_current_batch(network net);
|
||||
int64_t get_current_iteration(network net);
|
||||
//void free_network(network net); // darknet.h
|
||||
void compare_networks(network n1, network n2, data d);
|
||||
char *get_layer_string(LAYER_TYPE a);
|
||||
|
@ -318,7 +318,7 @@ float train_network_datum_gpu(network net, float *x, float *y)
|
||||
float error = get_network_cost(net);
|
||||
//if (((*net.seen) / net.batch) % net.subdivisions == 0) update_network_gpu(net);
|
||||
const int sequence = get_sequence_value(net);
|
||||
if (((*net.seen) / net.batch) % (net.subdivisions*sequence) == 0) update_network_gpu(net);
|
||||
//if (((*net.seen) / net.batch) % (net.subdivisions*sequence) == 0) update_network_gpu(net);
|
||||
|
||||
return error;
|
||||
}
|
||||
@ -564,7 +564,9 @@ float train_networks(network *nets, int n, data d, int interval)
|
||||
sum += errors[i];
|
||||
}
|
||||
//cudaDeviceSynchronize();
|
||||
if (get_current_batch(nets[0]) % interval == 0) {
|
||||
*nets[0].cur_iteration += (n - 1);
|
||||
if (get_current_iteration(nets[0]) % interval == 0)
|
||||
{
|
||||
printf("Syncing... ");
|
||||
fflush(stdout);
|
||||
sync_nets(nets, n, interval);
|
||||
|
@ -1053,6 +1053,8 @@ void parse_net_options(list *options, network *net)
|
||||
net->batch *= net->time_steps;
|
||||
net->subdivisions = subdivs;
|
||||
|
||||
*net->cur_iteration = 0;
|
||||
net->dynamic_minibatch = option_find_int_quiet(options, "dynamic_minibatch", 0);
|
||||
net->optimized_memory = option_find_int_quiet(options, "optimized_memory", 0);
|
||||
net->workspace_size_limit = (size_t)1024*1024 * option_find_float_quiet(options, "workspace_size_limit_MB", 1024); // 1024 MB by default
|
||||
|
||||
@ -1357,6 +1359,7 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
|
||||
}
|
||||
#endif // GPU
|
||||
|
||||
l.dynamic_minibatch = net.dynamic_minibatch;
|
||||
l.onlyforward = option_find_int_quiet(options, "onlyforward", 0);
|
||||
l.stopbackward = option_find_int_quiet(options, "stopbackward", 0);
|
||||
l.dontload = option_find_int_quiet(options, "dontload", 0);
|
||||
@ -1873,6 +1876,7 @@ void load_weights_upto(network *net, char *filename, int cutoff)
|
||||
fread(&iseen, sizeof(uint32_t), 1, fp);
|
||||
*net->seen = iseen;
|
||||
}
|
||||
*net->cur_iteration = get_current_batch(*net);
|
||||
printf(", trained: %.0f K-images (%.0f Kilo-batches_64) \n", (float)(*net->seen / 1000), (float)(*net->seen / 64000));
|
||||
int transpose = (major > 1000) || (minor > 1000);
|
||||
|
||||
|
@ -67,7 +67,8 @@ void resize_region_layer(layer *l, int w, int h)
|
||||
l->delta = (float*)xrealloc(l->delta, l->batch * l->outputs * sizeof(float));
|
||||
|
||||
#ifdef GPU
|
||||
if (old_w < w || old_h < h) {
|
||||
//if (old_w < w || old_h < h)
|
||||
{
|
||||
cuda_free(l->delta_gpu);
|
||||
cuda_free(l->output_gpu);
|
||||
|
||||
|
Reference in New Issue
Block a user