mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Use non-default stream for all CUDA-functions
This commit is contained in:
@ -163,7 +163,7 @@ __global__ void binary_gradient_array_kernel(float *x, float *dy, int n, int s,
|
|||||||
|
|
||||||
extern "C" void binary_gradient_array_gpu(float *x, float *dx, int n, int size, BINARY_ACTIVATION a, float *y)
|
extern "C" void binary_gradient_array_gpu(float *x, float *dx, int n, int size, BINARY_ACTIVATION a, float *y)
|
||||||
{
|
{
|
||||||
binary_gradient_array_kernel << <cuda_gridsize(n / 2), BLOCK >> >(x, dx, n / 2, size, a, y);
|
binary_gradient_array_kernel << <cuda_gridsize(n / 2), BLOCK, 0, get_cuda_stream() >> >(x, dx, n / 2, size, a, y);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
__global__ void binary_activate_array_kernel(float *x, int n, int s, BINARY_ACTIVATION a, float *y)
|
__global__ void binary_activate_array_kernel(float *x, int n, int s, BINARY_ACTIVATION a, float *y)
|
||||||
@ -178,7 +178,7 @@ __global__ void binary_activate_array_kernel(float *x, int n, int s, BINARY_ACTI
|
|||||||
|
|
||||||
extern "C" void binary_activate_array_gpu(float *x, int n, int size, BINARY_ACTIVATION a, float *y)
|
extern "C" void binary_activate_array_gpu(float *x, int n, int size, BINARY_ACTIVATION a, float *y)
|
||||||
{
|
{
|
||||||
binary_activate_array_kernel << <cuda_gridsize(n / 2), BLOCK >> >(x, n / 2, size, a, y);
|
binary_activate_array_kernel << <cuda_gridsize(n / 2), BLOCK, 0, get_cuda_stream() >> >(x, n / 2, size, a, y);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -231,6 +231,6 @@ extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a)
|
|||||||
|
|
||||||
extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta)
|
extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta)
|
||||||
{
|
{
|
||||||
gradient_array_kernel<<<cuda_gridsize(n), BLOCK>>>(x, n, a, delta);
|
gradient_array_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >>>(x, n, a, delta);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
@ -47,7 +47,7 @@ extern "C" void forward_avgpool_layer_gpu(avgpool_layer layer, network_state sta
|
|||||||
{
|
{
|
||||||
size_t n = layer.c*layer.batch;
|
size_t n = layer.c*layer.batch;
|
||||||
|
|
||||||
forward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.w, layer.h, layer.c, state.input, layer.output_gpu);
|
forward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >>>(n, layer.w, layer.h, layer.c, state.input, layer.output_gpu);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ extern "C" void backward_avgpool_layer_gpu(avgpool_layer layer, network_state st
|
|||||||
{
|
{
|
||||||
size_t n = layer.c*layer.batch;
|
size_t n = layer.c*layer.batch;
|
||||||
|
|
||||||
backward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.w, layer.h, layer.c, state.delta, layer.delta_gpu);
|
backward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >>>(n, layer.w, layer.h, layer.c, state.delta, layer.delta_gpu);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ __global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, in
|
|||||||
|
|
||||||
void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
|
void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
|
||||||
{
|
{
|
||||||
backward_scale_kernel<<<n, BLOCK>>>(x_norm, delta, batch, n, size, scale_updates);
|
backward_scale_kernel<<<n, BLOCK, 0, get_cuda_stream() >>>(x_norm, delta, batch, n, size, scale_updates);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,14 +129,14 @@ __global__ void dot_kernel(float *output, float scale, int batch, int n, int siz
|
|||||||
|
|
||||||
void dot_error_gpu(layer l)
|
void dot_error_gpu(layer l)
|
||||||
{
|
{
|
||||||
dot_kernel<<<cuda_gridsize(l.n*l.n), BLOCK>>>(l.output_gpu, l.dot, l.batch, l.n, l.out_w * l.out_h, l.delta_gpu);
|
dot_kernel<<<cuda_gridsize(l.n*l.n), BLOCK, 0, get_cuda_stream()>>>(l.output_gpu, l.dot, l.batch, l.n, l.out_w * l.out_h, l.delta_gpu);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
|
void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
|
||||||
{
|
{
|
||||||
backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size);
|
backward_bias_kernel<<<n, BLOCK, 0, get_cuda_stream() >>>(bias_updates, delta, batch, n, size);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,7 +153,7 @@ __global__ void adam_kernel(int N, float *x, float *m, float *v, float B1, float
|
|||||||
|
|
||||||
extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
|
extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
|
||||||
{
|
{
|
||||||
adam_kernel << <cuda_gridsize(n), BLOCK >> >(n, x, m, v, B1, B2, rate, eps, t);
|
adam_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> >(n, x, m, v, B1, B2, rate, eps, t);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -192,7 +192,7 @@ __global__ void normalize_delta_kernel(int N, float *x, float *mean, float *vari
|
|||||||
extern "C" void normalize_delta_gpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
|
extern "C" void normalize_delta_gpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
|
||||||
{
|
{
|
||||||
size_t N = batch*filters*spatial;
|
size_t N = batch*filters*spatial;
|
||||||
normalize_delta_kernel<<<cuda_gridsize(N), BLOCK>>>(N, x, mean, variance, mean_delta, variance_delta, batch, filters, spatial, delta);
|
normalize_delta_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, x, mean, variance, mean_delta, variance_delta, batch, filters, spatial, delta);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -297,19 +297,19 @@ __global__ void mean_delta_kernel(float *delta, float *variance, int batch, int
|
|||||||
|
|
||||||
extern "C" void mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
|
extern "C" void mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
|
||||||
{
|
{
|
||||||
mean_delta_kernel<<<cuda_gridsize(filters), BLOCK>>>(delta, variance, batch, filters, spatial, mean_delta);
|
mean_delta_kernel<<<cuda_gridsize(filters), BLOCK, 0, get_cuda_stream() >>>(delta, variance, batch, filters, spatial, mean_delta);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void fast_mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
|
extern "C" void fast_mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
|
||||||
{
|
{
|
||||||
fast_mean_delta_kernel<<<filters, BLOCK>>>(delta, variance, batch, filters, spatial, mean_delta);
|
fast_mean_delta_kernel<<<filters, BLOCK, 0, get_cuda_stream() >>>(delta, variance, batch, filters, spatial, mean_delta);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
|
extern "C" void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
|
||||||
{
|
{
|
||||||
fast_variance_delta_kernel<<<filters, BLOCK>>>(x, delta, mean, variance, batch, filters, spatial, variance_delta);
|
fast_variance_delta_kernel<<<filters, BLOCK, 0, get_cuda_stream() >>>(x, delta, mean, variance, batch, filters, spatial, variance_delta);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -532,13 +532,13 @@ extern "C" void fast_variance_gpu(float *x, float *mean, int batch, int filters,
|
|||||||
|
|
||||||
extern "C" void mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
|
extern "C" void mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
|
||||||
{
|
{
|
||||||
mean_kernel<<<cuda_gridsize(filters), BLOCK>>>(x, batch, filters, spatial, mean);
|
mean_kernel<<<cuda_gridsize(filters), BLOCK, 0, get_cuda_stream() >>>(x, batch, filters, spatial, mean);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
|
extern "C" void variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
|
||||||
{
|
{
|
||||||
variance_kernel<<<cuda_gridsize(filters), BLOCK>>>(x, mean, batch, filters, spatial, variance);
|
variance_kernel<<<cuda_gridsize(filters), BLOCK, 0, get_cuda_stream() >>>(x, mean, batch, filters, spatial, variance);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -573,7 +573,7 @@ extern "C" void simple_copy_ongpu(int size, float *src, float *dst)
|
|||||||
|
|
||||||
extern "C" void mul_ongpu(int N, float * X, int INCX, float * Y, int INCY)
|
extern "C" void mul_ongpu(int N, float * X, int INCX, float * Y, int INCY)
|
||||||
{
|
{
|
||||||
mul_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, INCX, Y, INCY);
|
mul_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, X, INCX, Y, INCY);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -616,7 +616,7 @@ extern "C" void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride
|
|||||||
|
|
||||||
extern "C" void mask_gpu_new_api(int N, float * X, float mask_num, float * mask, float val)
|
extern "C" void mask_gpu_new_api(int N, float * X, float mask_num, float * mask, float val)
|
||||||
{
|
{
|
||||||
mask_kernel_new_api <<<cuda_gridsize(N), BLOCK >>>(N, X, mask_num, mask, val);
|
mask_kernel_new_api <<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, X, mask_num, mask, val);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -628,13 +628,13 @@ extern "C" void mask_ongpu(int N, float * X, float mask_num, float * mask)
|
|||||||
|
|
||||||
extern "C" void const_ongpu(int N, float ALPHA, float * X, int INCX)
|
extern "C" void const_ongpu(int N, float ALPHA, float * X, int INCX)
|
||||||
{
|
{
|
||||||
const_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
|
const_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, ALPHA, X, INCX);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void constrain_ongpu(int N, float ALPHA, float * X, int INCX)
|
extern "C" void constrain_ongpu(int N, float ALPHA, float * X, int INCX)
|
||||||
{
|
{
|
||||||
constrain_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
|
constrain_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, ALPHA, X, INCX);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -647,7 +647,7 @@ extern "C" void scal_ongpu(int N, float ALPHA, float * X, int INCX)
|
|||||||
|
|
||||||
extern "C" void supp_ongpu(int N, float ALPHA, float * X, int INCX)
|
extern "C" void supp_ongpu(int N, float ALPHA, float * X, int INCX)
|
||||||
{
|
{
|
||||||
supp_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
|
supp_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, ALPHA, X, INCX);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -761,7 +761,7 @@ __global__ void smooth_l1_kernel(int n, float *pred, float *truth, float *delta,
|
|||||||
|
|
||||||
extern "C" void smooth_l1_gpu(int n, float *pred, float *truth, float *delta, float *error)
|
extern "C" void smooth_l1_gpu(int n, float *pred, float *truth, float *delta, float *error)
|
||||||
{
|
{
|
||||||
smooth_l1_kernel<<<cuda_gridsize(n), BLOCK>>>(n, pred, truth, delta, error);
|
smooth_l1_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >>>(n, pred, truth, delta, error);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -778,7 +778,7 @@ __global__ void softmax_x_ent_kernel(int n, float *pred, float *truth, float *de
|
|||||||
|
|
||||||
extern "C" void softmax_x_ent_gpu(int n, float *pred, float *truth, float *delta, float *error)
|
extern "C" void softmax_x_ent_gpu(int n, float *pred, float *truth, float *delta, float *error)
|
||||||
{
|
{
|
||||||
softmax_x_ent_kernel << <cuda_gridsize(n), BLOCK >> >(n, pred, truth, delta, error);
|
softmax_x_ent_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> >(n, pred, truth, delta, error);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -794,7 +794,7 @@ __global__ void l2_kernel(int n, float *pred, float *truth, float *delta, float
|
|||||||
|
|
||||||
extern "C" void l2_gpu(int n, float *pred, float *truth, float *delta, float *error)
|
extern "C" void l2_gpu(int n, float *pred, float *truth, float *delta, float *error)
|
||||||
{
|
{
|
||||||
l2_kernel<<<cuda_gridsize(n), BLOCK>>>(n, pred, truth, delta, error);
|
l2_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >>>(n, pred, truth, delta, error);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -810,7 +810,7 @@ __global__ void weighted_sum_kernel(int n, float *a, float *b, float *s, float *
|
|||||||
|
|
||||||
extern "C" void weighted_sum_gpu(float *a, float *b, float *s, int num, float *c)
|
extern "C" void weighted_sum_gpu(float *a, float *b, float *s, int num, float *c)
|
||||||
{
|
{
|
||||||
weighted_sum_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, s, c);
|
weighted_sum_kernel<<<cuda_gridsize(num), BLOCK, 0, get_cuda_stream() >>>(num, a, b, s, c);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -826,7 +826,7 @@ __global__ void weighted_delta_kernel(int n, float *a, float *b, float *s, float
|
|||||||
|
|
||||||
extern "C" void weighted_delta_gpu(float *a, float *b, float *s, float *da, float *db, float *ds, int num, float *dc)
|
extern "C" void weighted_delta_gpu(float *a, float *b, float *s, float *da, float *db, float *ds, int num, float *dc)
|
||||||
{
|
{
|
||||||
weighted_delta_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, s, da, db, ds, dc);
|
weighted_delta_kernel<<<cuda_gridsize(num), BLOCK, 0, get_cuda_stream() >>>(num, a, b, s, da, db, ds, dc);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -840,7 +840,7 @@ __global__ void mult_add_into_kernel(int n, float *a, float *b, float *c)
|
|||||||
|
|
||||||
extern "C" void mult_add_into_gpu(int num, float *a, float *b, float *c)
|
extern "C" void mult_add_into_gpu(int num, float *a, float *b, float *c)
|
||||||
{
|
{
|
||||||
mult_add_into_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, c);
|
mult_add_into_kernel<<<cuda_gridsize(num), BLOCK, 0, get_cuda_stream() >>>(num, a, b, c);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -909,7 +909,7 @@ __global__ void softmax_kernel_new_api(float *input, int n, int batch, int batch
|
|||||||
|
|
||||||
extern "C" void softmax_gpu_new_api(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output)
|
extern "C" void softmax_gpu_new_api(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output)
|
||||||
{
|
{
|
||||||
softmax_kernel_new_api << <cuda_gridsize(batch*groups), BLOCK >> >(input, n, batch, batch_offset, groups, group_offset, stride, temp, output);
|
softmax_kernel_new_api << <cuda_gridsize(batch*groups), BLOCK, 0, get_cuda_stream() >> >(input, n, batch, batch_offset, groups, group_offset, stride, temp, output);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -971,7 +971,7 @@ extern "C" void softmax_tree_gpu(float *input, int spatial, int batch, int strid
|
|||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
int num = spatial*batch*hier.groups;
|
int num = spatial*batch*hier.groups;
|
||||||
softmax_tree_kernel <<<cuda_gridsize(num), BLOCK >>>(input, spatial, batch, stride, temp, output, hier.groups, tree_groups_size, tree_groups_offset);
|
softmax_tree_kernel <<<cuda_gridsize(num), BLOCK, 0, get_cuda_stream() >>>(input, spatial, batch, stride, temp, output, hier.groups, tree_groups_size, tree_groups_offset);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
cuda_free((float *)tree_groups_size);
|
cuda_free((float *)tree_groups_size);
|
||||||
cuda_free((float *)tree_groups_offset);
|
cuda_free((float *)tree_groups_offset);
|
||||||
|
@ -50,7 +50,7 @@ void col2im_ongpu(float *data_col,
|
|||||||
int width_col = (width + 2 * pad - ksize) / stride + 1;
|
int width_col = (width + 2 * pad - ksize) / stride + 1;
|
||||||
int num_kernels = channels * height * width;
|
int num_kernels = channels * height * width;
|
||||||
col2im_gpu_kernel<<<(num_kernels+BLOCK-1)/BLOCK,
|
col2im_gpu_kernel<<<(num_kernels+BLOCK-1)/BLOCK,
|
||||||
BLOCK>>>(
|
BLOCK, 0, get_cuda_stream() >>>(
|
||||||
num_kernels, data_col, height, width, ksize, pad,
|
num_kernels, data_col, height, width, ksize, pad,
|
||||||
stride, height_col,
|
stride, height_col,
|
||||||
width_col, data_im);
|
width_col, data_im);
|
||||||
|
@ -47,7 +47,7 @@ __global__ void binarize_input_kernel(float *input, int n, int size, float *bina
|
|||||||
|
|
||||||
void binarize_input_gpu(float *input, int n, int size, float *binary)
|
void binarize_input_gpu(float *input, int n, int size, float *binary)
|
||||||
{
|
{
|
||||||
binarize_input_kernel<<<cuda_gridsize(size), BLOCK>>>(input, n, size, binary);
|
binarize_input_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >>>(input, n, size, binary);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,8 +114,8 @@ void fast_binarize_weights_gpu(float *weights, int n, int size, float *binary, f
|
|||||||
size_t gridsize = n * size;
|
size_t gridsize = n * size;
|
||||||
const int num_blocks = gridsize / BLOCK + 1;
|
const int num_blocks = gridsize / BLOCK + 1;
|
||||||
|
|
||||||
set_zero_kernel << <(n/BLOCK + 1), BLOCK >> > (mean_arr_gpu, n);
|
set_zero_kernel << <(n/BLOCK + 1), BLOCK, 0, get_cuda_stream() >> > (mean_arr_gpu, n);
|
||||||
reduce_kernel << <num_blocks, BLOCK >> > (weights, n, size, mean_arr_gpu);
|
reduce_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (weights, n, size, mean_arr_gpu);
|
||||||
binarize_weights_mean_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (weights, n, size, binary, mean_arr_gpu);
|
binarize_weights_mean_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (weights, n, size, binary, mean_arr_gpu);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
@ -18,7 +18,7 @@ __device__ float get_pixel_kernel(float *image, int w, int h, int x, int y, int
|
|||||||
__device__ float3 rgb_to_hsv_kernel(float3 rgb)
|
__device__ float3 rgb_to_hsv_kernel(float3 rgb)
|
||||||
{
|
{
|
||||||
float r = rgb.x;
|
float r = rgb.x;
|
||||||
float g = rgb.y;
|
float g = rgb.y;
|
||||||
float b = rgb.z;
|
float b = rgb.z;
|
||||||
|
|
||||||
float h, s, v;
|
float h, s, v;
|
||||||
@ -46,7 +46,7 @@ __device__ float3 rgb_to_hsv_kernel(float3 rgb)
|
|||||||
__device__ float3 hsv_to_rgb_kernel(float3 hsv)
|
__device__ float3 hsv_to_rgb_kernel(float3 hsv)
|
||||||
{
|
{
|
||||||
float h = hsv.x;
|
float h = hsv.x;
|
||||||
float s = hsv.y;
|
float s = hsv.y;
|
||||||
float v = hsv.z;
|
float v = hsv.z;
|
||||||
|
|
||||||
float r, g, b;
|
float r, g, b;
|
||||||
@ -88,8 +88,8 @@ __device__ float bilinear_interpolate_kernel(float *image, int w, int h, float x
|
|||||||
float dx = x - ix;
|
float dx = x - ix;
|
||||||
float dy = y - iy;
|
float dy = y - iy;
|
||||||
|
|
||||||
float val = (1-dy) * (1-dx) * get_pixel_kernel(image, w, h, ix, iy, c) +
|
float val = (1-dy) * (1-dx) * get_pixel_kernel(image, w, h, ix, iy, c) +
|
||||||
dy * (1-dx) * get_pixel_kernel(image, w, h, ix, iy+1, c) +
|
dy * (1-dx) * get_pixel_kernel(image, w, h, ix, iy+1, c) +
|
||||||
(1-dy) * dx * get_pixel_kernel(image, w, h, ix+1, iy, c) +
|
(1-dy) * dx * get_pixel_kernel(image, w, h, ix+1, iy, c) +
|
||||||
dy * dx * get_pixel_kernel(image, w, h, ix+1, iy+1, c);
|
dy * dx * get_pixel_kernel(image, w, h, ix+1, iy+1, c);
|
||||||
return val;
|
return val;
|
||||||
@ -171,7 +171,7 @@ __global__ void forward_crop_layer_kernel(float *input, float *rand, int size, i
|
|||||||
|
|
||||||
input += w*h*c*b;
|
input += w*h*c*b;
|
||||||
|
|
||||||
float x = (flip) ? w - dw - j - 1 : j + dw;
|
float x = (flip) ? w - dw - j - 1 : j + dw;
|
||||||
float y = i + dh;
|
float y = i + dh;
|
||||||
|
|
||||||
float rx = cos(angle)*(x-cx) - sin(angle)*(y-cy) + cx;
|
float rx = cos(angle)*(x-cx) - sin(angle)*(y-cy) + cx;
|
||||||
@ -195,12 +195,12 @@ extern "C" void forward_crop_layer_gpu(crop_layer layer, network_state state)
|
|||||||
|
|
||||||
int size = layer.batch * layer.w * layer.h;
|
int size = layer.batch * layer.w * layer.h;
|
||||||
|
|
||||||
levels_image_kernel<<<cuda_gridsize(size), BLOCK>>>(state.input, layer.rand_gpu, layer.batch, layer.w, layer.h, state.train, layer.saturation, layer.exposure, translate, scale, layer.shift);
|
levels_image_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >>>(state.input, layer.rand_gpu, layer.batch, layer.w, layer.h, state.train, layer.saturation, layer.exposure, translate, scale, layer.shift);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
|
|
||||||
size = layer.batch*layer.c*layer.out_w*layer.out_h;
|
size = layer.batch*layer.c*layer.out_w*layer.out_h;
|
||||||
|
|
||||||
forward_crop_layer_kernel<<<cuda_gridsize(size), BLOCK>>>(state.input, layer.rand_gpu, size, layer.c, layer.h, layer.w, layer.out_h, layer.out_w, state.train, layer.flip, radians, layer.output_gpu);
|
forward_crop_layer_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >>>(state.input, layer.rand_gpu, size, layer.c, layer.h, layer.w, layer.out_h, layer.out_w, state.train, layer.flip, radians, layer.output_gpu);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -215,7 +215,7 @@ extern "C" void forward_crop_layer_gpu(crop_layer layer, network_state state)
|
|||||||
scale_image(im2, 1/scale);
|
scale_image(im2, 1/scale);
|
||||||
translate_image(im3, -translate);
|
translate_image(im3, -translate);
|
||||||
scale_image(im3, 1/scale);
|
scale_image(im3, 1/scale);
|
||||||
|
|
||||||
show_image(im, "cropped");
|
show_image(im, "cropped");
|
||||||
show_image(im2, "cropped2");
|
show_image(im2, "cropped2");
|
||||||
show_image(im3, "cropped3");
|
show_image(im3, "cropped3");
|
||||||
|
@ -192,7 +192,7 @@ int *cuda_make_int_array_new_api(int *x, size_t n)
|
|||||||
cudaError_t status = cudaMalloc((void **)&x_gpu, size);
|
cudaError_t status = cudaMalloc((void **)&x_gpu, size);
|
||||||
check_error(status);
|
check_error(status);
|
||||||
if (x) {
|
if (x) {
|
||||||
status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
|
status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice, get_cuda_stream());
|
||||||
check_error(status);
|
check_error(status);
|
||||||
}
|
}
|
||||||
if (!x_gpu) error("Cuda malloc failed\n");
|
if (!x_gpu) error("Cuda malloc failed\n");
|
||||||
|
@ -27,7 +27,7 @@ void forward_dropout_layer_gpu(dropout_layer layer, network_state state)
|
|||||||
cuda_push_array(layer.rand_gpu, layer.rand, size);
|
cuda_push_array(layer.rand_gpu, layer.rand, size);
|
||||||
*/
|
*/
|
||||||
|
|
||||||
yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK>>>(state.input, size, layer.rand_gpu, layer.probability, layer.scale);
|
yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >>>(state.input, size, layer.rand_gpu, layer.probability, layer.scale);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -36,6 +36,6 @@ void backward_dropout_layer_gpu(dropout_layer layer, network_state state)
|
|||||||
if(!state.delta) return;
|
if(!state.delta) return;
|
||||||
int size = layer.inputs*layer.batch;
|
int size = layer.inputs*layer.batch;
|
||||||
|
|
||||||
yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK>>>(state.delta, size, layer.rand_gpu, layer.probability, layer.scale);
|
yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >>>(state.delta, size, layer.rand_gpu, layer.probability, layer.scale);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
@ -125,7 +125,7 @@ extern "C" void backward_maxpool_layer_gpu(maxpool_layer layer, network_state st
|
|||||||
{
|
{
|
||||||
size_t n = layer.h*layer.w*layer.c*layer.batch;
|
size_t n = layer.h*layer.w*layer.c*layer.batch;
|
||||||
|
|
||||||
backward_maxpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.h, layer.w, layer.c, layer.stride, layer.size, layer.pad, layer.delta_gpu, state.delta, layer.indexes_gpu);
|
backward_maxpool_layer_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >>>(n, layer.h, layer.w, layer.c, layer.stride, layer.size, layer.pad, layer.delta_gpu, state.delta, layer.indexes_gpu);
|
||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user