mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
faster :tree: softmax
This commit is contained in:
parent
88b9ecb414
commit
579e588c84
@ -41,6 +41,7 @@ void softmax_cpu(float *input, int n, int batch, int batch_offset, int groups, i
|
|||||||
|
|
||||||
#ifdef GPU
|
#ifdef GPU
|
||||||
#include "cuda.h"
|
#include "cuda.h"
|
||||||
|
#include "tree.h"
|
||||||
|
|
||||||
void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY);
|
void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY);
|
||||||
void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY);
|
void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY);
|
||||||
@ -86,6 +87,7 @@ void softmax_gpu(float *input, int n, int batch, int batch_offset, int groups, i
|
|||||||
void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t);
|
void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t);
|
||||||
|
|
||||||
void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out);
|
void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out);
|
||||||
|
void softmax_tree(float *input, int spatial, int batch, int stride, float temp, float *output, tree hier);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
@ -788,6 +788,37 @@ __device__ void softmax_device(float *input, int n, float temp, int stride, floa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void softmax_tree_kernel(float *input, int spatial, int batch, int stride, float temp, float *output, int groups, int *group_size, int *group_offset)
|
||||||
|
{
|
||||||
|
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||||
|
if (id >= spatial*batch*groups) return;
|
||||||
|
int s = id % spatial;
|
||||||
|
id = id / spatial;
|
||||||
|
int g = id % groups;
|
||||||
|
int b = id / groups;
|
||||||
|
int goff = group_offset[g]*spatial;
|
||||||
|
int boff = b*stride;
|
||||||
|
softmax_device(input + goff + boff + s, group_size[g], temp, spatial, output + goff + boff + s);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void softmax_tree(float *input, int spatial, int batch, int stride, float temp, float *output, tree hier)
|
||||||
|
{
|
||||||
|
//int *tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups);
|
||||||
|
//int *tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups);
|
||||||
|
static int *tree_groups_size = 0;
|
||||||
|
static int *tree_groups_offset = 0;
|
||||||
|
if(!tree_groups_size){
|
||||||
|
tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups);
|
||||||
|
tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups);
|
||||||
|
}
|
||||||
|
int num = spatial*batch*hier.groups;
|
||||||
|
softmax_tree_kernel<<<cuda_gridsize(num), BLOCK>>>(input, spatial, batch, stride, temp, output, hier.groups, tree_groups_size, tree_groups_offset);
|
||||||
|
check_error(cudaPeekAtLastError());
|
||||||
|
//cuda_free((float *)tree_groups_size);
|
||||||
|
//cuda_free((float *)tree_groups_offset);
|
||||||
|
}
|
||||||
|
|
||||||
__global__ void softmax_kernel(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output)
|
__global__ void softmax_kernel(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output)
|
||||||
{
|
{
|
||||||
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||||
|
@ -128,12 +128,17 @@ float cuda_compare(float *x_gpu, float *x, size_t n, char *s)
|
|||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
|
|
||||||
int *cuda_make_int_array(size_t n)
|
int *cuda_make_int_array(int *x, size_t n)
|
||||||
{
|
{
|
||||||
int *x_gpu;
|
int *x_gpu;
|
||||||
size_t size = sizeof(int)*n;
|
size_t size = sizeof(int)*n;
|
||||||
cudaError_t status = cudaMalloc((void **)&x_gpu, size);
|
cudaError_t status = cudaMalloc((void **)&x_gpu, size);
|
||||||
check_error(status);
|
check_error(status);
|
||||||
|
if(x){
|
||||||
|
status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
|
||||||
|
check_error(status);
|
||||||
|
}
|
||||||
|
if(!x_gpu) error("Cuda malloc failed\n");
|
||||||
return x_gpu;
|
return x_gpu;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ extern int gpu_index;
|
|||||||
void check_error(cudaError_t status);
|
void check_error(cudaError_t status);
|
||||||
cublasHandle_t blas_handle();
|
cublasHandle_t blas_handle();
|
||||||
float *cuda_make_array(float *x, size_t n);
|
float *cuda_make_array(float *x, size_t n);
|
||||||
int *cuda_make_int_array(size_t n);
|
int *cuda_make_int_array(int *x, size_t n);
|
||||||
void cuda_push_array(float *x_gpu, float *x, size_t n);
|
void cuda_push_array(float *x_gpu, float *x, size_t n);
|
||||||
void cuda_pull_array(float *x_gpu, float *x, size_t n);
|
void cuda_pull_array(float *x_gpu, float *x, size_t n);
|
||||||
void cuda_set_device(int n);
|
void cuda_set_device(int n);
|
||||||
|
@ -43,7 +43,7 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s
|
|||||||
#ifdef GPU
|
#ifdef GPU
|
||||||
l.forward_gpu = forward_maxpool_layer_gpu;
|
l.forward_gpu = forward_maxpool_layer_gpu;
|
||||||
l.backward_gpu = backward_maxpool_layer_gpu;
|
l.backward_gpu = backward_maxpool_layer_gpu;
|
||||||
l.indexes_gpu = cuda_make_int_array(output_size);
|
l.indexes_gpu = cuda_make_int_array(0, output_size);
|
||||||
l.output_gpu = cuda_make_array(l.output, output_size);
|
l.output_gpu = cuda_make_array(l.output, output_size);
|
||||||
l.delta_gpu = cuda_make_array(l.delta, output_size);
|
l.delta_gpu = cuda_make_array(l.delta, output_size);
|
||||||
#endif
|
#endif
|
||||||
@ -70,7 +70,7 @@ void resize_maxpool_layer(maxpool_layer *l, int w, int h)
|
|||||||
cuda_free((float *)l->indexes_gpu);
|
cuda_free((float *)l->indexes_gpu);
|
||||||
cuda_free(l->output_gpu);
|
cuda_free(l->output_gpu);
|
||||||
cuda_free(l->delta_gpu);
|
cuda_free(l->delta_gpu);
|
||||||
l->indexes_gpu = cuda_make_int_array(output_size);
|
l->indexes_gpu = cuda_make_int_array(0, output_size);
|
||||||
l->output_gpu = cuda_make_array(l->output, output_size);
|
l->output_gpu = cuda_make_array(l->output, output_size);
|
||||||
l->delta_gpu = cuda_make_array(l->delta, output_size);
|
l->delta_gpu = cuda_make_array(l->delta, output_size);
|
||||||
#endif
|
#endif
|
||||||
|
@ -141,7 +141,6 @@ int entry_index(layer l, int batch, int location, int entry)
|
|||||||
return batch*l.outputs + n*l.w*l.h*(l.coords+l.classes+1) + entry*l.w*l.h + loc;
|
return batch*l.outputs + n*l.w*l.h*(l.coords+l.classes+1) + entry*l.w*l.h + loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
void softmax_tree(float *input, int batch, int inputs, float temp, tree *hierarchy, float *output);
|
|
||||||
void forward_region_layer(const layer l, network net)
|
void forward_region_layer(const layer l, network net)
|
||||||
{
|
{
|
||||||
int i,j,b,t,n;
|
int i,j,b,t,n;
|
||||||
@ -445,6 +444,9 @@ void forward_region_layer_gpu(const layer l, network net)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (l.softmax_tree){
|
if (l.softmax_tree){
|
||||||
|
int index = entry_index(l, 0, 0, 5);
|
||||||
|
softmax_tree(net.input_gpu + index, l.w*l.h, l.batch*l.n, l.inputs/l.n, 1, l.output_gpu + index, *l.softmax_tree);
|
||||||
|
/*
|
||||||
int i;
|
int i;
|
||||||
int count = 5;
|
int count = 5;
|
||||||
for (i = 0; i < l.softmax_tree->groups; ++i) {
|
for (i = 0; i < l.softmax_tree->groups; ++i) {
|
||||||
@ -453,6 +455,7 @@ void forward_region_layer_gpu(const layer l, network net)
|
|||||||
softmax_gpu(net.input_gpu + index, group_size, l.batch*l.n, l.inputs/l.n, l.w*l.h, 1, l.w*l.h, 1, l.output_gpu + index);
|
softmax_gpu(net.input_gpu + index, group_size, l.batch*l.n, l.inputs/l.n, l.w*l.h, 1, l.w*l.h, 1, l.output_gpu + index);
|
||||||
count += group_size;
|
count += group_size;
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
} else if (l.softmax) {
|
} else if (l.softmax) {
|
||||||
int index = entry_index(l, 0, 0, l.coords + !l.background);
|
int index = entry_index(l, 0, 0, l.coords + !l.background);
|
||||||
//printf("%d\n", index);
|
//printf("%d\n", index);
|
||||||
|
@ -149,7 +149,7 @@ void predict_segmenter(char *datafile, char *cfgfile, char *weightfile, char *fi
|
|||||||
float *X = sized.data;
|
float *X = sized.data;
|
||||||
time=clock();
|
time=clock();
|
||||||
float *predictions = network_predict(net, X);
|
float *predictions = network_predict(net, X);
|
||||||
image m = float_to_image(sized.w, sized.h, 80, predictions);
|
image m = float_to_image(sized.w, sized.h, 81, predictions);
|
||||||
image rgb = mask_to_rgb(m);
|
image rgb = mask_to_rgb(m);
|
||||||
show_image(sized, "orig");
|
show_image(sized, "orig");
|
||||||
show_image(rgb, "pred");
|
show_image(rgb, "pred");
|
||||||
|
Loading…
Reference in New Issue
Block a user