diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index 26dee46a..6d2ce905 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -117,6 +117,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) } if(l.xnor){ + if (!l.align_bit_weights_gpu || state.train) { binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu); } @@ -128,6 +129,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) if (l.align_bit_weights_gpu && !state.train) { cudaError_t status = cudaSuccess; + int input_size = l.c*l.h*l.w*l.batch; int m = l.n; int k = l.size*l.size*l.c; @@ -139,6 +141,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) size_t t_intput_size = new_ldb * n; size_t t_bit_input_size = t_intput_size / 8;// +1; + int i = 0; im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.align_workspace_gpu, l.bit_align); //cudaDeviceSynchronize(); @@ -152,13 +155,34 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) transpose_bin_gpu((unsigned char *)state.workspace, (unsigned char *)l.transposed_align_workspace_gpu, k, n, l.bit_align, new_ldb, 8); //cudaDeviceSynchronize(); - // should be optimized gemm_nn_custom_bin_mean_transposed_gpu(m, n, k, (unsigned char *)l.align_bit_weights_gpu, new_ldb, (unsigned char *)l.transposed_align_workspace_gpu, new_ldb, l.output_gpu, n, l.mean_arr_gpu); //cudaDeviceSynchronize(); //check_error(status); + + { + //float_to_bit_gpu(state.input, (unsigned char *)l.align_workspace_gpu, input_size); + + /* + float *input_cpu = (float *)calloc(input_size, sizeof(float)); + status = cudaMemcpy(input_cpu, state.input, input_size* sizeof(float), cudaMemcpyDeviceToHost); + check_error(status); + + convolve_bin_cpu(input_cpu, l.weights, l.output, l.w, l.h, l.c, l.n, l.size, l.pad); // CPU + status = cudaMemcpy(l.output_gpu, l.output, l.outputs * sizeof(float), cudaMemcpyHostToDevice); + check_error(status); + free(input_cpu); + */ + + //convolve_bin_gpu(state.input, l.weights_gpu, l.output_gpu, l.w, l.h, l.c, l.n, l.size, l.pad); + //cudaDeviceSynchronize(); + //check_error(status); + + + } + add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h); activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); if (l.binary || l.xnor) swap_binary(&l); diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 5efcfd68..452f792e 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -637,6 +637,8 @@ void binary_align_weights(convolutional_layer *l) check_error(status); status = cudaMemcpy(l->align_bit_weights_gpu, l->align_bit_weights, l->align_bit_weights_size, cudaMemcpyHostToDevice); check_error(status); + status = cudaMemcpy(l->binary_weights_gpu, l->binary_weights, m*k*sizeof(float), cudaMemcpyHostToDevice); + check_error(status); l->mean_arr_gpu = cuda_make_array(l->mean_arr, l->n); cudaDeviceSynchronize(); diff --git a/src/detector.c b/src/detector.c index 0c1c3986..dc312345 100644 --- a/src/detector.c +++ b/src/detector.c @@ -1026,19 +1026,24 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int char buff[1024]; FILE* fw = fopen("anchors.txt", "wb"); - printf("\nSaving anchors to the file: anchors.txt \n"); - printf("anchors = "); - for (i = 0; i < num_of_clusters; ++i) { - sprintf(buff, "%2.4f,%2.4f", centers->data.fl[i * 2], centers->data.fl[i * 2 + 1]); - printf("%s", buff); - fwrite(buff, sizeof(char), strlen(buff), fw); - if (i + 1 < num_of_clusters) { - fwrite(", ", sizeof(char), 2, fw); - printf(", "); + if (fw) { + printf("\nSaving anchors to the file: anchors.txt \n"); + printf("anchors = "); + for (i = 0; i < num_of_clusters; ++i) { + sprintf(buff, "%2.4f,%2.4f", centers->data.fl[i * 2], centers->data.fl[i * 2 + 1]); + printf("%s", buff); + fwrite(buff, sizeof(char), strlen(buff), fw); + if (i + 1 < num_of_clusters) { + fwrite(", ", sizeof(char), 2, fw); + printf(", "); + } } + printf("\n"); + fclose(fw); + } + else { + printf(" Error: file anchors.txt can't be open \n"); } - printf("\n"); - fclose(fw); if (show) { size_t img_size = 700; diff --git a/src/im2col.h b/src/im2col.h index c96bfea9..b518e803 100644 --- a/src/im2col.h +++ b/src/im2col.h @@ -1,6 +1,8 @@ #ifndef IM2COL_H #define IM2COL_H +#include + void im2col_cpu(float* data_im, int channels, int height, int width, int ksize, int stride, int pad, float* data_col); @@ -27,5 +29,9 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K, unsigned char *B, int ldb, float *C, int ldc, float *mean_arr); +void convolve_bin_gpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad); + +void convolve_bin_cpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad); + #endif #endif diff --git a/src/im2col_kernels.cu b/src/im2col_kernels.cu index 91fedddf..e545260d 100644 --- a/src/im2col_kernels.cu +++ b/src/im2col_kernels.cu @@ -460,4 +460,119 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K, B, ldb, C, ldc, mean_arr); +} + +// -------------------------------- + + +__global__ void convolve_bin_gpu_kernel(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + + int fil; + // filter index + //for (fil = 0; fil < n; ++fil) + int chan, y, x, f_y, f_x; + // channel index + //for (chan = 0; chan < in_c; ++chan) + // input - y + //for (y = 0; y < in_h; ++y) + // input - x + //for (x = 0; x < in_w; ++x) + x = index % in_w; + int index2 = index / in_w; + y = index2 % in_h; + fil = index2 / in_h; + if (fil < n) + { + + int const output_index = fil*in_w*in_h + y*in_w + x; + float sum = 0; + + for (chan = 0; chan < in_c; ++chan) + { + int const weights_pre_index = fil*in_c*size*size + chan*size*size; + int const input_pre_index = chan*in_w*in_h; + + // filter - y + for (f_y = 0; f_y < size; ++f_y) + { + int input_y = y + f_y - pad; + // filter - x + for (f_x = 0; f_x < size; ++f_x) + { + int input_x = x + f_x - pad; + if (input_y < 0 || input_x < 0 || input_y >= in_h || input_x >= in_w) continue; + + int input_index = input_pre_index + input_y*in_w + input_x; + int weights_index = weights_pre_index + f_y*size + f_x; + + sum += input[input_index] *weights[weights_index]; + + } + } + // l.output[filters][width][height] += + // state.input[channels][width][height] * + // l.weights[filters][channels][filter_width][filter_height]; + //output[output_index] += sum; + } + output[output_index] = sum; + } + +} + +void convolve_bin_gpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad) +{ + size_t array_size = in_w*in_h*n; // width X height X filters + const int num_blocks = array_size / BLOCK + 1; + //printf("\n array_size = %d, num_blocks = %d, w = %d, h = %d, n = %d, c = %d, pad = %d \n", array_size, num_blocks, in_w, in_h, n, in_c, pad); + + convolve_bin_gpu_kernel << > > (input, weights, output, in_w, in_h, in_c, n, size, pad); +} + + + +void convolve_bin_cpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad) +{ + int fil; + // filter index +#pragma omp parallel for // "omp parallel for" - automatic parallelization of loop by using OpenMP + for (fil = 0; fil < n; ++fil) { + int chan, y, x, f_y, f_x; + // channel index + for (chan = 0; chan < in_c; ++chan) + // input - y + for (y = 0; y < in_h; ++y) + // input - x + for (x = 0; x < in_w; ++x) + { + int const output_index = fil*in_w*in_h + y*in_w + x; + int const weights_pre_index = fil*in_c*size*size + chan*size*size; + int const input_pre_index = chan*in_w*in_h; + float sum = 0; + + // filter - y + for (f_y = 0; f_y < size; ++f_y) + { + int input_y = y + f_y - pad; + // filter - x + for (f_x = 0; f_x < size; ++f_x) + { + int input_x = x + f_x - pad; + if (input_y < 0 || input_x < 0 || input_y >= in_h || input_x >= in_w) continue; + + int input_index = input_pre_index + input_y*in_w + input_x; + int weights_index = weights_pre_index + f_y*size + f_x; + + sum += input[input_index] * weights[weights_index]; + } + } + // l.output[filters][width][height] += + // state.input[channels][width][height] * + // l.weights[filters][channels][filter_width][filter_height]; + output[output_index] += sum; + } + } + + } \ No newline at end of file