darknet/src/softmax_layer_kernels.cu
2015-11-15 19:51:26 -08:00

71 lines
2.0 KiB
Plaintext

#include "cuda_runtime.h"
#include "curand.h"
#include "cublas_v2.h"
extern "C" {
#include "softmax_layer.h"
#include "cuda.h"
#include "blas.h"
}
__global__ void forward_softmax_layer_kernel(int n, int batch, float *input, float *output)
{
int b = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(b >= batch) return;
int i;
float sum = 0;
float largest = -INFINITY;
for(i = 0; i < n; ++i){
int val = input[i+b*n];
largest = (val>largest) ? val : largest;
}
for(i = 0; i < n; ++i){
sum += exp(input[i+b*n]-largest);
}
sum = (sum != 0) ? largest+log(sum) : largest-100;
for(i = 0; i < n; ++i){
output[i+b*n] = exp(input[i+b*n]-sum);
}
}
extern "C" void pull_softmax_layer_output(const softmax_layer layer)
{
cuda_pull_array(layer.output_gpu, layer.output, layer.inputs*layer.batch);
}
extern "C" void forward_softmax_layer_gpu(const softmax_layer layer, network_state state)
{
int inputs = layer.inputs / layer.groups;
int batch = layer.batch * layer.groups;
forward_softmax_layer_kernel<<<cuda_gridsize(batch), BLOCK>>>(inputs, batch, state.input, layer.output_gpu);
check_error(cudaPeekAtLastError());
}
extern "C" void backward_softmax_layer_gpu(const softmax_layer layer, network_state state)
{
axpy_ongpu(layer.batch*layer.inputs, 1, layer.delta_gpu, 1, state.delta, 1);
}
/* This is if you want softmax w/o log-loss classification. You probably don't.
int i,j,b;
for(b = 0; b < layer.batch; ++b){
for(i = 0; i < layer.inputs; ++i){
for(j = 0; j < layer.inputs; ++j){
int d = (i==j);
layer.jacobian[b*layer.inputs*layer.inputs + i*layer.inputs + j] =
layer.output[b*layer.inputs + i] * (d - layer.output[b*layer.inputs + j]);
}
}
}
for(b = 0; b < layer.batch; ++b){
int M = layer.inputs;
int N = 1;
int K = layer.inputs;
float *A = layer.jacobian + b*layer.inputs*layer.inputs;
float *B = layer.delta + b*layer.inputs;
float *C = delta + b*layer.inputs;
gemm(0,0,M,N,K,1,A,K,B,N,0,C,N);
}
*/