mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Added CUDA-streams to Darknet-Yolo forward inference
This commit is contained in:
@ -115,14 +115,14 @@
|
||||
<FunctionLevelLinking>true</FunctionLevelLinking>
|
||||
<IntrinsicFunctions>true</IntrinsicFunctions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<AdditionalIncludeDirectories>C:\opencv_3.0\opencv\build\include</AdditionalIncludeDirectories>
|
||||
<AdditionalIncludeDirectories>C:\opencv_source\opencv\bin\install\include</AdditionalIncludeDirectories>
|
||||
<PreprocessorDefinitions>_CRT_SECURE_NO_WARNINGS;_MBCS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<ExceptionHandling>Async</ExceptionHandling>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalLibraryDirectories>C:\opencv_3.0\opencv\build\x64\vc14\lib;C:\opencv_2.4.13\opencv\build\x64\vc12\lib</AdditionalLibraryDirectories>
|
||||
<AdditionalLibraryDirectories>C:\opencv_source\opencv\bin\install\x64\vc14\lib;C:\opencv_3.0\opencv\build\x64\vc14\lib;C:\opencv_2.4.13\opencv\build\x64\vc12\lib</AdditionalLibraryDirectories>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
|
@ -154,7 +154,7 @@ __global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delt
|
||||
|
||||
extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a)
|
||||
{
|
||||
activate_array_kernel<<<cuda_gridsize(n), BLOCK>>>(x, n, a);
|
||||
activate_array_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream()>>>(x, n, a);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
@ -23,7 +23,7 @@ void scale_bias_gpu(float *output, float *biases, int batch, int n, int size)
|
||||
dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
|
||||
dim3 dimBlock(BLOCK, 1, 1);
|
||||
|
||||
scale_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
|
||||
scale_bias_kernel<<<dimGrid, dimBlock, 0, get_cuda_stream()>>>(output, biases, n, size);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
@ -67,7 +67,7 @@ void add_bias_gpu(float *output, float *biases, int batch, int n, int size)
|
||||
dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
|
||||
dim3 dimBlock(BLOCK, 1, 1);
|
||||
|
||||
add_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
|
||||
add_bias_kernel<<<dimGrid, dimBlock, 0, get_cuda_stream()>>>(output, biases, n, size);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
@ -427,7 +427,7 @@ __global__ void mul_kernel(int N, float *X, int INCX, float *Y, int INCY)
|
||||
extern "C" void normalize_gpu(float *x, float *mean, float *variance, int batch, int filters, int spatial)
|
||||
{
|
||||
size_t N = batch*filters*spatial;
|
||||
normalize_kernel<<<cuda_gridsize(N), BLOCK>>>(N, x, mean, variance, batch, filters, spatial);
|
||||
normalize_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, x, mean, variance, batch, filters, spatial);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
@ -490,13 +490,13 @@ __global__ void fast_variance_kernel(float *x, float *mean, int batch, int filt
|
||||
|
||||
extern "C" void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
|
||||
{
|
||||
fast_mean_kernel<<<filters, BLOCK>>>(x, batch, filters, spatial, mean);
|
||||
fast_mean_kernel<<<filters, BLOCK, 0, get_cuda_stream()>>>(x, batch, filters, spatial, mean);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
|
||||
{
|
||||
fast_variance_kernel<<<filters, BLOCK>>>(x, mean, batch, filters, spatial, variance);
|
||||
fast_variance_kernel<<<filters, BLOCK, 0, get_cuda_stream() >>>(x, mean, batch, filters, spatial, variance);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
@ -520,13 +520,13 @@ extern "C" void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, i
|
||||
|
||||
extern "C" void pow_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY)
|
||||
{
|
||||
pow_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX, Y, INCY);
|
||||
pow_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, ALPHA, X, INCX, Y, INCY);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
|
||||
{
|
||||
axpy_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, OFFX, INCX, Y, OFFY, INCY);
|
||||
axpy_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, ALPHA, X, OFFX, INCX, Y, OFFY, INCY);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
@ -543,7 +543,7 @@ extern "C" void mul_ongpu(int N, float * X, int INCX, float * Y, int INCY)
|
||||
|
||||
extern "C" void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
|
||||
{
|
||||
copy_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, OFFX, INCX, Y, OFFY, INCY);
|
||||
copy_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, X, OFFX, INCX, Y, OFFY, INCY);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
@ -567,20 +567,20 @@ __global__ void flatten_kernel(int N, float *x, int spatial, int layers, int bat
|
||||
extern "C" void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out)
|
||||
{
|
||||
int size = spatial*batch*layers;
|
||||
flatten_kernel<<<cuda_gridsize(size), BLOCK>>>(size, x, spatial, layers, batch, forward, out);
|
||||
flatten_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream()>>>(size, x, spatial, layers, batch, forward, out);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
|
||||
{
|
||||
int size = w*h*c*batch;
|
||||
reorg_kernel<<<cuda_gridsize(size), BLOCK>>>(size, x, w, h, c, batch, stride, forward, out);
|
||||
reorg_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream()>>>(size, x, w, h, c, batch, stride, forward, out);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void mask_ongpu(int N, float * X, float mask_num, float * mask)
|
||||
{
|
||||
mask_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, mask_num, mask);
|
||||
mask_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, X, mask_num, mask);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
@ -599,7 +599,7 @@ extern "C" void constrain_ongpu(int N, float ALPHA, float * X, int INCX)
|
||||
|
||||
extern "C" void scal_ongpu(int N, float ALPHA, float * X, int INCX)
|
||||
{
|
||||
scal_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
|
||||
scal_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, ALPHA, X, INCX);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
@ -611,7 +611,7 @@ extern "C" void supp_ongpu(int N, float ALPHA, float * X, int INCX)
|
||||
|
||||
extern "C" void fill_ongpu(int N, float ALPHA, float * X, int INCX)
|
||||
{
|
||||
fill_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
|
||||
fill_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, ALPHA, X, INCX);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
@ -766,6 +766,6 @@ extern "C" void softmax_gpu(float *input, int n, int offset, int groups, float t
|
||||
{
|
||||
int inputs = n;
|
||||
int batch = groups;
|
||||
softmax_kernel<<<cuda_gridsize(batch), BLOCK>>>(inputs, offset, batch, input, temp, output);
|
||||
softmax_kernel<<<cuda_gridsize(batch), BLOCK, 0, get_cuda_stream()>>>(inputs, offset, batch, input, temp, output);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
25
src/cuda.c
25
src/cuda.c
@ -61,6 +61,19 @@ dim3 cuda_gridsize(size_t n){
|
||||
return d;
|
||||
}
|
||||
|
||||
static cudaStream_t streamsArray[16]; // cudaStreamSynchronize( get_cuda_stream() );
|
||||
static int streamInit[16] = { 0 };
|
||||
|
||||
cudaStream_t get_cuda_stream() {
|
||||
int i = cuda_get_device();
|
||||
if (!streamInit[i]) {
|
||||
cudaStreamCreate(&streamsArray[i]);
|
||||
streamInit[i] = 1;
|
||||
}
|
||||
return streamsArray[i];
|
||||
}
|
||||
|
||||
|
||||
#ifdef CUDNN
|
||||
cudnnHandle_t cudnn_handle()
|
||||
{
|
||||
@ -70,6 +83,7 @@ cudnnHandle_t cudnn_handle()
|
||||
if(!init[i]) {
|
||||
cudnnCreate(&handle[i]);
|
||||
init[i] = 1;
|
||||
cudnnStatus_t status = cudnnSetStream(handle[i], get_cuda_stream());
|
||||
}
|
||||
return handle[i];
|
||||
}
|
||||
@ -94,7 +108,8 @@ float *cuda_make_array(float *x, size_t n)
|
||||
cudaError_t status = cudaMalloc((void **)&x_gpu, size);
|
||||
check_error(status);
|
||||
if(x){
|
||||
status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
|
||||
//status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
|
||||
status = cudaMemcpyAsync(x_gpu, x, size, cudaMemcpyHostToDevice, get_cuda_stream());
|
||||
check_error(status);
|
||||
}
|
||||
if(!x_gpu) error("Cuda malloc failed\n");
|
||||
@ -139,6 +154,7 @@ int *cuda_make_int_array(size_t n)
|
||||
|
||||
void cuda_free(float *x_gpu)
|
||||
{
|
||||
//cudaStreamSynchronize(get_cuda_stream());
|
||||
cudaError_t status = cudaFree(x_gpu);
|
||||
check_error(status);
|
||||
}
|
||||
@ -146,15 +162,18 @@ void cuda_free(float *x_gpu)
|
||||
void cuda_push_array(float *x_gpu, float *x, size_t n)
|
||||
{
|
||||
size_t size = sizeof(float)*n;
|
||||
cudaError_t status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
|
||||
//cudaError_t status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
|
||||
cudaError_t status = cudaMemcpyAsync(x_gpu, x, size, cudaMemcpyHostToDevice, get_cuda_stream());
|
||||
check_error(status);
|
||||
}
|
||||
|
||||
void cuda_pull_array(float *x_gpu, float *x, size_t n)
|
||||
{
|
||||
size_t size = sizeof(float)*n;
|
||||
cudaError_t status = cudaMemcpy(x, x_gpu, size, cudaMemcpyDeviceToHost);
|
||||
//cudaError_t status = cudaMemcpy(x, x_gpu, size, cudaMemcpyDeviceToHost);
|
||||
cudaError_t status = cudaMemcpyAsync(x, x_gpu, size, cudaMemcpyDeviceToHost, get_cuda_stream());
|
||||
check_error(status);
|
||||
cudaStreamSynchronize(get_cuda_stream());
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@ -30,6 +30,7 @@ void cuda_free(float *x_gpu);
|
||||
void cuda_random(float *x_gpu, size_t n);
|
||||
float cuda_compare(float *x_gpu, float *x, size_t n, char *s);
|
||||
dim3 cuda_gridsize(size_t n);
|
||||
cudaStream_t get_cuda_stream();
|
||||
|
||||
#ifdef CUDNN
|
||||
cudnnHandle_t cudnn_handle();
|
||||
|
@ -177,6 +177,7 @@ void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA,
|
||||
float *C_gpu, int ldc)
|
||||
{
|
||||
cublasHandle_t handle = blas_handle();
|
||||
cudaError_t stream_status = cublasSetStream(handle, get_cuda_stream());
|
||||
cudaError_t status = cublasSgemm(handle, (TB ? CUBLAS_OP_T : CUBLAS_OP_N),
|
||||
(TA ? CUBLAS_OP_T : CUBLAS_OP_N), N, M, K, &ALPHA, B_gpu, ldb, A_gpu, lda, &BETA, C_gpu, ldc);
|
||||
check_error(status);
|
||||
|
@ -54,7 +54,7 @@ void im2col_ongpu(float *im,
|
||||
int width_col = (width + 2 * pad - ksize) / stride + 1;
|
||||
int num_kernels = channels * height_col * width_col;
|
||||
im2col_gpu_kernel<<<(num_kernels+BLOCK-1)/BLOCK,
|
||||
BLOCK>>>(
|
||||
BLOCK, 0, get_cuda_stream()>>>(
|
||||
num_kernels, im, height, width, ksize, pad,
|
||||
stride, height_col,
|
||||
width_col, data_col);
|
||||
|
@ -92,7 +92,7 @@ extern "C" void forward_maxpool_layer_gpu(maxpool_layer layer, network_state sta
|
||||
|
||||
size_t n = h*w*c*layer.batch;
|
||||
|
||||
forward_maxpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.h, layer.w, layer.c, layer.stride, layer.size, layer.pad, state.input, layer.output_gpu, layer.indexes_gpu);
|
||||
forward_maxpool_layer_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream()>>>(n, layer.h, layer.w, layer.c, layer.stride, layer.size, layer.pad, state.input, layer.output_gpu, layer.indexes_gpu);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
@ -409,6 +409,7 @@ void forward_region_layer_gpu(const region_layer l, network_state state)
|
||||
cuda_pull_array(state.truth, truth_cpu, num_truth);
|
||||
}
|
||||
cuda_pull_array(l.output_gpu, in_cpu, l.batch*l.inputs);
|
||||
cudaStreamSynchronize(get_cuda_stream());
|
||||
network_state cpu_state = state;
|
||||
cpu_state.train = state.train;
|
||||
cpu_state.truth = truth_cpu;
|
||||
|
@ -169,8 +169,8 @@ int main(int argc, char *argv[])
|
||||
//if (x > 10) return;
|
||||
if (result_vec.size() == 0) return;
|
||||
bbox_t i = result_vec[0];
|
||||
//cv::Rect r(i.x, i.y, i.w, i.h);
|
||||
cv::Rect r(i.x + (i.w-31)/2, i.y + (i.h - 31)/2, 31, 31);
|
||||
cv::Rect r(i.x, i.y, i.w, i.h);
|
||||
//cv::Rect r(i.x + (i.w-31)/2, i.y + (i.h - 31)/2, 31, 31);
|
||||
cv::Rect img_rect(cv::Point2i(0, 0), src_frame.size());
|
||||
cv::Rect rect_roi = r & img_rect;
|
||||
if (rect_roi.width < 1 || rect_roi.height < 1) return;
|
||||
@ -188,16 +188,25 @@ int main(int argc, char *argv[])
|
||||
|
||||
// track optical flow
|
||||
if (track_optflow_queue.size() > 0) {
|
||||
//show_flow = track_optflow_queue.front().clone();
|
||||
//draw_boxes(show_flow, result_vec, obj_names, 3, current_det_fps, current_cap_fps);
|
||||
|
||||
std::queue<cv::Mat> new_track_optflow_queue;
|
||||
std::cout << "\n !!!! all = " << track_optflow_queue.size() << ", cur = " << passed_flow_frames << std::endl;
|
||||
//draw_boxes(track_optflow_queue.front().clone(), result_vec, obj_names, 3, current_det_fps, current_cap_fps);
|
||||
//cv::waitKey(10);
|
||||
//std::cout << "\n !!!! all = " << track_optflow_queue.size() << ", cur = " << passed_flow_frames << std::endl;
|
||||
if (result_vec.size() > 0) {
|
||||
draw_boxes(track_optflow_queue.front().clone(), result_vec, obj_names, 3, current_det_fps, current_cap_fps);
|
||||
std::cout << "\n frame_size = " << track_optflow_queue.size() << std::endl;
|
||||
cv::waitKey(1000);
|
||||
}
|
||||
tracker_flow.update_tracking_flow(track_optflow_queue.front());
|
||||
lambda(show_flow, track_optflow_queue.front(), result_vec);
|
||||
track_optflow_queue.pop();
|
||||
while(track_optflow_queue.size() > 0) {
|
||||
//draw_boxes(track_optflow_queue.front().clone(), result_vec, obj_names, 3, current_det_fps, current_cap_fps);
|
||||
//cv::waitKey(10);
|
||||
if (result_vec.size() > 0) {
|
||||
draw_boxes(track_optflow_queue.front().clone(), result_vec, obj_names, 3, current_det_fps, current_cap_fps);
|
||||
std::cout << "\n frame_size = " << track_optflow_queue.size() << std::endl;
|
||||
cv::waitKey(1000);
|
||||
}
|
||||
result_vec = tracker_flow.tracking_flow(track_optflow_queue.front(), result_vec);
|
||||
if (track_optflow_queue.size() <= passed_flow_frames && new_track_optflow_queue.size() == 0)
|
||||
new_track_optflow_queue = track_optflow_queue;
|
||||
@ -207,10 +216,13 @@ int main(int argc, char *argv[])
|
||||
track_optflow_queue = new_track_optflow_queue;
|
||||
new_track_optflow_queue.swap(std::queue<cv::Mat>());
|
||||
passed_flow_frames = 0;
|
||||
std::cout << "\n !!!! now = " << track_optflow_queue.size() << ", cur = " << passed_flow_frames << std::endl;
|
||||
//std::cout << "\n !!!! now = " << track_optflow_queue.size() << ", cur = " << passed_flow_frames << std::endl;
|
||||
|
||||
cv::imshow("flow", show_flow);
|
||||
cv::waitKey(3);
|
||||
//if (result_vec.size() > 0) {
|
||||
// cv::waitKey(1000);
|
||||
//}
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -222,7 +234,8 @@ int main(int argc, char *argv[])
|
||||
consumed = true;
|
||||
while (current_image.use_count() > 0) {
|
||||
auto result = detector.detect_resized(*current_image, frame_size, 0.24, false); // true
|
||||
Sleep(500);
|
||||
//Sleep(200);
|
||||
Sleep(50);
|
||||
++fps_det_counter;
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
thread_result_vec = result;
|
||||
|
Reference in New Issue
Block a user