mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Better imagenet distributed training
This commit is contained in:
parent
aea3bceeb1
commit
79fffcce3c
250
src/cnn.c
250
src/cnn.c
@ -8,6 +8,7 @@
|
||||
#include "matrix.h"
|
||||
#include "utils.h"
|
||||
#include "mini_blas.h"
|
||||
#include "matrix.h"
|
||||
#include "server.h"
|
||||
|
||||
#include <time.h>
|
||||
@ -310,6 +311,44 @@ void train_asirra()
|
||||
}
|
||||
}
|
||||
|
||||
void draw_detection(image im, float *box, int side)
|
||||
{
|
||||
int j;
|
||||
int r, c;
|
||||
float amount[5];
|
||||
for(r = 0; r < side*side; ++r){
|
||||
for(j = 0; j < 5; ++j){
|
||||
if(box[r*5] > amount[j]) {
|
||||
amount[j] = box[r*5];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
float smallest = amount[0];
|
||||
for(j = 1; j < 5; ++j) if(amount[j] < smallest) smallest = amount[j];
|
||||
|
||||
for(r = 0; r < side; ++r){
|
||||
for(c = 0; c < side; ++c){
|
||||
j = (r*side + c) * 5;
|
||||
printf("Prob: %f\n", box[j]);
|
||||
if(box[j] >= smallest){
|
||||
int d = im.w/side;
|
||||
int y = r*d+box[j+1]*d;
|
||||
int x = c*d+box[j+2]*d;
|
||||
int h = box[j+3]*256;
|
||||
int w = box[j+4]*256;
|
||||
printf("%f %f %f %f\n", box[j+1], box[j+2], box[j+3], box[j+4]);
|
||||
printf("%d %d %d %d\n", x, y, w, h);
|
||||
printf("%d %d %d %d\n", x-w/2, y-h/2, x+w/2, y+h/2);
|
||||
draw_box(im, x-w/2, y-h/2, x+w/2, y+h/2);
|
||||
}
|
||||
}
|
||||
}
|
||||
show_image(im, "box");
|
||||
cvWaitKey(0);
|
||||
}
|
||||
|
||||
|
||||
void train_detection_net()
|
||||
{
|
||||
float avg_loss = 1;
|
||||
@ -317,8 +356,8 @@ void train_detection_net()
|
||||
network net = parse_network_cfg("cfg/detnet.cfg");
|
||||
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
||||
int imgs = 1000/net.batch+1;
|
||||
//srand(time(0));
|
||||
srand(23410);
|
||||
srand(time(0));
|
||||
//srand(23410);
|
||||
int i = 0;
|
||||
list *plist = get_paths("/home/pjreddie/data/imagenet/horse.txt");
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
@ -327,31 +366,10 @@ void train_detection_net()
|
||||
while(1){
|
||||
i += 1;
|
||||
time=clock();
|
||||
data train = load_data_detection_random(imgs*net.batch, paths, plist->size, 256, 256, 8, 8, 256);
|
||||
//translate_data_rows(train, -144);
|
||||
data train = load_data_detection_jitter_random(imgs*net.batch, paths, plist->size, 256, 256, 7, 7, 256);
|
||||
/*
|
||||
image im = float_to_image(256, 256, 3, train.X.vals[0]);
|
||||
float *truth = train.y.vals[0];
|
||||
int j;
|
||||
int r, c;
|
||||
for(r = 0; r < 8; ++r){
|
||||
for(c = 0; c < 8; ++c){
|
||||
j = (r*8 + c) * 5;
|
||||
if(truth[j]){
|
||||
int d = 256/8;
|
||||
int y = r*d+truth[j+1]*d;
|
||||
int x = c*d+truth[j+2]*d;
|
||||
int h = truth[j+3]*256;
|
||||
int w = truth[j+4]*256;
|
||||
printf("%f %f %f %f\n", truth[j+1], truth[j+2], truth[j+3], truth[j+4]);
|
||||
printf("%d %d %d %d\n", x, y, w, h);
|
||||
printf("%d %d %d %d\n", x-w/2, y-h/2, x+w/2, y+h/2);
|
||||
draw_box(im, x-w/2, y-h/2, x+w/2, y+h/2);
|
||||
}
|
||||
}
|
||||
}
|
||||
show_image(im, "box");
|
||||
cvWaitKey(0);
|
||||
image im = float_to_image(224, 224, 3, train.X.vals[0]);
|
||||
draw_detection(im, train.y.vals[0], 7);
|
||||
*/
|
||||
|
||||
normalize_data_rows(train);
|
||||
@ -362,12 +380,12 @@ void train_detection_net()
|
||||
avg_loss = avg_loss*.9 + loss*.1;
|
||||
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs*net.batch);
|
||||
#endif
|
||||
free_data(train);
|
||||
if(i%10==0){
|
||||
char buff[256];
|
||||
sprintf(buff, "/home/pjreddie/imagenet_backup/detnet_%d.cfg", i);
|
||||
save_network(net, buff);
|
||||
}
|
||||
free_data(train);
|
||||
}
|
||||
}
|
||||
|
||||
@ -375,36 +393,39 @@ void train_imagenet_distributed(char *address)
|
||||
{
|
||||
float avg_loss = 1;
|
||||
srand(time(0));
|
||||
network net = parse_network_cfg("cfg/alexnet.client");
|
||||
network net = parse_network_cfg("cfg/net.cfg");
|
||||
set_learning_network(&net, 0, 1, 0);
|
||||
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
||||
int imgs = 1000/net.batch+1;
|
||||
imgs = 1;
|
||||
int imgs = 1;
|
||||
int i = 0;
|
||||
char **labels = get_labels("/home/pjreddie/data/imagenet/cls.labels.list");
|
||||
list *plist = get_paths("/data/imagenet/cls.train.list");
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
printf("%d\n", plist->size);
|
||||
clock_t time;
|
||||
data train, buffer;
|
||||
pthread_t load_thread = load_data_random_thread(imgs*net.batch, paths, plist->size, labels, 1000, 224, 224, &buffer);
|
||||
while(1){
|
||||
i += 1;
|
||||
|
||||
time=clock();
|
||||
data train = load_data_random(imgs*net.batch, paths, plist->size, labels, 1000, 256, 256);
|
||||
//translate_data_rows(train, -144);
|
||||
client_update(net, address);
|
||||
printf("Updated: %lf seconds\n", sec(clock()-time));
|
||||
|
||||
time=clock();
|
||||
pthread_join(load_thread, 0);
|
||||
train = buffer;
|
||||
normalize_data_rows(train);
|
||||
load_thread = load_data_random_thread(imgs*net.batch, paths, plist->size, labels, 1000, 224, 224, &buffer);
|
||||
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
||||
time=clock();
|
||||
|
||||
#ifdef GPU
|
||||
float loss = train_network_data_gpu(net, train, imgs);
|
||||
client_update(net, address);
|
||||
avg_loss = avg_loss*.9 + loss*.1;
|
||||
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs*net.batch);
|
||||
#endif
|
||||
free_data(train);
|
||||
if(i%10==0){
|
||||
char buff[256];
|
||||
sprintf(buff, "/home/pjreddie/imagenet_backup/alexnet_%d.cfg", i);
|
||||
save_network(net, buff);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -413,7 +434,7 @@ void train_imagenet()
|
||||
float avg_loss = 1;
|
||||
//network net = parse_network_cfg("/home/pjreddie/imagenet_backup/alexnet_1270.cfg");
|
||||
srand(time(0));
|
||||
network net = parse_network_cfg("cfg/alexnet.cfg");
|
||||
network net = parse_network_cfg("cfg/net.cfg");
|
||||
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
||||
int imgs = 1000/net.batch+1;
|
||||
//imgs=1;
|
||||
@ -423,12 +444,17 @@ void train_imagenet()
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
printf("%d\n", plist->size);
|
||||
clock_t time;
|
||||
pthread_t load_thread;
|
||||
data train;
|
||||
data buffer;
|
||||
load_thread = load_data_random_thread(imgs*net.batch, paths, plist->size, labels, 1000, 224, 224, &buffer);
|
||||
while(1){
|
||||
i += 1;
|
||||
time=clock();
|
||||
data train = load_data_random(imgs*net.batch, paths, plist->size, labels, 1000, 256, 256);
|
||||
//translate_data_rows(train, -144);
|
||||
pthread_join(load_thread, 0);
|
||||
train = buffer;
|
||||
normalize_data_rows(train);
|
||||
load_thread = load_data_random_thread(imgs*net.batch, paths, plist->size, labels, 1000, 224, 224, &buffer);
|
||||
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
||||
time=clock();
|
||||
#ifdef GPU
|
||||
@ -460,51 +486,28 @@ void validate_imagenet(char *filename)
|
||||
|
||||
clock_t time;
|
||||
float avg_acc = 0;
|
||||
float avg_top5 = 0;
|
||||
int splits = 50;
|
||||
|
||||
for(i = 0; i < splits; ++i){
|
||||
time=clock();
|
||||
char **part = paths+(i*m/splits);
|
||||
int num = (i+1)*m/splits - i*m/splits;
|
||||
data val = load_data(part, num, labels, 1000, 256, 256);
|
||||
data val = load_data(part, num, labels, 1000, 224, 224);
|
||||
|
||||
normalize_data_rows(val);
|
||||
printf("Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time));
|
||||
time=clock();
|
||||
#ifdef GPU
|
||||
float acc = network_accuracy_gpu(net, val);
|
||||
avg_acc += acc;
|
||||
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, acc, avg_acc/(i+1), sec(clock()-time), val.X.rows);
|
||||
float *acc = network_accuracies_gpu(net, val);
|
||||
avg_acc += acc[0];
|
||||
avg_top5 += acc[1];
|
||||
printf("%d: top1: %f, top5: %f, %lf seconds, %d images\n", i, avg_acc/(i+1), avg_top5/(i+1), sec(clock()-time), val.X.rows);
|
||||
#endif
|
||||
free_data(val);
|
||||
}
|
||||
}
|
||||
|
||||
void draw_detection(image im, float *box)
|
||||
{
|
||||
int j;
|
||||
int r, c;
|
||||
for(r = 0; r < 8; ++r){
|
||||
for(c = 0; c < 8; ++c){
|
||||
j = (r*8 + c) * 5;
|
||||
printf("Prob: %f\n", box[j]);
|
||||
if(box[j] > .01){
|
||||
int d = 256/8;
|
||||
int y = r*d+box[j+1]*d;
|
||||
int x = c*d+box[j+2]*d;
|
||||
int h = box[j+3]*256;
|
||||
int w = box[j+4]*256;
|
||||
printf("%f %f %f %f\n", box[j+1], box[j+2], box[j+3], box[j+4]);
|
||||
printf("%d %d %d %d\n", x, y, w, h);
|
||||
printf("%d %d %d %d\n", x-w/2, y-h/2, x+w/2, y+h/2);
|
||||
draw_box(im, x-w/2, y-h/2, x+w/2, y+h/2);
|
||||
}
|
||||
}
|
||||
}
|
||||
show_image(im, "box");
|
||||
cvWaitKey(0);
|
||||
}
|
||||
|
||||
void test_detection()
|
||||
{
|
||||
network net = parse_network_cfg("cfg/detnet.test");
|
||||
@ -514,18 +517,50 @@ void test_detection()
|
||||
while(1){
|
||||
fgets(filename, 256, stdin);
|
||||
strtok(filename, "\n");
|
||||
image im = load_image_color(filename, 256, 256);
|
||||
image im = load_image_color(filename, 224, 224);
|
||||
z_normalize_image(im);
|
||||
printf("%d %d %d\n", im.h, im.w, im.c);
|
||||
float *X = im.data;
|
||||
time=clock();
|
||||
float *predictions = network_predict(net, X);
|
||||
printf("%s: Predicted in %f seconds.\n", filename, sec(clock()-time));
|
||||
draw_detection(im, predictions);
|
||||
draw_detection(im, predictions, 7);
|
||||
free_image(im);
|
||||
}
|
||||
}
|
||||
|
||||
void test_init(char *cfgfile)
|
||||
{
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
set_batch_network(&net, 1);
|
||||
srand(2222222);
|
||||
int i = 0;
|
||||
char *filename = "data/test.jpg";
|
||||
|
||||
image im = load_image_color(filename, 224, 224);
|
||||
z_normalize_image(im);
|
||||
float *X = im.data;
|
||||
forward_network(net, X, 0, 1);
|
||||
for(i = 0; i < net.n; ++i){
|
||||
if(net.types[i] == CONVOLUTIONAL){
|
||||
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
|
||||
image output = get_convolutional_image(layer);
|
||||
int size = output.h*output.w*output.c;
|
||||
float v = variance_array(layer.output, size);
|
||||
float m = mean_array(layer.output, size);
|
||||
printf("%d: Convolutional, mean: %f, variance %f\n", i, m, v);
|
||||
}
|
||||
else if(net.types[i] == CONNECTED){
|
||||
connected_layer layer = *(connected_layer *)net.layers[i];
|
||||
int size = layer.outputs;
|
||||
float v = variance_array(layer.output, size);
|
||||
float m = mean_array(layer.output, size);
|
||||
printf("%d: Connected, mean: %f, variance %f\n", i, m, v);
|
||||
}
|
||||
}
|
||||
free_image(im);
|
||||
}
|
||||
|
||||
void test_imagenet()
|
||||
{
|
||||
network net = parse_network_cfg("cfg/imagenet_test.cfg");
|
||||
@ -633,14 +668,14 @@ void test_nist_single()
|
||||
|
||||
}
|
||||
|
||||
void test_nist()
|
||||
void test_nist(char *path)
|
||||
{
|
||||
srand(222222);
|
||||
network net = parse_network_cfg("cfg/nist_final.cfg");
|
||||
network net = parse_network_cfg(path);
|
||||
data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10);
|
||||
translate_data_rows(test, -144);
|
||||
normalize_data_rows(test);
|
||||
clock_t start = clock(), end;
|
||||
float test_acc = network_accuracy_multi(net, test,16);
|
||||
float test_acc = network_accuracy_gpu(net, test);
|
||||
end = clock();
|
||||
printf("Accuracy: %f, Time: %lf seconds\n", test_acc,(float)(end-start)/CLOCKS_PER_SEC);
|
||||
}
|
||||
@ -654,14 +689,14 @@ void train_nist()
|
||||
normalize_data_rows(train);
|
||||
normalize_data_rows(test);
|
||||
int count = 0;
|
||||
int iters = 50000/net.batch;
|
||||
iters = 1000/net.batch + 1;
|
||||
int iters = 60000/net.batch + 1;
|
||||
//iters = 6000/net.batch + 1;
|
||||
while(++count <= 2000){
|
||||
clock_t start = clock(), end;
|
||||
float loss = train_network_sgd_gpu(net, train, iters);
|
||||
end = clock();
|
||||
float test_acc = network_accuracy_gpu(net, test);
|
||||
//float test_acc = 0;
|
||||
float test_acc = 0;
|
||||
if(count%1 == 0) test_acc = network_accuracy_gpu(net, test);
|
||||
printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds\n", count, loss, test_acc,(float)(end-start)/CLOCKS_PER_SEC);
|
||||
}
|
||||
}
|
||||
@ -714,14 +749,14 @@ void test_ensemble()
|
||||
lr /= 2;
|
||||
}
|
||||
matrix partial = network_predict_data(net, test);
|
||||
float acc = matrix_accuracy(test.y, partial);
|
||||
float acc = matrix_topk_accuracy(test.y, partial,1);
|
||||
printf("Model Accuracy: %lf\n", acc);
|
||||
matrix_add_matrix(partial, prediction);
|
||||
acc = matrix_accuracy(test.y, prediction);
|
||||
acc = matrix_topk_accuracy(test.y, prediction,1);
|
||||
printf("Current Ensemble Accuracy: %lf\n", acc);
|
||||
free_matrix(partial);
|
||||
}
|
||||
float acc = matrix_accuracy(test.y, prediction);
|
||||
float acc = matrix_topk_accuracy(test.y, prediction,1);
|
||||
printf("Full Ensemble Accuracy: %lf\n", acc);
|
||||
}
|
||||
|
||||
@ -778,26 +813,26 @@ void test_split()
|
||||
}
|
||||
|
||||
/*
|
||||
void test_im2row()
|
||||
{
|
||||
int h = 20;
|
||||
int w = 20;
|
||||
int c = 3;
|
||||
int stride = 1;
|
||||
int size = 11;
|
||||
image test = make_random_image(h,w,c);
|
||||
int mc = 1;
|
||||
int mw = ((h-size)/stride+1)*((w-size)/stride+1);
|
||||
int mh = (size*size*c);
|
||||
int msize = mc*mw*mh;
|
||||
float *matrix = calloc(msize, sizeof(float));
|
||||
int i;
|
||||
for(i = 0; i < 1000; ++i){
|
||||
//im2col_cpu(test.data,1, c, h, w, size, stride, 0, matrix);
|
||||
//image render = float_to_image(mh, mw, mc, matrix);
|
||||
}
|
||||
void test_im2row()
|
||||
{
|
||||
int h = 20;
|
||||
int w = 20;
|
||||
int c = 3;
|
||||
int stride = 1;
|
||||
int size = 11;
|
||||
image test = make_random_image(h,w,c);
|
||||
int mc = 1;
|
||||
int mw = ((h-size)/stride+1)*((w-size)/stride+1);
|
||||
int mh = (size*size*c);
|
||||
int msize = mc*mw*mh;
|
||||
float *matrix = calloc(msize, sizeof(float));
|
||||
int i;
|
||||
for(i = 0; i < 1000; ++i){
|
||||
//im2col_cpu(test.data,1, c, h, w, size, stride, 0, matrix);
|
||||
//image render = float_to_image(mh, mw, mc, matrix);
|
||||
}
|
||||
*/
|
||||
}
|
||||
*/
|
||||
|
||||
void flip_network()
|
||||
{
|
||||
@ -897,7 +932,8 @@ void test_correct_alexnet()
|
||||
void run_server()
|
||||
{
|
||||
srand(time(0));
|
||||
network net = parse_network_cfg("cfg/nist.server");
|
||||
network net = parse_network_cfg("cfg/net.cfg");
|
||||
set_batch_network(&net, 1);
|
||||
server_update(net);
|
||||
}
|
||||
void test_client()
|
||||
@ -927,9 +963,9 @@ int main(int argc, char *argv[])
|
||||
return 0;
|
||||
}
|
||||
int index = find_int_arg(argc, argv, "-i");
|
||||
#ifdef GPU
|
||||
#ifdef GPU
|
||||
cl_setup(index);
|
||||
#endif
|
||||
#endif
|
||||
if(0==strcmp(argv[1], "train")) train_imagenet();
|
||||
else if(0==strcmp(argv[1], "detection")) train_detection_net();
|
||||
else if(0==strcmp(argv[1], "asirra")) train_asirra();
|
||||
@ -945,9 +981,11 @@ int main(int argc, char *argv[])
|
||||
fprintf(stderr, "usage: %s <function>\n", argv[0]);
|
||||
return 0;
|
||||
}
|
||||
else if(0==strcmp(argv[1], "client")) train_nist_distributed(argv[2]);
|
||||
else if(0==strcmp(argv[1], "client")) train_imagenet_distributed(argv[2]);
|
||||
else if(0==strcmp(argv[1], "init")) test_init(argv[2]);
|
||||
else if(0==strcmp(argv[1], "visualize")) test_visualize(argv[2]);
|
||||
else if(0==strcmp(argv[1], "valid")) validate_imagenet(argv[2]);
|
||||
else if(0==strcmp(argv[1], "testnist")) test_nist(argv[2]);
|
||||
fprintf(stderr, "Success!\n");
|
||||
return 0;
|
||||
}
|
||||
|
@ -25,8 +25,8 @@ connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVA
|
||||
|
||||
layer->weight_updates = calloc(inputs*outputs, sizeof(float));
|
||||
layer->weights = calloc(inputs*outputs, sizeof(float));
|
||||
float scale = 1./inputs;
|
||||
scale = .01;
|
||||
float scale = 1./sqrt(inputs);
|
||||
//scale = .01;
|
||||
for(i = 0; i < inputs*outputs; ++i){
|
||||
layer->weights[i] = scale*rand_normal();
|
||||
}
|
||||
@ -34,7 +34,7 @@ connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVA
|
||||
layer->bias_updates = calloc(outputs, sizeof(float));
|
||||
layer->biases = calloc(outputs, sizeof(float));
|
||||
for(i = 0; i < outputs; ++i){
|
||||
layer->biases[i] = .01;
|
||||
layer->biases[i] = scale;
|
||||
}
|
||||
|
||||
#ifdef GPU
|
||||
|
@ -62,12 +62,12 @@ convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, in
|
||||
|
||||
layer->biases = calloc(n, sizeof(float));
|
||||
layer->bias_updates = calloc(n, sizeof(float));
|
||||
float scale = 1./(size*size*c);
|
||||
scale = .01;
|
||||
float scale = 1./sqrt(size*size*c);
|
||||
//scale = .05;
|
||||
for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*rand_normal();
|
||||
for(i = 0; i < n; ++i){
|
||||
//layer->biases[i] = rand_normal()*scale + scale;
|
||||
layer->biases[i] = .01;
|
||||
layer->biases[i] = scale;
|
||||
}
|
||||
int out_h = convolutional_out_height(*layer);
|
||||
int out_w = convolutional_out_width(*layer);
|
||||
|
75
src/data.c
75
src/data.c
@ -19,7 +19,7 @@ list *get_paths(char *filename)
|
||||
return lines;
|
||||
}
|
||||
|
||||
void fill_truth_detection(char *path, float *truth, int height, int width, int num_height, int num_width, float scale)
|
||||
void fill_truth_detection(char *path, float *truth, int height, int width, int num_height, int num_width, float scale, int dx, int dy)
|
||||
{
|
||||
int box_height = height/num_height;
|
||||
int box_width = width/num_width;
|
||||
@ -29,8 +29,16 @@ void fill_truth_detection(char *path, float *truth, int height, int width, int n
|
||||
if(!file) file_error(labelpath);
|
||||
int x, y, h, w;
|
||||
while(fscanf(file, "%d %d %d %d", &x, &y, &w, &h) == 4){
|
||||
x -= dx;
|
||||
y -= dy;
|
||||
int i = x/box_width;
|
||||
int j = y/box_height;
|
||||
|
||||
if(i < 0) i = 0;
|
||||
if(i >= num_width) i = num_width-1;
|
||||
if(j < 0) j = 0;
|
||||
if(j >= num_height) j = num_height-1;
|
||||
|
||||
float dw = (float)(x%box_width)/box_height;
|
||||
float dh = (float)(y%box_width)/box_width;
|
||||
float sh = h/scale;
|
||||
@ -89,7 +97,7 @@ matrix load_labels_detection(char **paths, int n, int height, int width, int num
|
||||
matrix y = make_matrix(n, k);
|
||||
int i;
|
||||
for(i = 0; i < n; ++i){
|
||||
fill_truth_detection(paths[i], y.vals[i], height, width, num_height, num_width, scale);
|
||||
fill_truth_detection(paths[i], y.vals[i], height, width, num_height, num_width, scale,0,0);
|
||||
}
|
||||
return y;
|
||||
}
|
||||
@ -128,6 +136,33 @@ void free_data(data d)
|
||||
}
|
||||
}
|
||||
|
||||
data load_data_detection_jitter_random(int n, char **paths, int m, int h, int w, int nh, int nw, float scale)
|
||||
{
|
||||
char **random_paths = calloc(n, sizeof(char*));
|
||||
int i;
|
||||
for(i = 0; i < n; ++i){
|
||||
int index = rand()%m;
|
||||
random_paths[i] = paths[index];
|
||||
if(i == 0) printf("%s\n", paths[index]);
|
||||
}
|
||||
data d;
|
||||
d.shallow = 0;
|
||||
d.X = load_image_paths(random_paths, n, h, w);
|
||||
int k = nh*nw*5;
|
||||
d.y = make_matrix(n, k);
|
||||
for(i = 0; i < n; ++i){
|
||||
int dx = rand()%32;
|
||||
int dy = rand()%32;
|
||||
fill_truth_detection(random_paths[i], d.y.vals[i], 224, 224, nh, nw, scale, dx, dy);
|
||||
|
||||
image a = float_to_image(h, w, 3, d.X.vals[i]);
|
||||
jitter_image(a,224,224,dy,dx);
|
||||
}
|
||||
free(random_paths);
|
||||
return d;
|
||||
}
|
||||
|
||||
|
||||
data load_data_detection_random(int n, char **paths, int m, int h, int w, int nh, int nw, float scale)
|
||||
{
|
||||
char **random_paths = calloc(n, sizeof(char*));
|
||||
@ -168,6 +203,42 @@ data load_data_random(int n, char **paths, int m, char **labels, int k, int h, i
|
||||
return d;
|
||||
}
|
||||
|
||||
struct load_args{
|
||||
int n;
|
||||
char **paths;
|
||||
int m;
|
||||
char **labels;
|
||||
int k;
|
||||
int h;
|
||||
int w;
|
||||
data *d;
|
||||
};
|
||||
|
||||
void *load_in_thread(void *ptr)
|
||||
{
|
||||
struct load_args a = *(struct load_args*)ptr;
|
||||
*a.d = load_data_random(a.n, a.paths, a.m, a.labels, a.k, a.h, a.w);
|
||||
return 0;
|
||||
}
|
||||
|
||||
pthread_t load_data_random_thread(int n, char **paths, int m, char **labels, int k, int h, int w, data *d)
|
||||
{
|
||||
pthread_t thread;
|
||||
struct load_args *args = calloc(1, sizeof(struct load_args));
|
||||
args->n = n;
|
||||
args->paths = paths;
|
||||
args->m = m;
|
||||
args->labels = labels;
|
||||
args->k = k;
|
||||
args->h = h;
|
||||
args->w = w;
|
||||
args->d = d;
|
||||
if(pthread_create(&thread, 0, load_in_thread, args)) {
|
||||
error("Thread creation failed");
|
||||
}
|
||||
return thread;
|
||||
}
|
||||
|
||||
data load_categorical_data_csv(char *filename, int target, int k)
|
||||
{
|
||||
data d;
|
||||
|
@ -1,5 +1,6 @@
|
||||
#ifndef DATA_H
|
||||
#define DATA_H
|
||||
#include <pthread.h>
|
||||
|
||||
#include "matrix.h"
|
||||
#include "list.h"
|
||||
@ -13,8 +14,10 @@ typedef struct{
|
||||
|
||||
void free_data(data d);
|
||||
data load_data(char **paths, int n, char **labels, int k, int h, int w);
|
||||
pthread_t load_data_random_thread(int n, char **paths, int m, char **labels, int k, int h, int w, data *d);
|
||||
data load_data_random(int n, char **paths, int m, char **labels, int k, int h, int w);
|
||||
data load_data_detection_random(int n, char **paths, int m, int h, int w, int nh, int nw, float scale);
|
||||
data load_data_detection_jitter_random(int n, char **paths, int m, int h, int w, int nh, int nw, float scale);
|
||||
data load_data_image_pathfile(char *filename, char **labels, int k, int h, int w);
|
||||
data load_cifar10_data(char *filename);
|
||||
data load_all_cifar10();
|
||||
|
24
src/image.c
24
src/image.c
@ -7,6 +7,16 @@ int windows = 0;
|
||||
void draw_box(image a, int x1, int y1, int x2, int y2)
|
||||
{
|
||||
int i, c;
|
||||
if(x1 < 0) x1 = 0;
|
||||
if(x1 >= a.w) x1 = a.w-1;
|
||||
if(x2 < 0) x2 = 0;
|
||||
if(x2 >= a.w) x2 = a.w-1;
|
||||
|
||||
if(y1 < 0) y1 = 0;
|
||||
if(y1 >= a.h) y1 = a.h-1;
|
||||
if(y2 < 0) y2 = 0;
|
||||
if(y2 >= a.h) y2 = a.h-1;
|
||||
|
||||
for(c = 0; c < a.c; ++c){
|
||||
for(i = x1; i < x2; ++i){
|
||||
a.data[i + y1*a.w + c*a.w*a.h] = (c==0)?1:-1;
|
||||
@ -21,6 +31,20 @@ void draw_box(image a, int x1, int y1, int x2, int y2)
|
||||
}
|
||||
}
|
||||
|
||||
void jitter_image(image a, int h, int w, int dh, int dw)
|
||||
{
|
||||
int i,j,k;
|
||||
for(k = 0; k < a.c; ++k){
|
||||
for(i = 0; i < h; ++i){
|
||||
for(j = 0; j < w; ++j){
|
||||
int src = j + dw + (i+dh)*a.w + k*a.w*a.h;
|
||||
int dst = j + i*w + k*w*h;
|
||||
a.data[dst] = a.data[src];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
image image_distance(image a, image b)
|
||||
{
|
||||
int i,j;
|
||||
|
@ -11,6 +11,7 @@ typedef struct {
|
||||
float *data;
|
||||
} image;
|
||||
|
||||
void jitter_image(image a, int h, int w, int dh, int dw);
|
||||
void draw_box(image a, int x1, int y1, int x2, int y2);
|
||||
image image_distance(image a, image b);
|
||||
void scale_image(image m, float s);
|
||||
|
56
src/matrix.c
56
src/matrix.c
@ -13,16 +13,24 @@ void free_matrix(matrix m)
|
||||
free(m.vals);
|
||||
}
|
||||
|
||||
float matrix_accuracy(matrix truth, matrix guess)
|
||||
float matrix_topk_accuracy(matrix truth, matrix guess, int k)
|
||||
{
|
||||
int k = truth.cols;
|
||||
int i;
|
||||
int count = 0;
|
||||
int *indexes = calloc(k, sizeof(int));
|
||||
int n = truth.cols;
|
||||
int i,j;
|
||||
int correct = 0;
|
||||
for(i = 0; i < truth.rows; ++i){
|
||||
int class = max_index(guess.vals[i], k);
|
||||
if(truth.vals[i][class]) ++count;
|
||||
top_k(guess.vals[i], n, k, indexes);
|
||||
for(j = 0; j < k; ++j){
|
||||
int class = indexes[j];
|
||||
if(truth.vals[i][class]){
|
||||
++correct;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return (float)count/truth.rows;
|
||||
free(indexes);
|
||||
return (float)correct/truth.rows;
|
||||
}
|
||||
|
||||
void matrix_add_matrix(matrix from, matrix to)
|
||||
@ -80,30 +88,30 @@ float *pop_column(matrix *m, int c)
|
||||
|
||||
matrix csv_to_matrix(char *filename)
|
||||
{
|
||||
FILE *fp = fopen(filename, "r");
|
||||
if(!fp) file_error(filename);
|
||||
FILE *fp = fopen(filename, "r");
|
||||
if(!fp) file_error(filename);
|
||||
|
||||
matrix m;
|
||||
m.cols = -1;
|
||||
|
||||
char *line;
|
||||
char *line;
|
||||
|
||||
int n = 0;
|
||||
int size = 1024;
|
||||
m.vals = calloc(size, sizeof(float*));
|
||||
while((line = fgetl(fp))){
|
||||
int n = 0;
|
||||
int size = 1024;
|
||||
m.vals = calloc(size, sizeof(float*));
|
||||
while((line = fgetl(fp))){
|
||||
if(m.cols == -1) m.cols = count_fields(line);
|
||||
if(n == size){
|
||||
size *= 2;
|
||||
m.vals = realloc(m.vals, size*sizeof(float*));
|
||||
}
|
||||
m.vals[n] = parse_fields(line, m.cols);
|
||||
free(line);
|
||||
++n;
|
||||
}
|
||||
m.vals = realloc(m.vals, n*sizeof(float*));
|
||||
if(n == size){
|
||||
size *= 2;
|
||||
m.vals = realloc(m.vals, size*sizeof(float*));
|
||||
}
|
||||
m.vals[n] = parse_fields(line, m.cols);
|
||||
free(line);
|
||||
++n;
|
||||
}
|
||||
m.vals = realloc(m.vals, n*sizeof(float*));
|
||||
m.rows = n;
|
||||
return m;
|
||||
return m;
|
||||
}
|
||||
|
||||
void print_matrix(matrix m)
|
||||
|
@ -11,7 +11,7 @@ void print_matrix(matrix m);
|
||||
|
||||
matrix csv_to_matrix(char *filename);
|
||||
matrix hold_out_matrix(matrix *m, int n);
|
||||
float matrix_accuracy(matrix truth, matrix guess);
|
||||
float matrix_topk_accuracy(matrix truth, matrix guess, int k);
|
||||
void matrix_add_matrix(matrix from, matrix to);
|
||||
|
||||
float *pop_column(matrix *m, int c);
|
||||
|
@ -323,6 +323,65 @@ void train_network(network net, data d)
|
||||
fprintf(stderr, "Accuracy: %f\n", (float)correct/d.X.rows);
|
||||
}
|
||||
|
||||
void set_learning_network(network *net, float rate, float momentum, float decay)
|
||||
{
|
||||
int i;
|
||||
net->learning_rate=rate;
|
||||
net->momentum = momentum;
|
||||
net->decay = decay;
|
||||
for(i = 0; i < net->n; ++i){
|
||||
if(net->types[i] == CONVOLUTIONAL){
|
||||
convolutional_layer *layer = (convolutional_layer *)net->layers[i];
|
||||
layer->learning_rate=rate;
|
||||
layer->momentum = momentum;
|
||||
layer->decay = decay;
|
||||
}
|
||||
else if(net->types[i] == CONNECTED){
|
||||
connected_layer *layer = (connected_layer *)net->layers[i];
|
||||
layer->learning_rate=rate;
|
||||
layer->momentum = momentum;
|
||||
layer->decay = decay;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void set_batch_network(network *net, int b)
|
||||
{
|
||||
net->batch = b;
|
||||
int i;
|
||||
for(i = 0; i < net->n; ++i){
|
||||
if(net->types[i] == CONVOLUTIONAL){
|
||||
convolutional_layer *layer = (convolutional_layer *)net->layers[i];
|
||||
layer->batch = b;
|
||||
}
|
||||
else if(net->types[i] == MAXPOOL){
|
||||
maxpool_layer *layer = (maxpool_layer *)net->layers[i];
|
||||
layer->batch = b;
|
||||
}
|
||||
else if(net->types[i] == CONNECTED){
|
||||
connected_layer *layer = (connected_layer *)net->layers[i];
|
||||
layer->batch = b;
|
||||
} else if(net->types[i] == DROPOUT){
|
||||
dropout_layer *layer = (dropout_layer *) net->layers[i];
|
||||
layer->batch = b;
|
||||
}
|
||||
else if(net->types[i] == FREEWEIGHT){
|
||||
freeweight_layer *layer = (freeweight_layer *) net->layers[i];
|
||||
layer->batch = b;
|
||||
}
|
||||
else if(net->types[i] == SOFTMAX){
|
||||
softmax_layer *layer = (softmax_layer *)net->layers[i];
|
||||
layer->batch = b;
|
||||
}
|
||||
else if(net->types[i] == COST){
|
||||
cost_layer *layer = (cost_layer *)net->layers[i];
|
||||
layer->batch = b;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int get_network_input_size_layer(network net, int i)
|
||||
{
|
||||
if(net.types[i] == CONVOLUTIONAL){
|
||||
@ -586,15 +645,26 @@ void print_network(network net)
|
||||
float network_accuracy(network net, data d)
|
||||
{
|
||||
matrix guess = network_predict_data(net, d);
|
||||
float acc = matrix_accuracy(d.y, guess);
|
||||
float acc = matrix_topk_accuracy(d.y, guess,1);
|
||||
free_matrix(guess);
|
||||
return acc;
|
||||
}
|
||||
|
||||
float *network_accuracies(network net, data d)
|
||||
{
|
||||
static float acc[2];
|
||||
matrix guess = network_predict_data(net, d);
|
||||
acc[0] = matrix_topk_accuracy(d.y, guess,1);
|
||||
acc[1] = matrix_topk_accuracy(d.y, guess,5);
|
||||
free_matrix(guess);
|
||||
return acc;
|
||||
}
|
||||
|
||||
|
||||
float network_accuracy_multi(network net, data d, int n)
|
||||
{
|
||||
matrix guess = network_predict_data_multi(net, d, n);
|
||||
float acc = matrix_accuracy(d.y, guess);
|
||||
float acc = matrix_topk_accuracy(d.y, guess,1);
|
||||
free_matrix(guess);
|
||||
return acc;
|
||||
}
|
||||
|
@ -35,17 +35,22 @@ typedef struct {
|
||||
#endif
|
||||
} network;
|
||||
|
||||
#ifdef GPU
|
||||
#ifndef GPU
|
||||
typedef int cl_mem;
|
||||
#endif
|
||||
|
||||
cl_mem get_network_output_cl_layer(network net, int i);
|
||||
cl_mem get_network_delta_cl_layer(network net, int i);
|
||||
void forward_network_gpu(network net, cl_mem input, cl_mem truth, int train);
|
||||
void backward_network_gpu(network net, cl_mem input);
|
||||
void update_network_gpu(network net);
|
||||
cl_mem get_network_output_cl_layer(network net, int i);
|
||||
cl_mem get_network_delta_cl_layer(network net, int i);
|
||||
float train_network_sgd_gpu(network net, data d, int n);
|
||||
float train_network_data_gpu(network net, data d, int n);
|
||||
float *network_predict_gpu(network net, float *input);
|
||||
float network_accuracy_gpu(network net, data d);
|
||||
#endif
|
||||
float *network_accuracies_gpu(network net, data d);
|
||||
|
||||
float *network_accuracies(network net, data d);
|
||||
|
||||
network make_network(int n, int batch);
|
||||
void forward_network(network net, float *input, float *truth, int train);
|
||||
@ -72,6 +77,8 @@ int get_predicted_class_network(network net);
|
||||
void print_network(network net);
|
||||
void visualize_network(network net);
|
||||
int resize_network(network net, int h, int w, int c);
|
||||
void set_batch_network(network *net, int b);
|
||||
void set_learning_network(network *net, float rate, float momentum, float decay);
|
||||
int get_network_input_size(network net);
|
||||
float get_network_cost(network net);
|
||||
|
||||
|
@ -182,22 +182,10 @@ float train_network_datum_gpu(network net, float *x, float *y)
|
||||
cl_write_array(*net.input_cl, x, x_size);
|
||||
cl_write_array(*net.truth_cl, y, y_size);
|
||||
}
|
||||
//printf("trans %f\n", sec(clock()-time));
|
||||
//time = clock();
|
||||
|
||||
forward_network_gpu(net, *net.input_cl, *net.truth_cl, 1);
|
||||
|
||||
//printf("forw %f\n", sec(clock()-time));
|
||||
//time = clock();
|
||||
backward_network_gpu(net, *net.input_cl);
|
||||
//printf("back %f\n", sec(clock()-time));
|
||||
//time = clock();
|
||||
|
||||
update_network_gpu(net);
|
||||
float error = get_network_cost(net);
|
||||
|
||||
//printf("updt %f\n", sec(clock()-time));
|
||||
//time = clock();
|
||||
return error;
|
||||
}
|
||||
|
||||
@ -302,11 +290,29 @@ matrix network_predict_data_gpu(network net, data test)
|
||||
float network_accuracy_gpu(network net, data d)
|
||||
{
|
||||
matrix guess = network_predict_data_gpu(net, d);
|
||||
float acc = matrix_accuracy(d.y, guess);
|
||||
float acc = matrix_topk_accuracy(d.y, guess,1);
|
||||
free_matrix(guess);
|
||||
return acc;
|
||||
}
|
||||
|
||||
float *network_accuracies_gpu(network net, data d)
|
||||
{
|
||||
static float acc[2];
|
||||
matrix guess = network_predict_data_gpu(net, d);
|
||||
acc[0] = matrix_topk_accuracy(d.y, guess,1);
|
||||
acc[1] = matrix_topk_accuracy(d.y, guess,5);
|
||||
free_matrix(guess);
|
||||
return acc;
|
||||
}
|
||||
|
||||
|
||||
#else
|
||||
void forward_network_gpu(network net, cl_mem input, cl_mem truth, int train){}
|
||||
void backward_network_gpu(network net, cl_mem input){}
|
||||
void update_network_gpu(network net){}
|
||||
float train_network_sgd_gpu(network net, data d, int n){return 0;}
|
||||
float train_network_data_gpu(network net, data d, int n){return 0;}
|
||||
float *network_predict_gpu(network net, float *input){return 0;}
|
||||
float network_accuracy_gpu(network net, data d){return 0;}
|
||||
|
||||
#endif
|
||||
|
19
src/server.c
19
src/server.c
@ -15,7 +15,7 @@
|
||||
#include "connected_layer.h"
|
||||
#include "convolutional_layer.h"
|
||||
|
||||
#define SERVER_PORT 9876
|
||||
#define SERVER_PORT 9423
|
||||
#define STR(x) #x
|
||||
|
||||
int socket_setup(int server)
|
||||
@ -46,7 +46,7 @@ int socket_setup(int server)
|
||||
|
||||
typedef struct{
|
||||
int fd;
|
||||
int *counter;
|
||||
int counter;
|
||||
network net;
|
||||
} connection_info;
|
||||
|
||||
@ -85,6 +85,11 @@ void handle_connection(void *pointer)
|
||||
connection_info info = *(connection_info *) pointer;
|
||||
free(pointer);
|
||||
//printf("New Connection\n");
|
||||
if(info.counter%100==0){
|
||||
char buff[256];
|
||||
sprintf(buff, "/home/pjreddie/net_%d.part", info.counter);
|
||||
save_network(info.net, buff);
|
||||
}
|
||||
int fd = info.fd;
|
||||
network net = info.net;
|
||||
int i;
|
||||
@ -134,7 +139,7 @@ void server_update(network net)
|
||||
while(1){
|
||||
connection_info *info = calloc(1, sizeof(connection_info));
|
||||
info->net = net;
|
||||
info->counter = &counter;
|
||||
info->counter = counter;
|
||||
pthread_t worker;
|
||||
int connection = accept(fd, (struct sockaddr *) &client, &client_size);
|
||||
if(!t) t=time(0);
|
||||
@ -142,10 +147,8 @@ void server_update(network net)
|
||||
pthread_create(&worker, NULL, (void *) &handle_connection, info);
|
||||
++counter;
|
||||
printf("%d\n", counter);
|
||||
if(counter == 1024) break;
|
||||
if(counter%1000==0) save_network(net, "cfg/nist.part");
|
||||
//if(counter == 1024) break;
|
||||
}
|
||||
printf("1024 epochs: %d seconds\n", time(0)-t);
|
||||
close(fd);
|
||||
}
|
||||
|
||||
@ -204,7 +207,9 @@ void client_update(network net, char *address)
|
||||
int num = layer.n*layer.c*layer.size*layer.size;
|
||||
read_all(fd, (char*) layer.filters, num*sizeof(float));
|
||||
|
||||
#ifdef GPU
|
||||
push_convolutional_layer(layer);
|
||||
#endif
|
||||
}
|
||||
if(net.types[i] == CONNECTED){
|
||||
connected_layer layer = *(connected_layer *) net.layers[i];
|
||||
@ -212,7 +217,9 @@ void client_update(network net, char *address)
|
||||
read_all(fd, (char *)layer.biases, layer.outputs*sizeof(float));
|
||||
read_all(fd, (char *)layer.weights, layer.outputs*layer.inputs*sizeof(float));
|
||||
|
||||
#ifdef GPU
|
||||
push_connected_layer(layer);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
//printf("Updated\n");
|
||||
|
55
src/utils.c
55
src/utils.c
@ -30,19 +30,16 @@ float sec(clock_t clocks)
|
||||
void top_k(float *a, int n, int k, int *index)
|
||||
{
|
||||
int i,j;
|
||||
float thresh = FLT_MAX;
|
||||
for(i = 0; i < k; ++i){
|
||||
float max = -FLT_MAX;
|
||||
int max_i = -1;
|
||||
for(j = 0; j < n; ++j){
|
||||
float val = a[j];
|
||||
if(val > max && val < thresh){
|
||||
max = val;
|
||||
max_i = j;
|
||||
for(j = 0; j < k; ++j) index[j] = 0;
|
||||
for(i = 0; i < n; ++i){
|
||||
int curr = i;
|
||||
for(j = 0; j < k; ++j){
|
||||
if(a[curr] > a[index[j]]){
|
||||
int swap = curr;
|
||||
curr = index[j];
|
||||
index[j] = swap;
|
||||
}
|
||||
}
|
||||
index[i] = max_i;
|
||||
thresh = max;
|
||||
}
|
||||
}
|
||||
|
||||
@ -260,14 +257,40 @@ int max_index(float *a, int n)
|
||||
return max_i;
|
||||
}
|
||||
|
||||
// From http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
||||
#define TWO_PI 6.2831853071795864769252866
|
||||
float rand_normal()
|
||||
{
|
||||
int n = 12;
|
||||
int i;
|
||||
float sum= 0;
|
||||
for(i = 0; i < n; ++i) sum += (float)rand()/RAND_MAX;
|
||||
return sum-n/2.;
|
||||
static int haveSpare = 0;
|
||||
static double rand1, rand2;
|
||||
|
||||
if(haveSpare)
|
||||
{
|
||||
haveSpare = 0;
|
||||
return sqrt(rand1) * sin(rand2);
|
||||
}
|
||||
|
||||
haveSpare = 1;
|
||||
|
||||
rand1 = rand() / ((double) RAND_MAX);
|
||||
if(rand1 < 1e-100) rand1 = 1e-100;
|
||||
rand1 = -2 * log(rand1);
|
||||
rand2 = (rand() / ((double) RAND_MAX)) * TWO_PI;
|
||||
|
||||
return sqrt(rand1) * cos(rand2);
|
||||
}
|
||||
|
||||
/*
|
||||
float rand_normal()
|
||||
{
|
||||
int n = 12;
|
||||
int i;
|
||||
float sum= 0;
|
||||
for(i = 0; i < n; ++i) sum += (float)rand()/RAND_MAX;
|
||||
return sum-n/2.;
|
||||
}
|
||||
*/
|
||||
|
||||
float rand_uniform()
|
||||
{
|
||||
return (float)rand()/RAND_MAX;
|
||||
|
Loading…
Reference in New Issue
Block a user