softmax does cost now, special case 1x1 convs

This commit is contained in:
Joseph Redmon 2018-05-04 01:28:59 -07:00
parent 508381b37f
commit e7405b513d
21 changed files with 41 additions and 65 deletions

View File

@ -68,8 +68,8 @@ EXECOBJ = $(addprefix $(OBJDIR), $(EXECOBJA))
OBJS = $(addprefix $(OBJDIR), $(OBJ)) OBJS = $(addprefix $(OBJDIR), $(OBJ))
DEPS = $(wildcard src/*.h) Makefile include/darknet.h DEPS = $(wildcard src/*.h) Makefile include/darknet.h
#all: obj backup results $(SLIB) $(ALIB) $(EXEC) all: obj backup results $(SLIB) $(ALIB) $(EXEC)
all: obj results $(SLIB) $(ALIB) $(EXEC) #all: obj results $(SLIB) $(ALIB) $(EXEC)
$(EXEC): $(EXECOBJ) $(ALIB) $(EXEC): $(EXECOBJ) $(ALIB)

View File

@ -90,6 +90,3 @@ activation=linear
[softmax] [softmax]
groups=1 groups=1
[cost]
type=sse

View File

@ -119,6 +119,3 @@ activation=leaky
[softmax] [softmax]
groups=1 groups=1
[cost]

View File

@ -115,5 +115,3 @@ activation=leaky
groups=1 groups=1
temperature=3 temperature=3
[cost]

View File

@ -203,6 +203,3 @@ activation=linear
[softmax] [softmax]
groups=1 groups=1
[cost]
type=sse

View File

@ -195,6 +195,3 @@ activation=linear
[softmax] [softmax]
groups=1 groups=1
[cost]
type=sse

View File

@ -1949,6 +1949,3 @@ activation=linear
[softmax] [softmax]
groups=1 groups=1
[cost]
type=sse

View File

@ -201,6 +201,3 @@ activation=leaky
[softmax] [softmax]
groups=1 groups=1
[cost]
type=sse

View File

@ -204,6 +204,3 @@ activation=leaky
[softmax] [softmax]
groups=1 groups=1
[cost]
type=sse

View File

@ -130,6 +130,3 @@ stride=1
[softmax] [softmax]
[cost]
type=sse

View File

@ -129,6 +129,4 @@ stride=1
[softmax] [softmax]
[cost]
type=sse

View File

@ -27,6 +27,4 @@ activation=linear
[softmax] [softmax]
[cost]
type=sse

View File

@ -1458,6 +1458,3 @@ activation=linear
[softmax] [softmax]
groups=1 groups=1
[cost]
type=sse

View File

@ -506,6 +506,4 @@ activation=linear
[softmax] [softmax]
groups=1 groups=1
[cost]
type=sse

View File

@ -35,6 +35,4 @@ activation=leaky
[softmax] [softmax]
[cost]
type=sse

View File

@ -35,6 +35,4 @@ activation=leaky
[softmax] [softmax]
[cost]
type=sse

View File

@ -180,6 +180,3 @@ activation=ramp
[softmax] [softmax]
[cost]
type=sse

View File

@ -171,6 +171,4 @@ activation=linear
[softmax] [softmax]
groups=1 groups=1
[cost]
type=sse

View File

@ -148,6 +148,4 @@ activation=linear
[softmax] [softmax]
groups=1 groups=1
[cost]
type=sse

View File

@ -111,9 +111,13 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network net)
float *a = l.weights_gpu + j*l.nweights/l.groups; float *a = l.weights_gpu + j*l.nweights/l.groups;
float *b = net.workspace; float *b = net.workspace;
float *c = l.output_gpu + (i*l.groups + j)*n*m; float *c = l.output_gpu + (i*l.groups + j)*n*m;
float *im = net.input_gpu + (i*l.groups + j)*l.c/l.groups*l.h*l.w;
im2col_gpu(net.input_gpu + (i*l.groups + j)*l.c/l.groups*l.h*l.w, if (l.size == 1){
l.c/l.groups, l.h, l.w, l.size, l.stride, l.pad, b); b = im;
} else {
im2col_gpu(im, l.c/l.groups, l.h, l.w, l.size, l.stride, l.pad, b);
}
gemm_gpu(0,0,m,n,k,1,a,k,b,n,1,c,n); gemm_gpu(0,0,m,n,k,1,a,k,b,n,1,c,n);
} }
} }
@ -236,22 +240,26 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network net)
float *b = net.workspace; float *b = net.workspace;
float *c = l.weight_updates_gpu + j*l.nweights/l.groups; float *c = l.weight_updates_gpu + j*l.nweights/l.groups;
float *im = net.input_gpu+(i*l.groups + j)*l.c/l.groups*l.h*l.w; float *im = net.input_gpu+(i*l.groups + j)*l.c/l.groups*l.h*l.w;
float *imd = net.delta_gpu+(i*l.groups + j)*l.c/l.groups*l.h*l.w;
im2col_gpu(im, l.c/l.groups, l.h, l.w, im2col_gpu(im, l.c/l.groups, l.h, l.w, l.size, l.stride, l.pad, b);
l.size, l.stride, l.pad, b);
gemm_gpu(0,1,m,n,k,1,a,k,b,k,1,c,n); gemm_gpu(0,1,m,n,k,1,a,k,b,k,1,c,n);
if(net.delta_gpu){ if (net.delta_gpu) {
if(l.binary || l.xnor) swap_binary(&l); if (l.binary || l.xnor) swap_binary(&l);
a = l.weights_gpu + j*l.nweights/l.groups; a = l.weights_gpu + j*l.nweights/l.groups;
b = l.delta_gpu + (i*l.groups + j)*m*k; b = l.delta_gpu + (i*l.groups + j)*m*k;
c = net.workspace; c = net.workspace;
if (l.size == 1) {
c = imd;
}
gemm_gpu(1,0,n,k,m,1,a,n,b,k,0,c,k); gemm_gpu(1,0,n,k,m,1,a,n,b,k,0,c,k);
col2im_gpu(net.workspace, l.c/l.groups, l.h, l.w, l.size, l.stride, if (l.size != 1) {
l.pad, net.delta_gpu + (i*l.groups + j)*l.c/l.groups*l.h*l.w); col2im_gpu(net.workspace, l.c/l.groups, l.h, l.w, l.size, l.stride, l.pad, imd);
}
if(l.binary || l.xnor) { if(l.binary || l.xnor) {
swap_binary(&l); swap_binary(&l);
} }

View File

@ -463,9 +463,13 @@ void forward_convolutional_layer(convolutional_layer l, network net)
float *a = l.weights + j*l.nweights/l.groups; float *a = l.weights + j*l.nweights/l.groups;
float *b = net.workspace; float *b = net.workspace;
float *c = l.output + (i*l.groups + j)*n*m; float *c = l.output + (i*l.groups + j)*n*m;
float *im = net.input + (i*l.groups + j)*l.c/l.groups*l.h*l.w;
im2col_cpu(net.input + (i*l.groups + j)*l.c/l.groups*l.h*l.w, if (l.size == 1) {
l.c/l.groups, l.h, l.w, l.size, l.stride, l.pad, b); b = im;
} else {
im2col_cpu(im, l.c/l.groups, l.h, l.w, l.size, l.stride, l.pad, b);
}
gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
} }
} }
@ -501,21 +505,31 @@ void backward_convolutional_layer(convolutional_layer l, network net)
float *b = net.workspace; float *b = net.workspace;
float *c = l.weight_updates + j*l.nweights/l.groups; float *c = l.weight_updates + j*l.nweights/l.groups;
float *im = net.input+(i*l.groups + j)*l.c/l.groups*l.h*l.w; float *im = net.input + (i*l.groups + j)*l.c/l.groups*l.h*l.w;
float *imd = net.delta + (i*l.groups + j)*l.c/l.groups*l.h*l.w;
if(l.size == 1){
b = im;
} else {
im2col_cpu(im, l.c/l.groups, l.h, l.w,
l.size, l.stride, l.pad, b);
}
im2col_cpu(im, l.c/l.groups, l.h, l.w,
l.size, l.stride, l.pad, b);
gemm(0,1,m,n,k,1,a,k,b,k,1,c,n); gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
if(net.delta){ if (net.delta) {
a = l.weights + j*l.nweights/l.groups; a = l.weights + j*l.nweights/l.groups;
b = l.delta + (i*l.groups + j)*m*k; b = l.delta + (i*l.groups + j)*m*k;
c = net.workspace; c = net.workspace;
if (l.size == 1) {
c = imd;
}
gemm(1,0,n,k,m,1,a,n,b,k,0,c,k); gemm(1,0,n,k,m,1,a,n,b,k,0,c,k);
col2im_cpu(net.workspace, l.c/l.groups, l.h, l.w, l.size, l.stride, if (l.size != 1) {
l.pad, net.delta + (i*l.groups + j)*l.c/l.groups*l.h*l.w); col2im_cpu(net.workspace, l.c/l.groups, l.h, l.w, l.size, l.stride, l.pad, imd);
}
} }
} }
} }