mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Another CUDA performance improvements
This commit is contained in:
@ -316,6 +316,8 @@ struct layer {
|
|||||||
float *col_image;
|
float *col_image;
|
||||||
float * delta;
|
float * delta;
|
||||||
float * output;
|
float * output;
|
||||||
|
int delta_pinned;
|
||||||
|
int output_pinned;
|
||||||
float * loss;
|
float * loss;
|
||||||
float * squared;
|
float * squared;
|
||||||
float * norms;
|
float * norms;
|
||||||
@ -582,6 +584,8 @@ typedef struct network {
|
|||||||
float *output_gpu;
|
float *output_gpu;
|
||||||
|
|
||||||
float *input_state_gpu;
|
float *input_state_gpu;
|
||||||
|
float *input_pinned_cpu;
|
||||||
|
int input_pinned_cpu_flag;
|
||||||
|
|
||||||
float **input_gpu;
|
float **input_gpu;
|
||||||
float **truth_gpu;
|
float **truth_gpu;
|
||||||
@ -777,6 +781,7 @@ LIB_API pthread_t load_data_in_thread(load_args args);
|
|||||||
|
|
||||||
// cuda.h
|
// cuda.h
|
||||||
LIB_API void cuda_pull_array(float *x_gpu, float *x, size_t n);
|
LIB_API void cuda_pull_array(float *x_gpu, float *x, size_t n);
|
||||||
|
LIB_API void cuda_pull_array_async(float *x_gpu, float *x, size_t n);
|
||||||
LIB_API void cuda_set_device(int n);
|
LIB_API void cuda_set_device(int n);
|
||||||
|
|
||||||
// utils.h
|
// utils.h
|
||||||
|
@ -692,6 +692,14 @@ extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int
|
|||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__global__ void simple_input_shortcut_kernel(float *in, int size, float *add, float *out)
|
||||||
|
{
|
||||||
|
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||||
|
if (id >= size) return;
|
||||||
|
|
||||||
|
out[id] = in[id] + add[id];
|
||||||
|
}
|
||||||
|
|
||||||
__global__ void input_shortcut_kernel(float *in, int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
|
__global__ void input_shortcut_kernel(float *in, int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
|
||||||
{
|
{
|
||||||
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||||
@ -711,6 +719,13 @@ __global__ void input_shortcut_kernel(float *in, int size, int minw, int minh, i
|
|||||||
|
|
||||||
extern "C" void input_shortcut_gpu(float *in, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
|
extern "C" void input_shortcut_gpu(float *in, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
|
||||||
{
|
{
|
||||||
|
if (w1 == w2 && h1 == h2 && c1 == c2) {
|
||||||
|
int size = batch * w1 * h1 * c1;
|
||||||
|
simple_input_shortcut_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> >(in, size, add, out);
|
||||||
|
check_error(cudaPeekAtLastError());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
int minw = (w1 < w2) ? w1 : w2;
|
int minw = (w1 < w2) ? w1 : w2;
|
||||||
int minh = (h1 < h2) ? h1 : h2;
|
int minh = (h1 < h2) ? h1 : h2;
|
||||||
int minc = (c1 < c2) ? c1 : c2;
|
int minc = (c1 < c2) ? c1 : c2;
|
||||||
|
30
src/cuda.c
30
src/cuda.c
@ -82,6 +82,27 @@ cudaStream_t get_cuda_stream() {
|
|||||||
return streamsArray[i];
|
return streamsArray[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static cudaStream_t streamsArray2[16]; // cudaStreamSynchronize( get_cuda_memcpy_stream() );
|
||||||
|
static int streamInit2[16] = { 0 };
|
||||||
|
|
||||||
|
cudaStream_t get_cuda_memcpy_stream() {
|
||||||
|
int i = cuda_get_device();
|
||||||
|
if (!streamInit2[i]) {
|
||||||
|
cudaError_t status = cudaStreamCreate(&streamsArray2[i]);
|
||||||
|
//cudaError_t status = cudaStreamCreateWithFlags(&streamsArray2[i], cudaStreamNonBlocking);
|
||||||
|
if (status != cudaSuccess) {
|
||||||
|
printf(" cudaStreamCreate Memcpy error: %d \n", status);
|
||||||
|
const char *s = cudaGetErrorString(status);
|
||||||
|
char buffer[256];
|
||||||
|
printf("CUDA Error: %s\n", s);
|
||||||
|
status = cudaStreamCreateWithFlags(&streamsArray2[i], cudaStreamDefault);
|
||||||
|
check_error(status);
|
||||||
|
}
|
||||||
|
streamInit2[i] = 1;
|
||||||
|
}
|
||||||
|
return streamsArray2[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#ifdef CUDNN
|
#ifdef CUDNN
|
||||||
cudnnHandle_t cudnn_handle()
|
cudnnHandle_t cudnn_handle()
|
||||||
@ -116,6 +137,7 @@ float *cuda_make_array(float *x, size_t n)
|
|||||||
float *x_gpu;
|
float *x_gpu;
|
||||||
size_t size = sizeof(float)*n;
|
size_t size = sizeof(float)*n;
|
||||||
cudaError_t status = cudaMalloc((void **)&x_gpu, size);
|
cudaError_t status = cudaMalloc((void **)&x_gpu, size);
|
||||||
|
if (status != cudaSuccess) fprintf(stderr, " Try to set subdivisions=64 in your cfg-file. \n");
|
||||||
check_error(status);
|
check_error(status);
|
||||||
if(x){
|
if(x){
|
||||||
//status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
|
//status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
|
||||||
@ -200,6 +222,14 @@ void cuda_pull_array(float *x_gpu, float *x, size_t n)
|
|||||||
cudaStreamSynchronize(get_cuda_stream());
|
cudaStreamSynchronize(get_cuda_stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void cuda_pull_array_async(float *x_gpu, float *x, size_t n)
|
||||||
|
{
|
||||||
|
size_t size = sizeof(float)*n;
|
||||||
|
cudaError_t status = cudaMemcpyAsync(x, x_gpu, size, cudaMemcpyDeviceToHost, get_cuda_stream());
|
||||||
|
check_error(status);
|
||||||
|
//cudaStreamSynchronize(get_cuda_stream());
|
||||||
|
}
|
||||||
|
|
||||||
#else // GPU
|
#else // GPU
|
||||||
#include "cuda.h"
|
#include "cuda.h"
|
||||||
void cuda_set_device(int n) {}
|
void cuda_set_device(int n) {}
|
||||||
|
@ -37,6 +37,7 @@ extern "C" {
|
|||||||
float cuda_compare(float *x_gpu, float *x, size_t n, char *s);
|
float cuda_compare(float *x_gpu, float *x, size_t n, char *s);
|
||||||
dim3 cuda_gridsize(size_t n);
|
dim3 cuda_gridsize(size_t n);
|
||||||
cudaStream_t get_cuda_stream();
|
cudaStream_t get_cuda_stream();
|
||||||
|
cudaStream_t get_cuda_memcpy_stream();
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
|
@ -1030,7 +1030,6 @@ void repack_input_gpu_2(float *input, float *re_packed_input, int w, int h, int
|
|||||||
__global__ void repack_input_kernel_bin(float *input, uint32_t *re_packed_input_bin, int w, int h, int c)
|
__global__ void repack_input_kernel_bin(float *input, uint32_t *re_packed_input_bin, int w, int h, int c)
|
||||||
{
|
{
|
||||||
__shared__ uint32_t tmp[32];
|
__shared__ uint32_t tmp[32];
|
||||||
|
|
||||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
const int num_of_warps = blockDim.x / WARP_SIZE;
|
const int num_of_warps = blockDim.x / WARP_SIZE;
|
||||||
@ -1350,6 +1349,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
|
|||||||
unsigned char *B, int ldb,
|
unsigned char *B, int ldb,
|
||||||
float *C, int ldc, float *mean_arr, float *bias_arr)
|
float *C, int ldc, float *mean_arr, float *bias_arr)
|
||||||
{
|
{
|
||||||
|
// total 57%
|
||||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
__shared__ uint8_t A_s[6144*8/4];
|
__shared__ uint8_t A_s[6144*8/4];
|
||||||
@ -1363,7 +1363,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
|
|||||||
|
|
||||||
int i_cur = index / N;
|
int i_cur = index / N;
|
||||||
int local_i = i_cur - start_i;
|
int local_i = i_cur - start_i;
|
||||||
|
// ~10%
|
||||||
for (int k = threadIdx.x * 64; k < shared_size; k += blockDim.x * 64) {
|
for (int k = threadIdx.x * 64; k < shared_size; k += blockDim.x * 64) {
|
||||||
int x = start_i*lda + k;
|
int x = start_i*lda + k;
|
||||||
if (x < (M*lda)) *((uint64_t *)(A_s + k / 8)) = *((uint64_t *)(A + x / 8));
|
if (x < (M*lda)) *((uint64_t *)(A_s + k / 8)) = *((uint64_t *)(A + x / 8));
|
||||||
@ -1371,7 +1371,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
int i, j, k, h;
|
int i, j, k, h;
|
||||||
|
// 47% = 29 + 10 + 8
|
||||||
j = index % N;
|
j = index % N;
|
||||||
{ // out_h*out_w - one channel output size [169 - 173056]
|
{ // out_h*out_w - one channel output size [169 - 173056]
|
||||||
i = index / N;
|
i = index / N;
|
||||||
@ -1413,7 +1413,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
//#ifdef NOT_USED
|
//#ifdef NOT_USED
|
||||||
// 32 thread X 64 bit = 2048 bit
|
// 32 thread X 64 bit = 2048 bit // 29%
|
||||||
for (; k < (K - 2048); k += 2048) { // l.size*l.size*l.c - one filter size [27 - 9216]
|
for (; k < (K - 2048); k += 2048) { // l.size*l.size*l.c - one filter size [27 - 9216]
|
||||||
uint64_t c_bit64;
|
uint64_t c_bit64;
|
||||||
|
|
||||||
@ -1444,7 +1444,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
|
|||||||
//#endif
|
//#endif
|
||||||
|
|
||||||
//#ifdef NOT_USED
|
//#ifdef NOT_USED
|
||||||
// 32 thread X 32 bit = 1024 bit
|
// 32 thread X 32 bit = 1024 bit // 10%
|
||||||
for (; k < (K - 1024); k += 1024) { // l.size*l.size*l.c - one filter size [27 - 9216]
|
for (; k < (K - 1024); k += 1024) { // l.size*l.size*l.c - one filter size [27 - 9216]
|
||||||
|
|
||||||
//int64_t A_cur_index = (i*lda + k) / 8;
|
//int64_t A_cur_index = (i*lda + k) / 8;
|
||||||
@ -1479,6 +1479,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
|
|||||||
float bias_val = bias_arr[i];
|
float bias_val = bias_arr[i];
|
||||||
|
|
||||||
//#ifdef NOT_USED
|
//#ifdef NOT_USED
|
||||||
|
// 8%
|
||||||
for (; k < K; k += 256) { // l.size*l.size*l.c - one filter size [27 - 144 - 9216]
|
for (; k < K; k += 256) { // l.size*l.size*l.c - one filter size [27 - 144 - 9216]
|
||||||
//ulonglong4 a_bit256 = *((ulonglong4 *)(A + (i*lda + k) / 8)); // weights
|
//ulonglong4 a_bit256 = *((ulonglong4 *)(A + (i*lda + k) / 8)); // weights
|
||||||
ulonglong4 a_bit256 = *((ulonglong4 *)(A_s + (local_i*lda + k) / 8)); // weights
|
ulonglong4 a_bit256 = *((ulonglong4 *)(A_s + (local_i*lda + k) / 8)); // weights
|
||||||
|
10
src/layer.c
10
src/layer.c
@ -35,6 +35,16 @@ void free_layer(layer l)
|
|||||||
if (l.weight_updates) free(l.weight_updates);
|
if (l.weight_updates) free(l.weight_updates);
|
||||||
if (l.align_bit_weights) free(l.align_bit_weights);
|
if (l.align_bit_weights) free(l.align_bit_weights);
|
||||||
if (l.mean_arr) free(l.mean_arr);
|
if (l.mean_arr) free(l.mean_arr);
|
||||||
|
#ifdef GPU
|
||||||
|
if (l.delta && l.delta_pinned) {
|
||||||
|
cudaFreeHost(l.delta);
|
||||||
|
l.delta = NULL;
|
||||||
|
}
|
||||||
|
if (l.output && l.output_pinned) {
|
||||||
|
cudaFreeHost(l.output);
|
||||||
|
l.output = NULL;
|
||||||
|
}
|
||||||
|
#endif // GPU
|
||||||
if (l.delta) free(l.delta);
|
if (l.delta) free(l.delta);
|
||||||
if (l.output) free(l.output);
|
if (l.output) free(l.output);
|
||||||
if (l.squared) free(l.squared);
|
if (l.squared) free(l.squared);
|
||||||
|
@ -856,6 +856,10 @@ void free_network(network net)
|
|||||||
if (gpu_index >= 0) cuda_free(net.workspace);
|
if (gpu_index >= 0) cuda_free(net.workspace);
|
||||||
else free(net.workspace);
|
else free(net.workspace);
|
||||||
if (net.input_state_gpu) cuda_free(net.input_state_gpu);
|
if (net.input_state_gpu) cuda_free(net.input_state_gpu);
|
||||||
|
if (net.input_pinned_cpu) { // CPU
|
||||||
|
if (net.input_pinned_cpu_flag) cudaFreeHost(net.input_pinned_cpu);
|
||||||
|
else free(net.input_pinned_cpu);
|
||||||
|
}
|
||||||
if (*net.input_gpu) cuda_free(*net.input_gpu);
|
if (*net.input_gpu) cuda_free(*net.input_gpu);
|
||||||
if (*net.truth_gpu) cuda_free(*net.truth_gpu);
|
if (*net.truth_gpu) cuda_free(*net.truth_gpu);
|
||||||
if (net.input_gpu) free(net.input_gpu);
|
if (net.input_gpu) free(net.input_gpu);
|
||||||
|
@ -87,6 +87,8 @@ void forward_network_gpu(network net, network_state state)
|
|||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
}
|
}
|
||||||
|
cudaStreamSynchronize(get_cuda_stream()); // sync CUDA-functions
|
||||||
|
//cudaStreamSynchronize(get_cuda_memcpy_stream()); // sync cudaMemcpyAsync()
|
||||||
//cudaDeviceSynchronize();
|
//cudaDeviceSynchronize();
|
||||||
//show_total_time();
|
//show_total_time();
|
||||||
}
|
}
|
||||||
@ -444,7 +446,8 @@ float *network_predict_gpu(network net, float *input)
|
|||||||
state.net = net;
|
state.net = net;
|
||||||
//state.input = cuda_make_array(input, size); // memory will be allocated in the parse_network_cfg_custom()
|
//state.input = cuda_make_array(input, size); // memory will be allocated in the parse_network_cfg_custom()
|
||||||
state.input = net.input_state_gpu;
|
state.input = net.input_state_gpu;
|
||||||
cuda_push_array(state.input, input, size);
|
memcpy(net.input_pinned_cpu, input, size * sizeof(float));
|
||||||
|
cuda_push_array(state.input, net.input_pinned_cpu, size);
|
||||||
state.truth = 0;
|
state.truth = 0;
|
||||||
state.train = 0;
|
state.train = 0;
|
||||||
state.delta = 0;
|
state.delta = 0;
|
||||||
|
@ -829,10 +829,14 @@ network parse_network_cfg_custom(char *filename, int batch)
|
|||||||
if(workspace_size){
|
if(workspace_size){
|
||||||
//printf("%ld\n", workspace_size);
|
//printf("%ld\n", workspace_size);
|
||||||
#ifdef GPU
|
#ifdef GPU
|
||||||
|
get_cuda_stream();
|
||||||
|
get_cuda_memcpy_stream();
|
||||||
if(gpu_index >= 0){
|
if(gpu_index >= 0){
|
||||||
net.workspace = cuda_make_array(0, workspace_size/sizeof(float) + 1);
|
net.workspace = cuda_make_array(0, workspace_size/sizeof(float) + 1);
|
||||||
int size = get_network_input_size(net) * net.batch;
|
int size = get_network_input_size(net) * net.batch;
|
||||||
net.input_state_gpu = cuda_make_array(0, size);
|
net.input_state_gpu = cuda_make_array(0, size);
|
||||||
|
if (cudaSuccess == cudaHostAlloc(&net.input_pinned_cpu, size*sizeof(float), cudaHostRegisterMapped)) net.input_pinned_cpu_flag = 1;
|
||||||
|
else net.input_pinned_cpu = calloc(size, sizeof(float));
|
||||||
|
|
||||||
// pre-allocate memory for inference on Tensor Cores (fp16)
|
// pre-allocate memory for inference on Tensor Cores (fp16)
|
||||||
if (net.cudnn_half) {
|
if (net.cudnn_half) {
|
||||||
|
@ -67,7 +67,7 @@ void resize_route_layer(route_layer *l, network *net)
|
|||||||
l->output_gpu = cuda_make_array(l->output, l->outputs*l->batch);
|
l->output_gpu = cuda_make_array(l->output, l->outputs*l->batch);
|
||||||
l->delta_gpu = cuda_make_array(l->delta, l->outputs*l->batch);
|
l->delta_gpu = cuda_make_array(l->delta, l->outputs*l->batch);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void forward_route_layer(const route_layer l, network_state state)
|
void forward_route_layer(const route_layer l, network_state state)
|
||||||
@ -110,7 +110,8 @@ void forward_route_layer_gpu(const route_layer l, network_state state)
|
|||||||
float *input = state.net.layers[index].output_gpu;
|
float *input = state.net.layers[index].output_gpu;
|
||||||
int input_size = l.input_sizes[i];
|
int input_size = l.input_sizes[i];
|
||||||
for(j = 0; j < l.batch; ++j){
|
for(j = 0; j < l.batch; ++j){
|
||||||
copy_ongpu(input_size, input + j*input_size, 1, l.output_gpu + offset + j*l.outputs, 1);
|
//copy_ongpu(input_size, input + j*input_size, 1, l.output_gpu + offset + j*l.outputs, 1);
|
||||||
|
simple_copy_ongpu(input_size, input + j*input_size, l.output_gpu + offset + j*l.outputs);
|
||||||
}
|
}
|
||||||
offset += input_size;
|
offset += input_size;
|
||||||
}
|
}
|
||||||
|
@ -53,6 +53,14 @@ layer make_yolo_layer(int batch, int w, int h, int n, int total, int *mask, int
|
|||||||
l.backward_gpu = backward_yolo_layer_gpu;
|
l.backward_gpu = backward_yolo_layer_gpu;
|
||||||
l.output_gpu = cuda_make_array(l.output, batch*l.outputs);
|
l.output_gpu = cuda_make_array(l.output, batch*l.outputs);
|
||||||
l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs);
|
l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs);
|
||||||
|
|
||||||
|
free(l.output);
|
||||||
|
if (cudaSuccess == cudaHostAlloc(&l.output, batch*l.outputs*sizeof(float), cudaHostRegisterMapped)) l.output_pinned = 1;
|
||||||
|
else l.output = calloc(batch*l.outputs, sizeof(float));
|
||||||
|
|
||||||
|
free(l.delta);
|
||||||
|
if (cudaSuccess == cudaHostAlloc(&l.delta, batch*l.outputs*sizeof(float), cudaHostRegisterMapped)) l.delta_pinned = 1;
|
||||||
|
else l.delta = calloc(batch*l.outputs, sizeof(float));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
fprintf(stderr, "yolo\n");
|
fprintf(stderr, "yolo\n");
|
||||||
@ -411,13 +419,14 @@ void forward_yolo_layer_gpu(const layer l, network_state state)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(!state.train || l.onlyforward){
|
if(!state.train || l.onlyforward){
|
||||||
cuda_pull_array(l.output_gpu, l.output, l.batch*l.outputs);
|
//cuda_pull_array(l.output_gpu, l.output, l.batch*l.outputs);
|
||||||
|
cuda_pull_array_async(l.output_gpu, l.output, l.batch*l.outputs);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
//cuda_pull_array(l.output_gpu, state.input, l.batch*l.inputs);
|
|
||||||
float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
|
float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
|
||||||
cuda_pull_array(l.output_gpu, in_cpu, l.batch*l.inputs);
|
cuda_pull_array(l.output_gpu, l.output, l.batch*l.outputs);
|
||||||
|
memcpy(in_cpu, l.output, l.batch*l.outputs*sizeof(float));
|
||||||
float *truth_cpu = 0;
|
float *truth_cpu = 0;
|
||||||
if (state.truth) {
|
if (state.truth) {
|
||||||
int num_truth = l.batch*l.truths;
|
int num_truth = l.batch*l.truths;
|
||||||
|
Reference in New Issue
Block a user