mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Activation improvement, more robust timer.
This commit is contained in:
@ -47,7 +47,7 @@ __device__ float stair_activate_kernel(float x)
|
||||
if (n%2 == 0) return floor(x/2.);
|
||||
else return (x - n) + floor(x/2.);
|
||||
}
|
||||
|
||||
|
||||
|
||||
__device__ float hardtan_gradient_kernel(float x)
|
||||
{
|
||||
@ -146,19 +146,29 @@ __global__ void activate_array_kernel(float *x, int n, ACTIVATION a)
|
||||
if(i < n) x[i] = activate_kernel(x[i], a);
|
||||
}
|
||||
|
||||
__global__ void activate_array_leaky_kernel(float *x, int n)
|
||||
{
|
||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
if (index < n) {
|
||||
float val = x[index];
|
||||
x[index] = (val > 0) ? val : val / 10;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delta)
|
||||
{
|
||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if(i < n) delta[i] *= gradient_kernel(x[i], a);
|
||||
}
|
||||
|
||||
extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a)
|
||||
extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a)
|
||||
{
|
||||
activate_array_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream()>>>(x, n, a);
|
||||
if(a == LEAKY) activate_array_leaky_kernel << <(n / BLOCK + 1), BLOCK, 0, get_cuda_stream() >> >(x, n);
|
||||
else activate_array_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream()>>>(x, n, a);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta)
|
||||
extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta)
|
||||
{
|
||||
gradient_array_kernel<<<cuda_gridsize(n), BLOCK>>>(x, n, a, delta);
|
||||
check_error(cudaPeekAtLastError());
|
||||
|
@ -17,6 +17,16 @@ extern "C" {
|
||||
#include "cuda.h"
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
double get_time_point();
|
||||
void start_timer();
|
||||
void stop_timer();
|
||||
double get_time();
|
||||
void stop_timer_and_show();
|
||||
void stop_timer_and_show_name(char *name);
|
||||
void show_total_time();
|
||||
}
|
||||
|
||||
__global__ void binarize_kernel(float *x, int n, float *binary)
|
||||
{
|
||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
@ -146,25 +156,33 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
//cudaDeviceSynchronize();
|
||||
|
||||
int i = 0;
|
||||
if (l.stride == 1 && l.c >= 256 && l.w > 13 && l.size > 1 && 0) // disabled
|
||||
// if (l.stride == 1 && l.c >= 256 && l.size > 1)
|
||||
if (l.stride == 1 && l.c >= 1024 && l.size > 1 && 0)// && l.w >= 13) // disabled
|
||||
{
|
||||
// stride=1 only
|
||||
//start_timer();
|
||||
im2col_align_bin_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace, l.bit_align);
|
||||
//cudaDeviceSynchronize();
|
||||
//stop_timer_and_show_name("im2col_align_bin_ongpu");
|
||||
}
|
||||
else
|
||||
{
|
||||
//start_timer();
|
||||
im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.align_workspace_gpu, l.bit_align);
|
||||
//cudaDeviceSynchronize();
|
||||
//stop_timer_and_show_name("im2col_align_ongpu");
|
||||
//getchar();
|
||||
|
||||
// should be optimized
|
||||
//start_timer();
|
||||
float_to_bit_gpu(l.align_workspace_gpu, (unsigned char *)state.workspace, l.align_workspace_size);
|
||||
//cudaDeviceSynchronize();
|
||||
//stop_timer_and_show_name("float_to_bit_gpu");
|
||||
}
|
||||
|
||||
//start_timer();
|
||||
transpose_bin_gpu((unsigned char *)state.workspace, (unsigned char *)l.transposed_align_workspace_gpu, k, n, l.bit_align, new_ldb, 8);
|
||||
//cudaDeviceSynchronize();
|
||||
//stop_timer_and_show_name("transpose_bin_gpu");
|
||||
|
||||
// should be optimized
|
||||
//if(0) {//if (k > 1000) { // sequentially input-shared - BAD
|
||||
@ -172,9 +190,12 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
// (unsigned char *)l.align_bit_weights_gpu, new_ldb, (unsigned char *)l.transposed_align_workspace_gpu, new_ldb, l.output_gpu, n, l.mean_arr_gpu);
|
||||
//}
|
||||
//else { // coalescing & weights-shared-memory - GOOD
|
||||
//start_timer();
|
||||
gemm_nn_custom_bin_mean_transposed_gpu(m, n, k,
|
||||
(unsigned char *)l.align_bit_weights_gpu, new_ldb, (unsigned char *)l.transposed_align_workspace_gpu,
|
||||
new_ldb, l.output_gpu, n, l.mean_arr_gpu, l.biases_gpu);
|
||||
//cudaDeviceSynchronize();
|
||||
//stop_timer_and_show_name("gemm_nn_custom_bin_mean_transposed_gpu");
|
||||
//}
|
||||
//cudaDeviceSynchronize();
|
||||
//check_error(status);
|
||||
@ -325,6 +346,8 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||
|
||||
|
||||
#else
|
||||
fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
|
||||
|
||||
int i;
|
||||
int m = l.n;
|
||||
int k = l.size*l.size*l.c;
|
||||
|
@ -25,6 +25,8 @@
|
||||
#pragma comment(lib, "opencv_highgui" OPENCV_VERSION ".lib")
|
||||
#endif
|
||||
|
||||
#include "http_stream.h"
|
||||
|
||||
IplImage* draw_train_chart(float max_img_loss, int max_batches, int number_of_lines, int img_size);
|
||||
void draw_train_loss(IplImage* img, int img_size, float avg_loss, float max_img_loss, int current_batch, int max_batches);
|
||||
#endif // OPENCV
|
||||
@ -1142,13 +1144,14 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam
|
||||
//for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(l.classes, sizeof(float *));
|
||||
|
||||
float *X = sized.data;
|
||||
time= what_time_is_it_now();
|
||||
|
||||
//time= what_time_is_it_now();
|
||||
double time = get_time_point();
|
||||
network_predict(net, X);
|
||||
//network_predict_image(&net, im); letterbox = 1;
|
||||
printf("%s: Predicted in %f seconds.\n", input, (what_time_is_it_now()-time));
|
||||
//get_region_boxes(l, 1, 1, thresh, probs, boxes, 0, 0);
|
||||
// if (nms) do_nms_sort_v2(boxes, probs, l.w*l.h*l.n, l.classes, nms);
|
||||
//draw_detections(im, l.w*l.h*l.n, thresh, boxes, probs, names, alphabet, l.classes);
|
||||
printf("%s: Predicted in %lf milli-seconds.\n", input, ((double)get_time_point() - time) / 1000);
|
||||
//printf("%s: Predicted in %f seconds.\n", input, (what_time_is_it_now()-time));
|
||||
|
||||
int nboxes = 0;
|
||||
detection *dets = get_network_boxes(&net, im.w, im.h, thresh, hier_thresh, 0, 1, &nboxes, letterbox);
|
||||
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
|
||||
|
@ -2146,7 +2146,7 @@ void time_ongpu(int TA, int TB, int m, int k, int n)
|
||||
clock_t start = clock(), end;
|
||||
for(i = 0; i<iter; ++i){
|
||||
gemm_ongpu(TA,TB,m,n,k,1,a_cl,lda,b_cl,ldb,1,c_cl,n);
|
||||
cudaThreadSynchronize();
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
double flop = ((double)m)*n*(2.*k + 2.)*iter;
|
||||
double gflop = flop/pow(10., 9);
|
||||
|
@ -329,3 +329,59 @@ image image_data_augmentation(IplImage* ipl, int w, int h,
|
||||
|
||||
|
||||
#endif // OPENCV
|
||||
|
||||
#if __cplusplus >= 201103L || _MSC_VER >= 1900 // C++11
|
||||
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
|
||||
static std::chrono::steady_clock::time_point steady_start, steady_end;
|
||||
static double total_time;
|
||||
|
||||
double get_time_point() {
|
||||
std::chrono::steady_clock::time_point current_time = std::chrono::steady_clock::now();
|
||||
//uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>(current_time.time_since_epoch()).count();
|
||||
return std::chrono::duration_cast<std::chrono::microseconds>(current_time.time_since_epoch()).count();
|
||||
}
|
||||
|
||||
void start_timer() {
|
||||
steady_start = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
void stop_timer() {
|
||||
steady_end = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
double get_time() {
|
||||
double took_time = std::chrono::duration<double>(steady_end - steady_start).count();
|
||||
total_time += took_time;
|
||||
return took_time;
|
||||
}
|
||||
|
||||
void stop_timer_and_show() {
|
||||
stop_timer();
|
||||
std::cout << " " << get_time()*1000 << " msec" << std::endl;
|
||||
}
|
||||
|
||||
void stop_timer_and_show_name(char *name) {
|
||||
std::cout << " " << name;
|
||||
stop_timer_and_show();
|
||||
}
|
||||
|
||||
void show_total_time() {
|
||||
std::cout << " Total: " << total_time * 1000 << " msec" << std::endl;
|
||||
}
|
||||
|
||||
#else // C++11
|
||||
#include <iostream>
|
||||
|
||||
double get_time_point() { return 0; }
|
||||
void start_timer() {}
|
||||
void stop_timer() {}
|
||||
double get_time() { return 0; }
|
||||
void stop_timer_and_show() {
|
||||
std::cout << " stop_timer_and_show() isn't implemented " << std::endl;
|
||||
}
|
||||
void stop_timer_and_show_name(char *name) { stop_timer_and_show(); }
|
||||
void total_time() {}
|
||||
#endif // C++11
|
@ -6,6 +6,7 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
#include "image.h"
|
||||
#include <stdint.h>
|
||||
|
||||
void send_mjpeg(IplImage* ipl, int port, int timeout, int quality);
|
||||
CvCapture* get_capture_webcam(int index);
|
||||
@ -17,6 +18,14 @@ image image_data_augmentation(IplImage* ipl, int w, int h,
|
||||
int pleft, int ptop, int swidth, int sheight, int flip,
|
||||
float jitter, float dhue, float dsat, float dexp);
|
||||
|
||||
double get_time_point();
|
||||
void start_timer();
|
||||
void stop_timer();
|
||||
double get_time();
|
||||
void stop_timer_and_show();
|
||||
void stop_timer_and_show_name(char *name);
|
||||
void show_total_time();
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -131,13 +131,13 @@ __global__ void im2col_align_gpu_kernel(const int n, const float* data_im,
|
||||
const int height_col, const int width_col,
|
||||
float *data_col, const int bit_align)
|
||||
{
|
||||
__shared__ float tmp_s[1];
|
||||
//__shared__ float tmp_s[1];
|
||||
|
||||
//#define SHRED_VALS ((BLOCK / 169) * )
|
||||
__shared__ float dst_s[1024];
|
||||
//__shared__ float dst_s[1024];
|
||||
//__shared__ float dst_s[1024];
|
||||
//__shared__ uint32_t bit_s[32];
|
||||
__shared__ uint8_t bit_s[128];
|
||||
//__shared__ uint8_t bit_s[128];
|
||||
|
||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
for (; index < n; index += blockDim.x*gridDim.x) {
|
||||
@ -551,7 +551,7 @@ void im2col_align_bin_ongpu(float *im,
|
||||
}
|
||||
// --------------------------------
|
||||
|
||||
|
||||
/*
|
||||
__global__ void float_to_bit_gpu_kernel(float *src, unsigned char *dst, size_t size)
|
||||
{
|
||||
//const int size_aligned = size + (WARP_SIZE - size % WARP_SIZE);
|
||||
@ -569,12 +569,45 @@ __global__ void float_to_bit_gpu_kernel(float *src, unsigned char *dst, size_t s
|
||||
if (threadIdx.x % WARP_SIZE == 0) ((unsigned int*)dst)[index / 32] = bit_mask;
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
__global__ void float_to_bit_gpu_kernel(float *src, unsigned char *dst, size_t size)
|
||||
{
|
||||
//const int size_aligned = size + (WARP_SIZE - size % WARP_SIZE);
|
||||
__shared__ uint32_t tmp[WARP_SIZE];
|
||||
|
||||
int index = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
float src_val;
|
||||
uint32_t *dst32_ptr = ((unsigned int*)dst);
|
||||
|
||||
//for (; index < size_aligned; index += blockDim.x*gridDim.x)
|
||||
{
|
||||
//src_val = src[index];
|
||||
if (index < size) src_val = src[index];
|
||||
else src_val = 0;
|
||||
//unsigned int bit_mask = __ballot_sync(0xffffffff, src_val > 0);
|
||||
const int num_of_warps = blockDim.x / WARP_SIZE;
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
uint32_t bit_mask = __ballot(src_val > 0);
|
||||
if (lane_id == 0) tmp[warp_id] = bit_mask;
|
||||
|
||||
__syncthreads();
|
||||
if (warp_id == 0) {
|
||||
if (lane_id < num_of_warps) {
|
||||
dst32_ptr[index / 32 + lane_id] = tmp[lane_id];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void float_to_bit_gpu(float *src, unsigned char *dst, size_t size)
|
||||
{
|
||||
const int num_blocks = size / BLOCK + 1;
|
||||
float_to_bit_gpu_kernel<<<num_blocks, BLOCK, 0, get_cuda_stream()>>>(src, dst, size);
|
||||
const int num_blocks = size / 1024 + 1;
|
||||
float_to_bit_gpu_kernel<<<num_blocks, 1024, 0, get_cuda_stream()>>>(src, dst, size);
|
||||
}
|
||||
// --------------------------------
|
||||
|
||||
|
@ -866,10 +866,10 @@ void calculate_binary_weights(network net)
|
||||
//if (l->size*l->size*l->c >= 2048) l->lda_align = 512;
|
||||
|
||||
binary_align_weights(l);
|
||||
}
|
||||
|
||||
if(net.layers[j].use_bin_output) {
|
||||
l->activation = LINEAR;
|
||||
}
|
||||
if (net.layers[j].use_bin_output) {
|
||||
l->activation = LINEAR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -40,12 +40,16 @@ extern "C" {
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#endif
|
||||
|
||||
#include "http_stream.h"
|
||||
|
||||
float * get_network_output_gpu_layer(network net, int i);
|
||||
float * get_network_delta_gpu_layer(network net, int i);
|
||||
float * get_network_output_gpu(network net);
|
||||
|
||||
void forward_network_gpu(network net, network_state state)
|
||||
{
|
||||
//cudaDeviceSynchronize();
|
||||
//printf("\n");
|
||||
state.workspace = net.workspace;
|
||||
int i;
|
||||
for(i = 0; i < net.n; ++i){
|
||||
@ -54,7 +58,12 @@ void forward_network_gpu(network net, network_state state)
|
||||
if(l.delta_gpu && state.train){
|
||||
fill_ongpu(l.outputs * l.batch, 0, l.delta_gpu, 1);
|
||||
}
|
||||
//printf("%d - type: %d - ", i, l.type);
|
||||
//start_timer();
|
||||
l.forward_gpu(l, state);
|
||||
//cudaDeviceSynchronize();
|
||||
//stop_timer_and_show();
|
||||
|
||||
if(net.wait_stream)
|
||||
cudaStreamSynchronize(get_cuda_stream());
|
||||
state.input = l.output_gpu;
|
||||
@ -75,6 +84,8 @@ void forward_network_gpu(network net, network_state state)
|
||||
}
|
||||
*/
|
||||
}
|
||||
//cudaDeviceSynchronize();
|
||||
//show_total_time();
|
||||
}
|
||||
|
||||
void backward_network_gpu(network net, network_state state)
|
||||
|
Reference in New Issue
Block a user