mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
temp fix, don't use it
This commit is contained in:
@ -3,7 +3,7 @@
|
|||||||
#include "cublas_v2.h"
|
#include "cublas_v2.h"
|
||||||
|
|
||||||
#ifdef CUDNN
|
#ifdef CUDNN
|
||||||
#pragma comment(lib, "cudnn.lib")
|
#pragma comment(lib, "cudnn.lib")
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
@ -117,18 +117,160 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if(l.xnor){
|
if(l.xnor){
|
||||||
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
|
if (!l.align_bit_weights_gpu || state.train) {
|
||||||
|
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
|
||||||
|
}
|
||||||
swap_binary(&l);
|
swap_binary(&l);
|
||||||
binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
|
binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
|
||||||
state.input = l.binary_input_gpu;
|
state.input = l.binary_input_gpu;
|
||||||
|
|
||||||
|
if (l.align_bit_weights_gpu && !state.train)
|
||||||
|
{
|
||||||
|
cudaError_t status;
|
||||||
|
status = cudaMemcpy(l.align_bit_weights, l.align_bit_weights_gpu, l.align_bit_weights_size, cudaMemcpyDeviceToHost);
|
||||||
|
check_error(status);
|
||||||
|
|
||||||
|
float *input = (float *)calloc(l.c*l.h*l.w*l.batch, sizeof(float));
|
||||||
|
float *workspace = (float *)calloc(l.bit_align*l.size*l.size*l.c, sizeof(float));
|
||||||
|
float *output = (float *)calloc(l.batch*l.out_c*l.out_h*l.out_w, sizeof(float));
|
||||||
|
|
||||||
|
status = cudaMemcpy(input, state.input, l.c*l.h*l.w*l.batch*sizeof(float), cudaMemcpyDeviceToHost);
|
||||||
|
check_error(status);
|
||||||
|
|
||||||
|
int m = l.n;
|
||||||
|
int k = l.size*l.size*l.c;
|
||||||
|
int n = l.out_w*l.out_h;
|
||||||
|
float * a = l.weights_gpu;
|
||||||
|
//float * b = state.workspace;
|
||||||
|
float *b = workspace;
|
||||||
|
//float * c = l.output_gpu;
|
||||||
|
float *c = output;
|
||||||
|
|
||||||
|
int ldb_align = l.lda_align;
|
||||||
|
size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8;
|
||||||
|
size_t t_intput_size = new_ldb * n;
|
||||||
|
size_t t_bit_input_size = t_intput_size / 8;// +1;
|
||||||
|
|
||||||
|
char *t_bit_input = (char *)calloc(t_bit_input_size, sizeof(char));
|
||||||
|
int src_size = k * l.bit_align;
|
||||||
|
|
||||||
|
//im2col_cpu_custom_bin(input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align);
|
||||||
|
|
||||||
|
float *align_workspace = NULL;
|
||||||
|
int align_workspace_size = l.bit_align * k; // aligned: n*k
|
||||||
|
status = cudaMalloc((void **)&align_workspace, align_workspace_size*sizeof(float));
|
||||||
|
check_error(status);
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, align_workspace, l.bit_align);
|
||||||
|
|
||||||
|
float_to_bit_gpu(align_workspace, (unsigned char *)state.workspace, align_workspace_size);
|
||||||
|
|
||||||
|
if(1)
|
||||||
|
{
|
||||||
|
{
|
||||||
|
/*
|
||||||
|
status = cudaMemcpy(t_bit_input, state.workspace, t_bit_input_size, cudaMemcpyDeviceToHost);
|
||||||
|
check_error(status);
|
||||||
|
for (int y = 0; y < 8; ++y) {
|
||||||
|
for (int x = 0; x < 8; ++x) {
|
||||||
|
int index = x + y*l.bit_align;
|
||||||
|
if (get_bit((unsigned char *)t_bit_input, index)) printf("1, ");
|
||||||
|
else printf("0, ");
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
fill_int8_gpu((unsigned char *)align_workspace, 0, t_bit_input_size);
|
||||||
|
|
||||||
|
transpose_bin_gpu((unsigned char *)state.workspace, (unsigned char *)align_workspace, k, n, l.bit_align, new_ldb, 8);
|
||||||
|
//cudaDeviceSynchronize();
|
||||||
|
|
||||||
|
//int size_transposed_array = l.bit_align * new_ldb;
|
||||||
|
status = cudaMemcpy(t_bit_input, align_workspace, t_bit_input_size, cudaMemcpyDeviceToHost);
|
||||||
|
check_error(status);
|
||||||
|
|
||||||
|
/*
|
||||||
|
for (int y = 0; y < 8; ++y) {
|
||||||
|
for (int x = 0; x < 8; ++x) {
|
||||||
|
int index = x + y*new_ldb;
|
||||||
|
if (get_bit((unsigned char *)t_bit_input, index)) printf("1, ");
|
||||||
|
else printf("0, ");
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
printf("-----------\n");
|
||||||
|
//getchar();
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
if (0) {
|
||||||
|
status = cudaMemcpy(b, state.workspace, align_workspace_size / 8, cudaMemcpyDeviceToHost);
|
||||||
|
check_error(status);
|
||||||
|
|
||||||
|
for (int y = 0; y < 8; ++y) {
|
||||||
|
for (int x = 0; x < 8; ++x) {
|
||||||
|
int index = x + y*l.bit_align;
|
||||||
|
if (get_bit((unsigned char *)b, index)) printf("1, ");
|
||||||
|
else printf("0, ");
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
|
||||||
|
//float *im2 = (float *)calloc(align_workspace_size, sizeof(float));
|
||||||
|
//status = cudaMemcpy(im2, align_workspace, align_workspace_size * sizeof(float), cudaMemcpyDeviceToHost);
|
||||||
|
//check_error(status);
|
||||||
|
//float_to_bit(im2, (unsigned char *)b, align_workspace_size);
|
||||||
|
|
||||||
|
memset(t_bit_input, 0, t_bit_input_size);
|
||||||
|
// b - [bit_align, k] - [l.bit_align, l.size*l.size*l.c] = src_size
|
||||||
|
// t_input - [bit_align, k] - [n', k]
|
||||||
|
// t_bit_input - [new_ldb, n] - [k', n]
|
||||||
|
transpose_bin((char *)b, t_bit_input, k, n, l.bit_align, new_ldb, 8);
|
||||||
|
|
||||||
|
for (int y = 0; y < 8; ++y) {
|
||||||
|
for (int x = 0; x < 8; ++x) {
|
||||||
|
int index = x + y*new_ldb;
|
||||||
|
if (get_bit((unsigned char *)t_bit_input, index)) printf("1, ");
|
||||||
|
else printf("0, ");
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
printf("-----------\n");
|
||||||
|
//getchar();
|
||||||
|
|
||||||
|
//free(im2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5x times faster than gemm()-float32
|
||||||
|
gemm_nn_custom_bin_mean_transposed(m, n, k, 1, (unsigned char *)l.align_bit_weights, new_ldb, (unsigned char *)t_bit_input, new_ldb, c, n, l.mean_arr);
|
||||||
|
|
||||||
|
status = cudaMemcpy(l.output_gpu, output, l.batch*l.out_c*l.out_h*l.out_w * sizeof(float), cudaMemcpyHostToDevice);
|
||||||
|
check_error(status);
|
||||||
|
|
||||||
|
free(t_bit_input);
|
||||||
|
free(input);
|
||||||
|
free(workspace);
|
||||||
|
free(output);
|
||||||
|
cudaFree(align_workspace);
|
||||||
|
|
||||||
|
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||||
|
activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
||||||
|
if (l.binary || l.xnor) swap_binary(&l);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef CUDNN
|
#ifdef CUDNN
|
||||||
float one = 1; // alpha[0], beta[0] is float for HALF and FLOAT
|
float one = 1; // alpha[0], beta[0] is float for HALF and FLOAT
|
||||||
float alpha = 1, beta = 0;
|
float alpha = 1, beta = 0;
|
||||||
|
|
||||||
#ifdef CUDNN_HALF
|
#ifdef CUDNN_HALF
|
||||||
// Note: For improved performance it is advised to use beta[0] = 0.0.
|
// Note: For improved performance it is advised to use beta[0] = 0.0.
|
||||||
// For Tensor Core: cudnnSetConvolutionMathType() where cudnnMathType_t mathType = CUDNN_TENSOR_OP_MATH;
|
// For Tensor Core: cudnnSetConvolutionMathType() where cudnnMathType_t mathType = CUDNN_TENSOR_OP_MATH;
|
||||||
// 1. or CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM and use CUDNN_DATA_HALF
|
// 1. or CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM and use CUDNN_DATA_HALF
|
||||||
// 2. or CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED
|
// 2. or CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED
|
||||||
@ -168,10 +310,10 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
|||||||
&beta,
|
&beta,
|
||||||
l.dstTensorDesc,
|
l.dstTensorDesc,
|
||||||
output16);
|
output16);
|
||||||
|
|
||||||
|
|
||||||
if (l.batch_normalize)
|
|
||||||
{
|
if (l.batch_normalize)
|
||||||
|
{
|
||||||
if (state.train) // Training
|
if (state.train) // Training
|
||||||
{
|
{
|
||||||
copy_ongpu(l.outputs*l.batch / 2, output16, 1, l.x_gpu, 1);
|
copy_ongpu(l.outputs*l.batch / 2, output16, 1, l.x_gpu, 1);
|
||||||
@ -213,7 +355,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
|||||||
{
|
{
|
||||||
cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
|
cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
|
||||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
@ -283,11 +425,11 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
|||||||
float alpha = 1, beta = 0;
|
float alpha = 1, beta = 0;
|
||||||
|
|
||||||
#ifdef CUDNN_HALF
|
#ifdef CUDNN_HALF
|
||||||
|
|
||||||
const size_t input16_size = l.batch*l.c*l.w*l.h;
|
const size_t input16_size = l.batch*l.c*l.w*l.h;
|
||||||
const size_t delta16_size = l.batch*l.n*l.out_w*l.out_h;
|
const size_t delta16_size = l.batch*l.n*l.out_w*l.out_h;
|
||||||
|
|
||||||
if (*state.net.max_input16_size < input16_size) {
|
if (*state.net.max_input16_size < input16_size) {
|
||||||
*state.net.max_input16_size = input16_size;
|
*state.net.max_input16_size = input16_size;
|
||||||
if(*state.net.input16_gpu) cuda_free(*state.net.input16_gpu);
|
if(*state.net.input16_gpu) cuda_free(*state.net.input16_gpu);
|
||||||
*state.net.input16_gpu = (float *)cuda_make_f16_from_f32_array(NULL, *state.net.max_input16_size);
|
*state.net.input16_gpu = (float *)cuda_make_f16_from_f32_array(NULL, *state.net.max_input16_size);
|
||||||
@ -368,7 +510,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
|||||||
// http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData
|
// http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData
|
||||||
// calculate delta for the next layer
|
// calculate delta for the next layer
|
||||||
// convert input: l.weights_gpu (w), l.delta_gpu (dy) from fp32 to fp16
|
// convert input: l.weights_gpu (w), l.delta_gpu (dy) from fp32 to fp16
|
||||||
// get output: state.delta (dx) and convert it to fp32 (ONLY if it is fp16)
|
// get output: state.delta (dx) and convert it to fp32 (ONLY if it is fp16)
|
||||||
cudnnConvolutionBackwardData(cudnn_handle(),
|
cudnnConvolutionBackwardData(cudnn_handle(),
|
||||||
&alpha,
|
&alpha,
|
||||||
l.weightDesc,
|
l.weightDesc,
|
||||||
@ -524,11 +666,11 @@ void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float
|
|||||||
}else{
|
}else{
|
||||||
// update weights:
|
// update weights:
|
||||||
// weights_gpu = weights_gpu*(1 - decay*lr) + weight_updates_gpu*lr / (batch*subdivision) =
|
// weights_gpu = weights_gpu*(1 - decay*lr) + weight_updates_gpu*lr / (batch*subdivision) =
|
||||||
// weights_gpu*(1 - 0.0005*0.001) + weight_updates_gpu*0.001/(64*8) =
|
// weights_gpu*(1 - 0.0005*0.001) + weight_updates_gpu*0.001/(64*8) =
|
||||||
// weights_gpu * 0.999 999 5 + weight_updates_gpu * 0.000 001 953125
|
// weights_gpu * 0.999 999 5 + weight_updates_gpu * 0.000 001 953125
|
||||||
//
|
//
|
||||||
// weight_updates_gpu = (weight_updates_gpu - weights_gpu*decay*batch*subdivision)*momentum =
|
// weight_updates_gpu = (weight_updates_gpu - weights_gpu*decay*batch*subdivision)*momentum =
|
||||||
// (weight_updates_gpu - weights_gpu * 0.0005 * 64 * 8) * 0.9 =
|
// (weight_updates_gpu - weights_gpu * 0.0005 * 64 * 8) * 0.9 =
|
||||||
// weight_updates_gpu*0.9 - weights_gpu*0.2304
|
// weight_updates_gpu*0.9 - weights_gpu*0.2304
|
||||||
axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
|
axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
|
||||||
axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
|
axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
|
||||||
|
@ -609,9 +609,9 @@ void binary_align_weights(convolutional_layer *l)
|
|||||||
binarize_weights(l->weights, m, k, l->binary_weights);
|
binarize_weights(l->weights, m, k, l->binary_weights);
|
||||||
|
|
||||||
size_t align_weights_size = new_lda * m;
|
size_t align_weights_size = new_lda * m;
|
||||||
size_t align_bit_weights_size = align_weights_size / 8;// +1;
|
l->align_bit_weights_size = align_weights_size / 8;// +1;
|
||||||
float *align_weights = calloc(align_weights_size, sizeof(float));
|
float *align_weights = calloc(align_weights_size, sizeof(float));
|
||||||
l->align_bit_weights = calloc(align_bit_weights_size, sizeof(char));
|
l->align_bit_weights = calloc(l->align_bit_weights_size, sizeof(char));
|
||||||
|
|
||||||
size_t i, j;
|
size_t i, j;
|
||||||
// align A without transpose
|
// align A without transpose
|
||||||
@ -625,29 +625,28 @@ void binary_align_weights(convolutional_layer *l)
|
|||||||
l->mean_arr = calloc(l->n, sizeof(float));
|
l->mean_arr = calloc(l->n, sizeof(float));
|
||||||
get_mean_array(align_weights, align_weights_size, l->n, l->mean_arr);
|
get_mean_array(align_weights, align_weights_size, l->n, l->mean_arr);
|
||||||
|
|
||||||
|
#ifdef GPU
|
||||||
|
//l->align_bit_weights_gpu = cuda_make_array(l->align_bit_weights, l->align_bit_weights_size * sizeof(char)/sizeof(float));
|
||||||
|
cudaError_t status = cudaMalloc((void **)&l->align_bit_weights_gpu, l->align_bit_weights_size);
|
||||||
|
check_error(status);
|
||||||
|
status = cudaMemcpy(l->align_bit_weights_gpu, l->align_bit_weights, l->align_bit_weights_size, cudaMemcpyHostToDevice);
|
||||||
|
check_error(status);
|
||||||
|
|
||||||
|
l->mean_arr_gpu = cuda_make_array(l->mean_arr, l->n);
|
||||||
|
#endif // GPU
|
||||||
|
|
||||||
free(align_weights);
|
free(align_weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
// further optimizations: im2col_bin() for XNOR, and then transpose_aling_bin()
|
// binary transpose
|
||||||
size_t binary_transpose_align_input(int k, int n, float *b, char **t_bit_input, size_t ldb_align, int bit_align)
|
size_t binary_transpose_align_input(int k, int n, float *b, char **t_bit_input, size_t ldb_align, int bit_align)
|
||||||
{
|
{
|
||||||
size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8;
|
size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8;
|
||||||
size_t t_intput_size = new_ldb * n;
|
size_t t_intput_size = new_ldb * n;
|
||||||
size_t t_bit_input_size = t_intput_size / 8;// +1;
|
size_t t_bit_input_size = t_intput_size / 8;// +1;
|
||||||
//float *t_input = calloc(t_intput_size, sizeof(float));
|
|
||||||
//char *
|
|
||||||
*t_bit_input = calloc(t_bit_input_size, sizeof(char));
|
*t_bit_input = calloc(t_bit_input_size, sizeof(char));
|
||||||
|
|
||||||
//printf("\n bit_input_size = %d, n = %d, k = %d, ldb = %d \n", bit_input_size, n, k, n);
|
|
||||||
//printf("\n t_bit_input_size = %d, k = %d, n = %d, new_ldb = %d \n", t_bit_input_size, k, n, new_ldb);
|
|
||||||
|
|
||||||
//printf("\n align_weights_size = %d, k = %d, m = %d, lda = %d \n", align_weights_size, k, m, k);
|
|
||||||
//printf("\n align_bit_weights_size = %d, k = %d, m = %d, new_lda = %d \n", align_bit_weights_size, k, m, new_ldb);
|
|
||||||
|
|
||||||
int src_size = k * bit_align;
|
int src_size = k * bit_align;
|
||||||
//printf("\n src_size = %d \n", src_size);
|
|
||||||
|
|
||||||
//float_to_bit(b, t_input, src_size);
|
|
||||||
|
|
||||||
// b - [bit_align, k] - [l.bit_align, l.size*l.size*l.c] = src_size
|
// b - [bit_align, k] - [l.bit_align, l.size*l.size*l.c] = src_size
|
||||||
// t_input - [bit_align, k] - [n', k]
|
// t_input - [bit_align, k] - [n', k]
|
||||||
@ -656,8 +655,6 @@ size_t binary_transpose_align_input(int k, int n, float *b, char **t_bit_input,
|
|||||||
//transpose_bin(t_input, *t_bit_input, k, n, bit_align, new_ldb, 8);
|
//transpose_bin(t_input, *t_bit_input, k, n, bit_align, new_ldb, 8);
|
||||||
transpose_bin(b, *t_bit_input, k, n, bit_align, new_ldb, 8);
|
transpose_bin(b, *t_bit_input, k, n, bit_align, new_ldb, 8);
|
||||||
|
|
||||||
//free(t_input);
|
|
||||||
|
|
||||||
return t_intput_size;
|
return t_intput_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -671,7 +668,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
|
|||||||
fill_cpu(l.outputs*l.batch, 0, l.output, 1);
|
fill_cpu(l.outputs*l.batch, 0, l.output, 1);
|
||||||
|
|
||||||
if(l.xnor){
|
if(l.xnor){
|
||||||
if (!l.align_bit_weights) {
|
if (!l.align_bit_weights || state.train) {
|
||||||
binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
|
binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
|
||||||
//printf("\n binarize_weights l.align_bit_weights = %p \n", l.align_bit_weights);
|
//printf("\n binarize_weights l.align_bit_weights = %p \n", l.align_bit_weights);
|
||||||
}
|
}
|
||||||
@ -709,7 +706,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
|
|||||||
|
|
||||||
//gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
|
//gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
|
||||||
//gemm_nn_custom(m, n, k, 1, a, k, b, n, c, n);
|
//gemm_nn_custom(m, n, k, 1, a, k, b, n, c, n);
|
||||||
if (l.xnor && (l.stride == 1 && l.pad == 1)) {
|
if (l.xnor && l.align_bit_weights && !state.train && (l.stride == 1 && l.pad == 1)) {
|
||||||
memset(b, 0, l.bit_align*l.size*l.size*l.c * sizeof(float));
|
memset(b, 0, l.bit_align*l.size*l.size*l.c * sizeof(float));
|
||||||
//im2col_cpu_custom_align(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align);
|
//im2col_cpu_custom_align(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align);
|
||||||
im2col_cpu_custom_bin(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align);
|
im2col_cpu_custom_bin(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align);
|
||||||
@ -812,7 +809,6 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
|
|||||||
//float_to_bit(t_input, t_bit_input, new_ldb * n); // for im2col_cpu_custom_transpose() only
|
//float_to_bit(t_input, t_bit_input, new_ldb * n); // for im2col_cpu_custom_transpose() only
|
||||||
|
|
||||||
// 5x times faster than gemm()-float32
|
// 5x times faster than gemm()-float32
|
||||||
// further optimizations: accelerate maxpool-layer with OpenMP/AVX
|
|
||||||
gemm_nn_custom_bin_mean_transposed(m, n, k, 1, l.align_bit_weights, new_ldb, t_bit_input, new_ldb, c, n, l.mean_arr);
|
gemm_nn_custom_bin_mean_transposed(m, n, k, 1, l.align_bit_weights, new_ldb, t_bit_input, new_ldb, c, n, l.mean_arr);
|
||||||
|
|
||||||
//gemm_nn_custom_bin_mean_transposed(m, n, k, 1, bit_weights, k, t_bit_input, new_ldb, c, n, mean_arr);
|
//gemm_nn_custom_bin_mean_transposed(m, n, k, 1, bit_weights, k, t_bit_input, new_ldb, c, n, mean_arr);
|
||||||
|
@ -1224,7 +1224,7 @@ void run_detector(int argc, char **argv)
|
|||||||
int ext_output = find_arg(argc, argv, "-ext_output");
|
int ext_output = find_arg(argc, argv, "-ext_output");
|
||||||
int save_labels = find_arg(argc, argv, "-save_labels");
|
int save_labels = find_arg(argc, argv, "-save_labels");
|
||||||
if(argc < 4){
|
if(argc < 4){
|
||||||
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
fprintf(stderr, "usage: %s %s [train/test/valid/demo/map] [data] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
|
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
|
||||||
|
@ -324,7 +324,7 @@ unsigned char reverse_byte_1(char a)
|
|||||||
((a & 0x40) >> 5) | ((a & 0x80) >> 7);
|
((a & 0x40) >> 5) | ((a & 0x80) >> 7);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned char reverse_byte_2(unsigned char a)
|
unsigned char reverse_byte(unsigned char a)
|
||||||
{
|
{
|
||||||
return ((a * 0x0802LU & 0x22110LU) | (a * 0x8020LU & 0x88440LU)) * 0x10101LU >> 16;
|
return ((a * 0x0802LU & 0x22110LU) | (a * 0x8020LU & 0x88440LU)) * 0x10101LU >> 16;
|
||||||
}
|
}
|
||||||
@ -333,7 +333,7 @@ static unsigned char lookup[16] = {
|
|||||||
0x0, 0x8, 0x4, 0xc, 0x2, 0xa, 0x6, 0xe,
|
0x0, 0x8, 0x4, 0xc, 0x2, 0xa, 0x6, 0xe,
|
||||||
0x1, 0x9, 0x5, 0xd, 0x3, 0xb, 0x7, 0xf, };
|
0x1, 0x9, 0x5, 0xd, 0x3, 0xb, 0x7, 0xf, };
|
||||||
|
|
||||||
unsigned char reverse_byte(unsigned char n) {
|
unsigned char reverse_byte_3(unsigned char n) {
|
||||||
// Reverse the top and bottom nibble then swap them.
|
// Reverse the top and bottom nibble then swap them.
|
||||||
return (lookup[n & 0b1111] << 4) | lookup[n >> 4];
|
return (lookup[n & 0b1111] << 4) | lookup[n >> 4];
|
||||||
}
|
}
|
||||||
|
11
src/im2col.h
11
src/im2col.h
@ -11,5 +11,16 @@ void im2col_ongpu(float *im,
|
|||||||
int channels, int height, int width,
|
int channels, int height, int width,
|
||||||
int ksize, int stride, int pad,float *data_col);
|
int ksize, int stride, int pad,float *data_col);
|
||||||
|
|
||||||
|
void im2col_align_ongpu(float *im,
|
||||||
|
int channels, int height, int width,
|
||||||
|
int ksize, int stride, int pad, float *data_col, int bit_align);
|
||||||
|
|
||||||
|
void float_to_bit_gpu(float *src, unsigned char *dst, size_t size);
|
||||||
|
|
||||||
|
void transpose_bin_gpu(unsigned char *A, unsigned char *B, const int n, const int m,
|
||||||
|
const int lda, const int ldb, const int block_size);
|
||||||
|
|
||||||
|
void fill_int8_gpu(unsigned char *src, unsigned char val, size_t size);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
@ -59,3 +59,196 @@ void im2col_ongpu(float *im,
|
|||||||
stride, height_col,
|
stride, height_col,
|
||||||
width_col, data_col);
|
width_col, data_col);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void im2col_align_gpu_kernel(const int n, const float* data_im,
|
||||||
|
const int height, const int width, const int ksize,
|
||||||
|
const int pad,
|
||||||
|
const int stride,
|
||||||
|
const int height_col, const int width_col,
|
||||||
|
float *data_col, const int bit_align)
|
||||||
|
{
|
||||||
|
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||||
|
for (; index < n; index += blockDim.x*gridDim.x) {
|
||||||
|
int w_out = index % width_col;
|
||||||
|
int h_index = index / width_col;
|
||||||
|
int h_out = h_index % height_col;
|
||||||
|
int channel_in = h_index / height_col;
|
||||||
|
int channel_out = channel_in * ksize * ksize;
|
||||||
|
int h_in = h_out * stride - pad;
|
||||||
|
int w_in = w_out * stride - pad;
|
||||||
|
float* data_col_ptr = data_col;
|
||||||
|
//data_col_ptr += (channel_out * height_col + h_out) * width_col + w_out;
|
||||||
|
data_col_ptr += channel_out * bit_align + h_out * width_col + w_out;
|
||||||
|
const float* data_im_ptr = data_im;
|
||||||
|
data_im_ptr += (channel_in * height + h_in) * width + w_in;
|
||||||
|
for (int i = 0; i < ksize; ++i) {
|
||||||
|
for (int j = 0; j < ksize; ++j) {
|
||||||
|
int h = h_in + i;
|
||||||
|
int w = w_in + j;
|
||||||
|
|
||||||
|
*data_col_ptr = (h >= 0 && w >= 0 && h < height && w < width) ?
|
||||||
|
data_im_ptr[i * width + j] : 0;
|
||||||
|
|
||||||
|
|
||||||
|
//data_col_ptr += height_col * width_col;
|
||||||
|
data_col_ptr += bit_align;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void im2col_align_ongpu(float *im,
|
||||||
|
int channels, int height, int width,
|
||||||
|
int ksize, int stride, int pad, float *data_col, int bit_align) {
|
||||||
|
// We are going to launch channels * height_col * width_col kernels, each
|
||||||
|
// kernel responsible for copying a single-channel grid.
|
||||||
|
int height_col = (height + 2 * pad - ksize) / stride + 1;
|
||||||
|
int width_col = (width + 2 * pad - ksize) / stride + 1;
|
||||||
|
int num_kernels = channels * height_col * width_col;
|
||||||
|
im2col_align_gpu_kernel << <(num_kernels + BLOCK - 1) / BLOCK,
|
||||||
|
BLOCK, 0, get_cuda_stream() >> >(
|
||||||
|
num_kernels, im, height, width, ksize, pad,
|
||||||
|
stride, height_col,
|
||||||
|
width_col, data_col, bit_align);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// --------------------------------
|
||||||
|
|
||||||
|
#define WARP_SIZE 32
|
||||||
|
|
||||||
|
__global__ void float_to_bit_gpu_kernel(float *src, unsigned char *dst, size_t size)
|
||||||
|
{
|
||||||
|
//size_t dst_size = size / 8 + 1;
|
||||||
|
//memset(dst, 0, dst_size);
|
||||||
|
//uint32_t bit_mask = __ballot_sync(FULL_MASK, src[i] > 0);
|
||||||
|
const int size_aligned = size + (WARP_SIZE - size % WARP_SIZE);
|
||||||
|
|
||||||
|
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||||
|
float src_val;
|
||||||
|
|
||||||
|
for (; index < size_aligned; index += blockDim.x*gridDim.x)
|
||||||
|
{
|
||||||
|
if(index < size) src_val = src[index];
|
||||||
|
else src_val = 0;
|
||||||
|
unsigned int bit_mask = __ballot_sync(0xffffffff, src_val > 0);
|
||||||
|
if (threadIdx.x % WARP_SIZE == 0) ((unsigned int*)dst)[index / 32] = bit_mask;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void float_to_bit_gpu(float *src, unsigned char *dst, size_t size)
|
||||||
|
{
|
||||||
|
const int num_blocks = size / BLOCK + 1;
|
||||||
|
float_to_bit_gpu_kernel<<<num_blocks, BLOCK, 0, get_cuda_stream()>>>(src, dst, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
__device__ __host__ static inline void set_bit(unsigned char *const dst, size_t index) {
|
||||||
|
size_t dst_i = index / 8;
|
||||||
|
int dst_shift = index % 8;
|
||||||
|
dst[dst_i] |= 1 << dst_shift;
|
||||||
|
//dst[dst_i] |= 1 << (8 - dst_shift);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __host__ static inline unsigned char get_bit(unsigned char const*const src, size_t index) {
|
||||||
|
size_t src_i = index / 8;
|
||||||
|
int src_shift = index % 8;
|
||||||
|
unsigned char val = (src[src_i] & (1 << src_shift)) > 0;
|
||||||
|
//unsigned char val = (src[src_i] & (1 << (8 - src_shift))) > 0;
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Intel CPUs and nVidia CUDA GPU are little endian
|
||||||
|
__device__ __host__ unsigned char reverse_byte(unsigned char a)
|
||||||
|
{
|
||||||
|
return ((a & 0x1) << 7) | ((a & 0x2) << 5) |
|
||||||
|
((a & 0x4) << 3) | ((a & 0x8) << 1) |
|
||||||
|
((a & 0x10) >> 1) | ((a & 0x20) >> 3) |
|
||||||
|
((a & 0x40) >> 5) | ((a & 0x80) >> 7);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __host__ unsigned char reverse_byte_2(unsigned char a)
|
||||||
|
{
|
||||||
|
return ((a * 0x0802LU & 0x22110LU) | (a * 0x8020LU & 0x88440LU)) * 0x10101LU >> 16;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
__device__ __host__ void transpose8rS32_reversed_diagonale(unsigned char* A, int m, int n, unsigned char* B)
|
||||||
|
{
|
||||||
|
unsigned x, y, t;
|
||||||
|
|
||||||
|
// Load the array and pack it into x and y.
|
||||||
|
x = (A[0] << 24) | (A[m] << 16) | (A[2 * m] << 8) | A[3 * m];
|
||||||
|
y = (A[4 * m] << 24) | (A[5 * m] << 16) | (A[6 * m] << 8) | A[7 * m];
|
||||||
|
|
||||||
|
t = (x ^ (x >> 7)) & 0x00AA00AA; x = x ^ t ^ (t << 7);
|
||||||
|
t = (y ^ (y >> 7)) & 0x00AA00AA; y = y ^ t ^ (t << 7);
|
||||||
|
|
||||||
|
t = (x ^ (x >> 14)) & 0x0000CCCC; x = x ^ t ^ (t << 14);
|
||||||
|
t = (y ^ (y >> 14)) & 0x0000CCCC; y = y ^ t ^ (t << 14);
|
||||||
|
|
||||||
|
t = (x & 0xF0F0F0F0) | ((y >> 4) & 0x0F0F0F0F);
|
||||||
|
y = ((x << 4) & 0xF0F0F0F0) | (y & 0x0F0F0F0F);
|
||||||
|
x = t;
|
||||||
|
|
||||||
|
B[7 * n] = reverse_byte(x >> 24); B[6 * n] = reverse_byte(x >> 16); B[5 * n] = reverse_byte(x >> 8); B[4 * n] = reverse_byte(x);
|
||||||
|
B[3 * n] = reverse_byte(y >> 24); B[2 * n] = reverse_byte(y >> 16); B[1 * n] = reverse_byte(y >> 8); B[0 * n] = reverse_byte(y);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void transpose_bin_gpu_kernel(unsigned char *A, unsigned char *B, const int n, const int m,
|
||||||
|
const int lda, const int ldb, const int block_size)
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
//for (i = 0; i < n; i += 8)
|
||||||
|
{
|
||||||
|
i = (index*8) % n;
|
||||||
|
int j;
|
||||||
|
//for (j = 0; j < m - 8; j += 8)
|
||||||
|
{
|
||||||
|
j = ((index * 8) / n) * 8;
|
||||||
|
if (j < m - 8) {
|
||||||
|
int a_index = i*lda + j;
|
||||||
|
int b_index = j*ldb + i;
|
||||||
|
//transpose_8x8_bits_my(&A[a_index/8], &B[b_index/8], lda/8, ldb/8);
|
||||||
|
transpose8rS32_reversed_diagonale(&A[a_index / 8], lda / 8, ldb / 8, &B[b_index / 8]);
|
||||||
|
}
|
||||||
|
else if (j < m) {
|
||||||
|
for (; j < m; ++j) {
|
||||||
|
if (get_bit(A, i*lda + j)) set_bit(B, j*ldb + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void transpose_bin_gpu(unsigned char *A, unsigned char *B, const int n, const int m,
|
||||||
|
const int lda, const int ldb, const int block_size)
|
||||||
|
{
|
||||||
|
size_t size = n*m/64 + 1;
|
||||||
|
const int num_blocks = size / BLOCK + 1;
|
||||||
|
transpose_bin_gpu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(A, B, n, m, lda, ldb, block_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// --------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void fill_int8_gpu_kernel(unsigned char *src, unsigned char val, size_t size) {
|
||||||
|
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||||
|
if(index < size) src[index] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill_int8_gpu(unsigned char *src, unsigned char val, size_t size) {
|
||||||
|
const int num_blocks = size / BLOCK + 1;
|
||||||
|
fill_int8_gpu_kernel<<<num_blocks, BLOCK, 0, get_cuda_stream() >>>(src, val, size);
|
||||||
|
}
|
@ -179,8 +179,11 @@ struct layer{
|
|||||||
float *weights;
|
float *weights;
|
||||||
float *weight_updates;
|
float *weight_updates;
|
||||||
|
|
||||||
|
char *align_bit_weights_gpu;
|
||||||
|
float *mean_arr_gpu;
|
||||||
char *align_bit_weights;
|
char *align_bit_weights;
|
||||||
float *mean_arr;
|
float *mean_arr;
|
||||||
|
int align_bit_weights_size;
|
||||||
int lda_align;
|
int lda_align;
|
||||||
int bit_align;
|
int bit_align;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user