mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Detection is back, baby\!
This commit is contained in:
parent
979d02126b
commit
0f645836f1
4
Makefile
4
Makefile
@ -25,9 +25,9 @@ CFLAGS+=-DGPU
|
||||
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas
|
||||
endif
|
||||
|
||||
OBJ=gemm.o utils.o cuda.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o normalization_layer.o parser.o option_list.o darknet.o
|
||||
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o normalization_layer.o parser.o option_list.o darknet.o
|
||||
ifeq ($(GPU), 1)
|
||||
OBJ+=convolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o
|
||||
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o
|
||||
endif
|
||||
|
||||
OBJS = $(addprefix $(OBJDIR), $(OBJ))
|
||||
|
@ -6,7 +6,7 @@ void col2im_cpu(float* data_col,
|
||||
int ksize, int stride, int pad, float* data_im);
|
||||
|
||||
#ifdef GPU
|
||||
void col2im_ongpu(float *data_col, int batch,
|
||||
void col2im_ongpu(float *data_col,
|
||||
int channels, int height, int width,
|
||||
int ksize, int stride, int pad, float *data_im);
|
||||
#endif
|
||||
|
@ -3,7 +3,7 @@ extern "C" {
|
||||
#include "cuda.h"
|
||||
}
|
||||
|
||||
__global__ void col2im_kernel(float *data_col, int offset,
|
||||
__global__ void col2im_kernel(float *data_col,
|
||||
int channels, int height, int width,
|
||||
int ksize, int stride, int pad, float *data_im)
|
||||
{
|
||||
@ -46,17 +46,17 @@ __global__ void col2im_kernel(float *data_col, int offset,
|
||||
val += part;
|
||||
}
|
||||
}
|
||||
data_im[index+offset] = val;
|
||||
data_im[index] = val;
|
||||
}
|
||||
|
||||
|
||||
extern "C" void col2im_ongpu(float *data_col, int offset,
|
||||
extern "C" void col2im_ongpu(float *data_col,
|
||||
int channels, int height, int width,
|
||||
int ksize, int stride, int pad, float *data_im)
|
||||
{
|
||||
|
||||
size_t n = channels*height*width;
|
||||
|
||||
col2im_kernel<<<cuda_gridsize(n), BLOCK>>>(data_col, offset, channels, height, width, ksize, stride, pad, data_im);
|
||||
col2im_kernel<<<cuda_gridsize(n), BLOCK>>>(data_col, channels, height, width, ksize, stride, pad, data_im);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
@ -65,7 +65,7 @@ extern "C" void forward_convolutional_layer_gpu(convolutional_layer layer, float
|
||||
bias_output_gpu(layer.output_gpu, layer.biases_gpu, layer.batch, layer.n, n);
|
||||
|
||||
for(i = 0; i < layer.batch; ++i){
|
||||
im2col_ongpu(in, i*layer.c*layer.h*layer.w, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, layer.col_image_gpu);
|
||||
im2col_ongpu(in + i*layer.c*layer.h*layer.w, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, layer.col_image_gpu);
|
||||
float * a = layer.filters_gpu;
|
||||
float * b = layer.col_image_gpu;
|
||||
float * c = layer.output_gpu;
|
||||
@ -93,7 +93,7 @@ extern "C" void backward_convolutional_layer_gpu(convolutional_layer layer, floa
|
||||
float * b = layer.col_image_gpu;
|
||||
float * c = layer.filter_updates_gpu;
|
||||
|
||||
im2col_ongpu(in, i*layer.c*layer.h*layer.w, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, layer.col_image_gpu);
|
||||
im2col_ongpu(in + i*layer.c*layer.h*layer.w, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, layer.col_image_gpu);
|
||||
gemm_ongpu(0,1,m,n,k,alpha,a + i*m*k,k,b,k,1,c,n);
|
||||
|
||||
if(delta_gpu){
|
||||
@ -104,7 +104,7 @@ extern "C" void backward_convolutional_layer_gpu(convolutional_layer layer, floa
|
||||
|
||||
gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k);
|
||||
|
||||
col2im_ongpu(layer.col_image_gpu, i*layer.c*layer.h*layer.w, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, delta_gpu);
|
||||
col2im_ongpu(layer.col_image_gpu, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, delta_gpu + i*layer.c*layer.h*layer.w);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -44,7 +44,6 @@ image get_convolutional_delta(convolutional_layer layer)
|
||||
convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, float learning_rate, float momentum, float decay)
|
||||
{
|
||||
int i;
|
||||
size = 2*(size/2)+1; //HA! And you thought you'd use an even sized filter...
|
||||
convolutional_layer *layer = calloc(1, sizeof(convolutional_layer));
|
||||
|
||||
layer->learning_rate = learning_rate;
|
||||
@ -95,11 +94,10 @@ convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, in
|
||||
return layer;
|
||||
}
|
||||
|
||||
void resize_convolutional_layer(convolutional_layer *layer, int h, int w, int c)
|
||||
void resize_convolutional_layer(convolutional_layer *layer, int h, int w)
|
||||
{
|
||||
layer->h = h;
|
||||
layer->w = w;
|
||||
layer->c = c;
|
||||
int out_h = convolutional_out_height(*layer);
|
||||
int out_w = convolutional_out_width(*layer);
|
||||
|
||||
@ -109,6 +107,16 @@ void resize_convolutional_layer(convolutional_layer *layer, int h, int w, int c)
|
||||
layer->batch*out_h * out_w * layer->n*sizeof(float));
|
||||
layer->delta = realloc(layer->delta,
|
||||
layer->batch*out_h * out_w * layer->n*sizeof(float));
|
||||
|
||||
#ifdef GPU
|
||||
cuda_free(layer->col_image_gpu);
|
||||
cuda_free(layer->delta_gpu);
|
||||
cuda_free(layer->output_gpu);
|
||||
|
||||
layer->col_image_gpu = cuda_make_array(layer->col_image, out_h*out_w*layer->size*layer->size*layer->c);
|
||||
layer->delta_gpu = cuda_make_array(layer->delta, layer->batch*out_h*out_w*layer->n);
|
||||
layer->output_gpu = cuda_make_array(layer->output, layer->batch*out_h*out_w*layer->n);
|
||||
#endif
|
||||
}
|
||||
|
||||
void bias_output(float *output, float *biases, int batch, int n, int size)
|
||||
|
@ -54,7 +54,7 @@ void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int
|
||||
#endif
|
||||
|
||||
convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, float learning_rate, float momentum, float decay);
|
||||
void resize_convolutional_layer(convolutional_layer *layer, int h, int w, int c);
|
||||
void resize_convolutional_layer(convolutional_layer *layer, int h, int w);
|
||||
void forward_convolutional_layer(const convolutional_layer layer, float *in);
|
||||
void update_convolutional_layer(convolutional_layer layer);
|
||||
image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_filters);
|
||||
|
121
src/darknet.c
121
src/darknet.c
@ -57,8 +57,8 @@ void draw_detection(image im, float *box, int side)
|
||||
int d = im.w/side;
|
||||
int y = r*d+box[j+1]*d;
|
||||
int x = c*d+box[j+2]*d;
|
||||
int h = box[j+3]*256;
|
||||
int w = box[j+4]*256;
|
||||
int h = box[j+3]*im.h;
|
||||
int w = box[j+4]*im.w;
|
||||
//printf("%f %f %f %f\n", box[j+1], box[j+2], box[j+3], box[j+4]);
|
||||
//printf("%d %d %d %d\n", x, y, w, h);
|
||||
//printf("%d %d %d %d\n", x-w/2, y-h/2, x+w/2, y+h/2);
|
||||
@ -70,54 +70,79 @@ void draw_detection(image im, float *box, int side)
|
||||
cvWaitKey(0);
|
||||
}
|
||||
|
||||
|
||||
void train_detection_net(char *cfgfile)
|
||||
char *basename(char *cfgfile)
|
||||
{
|
||||
char *c = cfgfile;
|
||||
char *next;
|
||||
while((next = strchr(c, '/')))
|
||||
{
|
||||
c = next+1;
|
||||
}
|
||||
c = copy_string(c);
|
||||
next = strchr(c, '_');
|
||||
if (next) *next = 0;
|
||||
next = strchr(c, '.');
|
||||
if (next) *next = 0;
|
||||
return c;
|
||||
}
|
||||
|
||||
void train_detection_net(char *cfgfile, char *weightfile)
|
||||
{
|
||||
char *base = basename(cfgfile);
|
||||
printf("%s\n", base);
|
||||
float avg_loss = 1;
|
||||
//network net = parse_network_cfg("/home/pjreddie/imagenet_backup/alexnet_1270.cfg");
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
||||
int imgs = 1024;
|
||||
int imgs = 128;
|
||||
srand(time(0));
|
||||
//srand(23410);
|
||||
int i = 0;
|
||||
list *plist = get_paths("/home/pjreddie/data/imagenet/horse.txt");
|
||||
int i = net.seen/imgs;
|
||||
list *plist = get_paths("/home/pjreddie/data/imagenet/horse_pos.txt");
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
printf("%d\n", plist->size);
|
||||
data train, buffer;
|
||||
pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, 256, 256, 7, 7, 256, &buffer);
|
||||
int im_dim = 512;
|
||||
int jitter = 64;
|
||||
pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, im_dim, im_dim, 7, 7, jitter, &buffer);
|
||||
clock_t time;
|
||||
while(1){
|
||||
i += 1;
|
||||
time=clock();
|
||||
pthread_join(load_thread, 0);
|
||||
train = buffer;
|
||||
load_thread = load_data_detection_thread(imgs, paths, plist->size, 256, 256, 7, 7, 256, &buffer);
|
||||
//data train = load_data_detection_random(imgs, paths, plist->size, 224, 224, 7, 7, 256);
|
||||
load_thread = load_data_detection_thread(imgs, paths, plist->size, im_dim, im_dim, 7, 7, jitter, &buffer);
|
||||
|
||||
/*
|
||||
image im = float_to_image(224, 224, 3, train.X.vals[923]);
|
||||
/*
|
||||
image im = float_to_image(im_dim - jitter, im_dim-jitter, 3, train.X.vals[923]);
|
||||
draw_detection(im, train.y.vals[923], 7);
|
||||
show_image(im, "truth");
|
||||
cvWaitKey(0);
|
||||
*/
|
||||
|
||||
normalize_data_rows(train);
|
||||
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
||||
time=clock();
|
||||
float loss = train_network(net, train);
|
||||
net.seen += imgs;
|
||||
avg_loss = avg_loss*.9 + loss*.1;
|
||||
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs);
|
||||
if(i%100==0){
|
||||
char buff[256];
|
||||
sprintf(buff, "/home/pjreddie/imagenet_backup/detnet_%d.cfg", i);
|
||||
save_network(net, buff);
|
||||
sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
|
||||
save_weights(net, buff);
|
||||
}
|
||||
free_data(train);
|
||||
}
|
||||
}
|
||||
|
||||
void validate_detection_net(char *cfgfile)
|
||||
void validate_detection_net(char *cfgfile, char *weightfile)
|
||||
{
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
||||
srand(time(0));
|
||||
|
||||
@ -137,7 +162,6 @@ void validate_detection_net(char *cfgfile)
|
||||
time=clock();
|
||||
pthread_join(load_thread, 0);
|
||||
val = buffer;
|
||||
normalize_data_rows(val);
|
||||
|
||||
num = (i+1)*m/splits - i*m/splits;
|
||||
char **part = paths+(i*m/splits);
|
||||
@ -206,20 +230,13 @@ void train_imagenet_distributed(char *address)
|
||||
}
|
||||
*/
|
||||
|
||||
char *basename(char *cfgfile)
|
||||
void convert(char *cfgfile, char *outfile, char *weightfile)
|
||||
{
|
||||
char *c = cfgfile;
|
||||
char *next;
|
||||
while((next = strchr(c, '/')))
|
||||
{
|
||||
c = next+1;
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
c = copy_string(c);
|
||||
next = strchr(c, '_');
|
||||
if (next) *next = 0;
|
||||
next = strchr(c, '.');
|
||||
if (next) *next = 0;
|
||||
return c;
|
||||
save_network(net, outfile);
|
||||
}
|
||||
|
||||
void train_imagenet(char *cfgfile, char *weightfile)
|
||||
@ -232,8 +249,6 @@ void train_imagenet(char *cfgfile, char *weightfile)
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
//test_learn_bias(*(convolutional_layer *)net.layers[1]);
|
||||
//set_learning_network(&net, net.learning_rate, 0, net.decay);
|
||||
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
||||
int imgs = 1024;
|
||||
int i = net.seen/imgs;
|
||||
@ -279,7 +294,7 @@ void validate_imagenet(char *filename, char *weightfile)
|
||||
|
||||
char **labels = get_labels("/home/pjreddie/data/imagenet/cls.val.labels.list");
|
||||
|
||||
list *plist = get_paths("/home/pjreddie/data/imagenet/cls.val.list");
|
||||
list *plist = get_paths("/data/imagenet/cls.val.list");
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
int m = plist->size;
|
||||
free_list(plist);
|
||||
@ -312,9 +327,12 @@ void validate_imagenet(char *filename, char *weightfile)
|
||||
}
|
||||
}
|
||||
|
||||
void test_detection(char *cfgfile)
|
||||
void test_detection(char *cfgfile, char *weightfile)
|
||||
{
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
set_batch_network(&net, 1);
|
||||
srand(2222222);
|
||||
clock_t time;
|
||||
@ -323,7 +341,8 @@ void test_detection(char *cfgfile)
|
||||
fgets(filename, 256, stdin);
|
||||
strtok(filename, "\n");
|
||||
image im = load_image_color(filename, 224, 224);
|
||||
z_normalize_image(im);
|
||||
translate_image(im, -128);
|
||||
scale_image(im, 1/128.);
|
||||
printf("%d %d %d\n", im.h, im.w, im.c);
|
||||
float *X = im.data;
|
||||
time=clock();
|
||||
@ -386,6 +405,30 @@ void test_dog(char *cfgfile)
|
||||
cvWaitKey(0);
|
||||
}
|
||||
|
||||
void test_voc_segment(char *cfgfile, char *weightfile)
|
||||
{
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
set_batch_network(&net, 1);
|
||||
while(1){
|
||||
char filename[256];
|
||||
fgets(filename, 256, stdin);
|
||||
strtok(filename, "\n");
|
||||
image im = load_image_color(filename, 500, 500);
|
||||
//resize_network(net, im.h, im.w, im.c);
|
||||
translate_image(im, -128);
|
||||
scale_image(im, 1/128.);
|
||||
//float *predictions = network_predict(net, im.data);
|
||||
network_predict(net, im.data);
|
||||
free_image(im);
|
||||
image output = get_network_image_layer(net, net.n-2);
|
||||
show_image(output, "Segment Output");
|
||||
cvWaitKey(0);
|
||||
}
|
||||
}
|
||||
|
||||
void test_imagenet(char *cfgfile)
|
||||
{
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
@ -764,25 +807,27 @@ int main(int argc, char **argv)
|
||||
fprintf(stderr, "usage: %s <function> <filename>\n", argv[0]);
|
||||
return 0;
|
||||
}
|
||||
else if(0==strcmp(argv[1], "detection")) train_detection_net(argv[2]);
|
||||
else if(0==strcmp(argv[1], "detection")) train_detection_net(argv[2], (argc > 3)? argv[3] : 0);
|
||||
else if(0==strcmp(argv[1], "test")) test_imagenet(argv[2]);
|
||||
else if(0==strcmp(argv[1], "dog")) test_dog(argv[2]);
|
||||
else if(0==strcmp(argv[1], "ctrain")) train_cifar10(argv[2]);
|
||||
else if(0==strcmp(argv[1], "nist")) train_nist(argv[2]);
|
||||
else if(0==strcmp(argv[1], "ctest")) test_cifar10(argv[2]);
|
||||
else if(0==strcmp(argv[1], "train")) train_imagenet(argv[2], (argc > 3)? argv[3] : 0);
|
||||
else if(0==strcmp(argv[1], "testseg")) test_voc_segment(argv[2], (argc > 3)? argv[3] : 0);
|
||||
//else if(0==strcmp(argv[1], "client")) train_imagenet_distributed(argv[2]);
|
||||
else if(0==strcmp(argv[1], "detect")) test_detection(argv[2]);
|
||||
else if(0==strcmp(argv[1], "detect")) test_detection(argv[2], (argc > 3)? argv[3] : 0);
|
||||
else if(0==strcmp(argv[1], "init")) test_init(argv[2]);
|
||||
else if(0==strcmp(argv[1], "visualize")) test_visualize(argv[2]);
|
||||
else if(0==strcmp(argv[1], "valid")) validate_imagenet(argv[2], (argc > 3)? argv[3] : 0);
|
||||
else if(0==strcmp(argv[1], "testnist")) test_nist(argv[2]);
|
||||
else if(0==strcmp(argv[1], "validetect")) validate_detection_net(argv[2]);
|
||||
else if(0==strcmp(argv[1], "validetect")) validate_detection_net(argv[2], (argc > 3)? argv[3] : 0);
|
||||
else if(argc < 4){
|
||||
fprintf(stderr, "usage: %s <function> <filename> <filename>\n", argv[0]);
|
||||
return 0;
|
||||
}
|
||||
else if(0==strcmp(argv[1], "compare")) compare_nist(argv[2], argv[3]);
|
||||
else if(0==strcmp(argv[1], "convert")) convert(argv[2], argv[3], (argc > 4)? argv[4] : 0);
|
||||
else if(0==strcmp(argv[1], "scale")) scale_rate(argv[2], atof(argv[3]));
|
||||
fprintf(stderr, "Success!\n");
|
||||
return 0;
|
||||
|
60
src/data.c
60
src/data.c
@ -16,7 +16,7 @@ struct load_args{
|
||||
int w;
|
||||
int nh;
|
||||
int nw;
|
||||
float scale;
|
||||
int jitter;
|
||||
data *d;
|
||||
};
|
||||
|
||||
@ -33,16 +33,18 @@ list *get_paths(char *filename)
|
||||
return lines;
|
||||
}
|
||||
|
||||
void fill_truth_detection(char *path, float *truth, int height, int width, int num_height, int num_width, float scale, int dx, int dy)
|
||||
void fill_truth_detection(char *path, float *truth, int height, int width, int num_height, int num_width, int dy, int dx, int jitter)
|
||||
{
|
||||
int box_height = height/num_height;
|
||||
int box_width = width/num_width;
|
||||
char *labelpath = find_replace(path, "imgs", "det");
|
||||
char *labelpath = find_replace(path, "imgs", "det/train");
|
||||
labelpath = find_replace(labelpath, ".JPEG", ".txt");
|
||||
FILE *file = fopen(labelpath, "r");
|
||||
if(!file) file_error(labelpath);
|
||||
int x, y, h, w;
|
||||
while(fscanf(file, "%d %d %d %d", &x, &y, &w, &h) == 4){
|
||||
float x, y, h, w;
|
||||
while(fscanf(file, "%f %f %f %f", &x, &y, &w, &h) == 4){
|
||||
x *= width + jitter;
|
||||
y *= height + jitter;
|
||||
x -= dx;
|
||||
y -= dy;
|
||||
int i = x/box_width;
|
||||
@ -53,17 +55,15 @@ void fill_truth_detection(char *path, float *truth, int height, int width, int n
|
||||
if(j < 0) j = 0;
|
||||
if(j >= num_height) j = num_height-1;
|
||||
|
||||
float dw = (float)(x%box_width)/box_height;
|
||||
float dh = (float)(y%box_width)/box_width;
|
||||
float sh = h/scale;
|
||||
float sw = w/scale;
|
||||
float dw = (x - i*box_width)/box_width;
|
||||
float dh = (y - j*box_height)/box_height;
|
||||
//printf("%d %d %f %f\n", i, j, dh, dw);
|
||||
int index = (i+j*num_width)*5;
|
||||
truth[index++] = 1;
|
||||
truth[index++] = dh;
|
||||
truth[index++] = dw;
|
||||
truth[index++] = sh;
|
||||
truth[index++] = sw;
|
||||
truth[index++] = h*(height+jitter)/height;
|
||||
truth[index++] = w*(width+jitter)/width;
|
||||
}
|
||||
fclose(file);
|
||||
}
|
||||
@ -120,13 +120,13 @@ matrix load_labels_paths(char **paths, int n, char **labels, int k)
|
||||
return y;
|
||||
}
|
||||
|
||||
matrix load_labels_detection(char **paths, int n, int height, int width, int num_height, int num_width, float scale)
|
||||
matrix load_labels_detection(char **paths, int n, int height, int width, int num_height, int num_width)
|
||||
{
|
||||
int k = num_height*num_width*5;
|
||||
matrix y = make_matrix(n, k);
|
||||
int i;
|
||||
for(i = 0; i < n; ++i){
|
||||
fill_truth_detection(paths[i], y.vals[i], height, width, num_height, num_width, scale,0,0);
|
||||
fill_truth_detection(paths[i], y.vals[i], height, width, num_height, num_width, 0, 0, 0);
|
||||
}
|
||||
return y;
|
||||
}
|
||||
@ -165,7 +165,7 @@ void free_data(data d)
|
||||
}
|
||||
}
|
||||
|
||||
data load_data_detection_jitter_random(int n, char **paths, int m, int h, int w, int nh, int nw, float scale)
|
||||
data load_data_detection_jitter_random(int n, char **paths, int m, int h, int w, int nh, int nw, int jitter)
|
||||
{
|
||||
char **random_paths = get_random_paths(paths, n, m);
|
||||
int i;
|
||||
@ -175,13 +175,13 @@ data load_data_detection_jitter_random(int n, char **paths, int m, int h, int w,
|
||||
int k = nh*nw*5;
|
||||
d.y = make_matrix(n, k);
|
||||
for(i = 0; i < n; ++i){
|
||||
int dx = rand()%32;
|
||||
int dy = rand()%32;
|
||||
fill_truth_detection(random_paths[i], d.y.vals[i], 224, 224, nh, nw, scale, dx, dy);
|
||||
int dx = rand()%jitter;
|
||||
int dy = rand()%jitter;
|
||||
fill_truth_detection(random_paths[i], d.y.vals[i], h-jitter, w-jitter, nh, nw, dy, dx, jitter);
|
||||
image a = float_to_image(h, w, 3, d.X.vals[i]);
|
||||
jitter_image(a,224,224,dy,dx);
|
||||
jitter_image(a,h-jitter,w-jitter,dy,dx);
|
||||
}
|
||||
d.X.cols = 224*224*3;
|
||||
d.X.cols = (h-jitter)*(w-jitter)*3;
|
||||
free(random_paths);
|
||||
return d;
|
||||
}
|
||||
@ -189,12 +189,14 @@ data load_data_detection_jitter_random(int n, char **paths, int m, int h, int w,
|
||||
void *load_detection_thread(void *ptr)
|
||||
{
|
||||
struct load_args a = *(struct load_args*)ptr;
|
||||
*a.d = load_data_detection_jitter_random(a.n, a.paths, a.m, a.h, a.w, a.nh, a.nw, a.scale);
|
||||
*a.d = load_data_detection_jitter_random(a.n, a.paths, a.m, a.h, a.w, a.nh, a.nw, a.jitter);
|
||||
translate_data_rows(*a.d, -128);
|
||||
scale_data_rows(*a.d, 1./128);
|
||||
free(ptr);
|
||||
return 0;
|
||||
}
|
||||
|
||||
pthread_t load_data_detection_thread(int n, char **paths, int m, int h, int w, int nh, int nw, float scale, data *d)
|
||||
pthread_t load_data_detection_thread(int n, char **paths, int m, int h, int w, int nh, int nw, int jitter, data *d)
|
||||
{
|
||||
pthread_t thread;
|
||||
struct load_args *args = calloc(1, sizeof(struct load_args));
|
||||
@ -205,7 +207,7 @@ pthread_t load_data_detection_thread(int n, char **paths, int m, int h, int w, i
|
||||
args->w = w;
|
||||
args->nh = nh;
|
||||
args->nw = nw;
|
||||
args->scale = scale;
|
||||
args->jitter = jitter;
|
||||
args->d = d;
|
||||
if(pthread_create(&thread, 0, load_detection_thread, args)) {
|
||||
error("Thread creation failed");
|
||||
@ -213,13 +215,13 @@ pthread_t load_data_detection_thread(int n, char **paths, int m, int h, int w, i
|
||||
return thread;
|
||||
}
|
||||
|
||||
data load_data_detection_random(int n, char **paths, int m, int h, int w, int nh, int nw, float scale)
|
||||
data load_data_detection_random(int n, char **paths, int m, int h, int w, int nh, int nw)
|
||||
{
|
||||
char **random_paths = get_random_paths(paths, n, m);
|
||||
data d;
|
||||
d.shallow = 0;
|
||||
d.X = load_image_paths(random_paths, n, h, w);
|
||||
d.y = load_labels_detection(random_paths, n, h, w, nh, nw, scale);
|
||||
d.y = load_labels_detection(random_paths, n, h, w, nh, nw);
|
||||
free(random_paths);
|
||||
return d;
|
||||
}
|
||||
@ -239,8 +241,8 @@ void *load_in_thread(void *ptr)
|
||||
{
|
||||
struct load_args a = *(struct load_args*)ptr;
|
||||
*a.d = load_data(a.paths, a.n, a.m, a.labels, a.k, a.h, a.w);
|
||||
translate_data_rows(*a.d, -128);
|
||||
scale_data_rows(*a.d, 1./128);
|
||||
translate_data_rows(*a.d, -128);
|
||||
scale_data_rows(*a.d, 1./128);
|
||||
free(ptr);
|
||||
return 0;
|
||||
}
|
||||
@ -301,9 +303,9 @@ data load_cifar10_data(char *filename)
|
||||
X.vals[i][j] = (double)bytes[j+1];
|
||||
}
|
||||
}
|
||||
translate_data_rows(d, -144);
|
||||
scale_data_rows(d, 1./128);
|
||||
//normalize_data_rows(d);
|
||||
translate_data_rows(d, -144);
|
||||
scale_data_rows(d, 1./128);
|
||||
//normalize_data_rows(d);
|
||||
fclose(fp);
|
||||
return d;
|
||||
}
|
||||
|
@ -17,10 +17,10 @@ void free_data(data d);
|
||||
data load_data(char **paths, int n, int m, char **labels, int k, int h, int w);
|
||||
pthread_t load_data_thread(char **paths, int n, int m, char **labels, int k, int h, int w, data *d);
|
||||
|
||||
pthread_t load_data_detection_thread(int n, char **paths, int m, int h, int w, int nh, int nw, float scale, data *d);
|
||||
pthread_t load_data_detection_thread(int n, char **paths, int m, int h, int w, int nh, int nw, int jitter, data *d);
|
||||
data load_data_detection_jitter_random(int n, char **paths, int m, int h, int w, int nh, int nw, int jitter);
|
||||
data load_data_detection_random(int n, char **paths, int m, int h, int w, int nh, int nw);
|
||||
|
||||
data load_data_detection_random(int n, char **paths, int m, int h, int w, int nh, int nw, float scale);
|
||||
data load_data_detection_jitter_random(int n, char **paths, int m, int h, int w, int nh, int nw, float scale);
|
||||
data load_data_image_pathfile(char *filename, char **labels, int k, int h, int w);
|
||||
data load_cifar10_data(char *filename);
|
||||
data load_all_cifar10();
|
||||
|
104
src/deconvolutional_kernels.cu
Normal file
104
src/deconvolutional_kernels.cu
Normal file
@ -0,0 +1,104 @@
|
||||
extern "C" {
|
||||
#include "convolutional_layer.h"
|
||||
#include "deconvolutional_layer.h"
|
||||
#include "gemm.h"
|
||||
#include "blas.h"
|
||||
#include "im2col.h"
|
||||
#include "col2im.h"
|
||||
#include "utils.h"
|
||||
#include "cuda.h"
|
||||
}
|
||||
|
||||
extern "C" void forward_deconvolutional_layer_gpu(deconvolutional_layer layer, float *in)
|
||||
{
|
||||
int i;
|
||||
int out_h = deconvolutional_out_height(layer);
|
||||
int out_w = deconvolutional_out_width(layer);
|
||||
int size = out_h*out_w;
|
||||
|
||||
int m = layer.size*layer.size*layer.n;
|
||||
int n = layer.h*layer.w;
|
||||
int k = layer.c;
|
||||
|
||||
bias_output_gpu(layer.output_gpu, layer.biases_gpu, layer.batch, layer.n, size);
|
||||
|
||||
for(i = 0; i < layer.batch; ++i){
|
||||
float *a = layer.filters_gpu;
|
||||
float *b = in + i*layer.c*layer.h*layer.w;
|
||||
float *c = layer.col_image_gpu;
|
||||
|
||||
gemm_ongpu(1,0,m,n,k,1,a,m,b,n,0,c,n);
|
||||
|
||||
col2im_ongpu(c, layer.n, out_h, out_w, layer.size, layer.stride, 0, layer.output_gpu+i*layer.n*size);
|
||||
}
|
||||
activate_array(layer.output_gpu, layer.batch*layer.n*size, layer.activation);
|
||||
}
|
||||
|
||||
extern "C" void backward_deconvolutional_layer_gpu(deconvolutional_layer layer, float *in, float *delta_gpu)
|
||||
{
|
||||
float alpha = 1./layer.batch;
|
||||
int out_h = deconvolutional_out_height(layer);
|
||||
int out_w = deconvolutional_out_width(layer);
|
||||
int size = out_h*out_w;
|
||||
int i;
|
||||
|
||||
gradient_array(layer.output_gpu, size*layer.n*layer.batch, layer.activation, layer.delta_gpu);
|
||||
backward_bias(layer.bias_updates_gpu, layer.delta, layer.batch, layer.n, size);
|
||||
|
||||
if(delta_gpu) memset(delta_gpu, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float));
|
||||
|
||||
for(i = 0; i < layer.batch; ++i){
|
||||
int m = layer.c;
|
||||
int n = layer.size*layer.size*layer.n;
|
||||
int k = layer.h*layer.w;
|
||||
|
||||
float *a = in + i*m*n;
|
||||
float *b = layer.col_image_gpu;
|
||||
float *c = layer.filter_updates_gpu;
|
||||
|
||||
im2col_ongpu(layer.delta_gpu + i*layer.n*size, layer.n, out_h, out_w,
|
||||
layer.size, layer.stride, 0, b);
|
||||
gemm_ongpu(0,1,m,n,k,alpha,a,k,b,k,1,c,n);
|
||||
|
||||
if(delta_gpu){
|
||||
int m = layer.c;
|
||||
int n = layer.h*layer.w;
|
||||
int k = layer.size*layer.size*layer.n;
|
||||
|
||||
float *a = layer.filters_gpu;
|
||||
float *b = layer.col_image_gpu;
|
||||
float *c = delta_gpu + i*n*m;
|
||||
|
||||
gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void pull_deconvolutional_layer(deconvolutional_layer layer)
|
||||
{
|
||||
cuda_pull_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_pull_array(layer.biases_gpu, layer.biases, layer.n);
|
||||
cuda_pull_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
|
||||
}
|
||||
|
||||
extern "C" void push_deconvolutional_layer(deconvolutional_layer layer)
|
||||
{
|
||||
cuda_push_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
|
||||
cuda_push_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
|
||||
cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
|
||||
}
|
||||
|
||||
extern "C" void update_deconvolutional_layer_gpu(deconvolutional_layer layer)
|
||||
{
|
||||
int size = layer.size*layer.size*layer.c*layer.n;
|
||||
|
||||
axpy_ongpu(layer.n, layer.learning_rate, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
|
||||
scal_ongpu(layer.n,layer.momentum, layer.bias_updates_gpu, 1);
|
||||
|
||||
axpy_ongpu(size, -layer.decay, layer.filters_gpu, 1, layer.filter_updates_gpu, 1);
|
||||
axpy_ongpu(size, layer.learning_rate, layer.filter_updates_gpu, 1, layer.filters_gpu, 1);
|
||||
scal_ongpu(size, layer.momentum, layer.filter_updates_gpu, 1);
|
||||
}
|
||||
|
200
src/deconvolutional_layer.c
Normal file
200
src/deconvolutional_layer.c
Normal file
@ -0,0 +1,200 @@
|
||||
#include "deconvolutional_layer.h"
|
||||
#include "convolutional_layer.h"
|
||||
#include "utils.h"
|
||||
#include "im2col.h"
|
||||
#include "col2im.h"
|
||||
#include "blas.h"
|
||||
#include "gemm.h"
|
||||
#include <stdio.h>
|
||||
#include <time.h>
|
||||
|
||||
int deconvolutional_out_height(deconvolutional_layer layer)
|
||||
{
|
||||
int h = layer.stride*(layer.h - 1) + layer.size;
|
||||
return h;
|
||||
}
|
||||
|
||||
int deconvolutional_out_width(deconvolutional_layer layer)
|
||||
{
|
||||
int w = layer.stride*(layer.w - 1) + layer.size;
|
||||
return w;
|
||||
}
|
||||
|
||||
int deconvolutional_out_size(deconvolutional_layer layer)
|
||||
{
|
||||
return deconvolutional_out_height(layer) * deconvolutional_out_width(layer);
|
||||
}
|
||||
|
||||
image get_deconvolutional_image(deconvolutional_layer layer)
|
||||
{
|
||||
int h,w,c;
|
||||
h = deconvolutional_out_height(layer);
|
||||
w = deconvolutional_out_width(layer);
|
||||
c = layer.n;
|
||||
return float_to_image(h,w,c,layer.output);
|
||||
}
|
||||
|
||||
image get_deconvolutional_delta(deconvolutional_layer layer)
|
||||
{
|
||||
int h,w,c;
|
||||
h = deconvolutional_out_height(layer);
|
||||
w = deconvolutional_out_width(layer);
|
||||
c = layer.n;
|
||||
return float_to_image(h,w,c,layer.delta);
|
||||
}
|
||||
|
||||
deconvolutional_layer *make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation, float learning_rate, float momentum, float decay)
|
||||
{
|
||||
int i;
|
||||
deconvolutional_layer *layer = calloc(1, sizeof(deconvolutional_layer));
|
||||
|
||||
layer->learning_rate = learning_rate;
|
||||
layer->momentum = momentum;
|
||||
layer->decay = decay;
|
||||
|
||||
layer->h = h;
|
||||
layer->w = w;
|
||||
layer->c = c;
|
||||
layer->n = n;
|
||||
layer->batch = batch;
|
||||
layer->stride = stride;
|
||||
layer->size = size;
|
||||
|
||||
layer->filters = calloc(c*n*size*size, sizeof(float));
|
||||
layer->filter_updates = calloc(c*n*size*size, sizeof(float));
|
||||
|
||||
layer->biases = calloc(n, sizeof(float));
|
||||
layer->bias_updates = calloc(n, sizeof(float));
|
||||
float scale = 1./sqrt(size*size*c);
|
||||
for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*rand_normal();
|
||||
for(i = 0; i < n; ++i){
|
||||
layer->biases[i] = scale;
|
||||
}
|
||||
int out_h = deconvolutional_out_height(*layer);
|
||||
int out_w = deconvolutional_out_width(*layer);
|
||||
|
||||
layer->col_image = calloc(h*w*size*size*n, sizeof(float));
|
||||
layer->output = calloc(layer->batch*out_h * out_w * n, sizeof(float));
|
||||
layer->delta = calloc(layer->batch*out_h * out_w * n, sizeof(float));
|
||||
|
||||
#ifdef GPU
|
||||
layer->filters_gpu = cuda_make_array(layer->filters, c*n*size*size);
|
||||
layer->filter_updates_gpu = cuda_make_array(layer->filter_updates, c*n*size*size);
|
||||
|
||||
layer->biases_gpu = cuda_make_array(layer->biases, n);
|
||||
layer->bias_updates_gpu = cuda_make_array(layer->bias_updates, n);
|
||||
|
||||
layer->col_image_gpu = cuda_make_array(layer->col_image, h*w*size*size*n);
|
||||
layer->delta_gpu = cuda_make_array(layer->delta, layer->batch*out_h*out_w*n);
|
||||
layer->output_gpu = cuda_make_array(layer->output, layer->batch*out_h*out_w*n);
|
||||
#endif
|
||||
|
||||
layer->activation = activation;
|
||||
|
||||
fprintf(stderr, "Deconvolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n);
|
||||
|
||||
return layer;
|
||||
}
|
||||
|
||||
void resize_deconvolutional_layer(deconvolutional_layer *layer, int h, int w)
|
||||
{
|
||||
layer->h = h;
|
||||
layer->w = w;
|
||||
int out_h = deconvolutional_out_height(*layer);
|
||||
int out_w = deconvolutional_out_width(*layer);
|
||||
|
||||
layer->col_image = realloc(layer->col_image,
|
||||
out_h*out_w*layer->size*layer->size*layer->c*sizeof(float));
|
||||
layer->output = realloc(layer->output,
|
||||
layer->batch*out_h * out_w * layer->n*sizeof(float));
|
||||
layer->delta = realloc(layer->delta,
|
||||
layer->batch*out_h * out_w * layer->n*sizeof(float));
|
||||
#ifdef GPU
|
||||
cuda_free(layer->col_image_gpu);
|
||||
cuda_free(layer->delta_gpu);
|
||||
cuda_free(layer->output_gpu);
|
||||
|
||||
layer->col_image_gpu = cuda_make_array(layer->col_image, out_h*out_w*layer->size*layer->size*layer->c);
|
||||
layer->delta_gpu = cuda_make_array(layer->delta, layer->batch*out_h*out_w*layer->n);
|
||||
layer->output_gpu = cuda_make_array(layer->output, layer->batch*out_h*out_w*layer->n);
|
||||
#endif
|
||||
}
|
||||
|
||||
void forward_deconvolutional_layer(const deconvolutional_layer layer, float *in)
|
||||
{
|
||||
int i;
|
||||
int out_h = deconvolutional_out_height(layer);
|
||||
int out_w = deconvolutional_out_width(layer);
|
||||
int size = out_h*out_w;
|
||||
|
||||
int m = layer.size*layer.size*layer.n;
|
||||
int n = layer.h*layer.w;
|
||||
int k = layer.c;
|
||||
|
||||
bias_output(layer.output, layer.biases, layer.batch, layer.n, size);
|
||||
|
||||
for(i = 0; i < layer.batch; ++i){
|
||||
float *a = layer.filters;
|
||||
float *b = in + i*layer.c*layer.h*layer.w;
|
||||
float *c = layer.col_image;
|
||||
|
||||
gemm(1,0,m,n,k,1,a,m,b,n,0,c,n);
|
||||
|
||||
col2im_cpu(c, layer.n, out_h, out_w, layer.size, layer.stride, 0, layer.output+i*layer.n*size);
|
||||
}
|
||||
activate_array(layer.output, layer.batch*layer.n*size, layer.activation);
|
||||
}
|
||||
|
||||
void backward_deconvolutional_layer(deconvolutional_layer layer, float *in, float *delta)
|
||||
{
|
||||
float alpha = 1./layer.batch;
|
||||
int out_h = deconvolutional_out_height(layer);
|
||||
int out_w = deconvolutional_out_width(layer);
|
||||
int size = out_h*out_w;
|
||||
int i;
|
||||
|
||||
gradient_array(layer.output, size*layer.n*layer.batch, layer.activation, layer.delta);
|
||||
backward_bias(layer.bias_updates, layer.delta, layer.batch, layer.n, size);
|
||||
|
||||
if(delta) memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float));
|
||||
|
||||
for(i = 0; i < layer.batch; ++i){
|
||||
int m = layer.c;
|
||||
int n = layer.size*layer.size*layer.n;
|
||||
int k = layer.h*layer.w;
|
||||
|
||||
float *a = in + i*m*n;
|
||||
float *b = layer.col_image;
|
||||
float *c = layer.filter_updates;
|
||||
|
||||
im2col_cpu(layer.delta + i*layer.n*size, layer.n, out_h, out_w,
|
||||
layer.size, layer.stride, 0, b);
|
||||
gemm(0,1,m,n,k,alpha,a,k,b,k,1,c,n);
|
||||
|
||||
if(delta){
|
||||
int m = layer.c;
|
||||
int n = layer.h*layer.w;
|
||||
int k = layer.size*layer.size*layer.n;
|
||||
|
||||
float *a = layer.filters;
|
||||
float *b = layer.col_image;
|
||||
float *c = delta + i*n*m;
|
||||
|
||||
gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void update_deconvolutional_layer(deconvolutional_layer layer)
|
||||
{
|
||||
int size = layer.size*layer.size*layer.c*layer.n;
|
||||
axpy_cpu(layer.n, layer.learning_rate, layer.bias_updates, 1, layer.biases, 1);
|
||||
scal_cpu(layer.n, layer.momentum, layer.bias_updates, 1);
|
||||
|
||||
axpy_cpu(size, -layer.decay, layer.filters, 1, layer.filter_updates, 1);
|
||||
axpy_cpu(size, layer.learning_rate, layer.filter_updates, 1, layer.filters, 1);
|
||||
scal_cpu(size, layer.momentum, layer.filter_updates, 1);
|
||||
}
|
||||
|
||||
|
||||
|
65
src/deconvolutional_layer.h
Normal file
65
src/deconvolutional_layer.h
Normal file
@ -0,0 +1,65 @@
|
||||
#ifndef DECONVOLUTIONAL_LAYER_H
|
||||
#define DECONVOLUTIONAL_LAYER_H
|
||||
|
||||
#include "cuda.h"
|
||||
#include "image.h"
|
||||
#include "activations.h"
|
||||
|
||||
typedef struct {
|
||||
float learning_rate;
|
||||
float momentum;
|
||||
float decay;
|
||||
|
||||
int batch;
|
||||
int h,w,c;
|
||||
int n;
|
||||
int size;
|
||||
int stride;
|
||||
float *filters;
|
||||
float *filter_updates;
|
||||
|
||||
float *biases;
|
||||
float *bias_updates;
|
||||
|
||||
float *col_image;
|
||||
float *delta;
|
||||
float *output;
|
||||
|
||||
#ifdef GPU
|
||||
float * filters_gpu;
|
||||
float * filter_updates_gpu;
|
||||
|
||||
float * biases_gpu;
|
||||
float * bias_updates_gpu;
|
||||
|
||||
float * col_image_gpu;
|
||||
float * delta_gpu;
|
||||
float * output_gpu;
|
||||
#endif
|
||||
|
||||
ACTIVATION activation;
|
||||
} deconvolutional_layer;
|
||||
|
||||
#ifdef GPU
|
||||
void forward_deconvolutional_layer_gpu(deconvolutional_layer layer, float * in);
|
||||
void backward_deconvolutional_layer_gpu(deconvolutional_layer layer, float * in, float * delta_gpu);
|
||||
void update_deconvolutional_layer_gpu(deconvolutional_layer layer);
|
||||
void push_deconvolutional_layer(deconvolutional_layer layer);
|
||||
void pull_deconvolutional_layer(deconvolutional_layer layer);
|
||||
#endif
|
||||
|
||||
deconvolutional_layer *make_deconvolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation, float learning_rate, float momentum, float decay);
|
||||
void resize_deconvolutional_layer(deconvolutional_layer *layer, int h, int w);
|
||||
void forward_deconvolutional_layer(const deconvolutional_layer layer, float *in);
|
||||
void update_deconvolutional_layer(deconvolutional_layer layer);
|
||||
void backward_deconvolutional_layer(deconvolutional_layer layer, float *in, float *delta);
|
||||
|
||||
image get_deconvolutional_image(deconvolutional_layer layer);
|
||||
image get_deconvolutional_delta(deconvolutional_layer layer);
|
||||
image get_deconvolutional_filter(deconvolutional_layer layer, int i);
|
||||
|
||||
int deconvolutional_out_height(deconvolutional_layer layer);
|
||||
int deconvolutional_out_width(deconvolutional_layer layer);
|
||||
|
||||
#endif
|
||||
|
@ -21,6 +21,19 @@ dropout_layer *make_dropout_layer(int batch, int inputs, float probability)
|
||||
return layer;
|
||||
}
|
||||
|
||||
void resize_dropout_layer(dropout_layer *layer, int inputs)
|
||||
{
|
||||
layer->output = realloc(layer->output, layer->inputs*layer->batch*sizeof(float));
|
||||
layer->rand = realloc(layer->rand, layer->inputs*layer->batch*sizeof(float));
|
||||
#ifdef GPU
|
||||
cuda_free(layer->output_gpu);
|
||||
cuda_free(layer->rand_gpu);
|
||||
|
||||
layer->output_gpu = cuda_make_array(layer->output, inputs*layer->batch);
|
||||
layer->rand_gpu = cuda_make_array(layer->rand, inputs*layer->batch);
|
||||
#endif
|
||||
}
|
||||
|
||||
void forward_dropout_layer(dropout_layer layer, float *input)
|
||||
{
|
||||
int i;
|
||||
|
@ -18,6 +18,7 @@ dropout_layer *make_dropout_layer(int batch, int inputs, float probability);
|
||||
|
||||
void forward_dropout_layer(dropout_layer layer, float *input);
|
||||
void backward_dropout_layer(dropout_layer layer, float *delta);
|
||||
void resize_dropout_layer(dropout_layer *layer, int inputs);
|
||||
|
||||
#ifdef GPU
|
||||
void forward_dropout_layer_gpu(dropout_layer layer, float * input);
|
||||
|
@ -7,7 +7,7 @@ void im2col_cpu(float* data_im,
|
||||
|
||||
#ifdef GPU
|
||||
|
||||
void im2col_ongpu(float *im, int offset,
|
||||
void im2col_ongpu(float *im,
|
||||
int channels, int height, int width,
|
||||
int ksize, int stride, int pad,float *data_col);
|
||||
|
||||
|
@ -3,7 +3,7 @@ extern "C" {
|
||||
#include "cuda.h"
|
||||
}
|
||||
|
||||
__global__ void im2col_pad_kernel(float *im, int offset,
|
||||
__global__ void im2col_pad_kernel(float *im,
|
||||
int channels, int height, int width,
|
||||
int ksize, int stride, float *data_col)
|
||||
{
|
||||
@ -32,13 +32,13 @@ __global__ void im2col_pad_kernel(float *im, int offset,
|
||||
int im_row = h_offset + h * stride - pad;
|
||||
int im_col = w_offset + w * stride - pad;
|
||||
|
||||
int im_index = offset + im_col + width*(im_row + height*im_channel);
|
||||
int im_index = im_col + width*(im_row + height*im_channel);
|
||||
float val = (im_row < 0 || im_col < 0 || im_row >= height || im_col >= width) ? 0 : im[im_index];
|
||||
|
||||
data_col[col_index] = val;
|
||||
}
|
||||
|
||||
__global__ void im2col_nopad_kernel(float *im, int offset,
|
||||
__global__ void im2col_nopad_kernel(float *im,
|
||||
int channels, int height, int width,
|
||||
int ksize, int stride, float *data_col)
|
||||
{
|
||||
@ -65,13 +65,13 @@ __global__ void im2col_nopad_kernel(float *im, int offset,
|
||||
int im_row = h_offset + h * stride;
|
||||
int im_col = w_offset + w * stride;
|
||||
|
||||
int im_index = offset + im_col + width*(im_row + height*im_channel);
|
||||
int im_index = im_col + width*(im_row + height*im_channel);
|
||||
float val = (im_row < 0 || im_col < 0 || im_row >= height || im_col >= width) ? 0 : im[im_index];
|
||||
|
||||
data_col[col_index] = val;
|
||||
}
|
||||
|
||||
extern "C" void im2col_ongpu(float *im, int offset,
|
||||
extern "C" void im2col_ongpu(float *im,
|
||||
int channels, int height, int width,
|
||||
int ksize, int stride, int pad, float *data_col)
|
||||
{
|
||||
@ -87,7 +87,7 @@ extern "C" void im2col_ongpu(float *im, int offset,
|
||||
|
||||
size_t n = channels_col*height_col*width_col;
|
||||
|
||||
if(pad)im2col_pad_kernel<<<cuda_gridsize(n),BLOCK>>>(im, offset, channels, height, width, ksize, stride, data_col);
|
||||
else im2col_nopad_kernel<<<cuda_gridsize(n),BLOCK>>>(im, offset, channels, height, width, ksize, stride, data_col);
|
||||
if(pad)im2col_pad_kernel<<<cuda_gridsize(n),BLOCK>>>(im, channels, height, width, ksize, stride, data_col);
|
||||
else im2col_nopad_kernel<<<cuda_gridsize(n),BLOCK>>>(im, channels, height, width, ksize, stride, data_col);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
@ -40,13 +40,22 @@ maxpool_layer *make_maxpool_layer(int batch, int h, int w, int c, int size, int
|
||||
return layer;
|
||||
}
|
||||
|
||||
void resize_maxpool_layer(maxpool_layer *layer, int h, int w, int c)
|
||||
void resize_maxpool_layer(maxpool_layer *layer, int h, int w)
|
||||
{
|
||||
layer->h = h;
|
||||
layer->w = w;
|
||||
layer->c = c;
|
||||
layer->output = realloc(layer->output, ((h-1)/layer->stride+1) * ((w-1)/layer->stride+1) * c * layer->batch* sizeof(float));
|
||||
layer->delta = realloc(layer->delta, ((h-1)/layer->stride+1) * ((w-1)/layer->stride+1) * c * layer->batch*sizeof(float));
|
||||
int output_size = ((h-1)/layer->stride+1) * ((w-1)/layer->stride+1) * layer->c * layer->batch;
|
||||
layer->output = realloc(layer->output, output_size * sizeof(float));
|
||||
layer->delta = realloc(layer->delta, output_size * sizeof(float));
|
||||
|
||||
#ifdef GPU
|
||||
cuda_free((float *)layer->indexes_gpu);
|
||||
cuda_free(layer->output_gpu);
|
||||
cuda_free(layer->delta_gpu);
|
||||
layer->indexes_gpu = cuda_make_int_array(output_size);
|
||||
layer->output_gpu = cuda_make_array(layer->output, output_size);
|
||||
layer->delta_gpu = cuda_make_array(layer->delta, output_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
void forward_maxpool_layer(const maxpool_layer layer, float *input)
|
||||
|
@ -21,7 +21,7 @@ typedef struct {
|
||||
|
||||
image get_maxpool_image(maxpool_layer layer);
|
||||
maxpool_layer *make_maxpool_layer(int batch, int h, int w, int c, int size, int stride);
|
||||
void resize_maxpool_layer(maxpool_layer *layer, int h, int w, int c);
|
||||
void resize_maxpool_layer(maxpool_layer *layer, int h, int w);
|
||||
void forward_maxpool_layer(const maxpool_layer layer, float *input);
|
||||
void backward_maxpool_layer(const maxpool_layer layer, float *delta);
|
||||
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include "crop_layer.h"
|
||||
#include "connected_layer.h"
|
||||
#include "convolutional_layer.h"
|
||||
#include "deconvolutional_layer.h"
|
||||
#include "maxpool_layer.h"
|
||||
#include "cost_layer.h"
|
||||
#include "normalization_layer.h"
|
||||
@ -20,6 +21,8 @@ char *get_layer_string(LAYER_TYPE a)
|
||||
switch(a){
|
||||
case CONVOLUTIONAL:
|
||||
return "convolutional";
|
||||
case DECONVOLUTIONAL:
|
||||
return "deconvolutional";
|
||||
case CONNECTED:
|
||||
return "connected";
|
||||
case MAXPOOL:
|
||||
@ -68,6 +71,11 @@ void forward_network(network net, float *input, float *truth, int train)
|
||||
forward_convolutional_layer(layer, input);
|
||||
input = layer.output;
|
||||
}
|
||||
else if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
|
||||
forward_deconvolutional_layer(layer, input);
|
||||
input = layer.output;
|
||||
}
|
||||
else if(net.types[i] == CONNECTED){
|
||||
connected_layer layer = *(connected_layer *)net.layers[i];
|
||||
forward_connected_layer(layer, input);
|
||||
@ -122,14 +130,9 @@ void update_network(network net)
|
||||
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
|
||||
update_convolutional_layer(layer);
|
||||
}
|
||||
else if(net.types[i] == MAXPOOL){
|
||||
//maxpool_layer layer = *(maxpool_layer *)net.layers[i];
|
||||
}
|
||||
else if(net.types[i] == SOFTMAX){
|
||||
//maxpool_layer layer = *(maxpool_layer *)net.layers[i];
|
||||
}
|
||||
else if(net.types[i] == NORMALIZATION){
|
||||
//maxpool_layer layer = *(maxpool_layer *)net.layers[i];
|
||||
else if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
|
||||
update_deconvolutional_layer(layer);
|
||||
}
|
||||
else if(net.types[i] == CONNECTED){
|
||||
connected_layer layer = *(connected_layer *)net.layers[i];
|
||||
@ -143,6 +146,9 @@ float *get_network_output_layer(network net, int i)
|
||||
if(net.types[i] == CONVOLUTIONAL){
|
||||
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
|
||||
return layer.output;
|
||||
} else if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
|
||||
return layer.output;
|
||||
} else if(net.types[i] == MAXPOOL){
|
||||
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
|
||||
return layer.output;
|
||||
@ -178,6 +184,9 @@ float *get_network_delta_layer(network net, int i)
|
||||
if(net.types[i] == CONVOLUTIONAL){
|
||||
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
|
||||
return layer.delta;
|
||||
} else if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
|
||||
return layer.delta;
|
||||
} else if(net.types[i] == MAXPOOL){
|
||||
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
|
||||
return layer.delta;
|
||||
@ -247,9 +256,13 @@ void backward_network(network net, float *input)
|
||||
prev_input = get_network_output_layer(net, i-1);
|
||||
prev_delta = get_network_delta_layer(net, i-1);
|
||||
}
|
||||
|
||||
if(net.types[i] == CONVOLUTIONAL){
|
||||
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
|
||||
backward_convolutional_layer(layer, prev_input, prev_delta);
|
||||
} else if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
|
||||
backward_deconvolutional_layer(layer, prev_input, prev_delta);
|
||||
}
|
||||
else if(net.types[i] == MAXPOOL){
|
||||
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
|
||||
@ -377,6 +390,9 @@ void set_batch_network(network *net, int b)
|
||||
if(net->types[i] == CONVOLUTIONAL){
|
||||
convolutional_layer *layer = (convolutional_layer *)net->layers[i];
|
||||
layer->batch = b;
|
||||
}else if(net->types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer *layer = (deconvolutional_layer *)net->layers[i];
|
||||
layer->batch = b;
|
||||
}
|
||||
else if(net->types[i] == MAXPOOL){
|
||||
maxpool_layer *layer = (maxpool_layer *)net->layers[i];
|
||||
@ -415,6 +431,10 @@ int get_network_input_size_layer(network net, int i)
|
||||
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
|
||||
return layer.h*layer.w*layer.c;
|
||||
}
|
||||
if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
|
||||
return layer.h*layer.w*layer.c;
|
||||
}
|
||||
else if(net.types[i] == MAXPOOL){
|
||||
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
|
||||
return layer.h*layer.w*layer.c;
|
||||
@ -448,6 +468,11 @@ int get_network_output_size_layer(network net, int i)
|
||||
image output = get_convolutional_image(layer);
|
||||
return output.h*output.w*output.c;
|
||||
}
|
||||
else if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
|
||||
image output = get_deconvolutional_image(layer);
|
||||
return output.h*output.w*output.c;
|
||||
}
|
||||
else if(net.types[i] == MAXPOOL){
|
||||
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
|
||||
image output = get_maxpool_image(layer);
|
||||
@ -483,21 +508,31 @@ int resize_network(network net, int h, int w, int c)
|
||||
for (i = 0; i < net.n; ++i){
|
||||
if(net.types[i] == CONVOLUTIONAL){
|
||||
convolutional_layer *layer = (convolutional_layer *)net.layers[i];
|
||||
resize_convolutional_layer(layer, h, w, c);
|
||||
resize_convolutional_layer(layer, h, w);
|
||||
image output = get_convolutional_image(*layer);
|
||||
h = output.h;
|
||||
w = output.w;
|
||||
c = output.c;
|
||||
} else if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer *layer = (deconvolutional_layer *)net.layers[i];
|
||||
resize_deconvolutional_layer(layer, h, w);
|
||||
image output = get_deconvolutional_image(*layer);
|
||||
h = output.h;
|
||||
w = output.w;
|
||||
c = output.c;
|
||||
}else if(net.types[i] == MAXPOOL){
|
||||
maxpool_layer *layer = (maxpool_layer *)net.layers[i];
|
||||
resize_maxpool_layer(layer, h, w, c);
|
||||
resize_maxpool_layer(layer, h, w);
|
||||
image output = get_maxpool_image(*layer);
|
||||
h = output.h;
|
||||
w = output.w;
|
||||
c = output.c;
|
||||
}else if(net.types[i] == DROPOUT){
|
||||
dropout_layer *layer = (dropout_layer *)net.layers[i];
|
||||
resize_dropout_layer(layer, h*w*c);
|
||||
}else if(net.types[i] == NORMALIZATION){
|
||||
normalization_layer *layer = (normalization_layer *)net.layers[i];
|
||||
resize_normalization_layer(layer, h, w, c);
|
||||
resize_normalization_layer(layer, h, w);
|
||||
image output = get_normalization_image(*layer);
|
||||
h = output.h;
|
||||
w = output.w;
|
||||
@ -527,6 +562,10 @@ image get_network_image_layer(network net, int i)
|
||||
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
|
||||
return get_convolutional_image(layer);
|
||||
}
|
||||
else if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
|
||||
return get_deconvolutional_image(layer);
|
||||
}
|
||||
else if(net.types[i] == MAXPOOL){
|
||||
maxpool_layer layer = *(maxpool_layer *)net.layers[i];
|
||||
return get_maxpool_image(layer);
|
||||
|
@ -7,6 +7,7 @@
|
||||
|
||||
typedef enum {
|
||||
CONVOLUTIONAL,
|
||||
DECONVOLUTIONAL,
|
||||
CONNECTED,
|
||||
MAXPOOL,
|
||||
SOFTMAX,
|
||||
|
@ -10,6 +10,7 @@ extern "C" {
|
||||
#include "crop_layer.h"
|
||||
#include "connected_layer.h"
|
||||
#include "convolutional_layer.h"
|
||||
#include "deconvolutional_layer.h"
|
||||
#include "maxpool_layer.h"
|
||||
#include "cost_layer.h"
|
||||
#include "normalization_layer.h"
|
||||
@ -31,6 +32,11 @@ void forward_network_gpu(network net, float * input, float * truth, int train)
|
||||
forward_convolutional_layer_gpu(layer, input);
|
||||
input = layer.output_gpu;
|
||||
}
|
||||
else if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
|
||||
forward_deconvolutional_layer_gpu(layer, input);
|
||||
input = layer.output_gpu;
|
||||
}
|
||||
else if(net.types[i] == COST){
|
||||
cost_layer layer = *(cost_layer *)net.layers[i];
|
||||
forward_cost_layer_gpu(layer, input, truth);
|
||||
@ -84,6 +90,10 @@ void backward_network_gpu(network net, float * input)
|
||||
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
|
||||
backward_convolutional_layer_gpu(layer, prev_input, prev_delta);
|
||||
}
|
||||
else if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
|
||||
backward_deconvolutional_layer_gpu(layer, prev_input, prev_delta);
|
||||
}
|
||||
else if(net.types[i] == COST){
|
||||
cost_layer layer = *(cost_layer *)net.layers[i];
|
||||
backward_cost_layer_gpu(layer, prev_input, prev_delta);
|
||||
@ -116,6 +126,10 @@ void update_network_gpu(network net)
|
||||
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
|
||||
update_convolutional_layer_gpu(layer);
|
||||
}
|
||||
else if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
|
||||
update_deconvolutional_layer_gpu(layer);
|
||||
}
|
||||
else if(net.types[i] == CONNECTED){
|
||||
connected_layer layer = *(connected_layer *)net.layers[i];
|
||||
update_connected_layer_gpu(layer);
|
||||
@ -129,6 +143,10 @@ float * get_network_output_gpu_layer(network net, int i)
|
||||
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
|
||||
return layer.output_gpu;
|
||||
}
|
||||
else if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
|
||||
return layer.output_gpu;
|
||||
}
|
||||
else if(net.types[i] == CONNECTED){
|
||||
connected_layer layer = *(connected_layer *)net.layers[i];
|
||||
return layer.output_gpu;
|
||||
@ -157,6 +175,10 @@ float * get_network_delta_gpu_layer(network net, int i)
|
||||
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
|
||||
return layer.delta_gpu;
|
||||
}
|
||||
else if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
|
||||
return layer.delta_gpu;
|
||||
}
|
||||
else if(net.types[i] == CONNECTED){
|
||||
connected_layer layer = *(connected_layer *)net.layers[i];
|
||||
return layer.delta_gpu;
|
||||
@ -208,6 +230,10 @@ float *get_network_output_layer_gpu(network net, int i)
|
||||
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
|
||||
return layer.output;
|
||||
}
|
||||
else if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
|
||||
return layer.output;
|
||||
}
|
||||
else if(net.types[i] == CONNECTED){
|
||||
connected_layer layer = *(connected_layer *)net.layers[i];
|
||||
cuda_pull_array(layer.output_gpu, layer.output, layer.outputs*layer.batch);
|
||||
|
@ -35,13 +35,12 @@ normalization_layer *make_normalization_layer(int batch, int h, int w, int c, in
|
||||
return layer;
|
||||
}
|
||||
|
||||
void resize_normalization_layer(normalization_layer *layer, int h, int w, int c)
|
||||
void resize_normalization_layer(normalization_layer *layer, int h, int w)
|
||||
{
|
||||
layer->h = h;
|
||||
layer->w = w;
|
||||
layer->c = c;
|
||||
layer->output = realloc(layer->output, h * w * c * layer->batch * sizeof(float));
|
||||
layer->delta = realloc(layer->delta, h * w * c * layer->batch * sizeof(float));
|
||||
layer->output = realloc(layer->output, h * w * layer->c * layer->batch * sizeof(float));
|
||||
layer->delta = realloc(layer->delta, h * w * layer->c * layer->batch * sizeof(float));
|
||||
layer->sums = realloc(layer->sums, h*w * sizeof(float));
|
||||
}
|
||||
|
||||
|
@ -17,7 +17,7 @@ typedef struct {
|
||||
|
||||
image get_normalization_image(normalization_layer layer);
|
||||
normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa);
|
||||
void resize_normalization_layer(normalization_layer *layer, int h, int w, int c);
|
||||
void resize_normalization_layer(normalization_layer *layer, int h, int w);
|
||||
void forward_normalization_layer(const normalization_layer layer, float *in);
|
||||
void backward_normalization_layer(const normalization_layer layer, float *in, float *delta);
|
||||
void visualize_normalization_layer(normalization_layer layer, char *window);
|
||||
|
123
src/parser.c
123
src/parser.c
@ -7,6 +7,7 @@
|
||||
#include "crop_layer.h"
|
||||
#include "cost_layer.h"
|
||||
#include "convolutional_layer.h"
|
||||
#include "deconvolutional_layer.h"
|
||||
#include "connected_layer.h"
|
||||
#include "maxpool_layer.h"
|
||||
#include "normalization_layer.h"
|
||||
@ -23,6 +24,7 @@ typedef struct{
|
||||
}section;
|
||||
|
||||
int is_convolutional(section *s);
|
||||
int is_deconvolutional(section *s);
|
||||
int is_connected(section *s);
|
||||
int is_maxpool(section *s);
|
||||
int is_dropout(section *s);
|
||||
@ -65,6 +67,49 @@ void parse_data(char *data, float *a, int n)
|
||||
}
|
||||
}
|
||||
|
||||
deconvolutional_layer *parse_deconvolutional(list *options, network *net, int count)
|
||||
{
|
||||
int h,w,c;
|
||||
float learning_rate, momentum, decay;
|
||||
int n = option_find_int(options, "filters",1);
|
||||
int size = option_find_int(options, "size",1);
|
||||
int stride = option_find_int(options, "stride",1);
|
||||
char *activation_s = option_find_str(options, "activation", "sigmoid");
|
||||
ACTIVATION activation = get_activation(activation_s);
|
||||
if(count == 0){
|
||||
learning_rate = option_find_float(options, "learning_rate", .001);
|
||||
momentum = option_find_float(options, "momentum", .9);
|
||||
decay = option_find_float(options, "decay", .0001);
|
||||
h = option_find_int(options, "height",1);
|
||||
w = option_find_int(options, "width",1);
|
||||
c = option_find_int(options, "channels",1);
|
||||
net->batch = option_find_int(options, "batch",1);
|
||||
net->learning_rate = learning_rate;
|
||||
net->momentum = momentum;
|
||||
net->decay = decay;
|
||||
net->seen = option_find_int(options, "seen",0);
|
||||
}else{
|
||||
learning_rate = option_find_float_quiet(options, "learning_rate", net->learning_rate);
|
||||
momentum = option_find_float_quiet(options, "momentum", net->momentum);
|
||||
decay = option_find_float_quiet(options, "decay", net->decay);
|
||||
image m = get_network_image_layer(*net, count-1);
|
||||
h = m.h;
|
||||
w = m.w;
|
||||
c = m.c;
|
||||
if(h == 0) error("Layer before deconvolutional layer must output image.");
|
||||
}
|
||||
deconvolutional_layer *layer = make_deconvolutional_layer(net->batch,h,w,c,n,size,stride,activation,learning_rate,momentum,decay);
|
||||
char *weights = option_find_str(options, "weights", 0);
|
||||
char *biases = option_find_str(options, "biases", 0);
|
||||
parse_data(weights, layer->filters, c*n*size*size);
|
||||
parse_data(biases, layer->biases, n);
|
||||
#ifdef GPU
|
||||
if(weights || biases) push_deconvolutional_layer(*layer);
|
||||
#endif
|
||||
option_unused(options);
|
||||
return layer;
|
||||
}
|
||||
|
||||
convolutional_layer *parse_convolutional(list *options, network *net, int count)
|
||||
{
|
||||
int h,w,c;
|
||||
@ -306,6 +351,10 @@ network parse_network_cfg(char *filename)
|
||||
convolutional_layer *layer = parse_convolutional(options, &net, count);
|
||||
net.types[count] = CONVOLUTIONAL;
|
||||
net.layers[count] = layer;
|
||||
}else if(is_deconvolutional(s)){
|
||||
deconvolutional_layer *layer = parse_deconvolutional(options, &net, count);
|
||||
net.types[count] = DECONVOLUTIONAL;
|
||||
net.layers[count] = layer;
|
||||
}else if(is_connected(s)){
|
||||
connected_layer *layer = parse_connected(options, &net, count);
|
||||
net.types[count] = CONNECTED;
|
||||
@ -360,6 +409,11 @@ int is_cost(section *s)
|
||||
{
|
||||
return (strcmp(s->type, "[cost]")==0);
|
||||
}
|
||||
int is_deconvolutional(section *s)
|
||||
{
|
||||
return (strcmp(s->type, "[deconv]")==0
|
||||
|| strcmp(s->type, "[deconvolutional]")==0);
|
||||
}
|
||||
int is_convolutional(section *s)
|
||||
{
|
||||
return (strcmp(s->type, "[conv]")==0
|
||||
@ -438,7 +492,7 @@ list *read_cfg(char *filename)
|
||||
break;
|
||||
default:
|
||||
if(!read_option(line, current->options)){
|
||||
printf("Config file error line %d, could parse: %s\n", nu, line);
|
||||
fprintf(stderr, "Config file error line %d, could parse: %s\n", nu, line);
|
||||
free(line);
|
||||
}
|
||||
break;
|
||||
@ -488,6 +542,45 @@ void print_convolutional_cfg(FILE *fp, convolutional_layer *l, network net, int
|
||||
fprintf(fp, "\n\n");
|
||||
}
|
||||
|
||||
void print_deconvolutional_cfg(FILE *fp, deconvolutional_layer *l, network net, int count)
|
||||
{
|
||||
#ifdef GPU
|
||||
if(gpu_index >= 0) pull_deconvolutional_layer(*l);
|
||||
#endif
|
||||
int i;
|
||||
fprintf(fp, "[deconvolutional]\n");
|
||||
if(count == 0) {
|
||||
fprintf(fp, "batch=%d\n"
|
||||
"height=%d\n"
|
||||
"width=%d\n"
|
||||
"channels=%d\n"
|
||||
"learning_rate=%g\n"
|
||||
"momentum=%g\n"
|
||||
"decay=%g\n"
|
||||
"seen=%d\n",
|
||||
l->batch,l->h, l->w, l->c, l->learning_rate, l->momentum, l->decay, net.seen);
|
||||
} else {
|
||||
if(l->learning_rate != net.learning_rate)
|
||||
fprintf(fp, "learning_rate=%g\n", l->learning_rate);
|
||||
if(l->momentum != net.momentum)
|
||||
fprintf(fp, "momentum=%g\n", l->momentum);
|
||||
if(l->decay != net.decay)
|
||||
fprintf(fp, "decay=%g\n", l->decay);
|
||||
}
|
||||
fprintf(fp, "filters=%d\n"
|
||||
"size=%d\n"
|
||||
"stride=%d\n"
|
||||
"activation=%s\n",
|
||||
l->n, l->size, l->stride,
|
||||
get_activation_string(l->activation));
|
||||
fprintf(fp, "biases=");
|
||||
for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]);
|
||||
fprintf(fp, "\n");
|
||||
fprintf(fp, "weights=");
|
||||
for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]);
|
||||
fprintf(fp, "\n\n");
|
||||
}
|
||||
|
||||
void print_freeweight_cfg(FILE *fp, freeweight_layer *l, network net, int count)
|
||||
{
|
||||
fprintf(fp, "[freeweight]\n");
|
||||
@ -599,7 +692,7 @@ void print_cost_cfg(FILE *fp, cost_layer *l, network net, int count)
|
||||
|
||||
void save_weights(network net, char *filename)
|
||||
{
|
||||
printf("Saving weights to %s\n", filename);
|
||||
fprintf(stderr, "Saving weights to %s\n", filename);
|
||||
FILE *fp = fopen(filename, "w");
|
||||
if(!fp) file_error(filename);
|
||||
|
||||
@ -621,6 +714,17 @@ void save_weights(network net, char *filename)
|
||||
fwrite(layer.biases, sizeof(float), layer.n, fp);
|
||||
fwrite(layer.filters, sizeof(float), num, fp);
|
||||
}
|
||||
if(net.types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *) net.layers[i];
|
||||
#ifdef GPU
|
||||
if(gpu_index >= 0){
|
||||
pull_deconvolutional_layer(layer);
|
||||
}
|
||||
#endif
|
||||
int num = layer.n*layer.c*layer.size*layer.size;
|
||||
fwrite(layer.biases, sizeof(float), layer.n, fp);
|
||||
fwrite(layer.filters, sizeof(float), num, fp);
|
||||
}
|
||||
if(net.types[i] == CONNECTED){
|
||||
connected_layer layer = *(connected_layer *) net.layers[i];
|
||||
#ifdef GPU
|
||||
@ -637,7 +741,7 @@ void save_weights(network net, char *filename)
|
||||
|
||||
void load_weights(network *net, char *filename)
|
||||
{
|
||||
printf("Loading weights from %s\n", filename);
|
||||
fprintf(stderr, "Loading weights from %s\n", filename);
|
||||
FILE *fp = fopen(filename, "r");
|
||||
if(!fp) file_error(filename);
|
||||
|
||||
@ -660,6 +764,17 @@ void load_weights(network *net, char *filename)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if(net->types[i] == DECONVOLUTIONAL){
|
||||
deconvolutional_layer layer = *(deconvolutional_layer *) net->layers[i];
|
||||
int num = layer.n*layer.c*layer.size*layer.size;
|
||||
fread(layer.biases, sizeof(float), layer.n, fp);
|
||||
fread(layer.filters, sizeof(float), num, fp);
|
||||
#ifdef GPU
|
||||
if(gpu_index >= 0){
|
||||
push_deconvolutional_layer(layer);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if(net->types[i] == CONNECTED){
|
||||
connected_layer layer = *(connected_layer *) net->layers[i];
|
||||
fread(layer.biases, sizeof(float), layer.outputs, fp);
|
||||
@ -683,6 +798,8 @@ void save_network(network net, char *filename)
|
||||
{
|
||||
if(net.types[i] == CONVOLUTIONAL)
|
||||
print_convolutional_cfg(fp, (convolutional_layer *)net.layers[i], net, i);
|
||||
else if(net.types[i] == DECONVOLUTIONAL)
|
||||
print_deconvolutional_cfg(fp, (deconvolutional_layer *)net.layers[i], net, i);
|
||||
else if(net.types[i] == CONNECTED)
|
||||
print_connected_cfg(fp, (connected_layer *)net.layers[i], net, i);
|
||||
else if(net.types[i] == CROP)
|
||||
|
Loading…
Reference in New Issue
Block a user