mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
temp fix, don't use it
This commit is contained in:
@ -3,7 +3,7 @@
|
||||
#include "cublas_v2.h"
|
||||
|
||||
#ifdef CUDNN
|
||||
#pragma comment(lib, "cudnn.lib")
|
||||
#pragma comment(lib, "cudnn.lib")
|
||||
#endif
|
||||
|
||||
extern "C" {
|
||||
@ -117,18 +117,160 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
}
|
||||
|
||||
if(l.xnor){
|
||||
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
|
||||
if (!l.align_bit_weights_gpu || state.train) {
|
||||
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
|
||||
}
|
||||
swap_binary(&l);
|
||||
binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
|
||||
state.input = l.binary_input_gpu;
|
||||
|
||||
if (l.align_bit_weights_gpu && !state.train)
|
||||
{
|
||||
cudaError_t status;
|
||||
status = cudaMemcpy(l.align_bit_weights, l.align_bit_weights_gpu, l.align_bit_weights_size, cudaMemcpyDeviceToHost);
|
||||
check_error(status);
|
||||
|
||||
float *input = (float *)calloc(l.c*l.h*l.w*l.batch, sizeof(float));
|
||||
float *workspace = (float *)calloc(l.bit_align*l.size*l.size*l.c, sizeof(float));
|
||||
float *output = (float *)calloc(l.batch*l.out_c*l.out_h*l.out_w, sizeof(float));
|
||||
|
||||
status = cudaMemcpy(input, state.input, l.c*l.h*l.w*l.batch*sizeof(float), cudaMemcpyDeviceToHost);
|
||||
check_error(status);
|
||||
|
||||
int m = l.n;
|
||||
int k = l.size*l.size*l.c;
|
||||
int n = l.out_w*l.out_h;
|
||||
float * a = l.weights_gpu;
|
||||
//float * b = state.workspace;
|
||||
float *b = workspace;
|
||||
//float * c = l.output_gpu;
|
||||
float *c = output;
|
||||
|
||||
int ldb_align = l.lda_align;
|
||||
size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8;
|
||||
size_t t_intput_size = new_ldb * n;
|
||||
size_t t_bit_input_size = t_intput_size / 8;// +1;
|
||||
|
||||
char *t_bit_input = (char *)calloc(t_bit_input_size, sizeof(char));
|
||||
int src_size = k * l.bit_align;
|
||||
|
||||
//im2col_cpu_custom_bin(input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align);
|
||||
|
||||
float *align_workspace = NULL;
|
||||
int align_workspace_size = l.bit_align * k; // aligned: n*k
|
||||
status = cudaMalloc((void **)&align_workspace, align_workspace_size*sizeof(float));
|
||||
check_error(status);
|
||||
|
||||
int i = 0;
|
||||
im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, align_workspace, l.bit_align);
|
||||
|
||||
float_to_bit_gpu(align_workspace, (unsigned char *)state.workspace, align_workspace_size);
|
||||
|
||||
if(1)
|
||||
{
|
||||
{
|
||||
/*
|
||||
status = cudaMemcpy(t_bit_input, state.workspace, t_bit_input_size, cudaMemcpyDeviceToHost);
|
||||
check_error(status);
|
||||
for (int y = 0; y < 8; ++y) {
|
||||
for (int x = 0; x < 8; ++x) {
|
||||
int index = x + y*l.bit_align;
|
||||
if (get_bit((unsigned char *)t_bit_input, index)) printf("1, ");
|
||||
else printf("0, ");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
*/
|
||||
}
|
||||
|
||||
fill_int8_gpu((unsigned char *)align_workspace, 0, t_bit_input_size);
|
||||
|
||||
transpose_bin_gpu((unsigned char *)state.workspace, (unsigned char *)align_workspace, k, n, l.bit_align, new_ldb, 8);
|
||||
//cudaDeviceSynchronize();
|
||||
|
||||
//int size_transposed_array = l.bit_align * new_ldb;
|
||||
status = cudaMemcpy(t_bit_input, align_workspace, t_bit_input_size, cudaMemcpyDeviceToHost);
|
||||
check_error(status);
|
||||
|
||||
/*
|
||||
for (int y = 0; y < 8; ++y) {
|
||||
for (int x = 0; x < 8; ++x) {
|
||||
int index = x + y*new_ldb;
|
||||
if (get_bit((unsigned char *)t_bit_input, index)) printf("1, ");
|
||||
else printf("0, ");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("-----------\n");
|
||||
//getchar();
|
||||
*/
|
||||
}
|
||||
|
||||
if (0) {
|
||||
status = cudaMemcpy(b, state.workspace, align_workspace_size / 8, cudaMemcpyDeviceToHost);
|
||||
check_error(status);
|
||||
|
||||
for (int y = 0; y < 8; ++y) {
|
||||
for (int x = 0; x < 8; ++x) {
|
||||
int index = x + y*l.bit_align;
|
||||
if (get_bit((unsigned char *)b, index)) printf("1, ");
|
||||
else printf("0, ");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
//float *im2 = (float *)calloc(align_workspace_size, sizeof(float));
|
||||
//status = cudaMemcpy(im2, align_workspace, align_workspace_size * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
//check_error(status);
|
||||
//float_to_bit(im2, (unsigned char *)b, align_workspace_size);
|
||||
|
||||
memset(t_bit_input, 0, t_bit_input_size);
|
||||
// b - [bit_align, k] - [l.bit_align, l.size*l.size*l.c] = src_size
|
||||
// t_input - [bit_align, k] - [n', k]
|
||||
// t_bit_input - [new_ldb, n] - [k', n]
|
||||
transpose_bin((char *)b, t_bit_input, k, n, l.bit_align, new_ldb, 8);
|
||||
|
||||
for (int y = 0; y < 8; ++y) {
|
||||
for (int x = 0; x < 8; ++x) {
|
||||
int index = x + y*new_ldb;
|
||||
if (get_bit((unsigned char *)t_bit_input, index)) printf("1, ");
|
||||
else printf("0, ");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("-----------\n");
|
||||
//getchar();
|
||||
|
||||
//free(im2);
|
||||
}
|
||||
|
||||
// 5x times faster than gemm()-float32
|
||||
gemm_nn_custom_bin_mean_transposed(m, n, k, 1, (unsigned char *)l.align_bit_weights, new_ldb, (unsigned char *)t_bit_input, new_ldb, c, n, l.mean_arr);
|
||||
|
||||
status = cudaMemcpy(l.output_gpu, output, l.batch*l.out_c*l.out_h*l.out_w * sizeof(float), cudaMemcpyHostToDevice);
|
||||
check_error(status);
|
||||
|
||||
free(t_bit_input);
|
||||
free(input);
|
||||
free(workspace);
|
||||
free(output);
|
||||
cudaFree(align_workspace);
|
||||
|
||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
|
||||
if (l.binary || l.xnor) swap_binary(&l);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef CUDNN
|
||||
float one = 1; // alpha[0], beta[0] is float for HALF and FLOAT
|
||||
float alpha = 1, beta = 0;
|
||||
float alpha = 1, beta = 0;
|
||||
|
||||
#ifdef CUDNN_HALF
|
||||
// Note: For improved performance it is advised to use beta[0] = 0.0.
|
||||
// Note: For improved performance it is advised to use beta[0] = 0.0.
|
||||
// For Tensor Core: cudnnSetConvolutionMathType() where cudnnMathType_t mathType = CUDNN_TENSOR_OP_MATH;
|
||||
// 1. or CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM and use CUDNN_DATA_HALF
|
||||
// 2. or CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED
|
||||
@ -168,10 +310,10 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
&beta,
|
||||
l.dstTensorDesc,
|
||||
output16);
|
||||
|
||||
|
||||
if (l.batch_normalize)
|
||||
{
|
||||
|
||||
if (l.batch_normalize)
|
||||
{
|
||||
if (state.train) // Training
|
||||
{
|
||||
copy_ongpu(l.outputs*l.batch / 2, output16, 1, l.x_gpu, 1);
|
||||
@ -213,7 +355,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
{
|
||||
cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
|
||||
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
@ -283,11 +425,11 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
||||
float alpha = 1, beta = 0;
|
||||
|
||||
#ifdef CUDNN_HALF
|
||||
|
||||
|
||||
const size_t input16_size = l.batch*l.c*l.w*l.h;
|
||||
const size_t delta16_size = l.batch*l.n*l.out_w*l.out_h;
|
||||
|
||||
if (*state.net.max_input16_size < input16_size) {
|
||||
|
||||
if (*state.net.max_input16_size < input16_size) {
|
||||
*state.net.max_input16_size = input16_size;
|
||||
if(*state.net.input16_gpu) cuda_free(*state.net.input16_gpu);
|
||||
*state.net.input16_gpu = (float *)cuda_make_f16_from_f32_array(NULL, *state.net.max_input16_size);
|
||||
@ -368,7 +510,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
||||
// http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData
|
||||
// calculate delta for the next layer
|
||||
// convert input: l.weights_gpu (w), l.delta_gpu (dy) from fp32 to fp16
|
||||
// get output: state.delta (dx) and convert it to fp32 (ONLY if it is fp16)
|
||||
// get output: state.delta (dx) and convert it to fp32 (ONLY if it is fp16)
|
||||
cudnnConvolutionBackwardData(cudnn_handle(),
|
||||
&alpha,
|
||||
l.weightDesc,
|
||||
@ -524,11 +666,11 @@ void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float
|
||||
}else{
|
||||
// update weights:
|
||||
// weights_gpu = weights_gpu*(1 - decay*lr) + weight_updates_gpu*lr / (batch*subdivision) =
|
||||
// weights_gpu*(1 - 0.0005*0.001) + weight_updates_gpu*0.001/(64*8) =
|
||||
// weights_gpu*(1 - 0.0005*0.001) + weight_updates_gpu*0.001/(64*8) =
|
||||
// weights_gpu * 0.999 999 5 + weight_updates_gpu * 0.000 001 953125
|
||||
//
|
||||
// weight_updates_gpu = (weight_updates_gpu - weights_gpu*decay*batch*subdivision)*momentum =
|
||||
// (weight_updates_gpu - weights_gpu * 0.0005 * 64 * 8) * 0.9 =
|
||||
//
|
||||
// weight_updates_gpu = (weight_updates_gpu - weights_gpu*decay*batch*subdivision)*momentum =
|
||||
// (weight_updates_gpu - weights_gpu * 0.0005 * 64 * 8) * 0.9 =
|
||||
// weight_updates_gpu*0.9 - weights_gpu*0.2304
|
||||
axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
|
||||
axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
|
||||
|
Reference in New Issue
Block a user