tree stuff

This commit is contained in:
Joseph Redmon 2016-10-21 13:16:43 -07:00
parent ae53edc6a4
commit d8adaf8ea6
17 changed files with 287 additions and 127 deletions

View File

@ -41,10 +41,10 @@ CFLAGS+= -DCUDNN
LDFLAGS+= -lcudnn LDFLAGS+= -lcudnn
endif endif
OBJ=gemm.o utils.o cuda.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o detector.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o demo.o tag.o cifar.o go.o batchnorm_layer.o art.o region_layer.o reorg_layer.o super.o voxel.o OBJ=gemm.o utils.o cuda.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o detector.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o demo.o tag.o cifar.o go.o batchnorm_layer.o art.o region_layer.o reorg_layer.o super.o voxel.o tree.o
ifeq ($(GPU), 1) ifeq ($(GPU), 1)
LDFLAGS+= -lstdc++ LDFLAGS+= -lstdc++
OBJ+=convolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o OBJ+=convolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
endif endif
OBJS = $(addprefix $(OBJDIR), $(OBJ)) OBJS = $(addprefix $(OBJDIR), $(OBJ))

View File

@ -1,6 +1,7 @@
#include "blas.h" #include "blas.h"
#include "math.h" #include "math.h"
#include <assert.h> #include <assert.h>
#include <float.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
@ -179,3 +180,21 @@ float dot_cpu(int N, float *X, int INCX, float *Y, int INCY)
return dot; return dot;
} }
void softmax(float *input, int n, float temp, float *output)
{
int i;
float sum = 0;
float largest = -FLT_MAX;
for(i = 0; i < n; ++i){
if(input[i] > largest) largest = input[i];
}
for(i = 0; i < n; ++i){
sum += exp(input[i]/temp-largest/temp);
}
if(sum) sum = largest/temp+log(sum);
else sum = largest-100;
for(i = 0; i < n; ++i){
output[i] = exp(input[i]/temp-sum);
}
}

View File

@ -34,7 +34,11 @@ void smooth_l1_cpu(int n, float *pred, float *truth, float *delta, float *error)
void l2_cpu(int n, float *pred, float *truth, float *delta, float *error); void l2_cpu(int n, float *pred, float *truth, float *delta, float *error);
void weighted_sum_cpu(float *a, float *b, float *s, int num, float *c); void weighted_sum_cpu(float *a, float *b, float *s, int num, float *c);
void softmax(float *input, int n, float temp, float *output);
#ifdef GPU #ifdef GPU
#include "cuda.h"
void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY); void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY);
void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY); void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY);
void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY); void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY);
@ -73,5 +77,7 @@ void mult_add_into_gpu(int num, float *a, float *b, float *c);
void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out); void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out);
void softmax_gpu(float *input, int n, int groups, float temp, float *output, cudaStream_t stream);
#endif #endif
#endif #endif

View File

@ -691,3 +691,33 @@ extern "C" void mult_add_into_gpu(int num, float *a, float *b, float *c)
mult_add_into_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, c); mult_add_into_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, c);
check_error(cudaPeekAtLastError()); check_error(cudaPeekAtLastError());
} }
__global__ void softmax_kernel(int n, int batch, float *input, float temp, float *output)
{
int b = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(b >= batch) return;
int i;
float sum = 0;
float largest = -INFINITY;
for(i = 0; i < n; ++i){
int val = input[i+b*n];
largest = (val>largest) ? val : largest;
}
for(i = 0; i < n; ++i){
sum += exp(input[i+b*n]/temp-largest/temp);
}
sum = (sum != 0) ? largest/temp+log(sum) : largest-100;
for(i = 0; i < n; ++i){
output[i+b*n] = exp(input[i+b*n]/temp-sum);
}
}
extern "C" void softmax_gpu(float *input, int n, int groups, float temp, float *output, cudaStream_t stream)
{
int inputs = n;
int batch = groups;
softmax_kernel<<<cuda_gridsize(batch), BLOCK, 0, stream>>>(inputs, batch, input, temp, output);
check_error(cudaPeekAtLastError());
}

View File

@ -41,6 +41,20 @@ list *read_data_cfg(char *filename)
return options; return options;
} }
void hierarchy_predictions(float *predictions, int n, tree *hier)
{
int j;
for(j = 0; j < n; ++j){
int parent = hier->parent[j];
if(parent >= 0){
predictions[j] *= predictions[parent];
}
}
for(j = 0; j < n; ++j){
if(!hier->leaf[j]) predictions[j] = 0;
}
}
float *get_regression_values(char **labels, int n) float *get_regression_values(char **labels, int n)
{ {
float *v = calloc(n, sizeof(float)); float *v = calloc(n, sizeof(float));
@ -99,7 +113,8 @@ void train_classifier_multi(char *datacfg, char *cfgfile, char *weightfile, int
load_args args = {0}; load_args args = {0};
args.w = net.w; args.w = net.w;
args.h = net.h; args.h = net.h;
args.threads = 16; args.threads = 32;
args.hierarchy = net.hierarchy;
args.min = net.min_crop; args.min = net.min_crop;
args.max = net.max_crop; args.max = net.max_crop;
@ -206,6 +221,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear)
args.saturation = net.saturation; args.saturation = net.saturation;
args.hue = net.hue; args.hue = net.hue;
args.size = net.w; args.size = net.w;
args.hierarchy = net.hierarchy;
args.paths = paths; args.paths = paths;
args.classes = classes; args.classes = classes;
@ -394,6 +410,7 @@ void validate_classifier_10(char *datacfg, char *filename, char *weightfile)
float *pred = calloc(classes, sizeof(float)); float *pred = calloc(classes, sizeof(float));
for(j = 0; j < 10; ++j){ for(j = 0; j < 10; ++j){
float *p = network_predict(net, images[j].data); float *p = network_predict(net, images[j].data);
if(net.hierarchy) hierarchy_predictions(p, net.outputs, net.hierarchy);
axpy_cpu(classes, 1, p, 1, pred, 1); axpy_cpu(classes, 1, p, 1, pred, 1);
free_image(images[j]); free_image(images[j]);
} }
@ -454,6 +471,7 @@ void validate_classifier_full(char *datacfg, char *filename, char *weightfile)
//show_image(crop, "cropped"); //show_image(crop, "cropped");
//cvWaitKey(0); //cvWaitKey(0);
float *pred = network_predict(net, resized.data); float *pred = network_predict(net, resized.data);
if(net.hierarchy) hierarchy_predictions(pred, net.outputs, net.hierarchy);
free_image(im); free_image(im);
free_image(resized); free_image(resized);
@ -513,6 +531,7 @@ void validate_classifier_single(char *datacfg, char *filename, char *weightfile)
//show_image(crop, "cropped"); //show_image(crop, "cropped");
//cvWaitKey(0); //cvWaitKey(0);
float *pred = network_predict(net, crop.data); float *pred = network_predict(net, crop.data);
if(net.hierarchy) hierarchy_predictions(pred, net.outputs, net.hierarchy);
if(resized.data != im.data) free_image(resized); if(resized.data != im.data) free_image(resized);
free_image(im); free_image(im);
@ -573,6 +592,7 @@ void validate_classifier_multi(char *datacfg, char *filename, char *weightfile)
image r = resize_min(im, scales[j]); image r = resize_min(im, scales[j]);
resize_network(&net, r.w, r.h); resize_network(&net, r.w, r.h);
float *p = network_predict(net, r.data); float *p = network_predict(net, r.data);
if(net.hierarchy) hierarchy_predictions(p, net.outputs, net.hierarchy);
axpy_cpu(classes, 1, p, 1, pred, 1); axpy_cpu(classes, 1, p, 1, pred, 1);
flip_image(r); flip_image(r);
p = network_predict(net, r.data); p = network_predict(net, r.data);
@ -672,7 +692,6 @@ void try_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filena
} }
} }
void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename) void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename)
{ {
network net = parse_network_cfg(cfgfile); network net = parse_network_cfg(cfgfile);
@ -713,11 +732,13 @@ void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *fi
float *X = r.data; float *X = r.data;
time=clock(); time=clock();
float *predictions = network_predict(net, X); float *predictions = network_predict(net, X);
top_predictions(net, top, indexes); if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy);
top_k(predictions, net.outputs, top, indexes);
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
for(i = 0; i < top; ++i){ for(i = 0; i < top; ++i){
int index = indexes[i]; int index = indexes[i];
printf("%s: %f\n", names[index], predictions[index]); if(net.hierarchy) printf("%d, %s: %f, parent: %s \n",index, names[index], predictions[index], (net.hierarchy->parent[index] >= 0) ? names[net.hierarchy->parent[index]] : "Root");
else printf("%s: %f\n",names[index], predictions[index]);
} }
if(r.data != im.data) free_image(r); if(r.data != im.data) free_image(r);
free_image(im); free_image(im);
@ -899,15 +920,15 @@ void threat_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_i
float curr_threat = 0; float curr_threat = 0;
if(1){ if(1){
curr_threat = predictions[0] * 0 + curr_threat = predictions[0] * 0 +
predictions[1] * .6 + predictions[1] * .6 +
predictions[2]; predictions[2];
} else { } else {
curr_threat = predictions[218] + curr_threat = predictions[218] +
predictions[539] + predictions[539] +
predictions[540] + predictions[540] +
predictions[368] + predictions[368] +
predictions[369] + predictions[369] +
predictions[370]; predictions[370];
} }
threat = roll * curr_threat + (1-roll) * threat; threat = roll * curr_threat + (1-roll) * threat;
@ -1092,6 +1113,7 @@ void demo_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_ind
show_image(in, "Classifier"); show_image(in, "Classifier");
float *predictions = network_predict(net, in_s.data); float *predictions = network_predict(net, in_s.data);
if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy);
top_predictions(net, top, indexes); top_predictions(net, top, indexes);
printf("\033[2J"); printf("\033[2J");

View File

@ -388,12 +388,47 @@ void fill_truth(char *path, char **labels, int k, float *truth)
if(count != 1) printf("Too many or too few labels: %d, %s\n", count, path); if(count != 1) printf("Too many or too few labels: %d, %s\n", count, path);
} }
matrix load_labels_paths(char **paths, int n, char **labels, int k) void fill_hierarchy(float *truth, int k, tree *hierarchy)
{
int j;
for(j = 0; j < k; ++j){
if(truth[j]){
int parent = hierarchy->parent[j];
while(parent >= 0){
truth[parent] = 1;
parent = hierarchy->parent[parent];
}
}
}
int i;
int count = 0;
for(j = 0; j < hierarchy->groups; ++j){
//printf("%d\n", count);
int mask = 1;
for(i = 0; i < hierarchy->group_size[j]; ++i){
if(truth[count + i]){
mask = 0;
break;
}
}
if (mask) {
for(i = 0; i < hierarchy->group_size[j]; ++i){
truth[count + i] = SECRET_NUM;
}
}
count += hierarchy->group_size[j];
}
}
matrix load_labels_paths(char **paths, int n, char **labels, int k, tree *hierarchy)
{ {
matrix y = make_matrix(n, k); matrix y = make_matrix(n, k);
int i; int i;
for(i = 0; i < n && labels; ++i){ for(i = 0; i < n && labels; ++i){
fill_truth(paths[i], labels, k, y.vals[i]); fill_truth(paths[i], labels, k, y.vals[i]);
if(hierarchy){
fill_hierarchy(y.vals[i], k, hierarchy);
}
} }
return y; return y;
} }
@ -540,7 +575,7 @@ data load_data_compare(int n, char **paths, int m, int classes, int w, int h)
while(fscanf(fp2, "%d %f", &id, &iou) == 2){ while(fscanf(fp2, "%d %f", &id, &iou) == 2){
if (d.y.vals[i][2*id + 1] < iou) d.y.vals[i][2*id + 1] = iou; if (d.y.vals[i][2*id + 1] < iou) d.y.vals[i][2*id + 1] = iou;
} }
for (j = 0; j < classes; ++j){ for (j = 0; j < classes; ++j){
if (d.y.vals[i][2*j] > .5 && d.y.vals[i][2*j+1] < .5){ if (d.y.vals[i][2*j] > .5 && d.y.vals[i][2*j+1] < .5){
d.y.vals[i][2*j] = 1; d.y.vals[i][2*j] = 1;
@ -567,7 +602,7 @@ data load_data_swag(char **paths, int n, int classes, float jitter)
{ {
int index = rand()%n; int index = rand()%n;
char *random_path = paths[index]; char *random_path = paths[index];
image orig = load_image_color(random_path, 0, 0); image orig = load_image_color(random_path, 0, 0);
int h = orig.h; int h = orig.h;
int w = orig.w; int w = orig.w;
@ -680,7 +715,7 @@ void *load_thread(void *ptr)
if (a.type == OLD_CLASSIFICATION_DATA){ if (a.type == OLD_CLASSIFICATION_DATA){
*a.d = load_data_old(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h); *a.d = load_data_old(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
} else if (a.type == CLASSIFICATION_DATA){ } else if (a.type == CLASSIFICATION_DATA){
*a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure); *a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.hierarchy, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
} else if (a.type == SUPER_DATA){ } else if (a.type == SUPER_DATA){
*a.d = load_data_super(a.paths, a.n, a.m, a.w, a.h, a.scale); *a.d = load_data_super(a.paths, a.n, a.m, a.w, a.h, a.scale);
} else if (a.type == WRITING_DATA){ } else if (a.type == WRITING_DATA){
@ -771,24 +806,24 @@ data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int
data d = {0}; data d = {0};
d.shallow = 0; d.shallow = 0;
d.X = load_image_paths(paths, n, w, h); d.X = load_image_paths(paths, n, w, h);
d.y = load_labels_paths(paths, n, labels, k); d.y = load_labels_paths(paths, n, labels, k, 0);
if(m) free(paths); if(m) free(paths);
return d; return d;
} }
/* /*
data load_data_study(char **paths, int n, int m, char **labels, int k, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure) data load_data_study(char **paths, int n, int m, char **labels, int k, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure)
{ {
data d = {0}; data d = {0};
d.indexes = calloc(n, sizeof(int)); d.indexes = calloc(n, sizeof(int));
if(m) paths = get_random_paths_indexes(paths, n, m, d.indexes); if(m) paths = get_random_paths_indexes(paths, n, m, d.indexes);
d.shallow = 0; d.shallow = 0;
d.X = load_image_augment_paths(paths, n, min, max, size, angle, aspect, hue, saturation, exposure); d.X = load_image_augment_paths(paths, n, min, max, size, angle, aspect, hue, saturation, exposure);
d.y = load_labels_paths(paths, n, labels, k); d.y = load_labels_paths(paths, n, labels, k);
if(m) free(paths); if(m) free(paths);
return d; return d;
} }
*/ */
data load_data_super(char **paths, int n, int m, int w, int h, int scale) data load_data_super(char **paths, int n, int m, int w, int h, int scale)
{ {
@ -820,13 +855,13 @@ data load_data_super(char **paths, int n, int m, int w, int h, int scale)
return d; return d;
} }
data load_data_augment(char **paths, int n, int m, char **labels, int k, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure) data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure)
{ {
if(m) paths = get_random_paths(paths, n, m); if(m) paths = get_random_paths(paths, n, m);
data d = {0}; data d = {0};
d.shallow = 0; d.shallow = 0;
d.X = load_image_augment_paths(paths, n, min, max, size, angle, aspect, hue, saturation, exposure); d.X = load_image_augment_paths(paths, n, min, max, size, angle, aspect, hue, saturation, exposure);
d.y = load_labels_paths(paths, n, labels, k); d.y = load_labels_paths(paths, n, labels, k, hierarchy);
if(m) free(paths); if(m) free(paths);
return d; return d;
} }

View File

@ -5,6 +5,7 @@
#include "matrix.h" #include "matrix.h"
#include "list.h" #include "list.h"
#include "image.h" #include "image.h"
#include "tree.h"
static inline float distance_from_edge(int x, int max) static inline float distance_from_edge(int x, int max)
{ {
@ -58,6 +59,7 @@ typedef struct load_args{
image *im; image *im;
image *resized; image *resized;
data_type type; data_type type;
tree *hierarchy;
} load_args; } load_args;
typedef struct{ typedef struct{
@ -80,7 +82,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, in
data load_data_tag(char **paths, int n, int m, int k, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure); data load_data_tag(char **paths, int n, int m, int k, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
matrix load_image_augment_paths(char **paths, int n, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure); matrix load_image_augment_paths(char **paths, int n, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
data load_data_super(char **paths, int n, int m, int w, int h, int scale); data load_data_super(char **paths, int n, int m, int w, int h, int scale);
data load_data_augment(char **paths, int n, int m, char **labels, int k, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure); data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
data load_go(char *filename); data load_go(char *filename);
box_label *read_boxes(char *filename, int *n); box_label *read_boxes(char *filename, int *n);

View File

@ -58,7 +58,7 @@ void forward_detection_layer(const detection_layer l, network_state state)
int index = b*l.inputs; int index = b*l.inputs;
for (i = 0; i < locations; ++i) { for (i = 0; i < locations; ++i) {
int offset = i*l.classes; int offset = i*l.classes;
softmax_array(l.output + index + offset, l.classes, 1, softmax(l.output + index + offset, l.classes, 1,
l.output + index + offset); l.output + index + offset);
} }
} }

View File

@ -3,6 +3,7 @@
#include "activations.h" #include "activations.h"
#include "stddef.h" #include "stddef.h"
#include "tree.h"
struct network_state; struct network_state;
@ -93,6 +94,8 @@ struct layer{
int reorg; int reorg;
int log; int log;
tree *softmax_tree;
float alpha; float alpha;
float beta; float beta;
float kappa; float kappa;

View File

@ -565,7 +565,6 @@ float *network_accuracies(network net, data d, int n)
return acc; return acc;
} }
float network_accuracy_multi(network net, data d, int n) float network_accuracy_multi(network net, data d, int n)
{ {
matrix guess = network_predict_data_multi(net, d, n); matrix guess = network_predict_data_multi(net, d, n);

View File

@ -5,6 +5,7 @@
#include "image.h" #include "image.h"
#include "layer.h" #include "layer.h"
#include "data.h" #include "data.h"
#include "tree.h"
typedef enum { typedef enum {
CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM
@ -47,6 +48,7 @@ typedef struct network{
float hue; float hue;
int gpu_index; int gpu_index;
tree *hierarchy;
#ifdef GPU #ifdef GPU
float **input_gpu; float **input_gpu;

View File

@ -221,6 +221,8 @@ softmax_layer parse_softmax(list *options, size_params params)
int groups = option_find_int_quiet(options, "groups",1); int groups = option_find_int_quiet(options, "groups",1);
softmax_layer layer = make_softmax_layer(params.batch, params.inputs, groups); softmax_layer layer = make_softmax_layer(params.batch, params.inputs, groups);
layer.temperature = option_find_float_quiet(options, "temperature", 1); layer.temperature = option_find_float_quiet(options, "temperature", 1);
char *tree_file = option_find_str(options, "tree", 0);
if (tree_file) layer.softmax_tree = read_tree(tree_file);
return layer; return layer;
} }
@ -598,6 +600,7 @@ network parse_network_cfg(char *filename)
l = parse_detection(options, params); l = parse_detection(options, params);
}else if(lt == SOFTMAX){ }else if(lt == SOFTMAX){
l = parse_softmax(options, params); l = parse_softmax(options, params);
net.hierarchy = l.softmax_tree;
}else if(lt == NORMALIZATION){ }else if(lt == NORMALIZATION){
l = parse_normalization(options, params); l = parse_normalization(options, params);
}else if(lt == BATCHNORM){ }else if(lt == BATCHNORM){

View File

@ -1,6 +1,5 @@
#include "region_layer.h" #include "region_layer.h"
#include "activations.h" #include "activations.h"
#include "softmax_layer.h"
#include "blas.h" #include "blas.h"
#include "box.h" #include "box.h"
#include "cuda.h" #include "cuda.h"
@ -99,7 +98,7 @@ void forward_region_layer(const region_layer l, network_state state)
int index = size*i + b*l.outputs; int index = size*i + b*l.outputs;
l.output[index + 4] = logistic_activate(l.output[index + 4]); l.output[index + 4] = logistic_activate(l.output[index + 4]);
if(l.softmax){ if(l.softmax){
softmax_array(l.output + index + 5, l.classes, 1, l.output + index + 5); softmax(l.output + index + 5, l.classes, 1, l.output + index + 5);
} }
} }
} }

View File

@ -32,31 +32,25 @@ softmax_layer make_softmax_layer(int batch, int inputs, int groups)
return l; return l;
} }
void softmax_array(float *input, int n, float temp, float *output)
{
int i;
float sum = 0;
float largest = -FLT_MAX;
for(i = 0; i < n; ++i){
if(input[i] > largest) largest = input[i];
}
for(i = 0; i < n; ++i){
sum += exp(input[i]/temp-largest/temp);
}
if(sum) sum = largest/temp+log(sum);
else sum = largest-100;
for(i = 0; i < n; ++i){
output[i] = exp(input[i]/temp-sum);
}
}
void forward_softmax_layer(const softmax_layer l, network_state state) void forward_softmax_layer(const softmax_layer l, network_state state)
{ {
int b; int b;
int inputs = l.inputs / l.groups; int inputs = l.inputs / l.groups;
int batch = l.batch * l.groups; int batch = l.batch * l.groups;
for(b = 0; b < batch; ++b){ if(l.softmax_tree){
softmax_array(state.input+b*inputs, inputs, l.temperature, l.output+b*inputs); for(b = 0; b < batch; ++b){
int i;
int count = 0;
for(i = 0; i < l.softmax_tree->groups; ++i){
int group_size = l.softmax_tree->group_size[i];
softmax(state.input+b*inputs + count, group_size, l.temperature, l.output+b*inputs + count);
count += group_size;
}
}
} else {
for(b = 0; b < batch; ++b){
softmax(state.input+b*inputs, inputs, l.temperature, l.output+b*inputs);
}
} }
} }
@ -68,3 +62,54 @@ void backward_softmax_layer(const softmax_layer l, network_state state)
} }
} }
#ifdef GPU
void pull_softmax_layer_output(const softmax_layer layer)
{
cuda_pull_array(layer.output_gpu, layer.output, layer.inputs*layer.batch);
}
void forward_softmax_layer_gpu(const softmax_layer l, network_state state)
{
int inputs = l.inputs / l.groups;
int batch = l.batch * l.groups;
int b;
if(l.softmax_tree){
if(0){
float *buff = calloc(inputs * batch, sizeof(float));
cuda_pull_array(state.input, buff, batch * inputs);
state.input = buff;
forward_softmax_layer(l, state);
cuda_push_array(l.output_gpu, l.output, batch*inputs);
free(buff);
} else {
int i;
const int nstreams = 32;
cudaStream_t streams[nstreams];
for (i = 0; i < nstreams; ++i) {
cudaStreamCreate(&streams[i]);
}
for (b = 0; b < batch; ++b) {
int i;
int count = 0;
for (i = 0; i < l.softmax_tree->groups; ++i) {
int group_size = l.softmax_tree->group_size[i];
softmax_gpu(state.input+b*inputs + count, group_size, 1, l.temperature, l.output_gpu+b*inputs + count, streams[(b*l.softmax_tree->groups + i) % nstreams]);
count += group_size;
}
}
for(i = 0; i < nstreams; ++i){
cudaStreamDestroy(streams[i]);
}
}
} else {
softmax_gpu(state.input, inputs, batch, l.temperature, l.output_gpu, 0);
}
}
void backward_softmax_layer_gpu(const softmax_layer layer, network_state state)
{
axpy_ongpu(layer.batch*layer.inputs, 1, layer.delta_gpu, 1, state.delta, 1);
}
#endif

View File

@ -1,70 +0,0 @@
#include "cuda_runtime.h"
#include "curand.h"
#include "cublas_v2.h"
extern "C" {
#include "softmax_layer.h"
#include "cuda.h"
#include "blas.h"
}
__global__ void forward_softmax_layer_kernel(int n, int batch, float *input, float temp, float *output)
{
int b = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(b >= batch) return;
int i;
float sum = 0;
float largest = -INFINITY;
for(i = 0; i < n; ++i){
int val = input[i+b*n];
largest = (val>largest) ? val : largest;
}
for(i = 0; i < n; ++i){
sum += exp(input[i+b*n]/temp-largest/temp);
}
sum = (sum != 0) ? largest/temp+log(sum) : largest-100;
for(i = 0; i < n; ++i){
output[i+b*n] = exp(input[i+b*n]/temp-sum);
}
}
extern "C" void pull_softmax_layer_output(const softmax_layer layer)
{
cuda_pull_array(layer.output_gpu, layer.output, layer.inputs*layer.batch);
}
extern "C" void forward_softmax_layer_gpu(const softmax_layer layer, network_state state)
{
int inputs = layer.inputs / layer.groups;
int batch = layer.batch * layer.groups;
forward_softmax_layer_kernel<<<cuda_gridsize(batch), BLOCK>>>(inputs, batch, state.input, layer.temperature, layer.output_gpu);
check_error(cudaPeekAtLastError());
}
extern "C" void backward_softmax_layer_gpu(const softmax_layer layer, network_state state)
{
axpy_ongpu(layer.batch*layer.inputs, 1, layer.delta_gpu, 1, state.delta, 1);
}
/* This is if you want softmax w/o log-loss classification. You probably don't.
int i,j,b;
for(b = 0; b < layer.batch; ++b){
for(i = 0; i < layer.inputs; ++i){
for(j = 0; j < layer.inputs; ++j){
int d = (i==j);
layer.jacobian[b*layer.inputs*layer.inputs + i*layer.inputs + j] =
layer.output[b*layer.inputs + i] * (d - layer.output[b*layer.inputs + j]);
}
}
}
for(b = 0; b < layer.batch; ++b){
int M = layer.inputs;
int N = 1;
int K = layer.inputs;
float *A = layer.jacobian + b*layer.inputs*layer.inputs;
float *B = layer.delta + b*layer.inputs;
float *C = delta + b*layer.inputs;
gemm(0,0,M,N,K,1,A,K,B,N,0,C,N);
}
*/

49
src/tree.c Normal file
View File

@ -0,0 +1,49 @@
#include <stdio.h>
#include <stdlib.h>
#include "tree.h"
#include "utils.h"
tree *read_tree(char *filename)
{
tree t = {0};
FILE *fp = fopen(filename, "r");
char *line;
int last_parent = -1;
int group_size = 0;
int groups = 0;
int n = 0;
while((line=fgetl(fp)) != 0){
char *id = calloc(256, sizeof(char));
int parent = -1;
sscanf(line, "%s %d", id, &parent);
t.parent = realloc(t.parent, (n+1)*sizeof(int));
t.parent[n] = parent;
t.name = realloc(t.name, (n+1)*sizeof(char *));
t.name[n] = id;
if(parent != last_parent){
++groups;
t.group_size = realloc(t.group_size, groups * sizeof(int));
t.group_size[groups - 1] = group_size;
group_size = 0;
last_parent = parent;
}
++n;
++group_size;
}
++groups;
t.group_size = realloc(t.group_size, groups * sizeof(int));
t.group_size[groups - 1] = group_size;
t.n = n;
t.groups = groups;
t.leaf = calloc(n, sizeof(int));
int i;
for(i = 0; i < n; ++i) t.leaf[i] = 1;
for(i = 0; i < n; ++i) if(t.parent[i] >= 0) t.leaf[t.parent[i]] = 0;
fclose(fp);
tree *tree_ptr = calloc(1, sizeof(tree));
*tree_ptr = t;
//error(0);
return tree_ptr;
}

16
src/tree.h Normal file
View File

@ -0,0 +1,16 @@
#ifndef TREE_H
#define TREE_H
typedef struct{
int *leaf;
int n;
int *parent;
char **name;
int groups;
int *group_size;
} tree;
tree *read_tree(char *filename);
#endif