mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
304 lines
8.9 KiB
C
304 lines
8.9 KiB
C
|
#include <stdio.h>
|
||
|
|
||
|
#include "network.h"
|
||
|
#include "detection_layer.h"
|
||
|
#include "cost_layer.h"
|
||
|
#include "utils.h"
|
||
|
#include "parser.h"
|
||
|
#include "box.h"
|
||
|
|
||
|
void train_compare(char *cfgfile, char *weightfile)
|
||
|
{
|
||
|
data_seed = time(0);
|
||
|
srand(time(0));
|
||
|
float avg_loss = -1;
|
||
|
char *base = basecfg(cfgfile);
|
||
|
char *backup_directory = "/home/pjreddie/backup/";
|
||
|
printf("%s\n", base);
|
||
|
network net = parse_network_cfg(cfgfile);
|
||
|
if(weightfile){
|
||
|
load_weights(&net, weightfile);
|
||
|
}
|
||
|
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
||
|
int imgs = 1024;
|
||
|
list *plist = get_paths("data/compare.train.list");
|
||
|
char **paths = (char **)list_to_array(plist);
|
||
|
int N = plist->size;
|
||
|
printf("%d\n", N);
|
||
|
clock_t time;
|
||
|
pthread_t load_thread;
|
||
|
data train;
|
||
|
data buffer;
|
||
|
|
||
|
load_args args = {0};
|
||
|
args.w = net.w;
|
||
|
args.h = net.h;
|
||
|
args.paths = paths;
|
||
|
args.classes = 20;
|
||
|
args.n = imgs;
|
||
|
args.m = N;
|
||
|
args.d = &buffer;
|
||
|
args.type = COMPARE_DATA;
|
||
|
|
||
|
load_thread = load_data_in_thread(args);
|
||
|
int epoch = *net.seen/N;
|
||
|
int i = 0;
|
||
|
while(1){
|
||
|
++i;
|
||
|
time=clock();
|
||
|
pthread_join(load_thread, 0);
|
||
|
train = buffer;
|
||
|
|
||
|
load_thread = load_data_in_thread(args);
|
||
|
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
||
|
time=clock();
|
||
|
float loss = train_network(net, train);
|
||
|
if(avg_loss == -1) avg_loss = loss;
|
||
|
avg_loss = avg_loss*.9 + loss*.1;
|
||
|
printf("%.3f: %f, %f avg, %lf seconds, %d images\n", (float)*net.seen/N, loss, avg_loss, sec(clock()-time), *net.seen);
|
||
|
free_data(train);
|
||
|
if(i%100 == 0){
|
||
|
char buff[256];
|
||
|
sprintf(buff, "%s/%s_%d_minor_%d.weights",backup_directory,base, epoch, i);
|
||
|
save_weights(net, buff);
|
||
|
}
|
||
|
if(*net.seen/N > epoch){
|
||
|
epoch = *net.seen/N;
|
||
|
i = 0;
|
||
|
char buff[256];
|
||
|
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
|
||
|
save_weights(net, buff);
|
||
|
if(epoch%22 == 0) net.learning_rate *= .1;
|
||
|
}
|
||
|
}
|
||
|
pthread_join(load_thread, 0);
|
||
|
free_data(buffer);
|
||
|
free_network(net);
|
||
|
free_ptrs((void**)paths, plist->size);
|
||
|
free_list(plist);
|
||
|
free(base);
|
||
|
}
|
||
|
|
||
|
void validate_compare(char *filename, char *weightfile)
|
||
|
{
|
||
|
int i = 0;
|
||
|
network net = parse_network_cfg(filename);
|
||
|
if(weightfile){
|
||
|
load_weights(&net, weightfile);
|
||
|
}
|
||
|
srand(time(0));
|
||
|
|
||
|
list *plist = get_paths("data/compare.val.list");
|
||
|
//list *plist = get_paths("data/compare.val.old");
|
||
|
char **paths = (char **)list_to_array(plist);
|
||
|
int N = plist->size/2;
|
||
|
free_list(plist);
|
||
|
|
||
|
clock_t time;
|
||
|
int correct = 0;
|
||
|
int total = 0;
|
||
|
int splits = 10;
|
||
|
int num = (i+1)*N/splits - i*N/splits;
|
||
|
|
||
|
data val, buffer;
|
||
|
|
||
|
load_args args = {0};
|
||
|
args.w = net.w;
|
||
|
args.h = net.h;
|
||
|
args.paths = paths;
|
||
|
args.classes = 20;
|
||
|
args.n = num;
|
||
|
args.m = 0;
|
||
|
args.d = &buffer;
|
||
|
args.type = COMPARE_DATA;
|
||
|
|
||
|
pthread_t load_thread = load_data_in_thread(args);
|
||
|
for(i = 1; i <= splits; ++i){
|
||
|
time=clock();
|
||
|
|
||
|
pthread_join(load_thread, 0);
|
||
|
val = buffer;
|
||
|
|
||
|
num = (i+1)*N/splits - i*N/splits;
|
||
|
char **part = paths+(i*N/splits);
|
||
|
if(i != splits){
|
||
|
args.paths = part;
|
||
|
load_thread = load_data_in_thread(args);
|
||
|
}
|
||
|
printf("Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time));
|
||
|
|
||
|
time=clock();
|
||
|
matrix pred = network_predict_data(net, val);
|
||
|
int j,k;
|
||
|
for(j = 0; j < val.y.rows; ++j){
|
||
|
for(k = 0; k < 20; ++k){
|
||
|
if(val.y.vals[j][k*2] != val.y.vals[j][k*2+1]){
|
||
|
++total;
|
||
|
if((val.y.vals[j][k*2] < val.y.vals[j][k*2+1]) == (pred.vals[j][k*2] < pred.vals[j][k*2+1])){
|
||
|
++correct;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
free_matrix(pred);
|
||
|
printf("%d: Acc: %f, %lf seconds, %d images\n", i, (float)correct/total, sec(clock()-time), val.X.rows);
|
||
|
free_data(val);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
typedef struct {
|
||
|
network net;
|
||
|
char *filename;
|
||
|
int class;
|
||
|
float elo;
|
||
|
} sortable_bbox;
|
||
|
|
||
|
int total_compares = 0;
|
||
|
|
||
|
int elo_comparator(const void*a, const void *b)
|
||
|
{
|
||
|
sortable_bbox box1 = *(sortable_bbox*)a;
|
||
|
sortable_bbox box2 = *(sortable_bbox*)b;
|
||
|
if(box1.elo == box2.elo) return 0;
|
||
|
if(box1.elo > box2.elo) return -1;
|
||
|
return 1;
|
||
|
}
|
||
|
|
||
|
int bbox_comparator(const void *a, const void *b)
|
||
|
{
|
||
|
++total_compares;
|
||
|
sortable_bbox box1 = *(sortable_bbox*)a;
|
||
|
sortable_bbox box2 = *(sortable_bbox*)b;
|
||
|
network net = box1.net;
|
||
|
int class = box1.class;
|
||
|
|
||
|
image im1 = load_image_color(box1.filename, net.w, net.h);
|
||
|
image im2 = load_image_color(box2.filename, net.w, net.h);
|
||
|
float *X = calloc(net.w*net.h*net.c, sizeof(float));
|
||
|
memcpy(X, im1.data, im1.w*im1.h*im1.c);
|
||
|
memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c);
|
||
|
float *predictions = network_predict(net, X);
|
||
|
|
||
|
free_image(im1);
|
||
|
free_image(im2);
|
||
|
free(X);
|
||
|
if (predictions[class*2] > predictions[class*2+1]){
|
||
|
return 1;
|
||
|
}
|
||
|
return -1;
|
||
|
}
|
||
|
|
||
|
void bbox_fight(sortable_bbox *a, sortable_bbox *b)
|
||
|
{
|
||
|
int k = 32;
|
||
|
int result = bbox_comparator(a,b);
|
||
|
float EA = 1./(1+pow(10, (b->elo - a->elo)/400.));
|
||
|
float EB = 1./(1+pow(10, (a->elo - b->elo)/400.));
|
||
|
float SA = 1.*(result > 0);
|
||
|
float SB = 1.*(result < 0);
|
||
|
a->elo = a->elo + k*(SA - EA);
|
||
|
b->elo = b->elo + k*(SB - EB);
|
||
|
}
|
||
|
|
||
|
void SortMaster3000(char *filename, char *weightfile)
|
||
|
{
|
||
|
int i = 0;
|
||
|
network net = parse_network_cfg(filename);
|
||
|
if(weightfile){
|
||
|
load_weights(&net, weightfile);
|
||
|
}
|
||
|
srand(time(0));
|
||
|
set_batch_network(&net, 1);
|
||
|
|
||
|
list *plist = get_paths("data/compare.sort.list");
|
||
|
//list *plist = get_paths("data/compare.val.old");
|
||
|
char **paths = (char **)list_to_array(plist);
|
||
|
int N = plist->size;
|
||
|
free_list(plist);
|
||
|
sortable_bbox *boxes = calloc(N, sizeof(sortable_bbox));
|
||
|
printf("Sorting %d boxes...\n", N);
|
||
|
for(i = 0; i < N; ++i){
|
||
|
boxes[i].filename = paths[i];
|
||
|
boxes[i].net = net;
|
||
|
boxes[i].class = 7;
|
||
|
boxes[i].elo = 1500;
|
||
|
}
|
||
|
clock_t time=clock();
|
||
|
qsort(boxes, N, sizeof(sortable_bbox), bbox_comparator);
|
||
|
for(i = 0; i < N; ++i){
|
||
|
printf("%s\n", boxes[i].filename);
|
||
|
}
|
||
|
printf("Sorted in %d compares, %f secs\n", total_compares, sec(clock()-time));
|
||
|
}
|
||
|
|
||
|
void BattleRoyaleWithCheese(char *filename, char *weightfile)
|
||
|
{
|
||
|
int i = 0;
|
||
|
network net = parse_network_cfg(filename);
|
||
|
if(weightfile){
|
||
|
load_weights(&net, weightfile);
|
||
|
}
|
||
|
srand(time(0));
|
||
|
set_batch_network(&net, 1);
|
||
|
|
||
|
list *plist = get_paths("data/compare.sort.list");
|
||
|
//list *plist = get_paths("data/compare.val.old");
|
||
|
char **paths = (char **)list_to_array(plist);
|
||
|
int N = plist->size;
|
||
|
free_list(plist);
|
||
|
sortable_bbox *boxes = calloc(N, sizeof(sortable_bbox));
|
||
|
printf("Battling %d boxes...\n", N);
|
||
|
for(i = 0; i < N; ++i){
|
||
|
boxes[i].filename = paths[i];
|
||
|
boxes[i].net = net;
|
||
|
boxes[i].class = 7;
|
||
|
boxes[i].elo = 1500;
|
||
|
}
|
||
|
int round;
|
||
|
clock_t time=clock();
|
||
|
for(round = 1; round <= 40; ++round){
|
||
|
clock_t round_time=clock();
|
||
|
printf("Round: %d\n", round);
|
||
|
qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
|
||
|
sorta_shuffle(boxes, N, sizeof(sortable_bbox), 10);
|
||
|
for(i = 0; i < N/2; ++i){
|
||
|
bbox_fight(boxes+i*2, boxes+i*2+1);
|
||
|
}
|
||
|
if(round >= 4){
|
||
|
qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
|
||
|
if(round == 4){
|
||
|
N = N/2;
|
||
|
}else{
|
||
|
N = (N*9/10)/2*2;
|
||
|
}
|
||
|
}
|
||
|
printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N);
|
||
|
}
|
||
|
qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
|
||
|
for(i = 0; i < N; ++i){
|
||
|
printf("%s %f\n", boxes[i].filename, boxes[i].elo);
|
||
|
}
|
||
|
printf("Tournament in %d compares, %f secs\n", total_compares, sec(clock()-time));
|
||
|
}
|
||
|
|
||
|
void run_compare(int argc, char **argv)
|
||
|
{
|
||
|
if(argc < 4){
|
||
|
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
char *cfg = argv[3];
|
||
|
char *weights = (argc > 4) ? argv[4] : 0;
|
||
|
//char *filename = (argc > 5) ? argv[5]: 0;
|
||
|
if(0==strcmp(argv[2], "train")) train_compare(cfg, weights);
|
||
|
else if(0==strcmp(argv[2], "valid")) validate_compare(cfg, weights);
|
||
|
else if(0==strcmp(argv[2], "sort")) SortMaster3000(cfg, weights);
|
||
|
else if(0==strcmp(argv[2], "battle")) BattleRoyaleWithCheese(cfg, weights);
|
||
|
/*
|
||
|
else if(0==strcmp(argv[2], "train")) train_coco(cfg, weights);
|
||
|
else if(0==strcmp(argv[2], "extract")) extract_boxes(cfg, weights);
|
||
|
else if(0==strcmp(argv[2], "valid")) validate_recall(cfg, weights);
|
||
|
*/
|
||
|
}
|