diff --git a/src/blas.h b/src/blas.h index 968bb955..a61b7d8c 100644 --- a/src/blas.h +++ b/src/blas.h @@ -41,6 +41,7 @@ void softmax_cpu(float *input, int n, int batch, int batch_offset, int groups, i #ifdef GPU #include "cuda.h" +#include "tree.h" void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY); void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY); @@ -86,6 +87,7 @@ void softmax_gpu(float *input, int n, int batch, int batch_offset, int groups, i void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t); void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out); +void softmax_tree(float *input, int spatial, int batch, int stride, float temp, float *output, tree hier); #endif #endif diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu index ac29d3f0..9f1337ca 100644 --- a/src/blas_kernels.cu +++ b/src/blas_kernels.cu @@ -788,6 +788,37 @@ __device__ void softmax_device(float *input, int n, float temp, int stride, floa } } + +__global__ void softmax_tree_kernel(float *input, int spatial, int batch, int stride, float temp, float *output, int groups, int *group_size, int *group_offset) +{ + int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (id >= spatial*batch*groups) return; + int s = id % spatial; + id = id / spatial; + int g = id % groups; + int b = id / groups; + int goff = group_offset[g]*spatial; + int boff = b*stride; + softmax_device(input + goff + boff + s, group_size[g], temp, spatial, output + goff + boff + s); +} + +extern "C" void softmax_tree(float *input, int spatial, int batch, int stride, float temp, float *output, tree hier) +{ + //int *tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups); + //int *tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups); + static int *tree_groups_size = 0; + static int *tree_groups_offset = 0; + if(!tree_groups_size){ + tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups); + tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups); + } + int num = spatial*batch*hier.groups; + softmax_tree_kernel<<>>(input, spatial, batch, stride, temp, output, hier.groups, tree_groups_size, tree_groups_offset); + check_error(cudaPeekAtLastError()); + //cuda_free((float *)tree_groups_size); + //cuda_free((float *)tree_groups_offset); +} + __global__ void softmax_kernel(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output) { int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; diff --git a/src/cuda.c b/src/cuda.c index af3d412b..7e53d254 100644 --- a/src/cuda.c +++ b/src/cuda.c @@ -128,12 +128,17 @@ float cuda_compare(float *x_gpu, float *x, size_t n, char *s) return err; } -int *cuda_make_int_array(size_t n) +int *cuda_make_int_array(int *x, size_t n) { int *x_gpu; size_t size = sizeof(int)*n; cudaError_t status = cudaMalloc((void **)&x_gpu, size); check_error(status); + if(x){ + status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice); + check_error(status); + } + if(!x_gpu) error("Cuda malloc failed\n"); return x_gpu; } diff --git a/src/cuda.h b/src/cuda.h index a825ded5..52794fa4 100644 --- a/src/cuda.h +++ b/src/cuda.h @@ -18,7 +18,7 @@ extern int gpu_index; void check_error(cudaError_t status); cublasHandle_t blas_handle(); float *cuda_make_array(float *x, size_t n); -int *cuda_make_int_array(size_t n); +int *cuda_make_int_array(int *x, size_t n); void cuda_push_array(float *x_gpu, float *x, size_t n); void cuda_pull_array(float *x_gpu, float *x, size_t n); void cuda_set_device(int n); diff --git a/src/maxpool_layer.c b/src/maxpool_layer.c index 7b3a836b..17dedf7a 100644 --- a/src/maxpool_layer.c +++ b/src/maxpool_layer.c @@ -43,7 +43,7 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s #ifdef GPU l.forward_gpu = forward_maxpool_layer_gpu; l.backward_gpu = backward_maxpool_layer_gpu; - l.indexes_gpu = cuda_make_int_array(output_size); + l.indexes_gpu = cuda_make_int_array(0, output_size); l.output_gpu = cuda_make_array(l.output, output_size); l.delta_gpu = cuda_make_array(l.delta, output_size); #endif @@ -70,7 +70,7 @@ void resize_maxpool_layer(maxpool_layer *l, int w, int h) cuda_free((float *)l->indexes_gpu); cuda_free(l->output_gpu); cuda_free(l->delta_gpu); - l->indexes_gpu = cuda_make_int_array(output_size); + l->indexes_gpu = cuda_make_int_array(0, output_size); l->output_gpu = cuda_make_array(l->output, output_size); l->delta_gpu = cuda_make_array(l->delta, output_size); #endif diff --git a/src/region_layer.c b/src/region_layer.c index 5fe931c6..9b645944 100644 --- a/src/region_layer.c +++ b/src/region_layer.c @@ -141,7 +141,6 @@ int entry_index(layer l, int batch, int location, int entry) return batch*l.outputs + n*l.w*l.h*(l.coords+l.classes+1) + entry*l.w*l.h + loc; } -void softmax_tree(float *input, int batch, int inputs, float temp, tree *hierarchy, float *output); void forward_region_layer(const layer l, network net) { int i,j,b,t,n; @@ -445,6 +444,9 @@ void forward_region_layer_gpu(const layer l, network net) } } if (l.softmax_tree){ + int index = entry_index(l, 0, 0, 5); + softmax_tree(net.input_gpu + index, l.w*l.h, l.batch*l.n, l.inputs/l.n, 1, l.output_gpu + index, *l.softmax_tree); + /* int i; int count = 5; for (i = 0; i < l.softmax_tree->groups; ++i) { @@ -453,6 +455,7 @@ void forward_region_layer_gpu(const layer l, network net) softmax_gpu(net.input_gpu + index, group_size, l.batch*l.n, l.inputs/l.n, l.w*l.h, 1, l.w*l.h, 1, l.output_gpu + index); count += group_size; } + */ } else if (l.softmax) { int index = entry_index(l, 0, 0, l.coords + !l.background); //printf("%d\n", index); diff --git a/src/segmenter.c b/src/segmenter.c index fab66cb0..32bb4843 100644 --- a/src/segmenter.c +++ b/src/segmenter.c @@ -149,7 +149,7 @@ void predict_segmenter(char *datafile, char *cfgfile, char *weightfile, char *fi float *X = sized.data; time=clock(); float *predictions = network_predict(net, X); - image m = float_to_image(sized.w, sized.h, 80, predictions); + image m = float_to_image(sized.w, sized.h, 81, predictions); image rgb = mask_to_rgb(m); show_image(sized, "orig"); show_image(rgb, "pred");