mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Compile fix
This commit is contained in:
@ -117,6 +117,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if(l.xnor){
|
if(l.xnor){
|
||||||
|
|
||||||
if (!l.align_bit_weights_gpu || state.train) {
|
if (!l.align_bit_weights_gpu || state.train) {
|
||||||
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
|
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
|
||||||
}
|
}
|
||||||
@ -128,6 +129,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
|||||||
if (l.align_bit_weights_gpu && !state.train)
|
if (l.align_bit_weights_gpu && !state.train)
|
||||||
{
|
{
|
||||||
cudaError_t status = cudaSuccess;
|
cudaError_t status = cudaSuccess;
|
||||||
|
int input_size = l.c*l.h*l.w*l.batch;
|
||||||
|
|
||||||
int m = l.n;
|
int m = l.n;
|
||||||
int k = l.size*l.size*l.c;
|
int k = l.size*l.size*l.c;
|
||||||
@ -139,6 +141,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
|||||||
size_t t_intput_size = new_ldb * n;
|
size_t t_intput_size = new_ldb * n;
|
||||||
size_t t_bit_input_size = t_intput_size / 8;// +1;
|
size_t t_bit_input_size = t_intput_size / 8;// +1;
|
||||||
|
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.align_workspace_gpu, l.bit_align);
|
im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.align_workspace_gpu, l.bit_align);
|
||||||
//cudaDeviceSynchronize();
|
//cudaDeviceSynchronize();
|
||||||
@ -152,13 +155,34 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
|||||||
transpose_bin_gpu((unsigned char *)state.workspace, (unsigned char *)l.transposed_align_workspace_gpu, k, n, l.bit_align, new_ldb, 8);
|
transpose_bin_gpu((unsigned char *)state.workspace, (unsigned char *)l.transposed_align_workspace_gpu, k, n, l.bit_align, new_ldb, 8);
|
||||||
//cudaDeviceSynchronize();
|
//cudaDeviceSynchronize();
|
||||||
|
|
||||||
|
|
||||||
// should be optimized
|
// should be optimized
|
||||||
gemm_nn_custom_bin_mean_transposed_gpu(m, n, k,
|
gemm_nn_custom_bin_mean_transposed_gpu(m, n, k,
|
||||||
(unsigned char *)l.align_bit_weights_gpu, new_ldb, (unsigned char *)l.transposed_align_workspace_gpu, new_ldb, l.output_gpu, n, l.mean_arr_gpu);
|
(unsigned char *)l.align_bit_weights_gpu, new_ldb, (unsigned char *)l.transposed_align_workspace_gpu, new_ldb, l.output_gpu, n, l.mean_arr_gpu);
|
||||||
//cudaDeviceSynchronize();
|
//cudaDeviceSynchronize();
|
||||||
//check_error(status);
|
//check_error(status);
|
||||||
|
|
||||||
|
|
||||||
|
{
|
||||||
|
//float_to_bit_gpu(state.input, (unsigned char *)l.align_workspace_gpu, input_size);
|
||||||
|
|
||||||
|
/*
|
||||||
|
float *input_cpu = (float *)calloc(input_size, sizeof(float));
|
||||||
|
status = cudaMemcpy(input_cpu, state.input, input_size* sizeof(float), cudaMemcpyDeviceToHost);
|
||||||
|
check_error(status);
|
||||||
|
|
||||||
|
convolve_bin_cpu(input_cpu, l.weights, l.output, l.w, l.h, l.c, l.n, l.size, l.pad); // CPU
|
||||||
|
status = cudaMemcpy(l.output_gpu, l.output, l.outputs * sizeof(float), cudaMemcpyHostToDevice);
|
||||||
|
check_error(status);
|
||||||
|
free(input_cpu);
|
||||||
|
*/
|
||||||
|
|
||||||
|
//convolve_bin_gpu(state.input, l.weights_gpu, l.output_gpu, l.w, l.h, l.c, l.n, l.size, l.pad);
|
||||||
|
//cudaDeviceSynchronize();
|
||||||
|
//check_error(status);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||||
activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
||||||
if (l.binary || l.xnor) swap_binary(&l);
|
if (l.binary || l.xnor) swap_binary(&l);
|
||||||
|
@ -637,6 +637,8 @@ void binary_align_weights(convolutional_layer *l)
|
|||||||
check_error(status);
|
check_error(status);
|
||||||
status = cudaMemcpy(l->align_bit_weights_gpu, l->align_bit_weights, l->align_bit_weights_size, cudaMemcpyHostToDevice);
|
status = cudaMemcpy(l->align_bit_weights_gpu, l->align_bit_weights, l->align_bit_weights_size, cudaMemcpyHostToDevice);
|
||||||
check_error(status);
|
check_error(status);
|
||||||
|
status = cudaMemcpy(l->binary_weights_gpu, l->binary_weights, m*k*sizeof(float), cudaMemcpyHostToDevice);
|
||||||
|
check_error(status);
|
||||||
|
|
||||||
l->mean_arr_gpu = cuda_make_array(l->mean_arr, l->n);
|
l->mean_arr_gpu = cuda_make_array(l->mean_arr, l->n);
|
||||||
cudaDeviceSynchronize();
|
cudaDeviceSynchronize();
|
||||||
|
@ -1026,6 +1026,7 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int
|
|||||||
|
|
||||||
char buff[1024];
|
char buff[1024];
|
||||||
FILE* fw = fopen("anchors.txt", "wb");
|
FILE* fw = fopen("anchors.txt", "wb");
|
||||||
|
if (fw) {
|
||||||
printf("\nSaving anchors to the file: anchors.txt \n");
|
printf("\nSaving anchors to the file: anchors.txt \n");
|
||||||
printf("anchors = ");
|
printf("anchors = ");
|
||||||
for (i = 0; i < num_of_clusters; ++i) {
|
for (i = 0; i < num_of_clusters; ++i) {
|
||||||
@ -1039,6 +1040,10 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int
|
|||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
fclose(fw);
|
fclose(fw);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
printf(" Error: file anchors.txt can't be open \n");
|
||||||
|
}
|
||||||
|
|
||||||
if (show) {
|
if (show) {
|
||||||
size_t img_size = 700;
|
size_t img_size = 700;
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
#ifndef IM2COL_H
|
#ifndef IM2COL_H
|
||||||
#define IM2COL_H
|
#define IM2COL_H
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
void im2col_cpu(float* data_im,
|
void im2col_cpu(float* data_im,
|
||||||
int channels, int height, int width,
|
int channels, int height, int width,
|
||||||
int ksize, int stride, int pad, float* data_col);
|
int ksize, int stride, int pad, float* data_col);
|
||||||
@ -27,5 +29,9 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K,
|
|||||||
unsigned char *B, int ldb,
|
unsigned char *B, int ldb,
|
||||||
float *C, int ldc, float *mean_arr);
|
float *C, int ldc, float *mean_arr);
|
||||||
|
|
||||||
|
void convolve_bin_gpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad);
|
||||||
|
|
||||||
|
void convolve_bin_cpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
@ -461,3 +461,118 @@ void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K,
|
|||||||
C, ldc,
|
C, ldc,
|
||||||
mean_arr);
|
mean_arr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void convolve_bin_gpu_kernel(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad)
|
||||||
|
{
|
||||||
|
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
int fil;
|
||||||
|
// filter index
|
||||||
|
//for (fil = 0; fil < n; ++fil)
|
||||||
|
int chan, y, x, f_y, f_x;
|
||||||
|
// channel index
|
||||||
|
//for (chan = 0; chan < in_c; ++chan)
|
||||||
|
// input - y
|
||||||
|
//for (y = 0; y < in_h; ++y)
|
||||||
|
// input - x
|
||||||
|
//for (x = 0; x < in_w; ++x)
|
||||||
|
x = index % in_w;
|
||||||
|
int index2 = index / in_w;
|
||||||
|
y = index2 % in_h;
|
||||||
|
fil = index2 / in_h;
|
||||||
|
if (fil < n)
|
||||||
|
{
|
||||||
|
|
||||||
|
int const output_index = fil*in_w*in_h + y*in_w + x;
|
||||||
|
float sum = 0;
|
||||||
|
|
||||||
|
for (chan = 0; chan < in_c; ++chan)
|
||||||
|
{
|
||||||
|
int const weights_pre_index = fil*in_c*size*size + chan*size*size;
|
||||||
|
int const input_pre_index = chan*in_w*in_h;
|
||||||
|
|
||||||
|
// filter - y
|
||||||
|
for (f_y = 0; f_y < size; ++f_y)
|
||||||
|
{
|
||||||
|
int input_y = y + f_y - pad;
|
||||||
|
// filter - x
|
||||||
|
for (f_x = 0; f_x < size; ++f_x)
|
||||||
|
{
|
||||||
|
int input_x = x + f_x - pad;
|
||||||
|
if (input_y < 0 || input_x < 0 || input_y >= in_h || input_x >= in_w) continue;
|
||||||
|
|
||||||
|
int input_index = input_pre_index + input_y*in_w + input_x;
|
||||||
|
int weights_index = weights_pre_index + f_y*size + f_x;
|
||||||
|
|
||||||
|
sum += input[input_index] *weights[weights_index];
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// l.output[filters][width][height] +=
|
||||||
|
// state.input[channels][width][height] *
|
||||||
|
// l.weights[filters][channels][filter_width][filter_height];
|
||||||
|
//output[output_index] += sum;
|
||||||
|
}
|
||||||
|
output[output_index] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void convolve_bin_gpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad)
|
||||||
|
{
|
||||||
|
size_t array_size = in_w*in_h*n; // width X height X filters
|
||||||
|
const int num_blocks = array_size / BLOCK + 1;
|
||||||
|
//printf("\n array_size = %d, num_blocks = %d, w = %d, h = %d, n = %d, c = %d, pad = %d \n", array_size, num_blocks, in_w, in_h, n, in_c, pad);
|
||||||
|
|
||||||
|
convolve_bin_gpu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (input, weights, output, in_w, in_h, in_c, n, size, pad);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void convolve_bin_cpu(float *input, float *weights, float *output, int in_w, int in_h, int in_c, int n, int size, int pad)
|
||||||
|
{
|
||||||
|
int fil;
|
||||||
|
// filter index
|
||||||
|
#pragma omp parallel for // "omp parallel for" - automatic parallelization of loop by using OpenMP
|
||||||
|
for (fil = 0; fil < n; ++fil) {
|
||||||
|
int chan, y, x, f_y, f_x;
|
||||||
|
// channel index
|
||||||
|
for (chan = 0; chan < in_c; ++chan)
|
||||||
|
// input - y
|
||||||
|
for (y = 0; y < in_h; ++y)
|
||||||
|
// input - x
|
||||||
|
for (x = 0; x < in_w; ++x)
|
||||||
|
{
|
||||||
|
int const output_index = fil*in_w*in_h + y*in_w + x;
|
||||||
|
int const weights_pre_index = fil*in_c*size*size + chan*size*size;
|
||||||
|
int const input_pre_index = chan*in_w*in_h;
|
||||||
|
float sum = 0;
|
||||||
|
|
||||||
|
// filter - y
|
||||||
|
for (f_y = 0; f_y < size; ++f_y)
|
||||||
|
{
|
||||||
|
int input_y = y + f_y - pad;
|
||||||
|
// filter - x
|
||||||
|
for (f_x = 0; f_x < size; ++f_x)
|
||||||
|
{
|
||||||
|
int input_x = x + f_x - pad;
|
||||||
|
if (input_y < 0 || input_x < 0 || input_y >= in_h || input_x >= in_w) continue;
|
||||||
|
|
||||||
|
int input_index = input_pre_index + input_y*in_w + input_x;
|
||||||
|
int weights_index = weights_pre_index + f_y*size + f_x;
|
||||||
|
|
||||||
|
sum += input[input_index] * weights[weights_index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// l.output[filters][width][height] +=
|
||||||
|
// state.input[channels][width][height] *
|
||||||
|
// l.weights[filters][channels][filter_width][filter_height];
|
||||||
|
output[output_index] += sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
Reference in New Issue
Block a user