mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
this'll teach me to mess with maxpooling
This commit is contained in:
parent
e209b3bbbf
commit
b13f67bfdd
4
Makefile
4
Makefile
@ -57,8 +57,8 @@ CFLAGS+= -DCUDNN
|
||||
LDFLAGS+= -lcudnn
|
||||
endif
|
||||
|
||||
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o detection_layer.o route_layer.o upsample_layer.o box.o normalization_layer.o avgpool_layer.o layer.o local_layer.o shortcut_layer.o logistic_layer.o activation_layer.o rnn_layer.o gru_layer.o crnn_layer.o demo.o batchnorm_layer.o region_layer.o reorg_layer.o tree.o lstm_layer.o l2norm_layer.o yolo_layer.o
|
||||
EXECOBJA=captcha.o lsd.o super.o art.o tag.o cifar.o go.o rnn.o segmenter.o regressor.o classifier.o coco.o yolo.o detector.o nightmare.o darknet.o
|
||||
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o detection_layer.o route_layer.o upsample_layer.o box.o normalization_layer.o avgpool_layer.o layer.o local_layer.o shortcut_layer.o logistic_layer.o activation_layer.o rnn_layer.o gru_layer.o crnn_layer.o demo.o batchnorm_layer.o region_layer.o reorg_layer.o tree.o lstm_layer.o l2norm_layer.o yolo_layer.o iseg_layer.o
|
||||
EXECOBJA=captcha.o lsd.o super.o art.o tag.o cifar.o go.o rnn.o segmenter.o regressor.o classifier.o coco.o yolo.o detector.o nightmare.o instance-segmenter.o darknet.o
|
||||
ifeq ($(GPU), 1)
|
||||
LDFLAGS+= -lstdc++
|
||||
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o avgpool_layer_kernels.o
|
||||
|
@ -24,7 +24,6 @@ void demo_art(char *cfgfile, char *weightfile, int cam_index)
|
||||
while(1){
|
||||
image in = get_image_from_stream(cap);
|
||||
image in_s = resize_image(in, net->w, net->h);
|
||||
show_image(in, window);
|
||||
|
||||
float *p = network_predict(net, in_s.data);
|
||||
|
||||
@ -45,10 +44,9 @@ void demo_art(char *cfgfile, char *weightfile, int cam_index)
|
||||
}
|
||||
printf("]\n");
|
||||
|
||||
show_image(in, window, 1);
|
||||
free_image(in_s);
|
||||
free_image(in);
|
||||
|
||||
cvWaitKey(1);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
@ -645,6 +645,45 @@ void label_classifier(char *datacfg, char *filename, char *weightfile)
|
||||
}
|
||||
}
|
||||
|
||||
void csv_classifier(char *datacfg, char *cfgfile, char *weightfile)
|
||||
{
|
||||
int i,j;
|
||||
network *net = load_network(cfgfile, weightfile, 0);
|
||||
srand(time(0));
|
||||
|
||||
list *options = read_data_cfg(datacfg);
|
||||
|
||||
char *test_list = option_find_str(options, "test", "data/test.list");
|
||||
int top = option_find_int(options, "top", 1);
|
||||
|
||||
list *plist = get_paths(test_list);
|
||||
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
int m = plist->size;
|
||||
free_list(plist);
|
||||
int *indexes = calloc(top, sizeof(int));
|
||||
|
||||
for(i = 0; i < m; ++i){
|
||||
double time = what_time_is_it_now();
|
||||
char *path = paths[i];
|
||||
image im = load_image_color(path, 0, 0);
|
||||
image r = letterbox_image(im, net->w, net->h);
|
||||
float *predictions = network_predict(net, r.data);
|
||||
if(net->hierarchy) hierarchy_predictions(predictions, net->outputs, net->hierarchy, 1, 1);
|
||||
top_k(predictions, net->outputs, top, indexes);
|
||||
|
||||
printf("%s", path);
|
||||
for(j = 0; j < top; ++j){
|
||||
printf("\t%d", indexes[j]);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
free_image(im);
|
||||
free_image(r);
|
||||
|
||||
fprintf(stderr, "%lf seconds, %d images, %d total\n", what_time_is_it_now() - time, i+1, m);
|
||||
}
|
||||
}
|
||||
|
||||
void test_classifier(char *datacfg, char *cfgfile, char *weightfile, int target_layer)
|
||||
{
|
||||
@ -869,8 +908,7 @@ void threat_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_i
|
||||
}
|
||||
|
||||
if(1){
|
||||
show_image(out, "Threat");
|
||||
cvWaitKey(10);
|
||||
show_image(out, "Threat", 10);
|
||||
}
|
||||
free_image(in_s);
|
||||
free_image(in);
|
||||
@ -922,7 +960,6 @@ void gun_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_inde
|
||||
|
||||
image in = get_image_from_stream(cap);
|
||||
image in_s = resize_image(in, net->w, net->h);
|
||||
show_image(in, "Threat Detection");
|
||||
|
||||
float *predictions = network_predict(net, in_s.data);
|
||||
top_predictions(net, top, indexes);
|
||||
@ -947,11 +984,10 @@ void gun_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_inde
|
||||
}
|
||||
}
|
||||
|
||||
show_image(in, "Threat Detection", 10);
|
||||
free_image(in_s);
|
||||
free_image(in);
|
||||
|
||||
cvWaitKey(10);
|
||||
|
||||
gettimeofday(&tval_after, NULL);
|
||||
timersub(&tval_after, &tval_before, &tval_result);
|
||||
float curr = 1000000.f/((long int)tval_result.tv_usec);
|
||||
@ -1036,12 +1072,10 @@ void demo_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_ind
|
||||
free_image(label);
|
||||
}
|
||||
|
||||
show_image(in, base);
|
||||
show_image(in, base, 10);
|
||||
free_image(in_s);
|
||||
free_image(in);
|
||||
|
||||
cvWaitKey(10);
|
||||
|
||||
gettimeofday(&tval_after, NULL);
|
||||
timersub(&tval_after, &tval_before, &tval_result);
|
||||
float curr = 1000000.f/((long int)tval_result.tv_usec);
|
||||
@ -1080,6 +1114,7 @@ void run_classifier(int argc, char **argv)
|
||||
else if(0==strcmp(argv[2], "gun")) gun_classifier(data, cfg, weights, cam_index, filename);
|
||||
else if(0==strcmp(argv[2], "threat")) threat_classifier(data, cfg, weights, cam_index, filename);
|
||||
else if(0==strcmp(argv[2], "test")) test_classifier(data, cfg, weights, layer);
|
||||
else if(0==strcmp(argv[2], "csv")) csv_classifier(data, cfg, weights);
|
||||
else if(0==strcmp(argv[2], "label")) label_classifier(data, cfg, weights);
|
||||
else if(0==strcmp(argv[2], "valid")) validate_classifier_single(data, cfg, weights);
|
||||
else if(0==strcmp(argv[2], "validmulti")) validate_classifier_multi(data, cfg, weights);
|
||||
|
@ -325,14 +325,10 @@ void test_coco(char *cfgfile, char *weightfile, char *filename, float thresh)
|
||||
|
||||
draw_detections(im, dets, l.side*l.side*l.n, thresh, coco_classes, alphabet, 80);
|
||||
save_image(im, "prediction");
|
||||
show_image(im, "predictions");
|
||||
show_image(im, "predictions", 0);
|
||||
free_detections(dets, nboxes);
|
||||
free_image(im);
|
||||
free_image(sized);
|
||||
#ifdef OPENCV
|
||||
cvWaitKey(0);
|
||||
cvDestroyAllWindows();
|
||||
#endif
|
||||
if (filename) break;
|
||||
}
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ extern void run_nightmare(int argc, char **argv);
|
||||
extern void run_classifier(int argc, char **argv);
|
||||
extern void run_regressor(int argc, char **argv);
|
||||
extern void run_segmenter(int argc, char **argv);
|
||||
extern void run_isegmenter(int argc, char **argv);
|
||||
extern void run_char_rnn(int argc, char **argv);
|
||||
extern void run_tag(int argc, char **argv);
|
||||
extern void run_cifar(int argc, char **argv);
|
||||
@ -452,6 +453,8 @@ int main(int argc, char **argv)
|
||||
run_classifier(argc, argv);
|
||||
} else if (0 == strcmp(argv[1], "regressor")){
|
||||
run_regressor(argc, argv);
|
||||
} else if (0 == strcmp(argv[1], "isegmenter")){
|
||||
run_isegmenter(argc, argv);
|
||||
} else if (0 == strcmp(argv[1], "segmenter")){
|
||||
run_segmenter(argc, argv);
|
||||
} else if (0 == strcmp(argv[1], "art")){
|
||||
|
@ -613,9 +613,7 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam
|
||||
if(fullscreen){
|
||||
cvSetWindowProperty("predictions", CV_WND_PROP_FULLSCREEN, CV_WINDOW_FULLSCREEN);
|
||||
}
|
||||
show_image(im, "predictions");
|
||||
cvWaitKey(0);
|
||||
cvDestroyAllWindows();
|
||||
show_image(im, "predictions", 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
265
examples/instance-segmenter.c
Normal file
265
examples/instance-segmenter.c
Normal file
@ -0,0 +1,265 @@
|
||||
#include "darknet.h"
|
||||
#include <sys/time.h>
|
||||
#include <assert.h>
|
||||
|
||||
void train_isegmenter(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int display)
|
||||
{
|
||||
int i;
|
||||
|
||||
float avg_loss = -1;
|
||||
char *base = basecfg(cfgfile);
|
||||
printf("%s\n", base);
|
||||
printf("%d\n", ngpus);
|
||||
network **nets = calloc(ngpus, sizeof(network*));
|
||||
|
||||
srand(time(0));
|
||||
int seed = rand();
|
||||
for(i = 0; i < ngpus; ++i){
|
||||
srand(seed);
|
||||
#ifdef GPU
|
||||
cuda_set_device(gpus[i]);
|
||||
#endif
|
||||
nets[i] = load_network(cfgfile, weightfile, clear);
|
||||
nets[i]->learning_rate *= ngpus;
|
||||
}
|
||||
srand(time(0));
|
||||
network *net = nets[0];
|
||||
image pred = get_network_image(net);
|
||||
|
||||
int div = net->w/pred.w;
|
||||
assert(pred.w * div == net->w);
|
||||
assert(pred.h * div == net->h);
|
||||
|
||||
int imgs = net->batch * net->subdivisions * ngpus;
|
||||
|
||||
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
|
||||
list *options = read_data_cfg(datacfg);
|
||||
|
||||
char *backup_directory = option_find_str(options, "backup", "/backup/");
|
||||
char *train_list = option_find_str(options, "train", "data/train.list");
|
||||
|
||||
list *plist = get_paths(train_list);
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
printf("%d\n", plist->size);
|
||||
int N = plist->size;
|
||||
|
||||
load_args args = {0};
|
||||
args.w = net->w;
|
||||
args.h = net->h;
|
||||
args.threads = 32;
|
||||
args.scale = div;
|
||||
args.num_boxes = 90;
|
||||
|
||||
args.min = net->min_crop;
|
||||
args.max = net->max_crop;
|
||||
args.angle = net->angle;
|
||||
args.aspect = net->aspect;
|
||||
args.exposure = net->exposure;
|
||||
args.saturation = net->saturation;
|
||||
args.hue = net->hue;
|
||||
args.size = net->w;
|
||||
args.classes = 80;
|
||||
|
||||
args.paths = paths;
|
||||
args.n = imgs;
|
||||
args.m = N;
|
||||
args.type = ISEG_DATA;
|
||||
|
||||
data train;
|
||||
data buffer;
|
||||
pthread_t load_thread;
|
||||
args.d = &buffer;
|
||||
load_thread = load_data(args);
|
||||
|
||||
int epoch = (*net->seen)/N;
|
||||
while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
|
||||
double time = what_time_is_it_now();
|
||||
|
||||
pthread_join(load_thread, 0);
|
||||
train = buffer;
|
||||
load_thread = load_data(args);
|
||||
|
||||
printf("Loaded: %lf seconds\n", what_time_is_it_now()-time);
|
||||
time = what_time_is_it_now();
|
||||
|
||||
float loss = 0;
|
||||
#ifdef GPU
|
||||
if(ngpus == 1){
|
||||
loss = train_network(net, train);
|
||||
} else {
|
||||
loss = train_networks(nets, ngpus, train, 4);
|
||||
}
|
||||
#else
|
||||
loss = train_network(net, train);
|
||||
#endif
|
||||
if(display){
|
||||
image tr = float_to_image(net->w/div, net->h/div, 80, train.y.vals[net->batch*(net->subdivisions-1)]);
|
||||
image im = float_to_image(net->w, net->h, net->c, train.X.vals[net->batch*(net->subdivisions-1)]);
|
||||
pred.c = 80;
|
||||
image mask = mask_to_rgb(tr);
|
||||
image prmask = mask_to_rgb(pred);
|
||||
show_image(im, "input", 1);
|
||||
show_image(prmask, "pred", 1);
|
||||
show_image(mask, "truth", 100);
|
||||
free_image(mask);
|
||||
free_image(prmask);
|
||||
}
|
||||
if(avg_loss == -1) avg_loss = loss;
|
||||
avg_loss = avg_loss*.9 + loss*.1;
|
||||
printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), what_time_is_it_now()-time, *net->seen);
|
||||
free_data(train);
|
||||
if(*net->seen/N > epoch){
|
||||
epoch = *net->seen/N;
|
||||
char buff[256];
|
||||
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
|
||||
save_weights(net, buff);
|
||||
}
|
||||
if(get_current_batch(net)%100 == 0){
|
||||
char buff[256];
|
||||
sprintf(buff, "%s/%s.backup",backup_directory,base);
|
||||
save_weights(net, buff);
|
||||
}
|
||||
}
|
||||
char buff[256];
|
||||
sprintf(buff, "%s/%s.weights", backup_directory, base);
|
||||
save_weights(net, buff);
|
||||
|
||||
free_network(net);
|
||||
free_ptrs((void**)paths, plist->size);
|
||||
free_list(plist);
|
||||
free(base);
|
||||
}
|
||||
|
||||
void predict_isegmenter(char *datafile, char *cfg, char *weights, char *filename)
|
||||
{
|
||||
network *net = load_network(cfg, weights, 0);
|
||||
set_batch_network(net, 1);
|
||||
srand(2222222);
|
||||
|
||||
clock_t time;
|
||||
char buff[256];
|
||||
char *input = buff;
|
||||
while(1){
|
||||
if(filename){
|
||||
strncpy(input, filename, 256);
|
||||
}else{
|
||||
printf("Enter Image Path: ");
|
||||
fflush(stdout);
|
||||
input = fgets(input, 256, stdin);
|
||||
if(!input) return;
|
||||
strtok(input, "\n");
|
||||
}
|
||||
image im = load_image_color(input, 0, 0);
|
||||
image sized = letterbox_image(im, net->w, net->h);
|
||||
|
||||
float *X = sized.data;
|
||||
time=clock();
|
||||
float *predictions = network_predict(net, X);
|
||||
image pred = get_network_image(net);
|
||||
image prmask = mask_to_rgb(pred);
|
||||
printf("Predicted: %f\n", predictions[0]);
|
||||
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
||||
show_image(sized, "orig", 1);
|
||||
show_image(prmask, "pred", 0);
|
||||
free_image(im);
|
||||
free_image(sized);
|
||||
free_image(prmask);
|
||||
if (filename) break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void demo_isegmenter(char *datacfg, char *cfg, char *weights, int cam_index, const char *filename)
|
||||
{
|
||||
#ifdef OPENCV
|
||||
printf("Classifier Demo\n");
|
||||
network *net = load_network(cfg, weights, 0);
|
||||
set_batch_network(net, 1);
|
||||
|
||||
srand(2222222);
|
||||
CvCapture * cap;
|
||||
|
||||
if(filename){
|
||||
cap = cvCaptureFromFile(filename);
|
||||
}else{
|
||||
cap = cvCaptureFromCAM(cam_index);
|
||||
}
|
||||
|
||||
if(!cap) error("Couldn't connect to webcam.\n");
|
||||
cvNamedWindow("Segmenter", CV_WINDOW_NORMAL);
|
||||
cvResizeWindow("Segmenter", 512, 512);
|
||||
float fps = 0;
|
||||
|
||||
while(1){
|
||||
struct timeval tval_before, tval_after, tval_result;
|
||||
gettimeofday(&tval_before, NULL);
|
||||
|
||||
image in = get_image_from_stream(cap);
|
||||
image in_s = letterbox_image(in, net->w, net->h);
|
||||
|
||||
network_predict(net, in_s.data);
|
||||
|
||||
printf("\033[2J");
|
||||
printf("\033[1;1H");
|
||||
printf("\nFPS:%.0f\n",fps);
|
||||
|
||||
image pred = get_network_image(net);
|
||||
image prmask = mask_to_rgb(pred);
|
||||
show_image(prmask, "Segmenter", 10);
|
||||
|
||||
free_image(in_s);
|
||||
free_image(in);
|
||||
free_image(prmask);
|
||||
|
||||
gettimeofday(&tval_after, NULL);
|
||||
timersub(&tval_after, &tval_before, &tval_result);
|
||||
float curr = 1000000.f/((long int)tval_result.tv_usec);
|
||||
fps = .9*fps + .1*curr;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
void run_isegmenter(int argc, char **argv)
|
||||
{
|
||||
if(argc < 4){
|
||||
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
|
||||
return;
|
||||
}
|
||||
|
||||
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
|
||||
int *gpus = 0;
|
||||
int gpu = 0;
|
||||
int ngpus = 0;
|
||||
if(gpu_list){
|
||||
printf("%s\n", gpu_list);
|
||||
int len = strlen(gpu_list);
|
||||
ngpus = 1;
|
||||
int i;
|
||||
for(i = 0; i < len; ++i){
|
||||
if (gpu_list[i] == ',') ++ngpus;
|
||||
}
|
||||
gpus = calloc(ngpus, sizeof(int));
|
||||
for(i = 0; i < ngpus; ++i){
|
||||
gpus[i] = atoi(gpu_list);
|
||||
gpu_list = strchr(gpu_list, ',')+1;
|
||||
}
|
||||
} else {
|
||||
gpu = gpu_index;
|
||||
gpus = &gpu;
|
||||
ngpus = 1;
|
||||
}
|
||||
|
||||
int cam_index = find_int_arg(argc, argv, "-c", 0);
|
||||
int clear = find_arg(argc, argv, "-clear");
|
||||
int display = find_arg(argc, argv, "-display");
|
||||
char *data = argv[3];
|
||||
char *cfg = argv[4];
|
||||
char *weights = (argc > 5) ? argv[5] : 0;
|
||||
char *filename = (argc > 6) ? argv[6]: 0;
|
||||
if(0==strcmp(argv[2], "test")) predict_isegmenter(data, cfg, weights, filename);
|
||||
else if(0==strcmp(argv[2], "train")) train_isegmenter(data, cfg, weights, gpus, ngpus, clear, display);
|
||||
else if(0==strcmp(argv[2], "demo")) demo_isegmenter(data, cfg, weights, cam_index, filename);
|
||||
}
|
||||
|
||||
|
@ -460,13 +460,9 @@ void inter_dcgan(char *cfgfile, char *weightfile)
|
||||
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
||||
//char buff[256];
|
||||
sprintf(buff, "out%05d", c);
|
||||
show_image(out, "out");
|
||||
save_image(out, "out");
|
||||
save_image(out, buff);
|
||||
#ifdef OPENCV
|
||||
//cvWaitKey(0);
|
||||
#endif
|
||||
|
||||
show_image(out, "out", 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -499,11 +495,8 @@ void test_dcgan(char *cfgfile, char *weightfile)
|
||||
//yuv_to_rgb(out);
|
||||
normalize_image(out);
|
||||
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
||||
show_image(out, "out");
|
||||
save_image(out, "out");
|
||||
#ifdef OPENCV
|
||||
cvWaitKey(0);
|
||||
#endif
|
||||
show_image(out, "out", 0);
|
||||
|
||||
free_image(im);
|
||||
}
|
||||
@ -639,11 +632,10 @@ void train_prog(char *cfg, char *weight, char *acfg, char *aweight, int clear, i
|
||||
if(display){
|
||||
image im = float_to_image(anet->w, anet->h, anet->c, gen.X.vals[0]);
|
||||
image im2 = float_to_image(anet->w, anet->h, anet->c, train.X.vals[0]);
|
||||
show_image(im, "gen");
|
||||
show_image(im2, "train");
|
||||
show_image(im, "gen", 1);
|
||||
show_image(im2, "train", 1);
|
||||
save_image(im, "gen");
|
||||
save_image(im2, "train");
|
||||
cvWaitKey(1);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -826,11 +818,10 @@ void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear,
|
||||
if(display){
|
||||
image im = float_to_image(anet->w, anet->h, anet->c, gen.X.vals[0]);
|
||||
image im2 = float_to_image(anet->w, anet->h, anet->c, train.X.vals[0]);
|
||||
show_image(im, "gen");
|
||||
show_image(im2, "train");
|
||||
show_image(im, "gen", 1);
|
||||
show_image(im2, "train", 1);
|
||||
save_image(im, "gen");
|
||||
save_image(im2, "train");
|
||||
cvWaitKey(1);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -1010,9 +1001,8 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle
|
||||
if(display){
|
||||
image im = float_to_image(anet->w, anet->h, anet->c, gray.X.vals[0]);
|
||||
image im2 = float_to_image(anet->w, anet->h, anet->c, train.X.vals[0]);
|
||||
show_image(im, "gen");
|
||||
show_image(im2, "train");
|
||||
cvWaitKey(1);
|
||||
show_image(im, "gen", 1);
|
||||
show_image(im2, "train", 1);
|
||||
}
|
||||
#endif
|
||||
free_data(merge);
|
||||
@ -1342,12 +1332,9 @@ void test_lsd(char *cfg, char *weights, char *filename, int gray)
|
||||
//yuv_to_rgb(out);
|
||||
constrain_image(out);
|
||||
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
||||
show_image(out, "out");
|
||||
show_image(crop, "crop");
|
||||
save_image(out, "out");
|
||||
#ifdef OPENCV
|
||||
cvWaitKey(0);
|
||||
#endif
|
||||
show_image(out, "out", 1);
|
||||
show_image(crop, "crop", 0);
|
||||
|
||||
free_image(im);
|
||||
free_image(resized);
|
||||
|
@ -376,10 +376,7 @@ void run_nightmare(int argc, char **argv)
|
||||
if(reconstruct){
|
||||
reconstruct_picture(net, features, im, update, rate, momentum, lambda, smooth_size, 1);
|
||||
//if ((n+1)%30 == 0) rate *= .5;
|
||||
show_image(im, "reconstruction");
|
||||
#ifdef OPENCV
|
||||
cvWaitKey(10);
|
||||
#endif
|
||||
show_image(im, "reconstruction", 10);
|
||||
}else{
|
||||
int layer = max_layer + rand()%range - range/2;
|
||||
int octave = rand()%octaves;
|
||||
@ -400,8 +397,7 @@ void run_nightmare(int argc, char **argv)
|
||||
}
|
||||
printf("%d %s\n", e, buff);
|
||||
save_image(im, buff);
|
||||
//show_image(im, buff);
|
||||
//cvWaitKey(0);
|
||||
//show_image(im, buff, 0);
|
||||
|
||||
if(rotate){
|
||||
image rot = rotate_image(im, rotate);
|
||||
|
@ -179,7 +179,6 @@ void demo_regressor(char *datacfg, char *cfgfile, char *weightfile, int cam_inde
|
||||
image in = get_image_from_stream(cap);
|
||||
image crop = center_crop_image(in, net->w, net->h);
|
||||
grayscale_image_3c(crop);
|
||||
show_image(crop, "Regressor");
|
||||
|
||||
float *predictions = network_predict(net, crop.data);
|
||||
|
||||
@ -192,11 +191,10 @@ void demo_regressor(char *datacfg, char *cfgfile, char *weightfile, int cam_inde
|
||||
printf("%s: %f\n", names[i], predictions[i]);
|
||||
}
|
||||
|
||||
show_image(crop, "Regressor", 10);
|
||||
free_image(in);
|
||||
free_image(crop);
|
||||
|
||||
cvWaitKey(10);
|
||||
|
||||
gettimeofday(&tval_after, NULL);
|
||||
timersub(&tval_after, &tval_before, &tval_result);
|
||||
float curr = 1000000.f/((long int)tval_result.tv_usec);
|
||||
|
@ -42,7 +42,6 @@ void train_segmenter(char *datacfg, char *cfgfile, char *weightfile, int *gpus,
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
printf("%d\n", plist->size);
|
||||
int N = plist->size;
|
||||
clock_t time;
|
||||
|
||||
load_args args = {0};
|
||||
args.w = net->w;
|
||||
@ -73,14 +72,14 @@ void train_segmenter(char *datacfg, char *cfgfile, char *weightfile, int *gpus,
|
||||
|
||||
int epoch = (*net->seen)/N;
|
||||
while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
|
||||
time=clock();
|
||||
double time = what_time_is_it_now();
|
||||
|
||||
pthread_join(load_thread, 0);
|
||||
train = buffer;
|
||||
load_thread = load_data(args);
|
||||
|
||||
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
||||
time=clock();
|
||||
printf("Loaded: %lf seconds\n", what_time_is_it_now()-time);
|
||||
time = what_time_is_it_now();
|
||||
|
||||
float loss = 0;
|
||||
#ifdef GPU
|
||||
@ -97,18 +96,15 @@ void train_segmenter(char *datacfg, char *cfgfile, char *weightfile, int *gpus,
|
||||
image im = float_to_image(net->w, net->h, net->c, train.X.vals[net->batch*(net->subdivisions-1)]);
|
||||
image mask = mask_to_rgb(tr);
|
||||
image prmask = mask_to_rgb(pred);
|
||||
show_image(im, "input");
|
||||
show_image(prmask, "pred");
|
||||
show_image(mask, "truth");
|
||||
#ifdef OPENCV
|
||||
cvWaitKey(100);
|
||||
#endif
|
||||
show_image(im, "input", 1);
|
||||
show_image(prmask, "pred", 1);
|
||||
show_image(mask, "truth", 100);
|
||||
free_image(mask);
|
||||
free_image(prmask);
|
||||
}
|
||||
if(avg_loss == -1) avg_loss = loss;
|
||||
avg_loss = avg_loss*.9 + loss*.1;
|
||||
printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net->seen);
|
||||
printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), what_time_is_it_now()-time, *net->seen);
|
||||
free_data(train);
|
||||
if(*net->seen/N > epoch){
|
||||
epoch = *net->seen/N;
|
||||
@ -159,13 +155,10 @@ void predict_segmenter(char *datafile, char *cfg, char *weights, char *filename)
|
||||
float *predictions = network_predict(net, X);
|
||||
image pred = get_network_image(net);
|
||||
image prmask = mask_to_rgb(pred);
|
||||
show_image(sized, "orig");
|
||||
show_image(prmask, "pred");
|
||||
#ifdef OPENCV
|
||||
cvWaitKey(0);
|
||||
#endif
|
||||
printf("Predicted: %f\n", predictions[0]);
|
||||
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
||||
show_image(sized, "orig", 1);
|
||||
show_image(prmask, "pred", 0);
|
||||
free_image(im);
|
||||
free_image(sized);
|
||||
free_image(prmask);
|
||||
@ -210,14 +203,12 @@ void demo_segmenter(char *datacfg, char *cfg, char *weights, int cam_index, cons
|
||||
|
||||
image pred = get_network_image(net);
|
||||
image prmask = mask_to_rgb(pred);
|
||||
show_image(prmask, "Segmenter");
|
||||
show_image(prmask, "Segmenter", 10);
|
||||
|
||||
free_image(in_s);
|
||||
free_image(in);
|
||||
free_image(prmask);
|
||||
|
||||
cvWaitKey(10);
|
||||
|
||||
gettimeofday(&tval_after, NULL);
|
||||
timersub(&tval_after, &tval_before, &tval_result);
|
||||
float curr = 1000000.f/((long int)tval_result.tv_usec);
|
||||
|
@ -93,7 +93,7 @@ void test_super(char *cfgfile, char *weightfile, char *filename)
|
||||
image out = get_network_image(net);
|
||||
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
||||
save_image(out, "out");
|
||||
show_image(out, "out");
|
||||
show_image(out, "out", 0);
|
||||
|
||||
free_image(im);
|
||||
if (filename) break;
|
||||
|
@ -296,14 +296,10 @@ void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh)
|
||||
|
||||
draw_detections(im, dets, l.side*l.side*l.n, thresh, voc_names, alphabet, 20);
|
||||
save_image(im, "predictions");
|
||||
show_image(im, "predictions");
|
||||
show_image(im, "predictions", 0);
|
||||
free_detections(dets, nboxes);
|
||||
free_image(im);
|
||||
free_image(sized);
|
||||
#ifdef OPENCV
|
||||
cvWaitKey(0);
|
||||
cvDestroyAllWindows();
|
||||
#endif
|
||||
if (filename) break;
|
||||
}
|
||||
}
|
||||
|
@ -86,6 +86,7 @@ typedef enum {
|
||||
XNOR,
|
||||
REGION,
|
||||
YOLO,
|
||||
ISEG,
|
||||
REORG,
|
||||
UPSAMPLE,
|
||||
LOGXENT,
|
||||
@ -166,6 +167,7 @@ struct layer{
|
||||
float ratio;
|
||||
float learning_rate_scale;
|
||||
float clip;
|
||||
int noloss;
|
||||
int softmax;
|
||||
int classes;
|
||||
int coords;
|
||||
@ -203,6 +205,7 @@ struct layer{
|
||||
int dontload;
|
||||
int dontsave;
|
||||
int dontloadscales;
|
||||
int numload;
|
||||
|
||||
float temperature;
|
||||
float probability;
|
||||
@ -213,6 +216,8 @@ struct layer{
|
||||
int * input_layers;
|
||||
int * input_sizes;
|
||||
int * map;
|
||||
int * counts;
|
||||
float ** sums;
|
||||
float * rand;
|
||||
float * cost;
|
||||
float * state;
|
||||
@ -540,7 +545,7 @@ typedef struct{
|
||||
} data;
|
||||
|
||||
typedef enum {
|
||||
CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA, COMPARE_DATA, WRITING_DATA, SWAG_DATA, TAG_DATA, OLD_CLASSIFICATION_DATA, STUDY_DATA, DET_DATA, SUPER_DATA, LETTERBOX_DATA, REGRESSION_DATA, SEGMENTATION_DATA, INSTANCE_DATA
|
||||
CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA, COMPARE_DATA, WRITING_DATA, SWAG_DATA, TAG_DATA, OLD_CLASSIFICATION_DATA, STUDY_DATA, DET_DATA, SUPER_DATA, LETTERBOX_DATA, REGRESSION_DATA, SEGMENTATION_DATA, INSTANCE_DATA, ISEG_DATA
|
||||
} data_type;
|
||||
|
||||
typedef struct load_args{
|
||||
@ -705,7 +710,7 @@ int resize_network(network *net, int w, int h);
|
||||
void free_matrix(matrix m);
|
||||
void test_resize(char *filename);
|
||||
void save_image(image p, const char *name);
|
||||
void show_image(image p, const char *name);
|
||||
int show_image(image p, const char *name, int ms);
|
||||
image copy_image(image p);
|
||||
void draw_box_width(image a, int x1, int y1, int x2, int y2, int w, float r, float g, float b);
|
||||
float get_current_rate(network *net);
|
||||
|
@ -151,7 +151,7 @@ void cudnn_convolutional_setup(layer *l)
|
||||
l->convDesc,
|
||||
l->dstTensorDesc,
|
||||
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
||||
4000000000,
|
||||
2000000000,
|
||||
&l->fw_algo);
|
||||
cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
|
||||
l->weightDesc,
|
||||
@ -159,7 +159,7 @@ void cudnn_convolutional_setup(layer *l)
|
||||
l->convDesc,
|
||||
l->dsrcTensorDesc,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
|
||||
4000000000,
|
||||
2000000000,
|
||||
&l->bd_algo);
|
||||
cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
|
||||
l->srcTensorDesc,
|
||||
@ -167,7 +167,7 @@ void cudnn_convolutional_setup(layer *l)
|
||||
l->convDesc,
|
||||
l->dweightDesc,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
|
||||
4000000000,
|
||||
2000000000,
|
||||
&l->bf_algo);
|
||||
}
|
||||
#endif
|
||||
|
95
src/data.c
95
src/data.c
@ -361,6 +361,44 @@ box bound_image(image im)
|
||||
}
|
||||
|
||||
void fill_truth_iseg(char *path, int num_boxes, float *truth, int classes, int w, int h, augment_args aug, int flip, int mw, int mh)
|
||||
{
|
||||
char labelpath[4096];
|
||||
find_replace(path, "images", "mask", labelpath);
|
||||
find_replace(labelpath, "JPEGImages", "mask", labelpath);
|
||||
find_replace(labelpath, ".jpg", ".txt", labelpath);
|
||||
find_replace(labelpath, ".JPG", ".txt", labelpath);
|
||||
find_replace(labelpath, ".JPEG", ".txt", labelpath);
|
||||
FILE *file = fopen(labelpath, "r");
|
||||
if(!file) file_error(labelpath);
|
||||
char buff[32788];
|
||||
int id;
|
||||
int i = 0;
|
||||
int j;
|
||||
image part = make_image(w, h, 1);
|
||||
while((fscanf(file, "%d %s", &id, buff) == 2) && i < num_boxes){
|
||||
int n = 0;
|
||||
int *rle = read_intlist(buff, &n, 0);
|
||||
load_rle(part, rle, n);
|
||||
image sized = rotate_crop_image(part, aug.rad, aug.scale, aug.w, aug.h, aug.dx, aug.dy, aug.aspect);
|
||||
if(flip) flip_image(sized);
|
||||
|
||||
image mask = resize_image(sized, mw, mh);
|
||||
truth[i*(mw*mh+1)] = id;
|
||||
for(j = 0; j < mw*mh; ++j){
|
||||
truth[i*(mw*mh + 1) + 1 + j] = mask.data[j];
|
||||
}
|
||||
++i;
|
||||
|
||||
free_image(mask);
|
||||
free_image(sized);
|
||||
free(rle);
|
||||
}
|
||||
if(i < num_boxes) truth[i*(mw*mh+1)] = -1;
|
||||
fclose(file);
|
||||
free_image(part);
|
||||
}
|
||||
|
||||
void fill_truth_mask(char *path, int num_boxes, float *truth, int classes, int w, int h, augment_args aug, int flip, int mw, int mh)
|
||||
{
|
||||
char labelpath[4096];
|
||||
find_replace(path, "images", "mask", labelpath);
|
||||
@ -743,7 +781,47 @@ data load_data_seg(int n, char **paths, int m, int w, int h, int classes, int mi
|
||||
return d;
|
||||
}
|
||||
|
||||
data load_data_iseg(int n, char **paths, int m, int w, int h, int classes, int boxes, int coords, int min, int max, float angle, float aspect, float hue, float saturation, float exposure)
|
||||
data load_data_iseg(int n, char **paths, int m, int w, int h, int classes, int boxes, int div, int min, int max, float angle, float aspect, float hue, float saturation, float exposure)
|
||||
{
|
||||
char **random_paths = get_random_paths(paths, n, m);
|
||||
int i;
|
||||
data d = {0};
|
||||
d.shallow = 0;
|
||||
|
||||
d.X.rows = n;
|
||||
d.X.vals = calloc(d.X.rows, sizeof(float*));
|
||||
d.X.cols = h*w*3;
|
||||
|
||||
d.y = make_matrix(n, (((w/div)*(h/div))+1)*boxes);
|
||||
|
||||
for(i = 0; i < n; ++i){
|
||||
image orig = load_image_color(random_paths[i], 0, 0);
|
||||
augment_args a = random_augment_args(orig, angle, aspect, min, max, w, h);
|
||||
image sized = rotate_crop_image(orig, a.rad, a.scale, a.w, a.h, a.dx, a.dy, a.aspect);
|
||||
|
||||
int flip = rand()%2;
|
||||
if(flip) flip_image(sized);
|
||||
random_distort_image(sized, hue, saturation, exposure);
|
||||
d.X.vals[i] = sized.data;
|
||||
//show_image(sized, "image");
|
||||
|
||||
fill_truth_iseg(random_paths[i], boxes, d.y.vals[i], classes, orig.w, orig.h, a, flip, w/div, h/div);
|
||||
|
||||
free_image(orig);
|
||||
|
||||
/*
|
||||
image rgb = mask_to_rgb(sized_m, classes);
|
||||
show_image(rgb, "part");
|
||||
show_image(sized, "orig");
|
||||
cvWaitKey(0);
|
||||
free_image(rgb);
|
||||
*/
|
||||
}
|
||||
free(random_paths);
|
||||
return d;
|
||||
}
|
||||
|
||||
data load_data_mask(int n, char **paths, int m, int w, int h, int classes, int boxes, int coords, int min, int max, float angle, float aspect, float hue, float saturation, float exposure)
|
||||
{
|
||||
char **random_paths = get_random_paths(paths, n, m);
|
||||
int i;
|
||||
@ -767,7 +845,7 @@ data load_data_iseg(int n, char **paths, int m, int w, int h, int classes, int b
|
||||
d.X.vals[i] = sized.data;
|
||||
//show_image(sized, "image");
|
||||
|
||||
fill_truth_iseg(random_paths[i], boxes, d.y.vals[i], classes, orig.w, orig.h, a, flip, 14, 14);
|
||||
fill_truth_mask(random_paths[i], boxes, d.y.vals[i], classes, orig.w, orig.h, a, flip, 14, 14);
|
||||
|
||||
free_image(orig);
|
||||
|
||||
@ -975,7 +1053,8 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, in
|
||||
float dh = jitter * orig.h;
|
||||
|
||||
float new_ar = (orig.w + rand_uniform(-dw, dw)) / (orig.h + rand_uniform(-dh, dh));
|
||||
float scale = rand_uniform(.25, 2);
|
||||
//float scale = rand_uniform(.25, 2);
|
||||
float scale = 1;
|
||||
|
||||
float nw, nh;
|
||||
|
||||
@ -1025,8 +1104,10 @@ void *load_thread(void *ptr)
|
||||
*a.d = load_data_super(a.paths, a.n, a.m, a.w, a.h, a.scale);
|
||||
} else if (a.type == WRITING_DATA){
|
||||
*a.d = load_data_writing(a.paths, a.n, a.m, a.w, a.h, a.out_w, a.out_h);
|
||||
} else if (a.type == ISEG_DATA){
|
||||
*a.d = load_data_iseg(a.n, a.paths, a.m, a.w, a.h, a.classes, a.num_boxes, a.scale, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
|
||||
} else if (a.type == INSTANCE_DATA){
|
||||
*a.d = load_data_iseg(a.n, a.paths, a.m, a.w, a.h, a.classes, a.num_boxes, a.coords, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
|
||||
*a.d = load_data_mask(a.n, a.paths, a.m, a.w, a.h, a.classes, a.num_boxes, a.coords, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
|
||||
} else if (a.type == SEGMENTATION_DATA){
|
||||
*a.d = load_data_seg(a.n, a.paths, a.m, a.w, a.h, a.classes, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure, a.scale);
|
||||
} else if (a.type == REGION_DATA){
|
||||
@ -1212,7 +1293,7 @@ data *tile_data(data orig, int divs, int size)
|
||||
{
|
||||
data *ds = calloc(divs*divs, sizeof(data));
|
||||
int i, j;
|
||||
#pragma omp parallel for
|
||||
#pragma omp parallel for
|
||||
for(i = 0; i < divs*divs; ++i){
|
||||
data d;
|
||||
d.shallow = 0;
|
||||
@ -1223,7 +1304,7 @@ data *tile_data(data orig, int divs, int size)
|
||||
d.X.vals = calloc(d.X.rows, sizeof(float*));
|
||||
|
||||
d.y = copy_matrix(orig.y);
|
||||
#pragma omp parallel for
|
||||
#pragma omp parallel for
|
||||
for(j = 0; j < orig.X.rows; ++j){
|
||||
int x = (i%divs) * orig.w / divs - (d.w - orig.w/divs)/2;
|
||||
int y = (i/divs) * orig.h / divs - (d.h - orig.h/divs)/2;
|
||||
@ -1247,7 +1328,7 @@ data resize_data(data orig, int w, int h)
|
||||
d.X.vals = calloc(d.X.rows, sizeof(float*));
|
||||
|
||||
d.y = copy_matrix(orig.y);
|
||||
#pragma omp parallel for
|
||||
#pragma omp parallel for
|
||||
for(i = 0; i < orig.X.rows; ++i){
|
||||
image im = float_to_image(orig.w, orig.h, 3, orig.X.vals[i]);
|
||||
d.X.vals[i] = resize_image(im, w, h).data;
|
||||
|
30
src/image.c
30
src/image.c
@ -572,7 +572,7 @@ void show_image_cv(image p, const char *name, IplImage *disp)
|
||||
}
|
||||
#endif
|
||||
|
||||
void show_image(image p, const char *name)
|
||||
int show_image(image p, const char *name, int ms)
|
||||
{
|
||||
#ifdef OPENCV
|
||||
IplImage *disp = cvCreateImage(cvSize(p.w,p.h), IPL_DEPTH_8U, p.c);
|
||||
@ -581,9 +581,13 @@ void show_image(image p, const char *name)
|
||||
show_image_cv(copy, name, disp);
|
||||
free_image(copy);
|
||||
cvReleaseImage(&disp);
|
||||
int c = cvWaitKey(ms);
|
||||
if (c != -1) c = c%256;
|
||||
return c;
|
||||
#else
|
||||
fprintf(stderr, "Not compiled with OpenCV, saving to %s.png instead\n", name);
|
||||
save_image(p, name);
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -727,7 +731,7 @@ void show_image_layers(image p, char *name)
|
||||
for(i = 0; i < p.c; ++i){
|
||||
sprintf(buff, "%s - Layer %d", name, i);
|
||||
image layer = get_image_layer(p, i);
|
||||
show_image(layer, buff);
|
||||
show_image(layer, buff, 1);
|
||||
free_image(layer);
|
||||
}
|
||||
}
|
||||
@ -735,7 +739,7 @@ void show_image_layers(image p, char *name)
|
||||
void show_image_collapsed(image p, char *name)
|
||||
{
|
||||
image c = collapse_image_layers(p, 1);
|
||||
show_image(c, name);
|
||||
show_image(c, name, 1);
|
||||
free_image(c);
|
||||
}
|
||||
|
||||
@ -1406,16 +1410,16 @@ void test_resize(char *filename)
|
||||
distort_image(c4, .1, .66666, 1.5);
|
||||
|
||||
|
||||
show_image(im, "Original");
|
||||
show_image(gray, "Gray");
|
||||
show_image(c1, "C1");
|
||||
show_image(c2, "C2");
|
||||
show_image(c3, "C3");
|
||||
show_image(c4, "C4");
|
||||
show_image(im, "Original", 1);
|
||||
show_image(gray, "Gray", 1);
|
||||
show_image(c1, "C1", 1);
|
||||
show_image(c2, "C2", 1);
|
||||
show_image(c3, "C3", 1);
|
||||
show_image(c4, "C4", 1);
|
||||
#ifdef OPENCV
|
||||
while(1){
|
||||
image aug = random_augment_image(im, 0, .75, 320, 448, 320, 320);
|
||||
show_image(aug, "aug");
|
||||
show_image(aug, "aug", 1);
|
||||
free_image(aug);
|
||||
|
||||
|
||||
@ -1430,7 +1434,7 @@ void test_resize(char *filename)
|
||||
float dhue = rand_uniform(-hue, hue);
|
||||
|
||||
distort_image(c, dhue, dsat, dexp);
|
||||
show_image(c, "rand");
|
||||
show_image(c, "rand", 1);
|
||||
printf("%f %f %f\n", dhue, dsat, dexp);
|
||||
free_image(c);
|
||||
cvWaitKey(0);
|
||||
@ -1585,7 +1589,7 @@ void show_image_normalized(image im, const char *name)
|
||||
{
|
||||
image c = copy_image(im);
|
||||
normalize_image(c);
|
||||
show_image(c, name);
|
||||
show_image(c, name, 1);
|
||||
free_image(c);
|
||||
}
|
||||
|
||||
@ -1603,7 +1607,7 @@ void show_images(image *ims, int n, char *window)
|
||||
*/
|
||||
normalize_image(m);
|
||||
save_image(m, window);
|
||||
show_image(m, window);
|
||||
show_image(m, window, 1);
|
||||
free_image(m);
|
||||
}
|
||||
|
||||
|
219
src/iseg_layer.c
Normal file
219
src/iseg_layer.c
Normal file
@ -0,0 +1,219 @@
|
||||
#include "iseg_layer.h"
|
||||
#include "activations.h"
|
||||
#include "blas.h"
|
||||
#include "box.h"
|
||||
#include "cuda.h"
|
||||
#include "utils.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
layer make_iseg_layer(int batch, int w, int h, int classes, int ids)
|
||||
{
|
||||
layer l = {0};
|
||||
l.type = ISEG;
|
||||
|
||||
l.h = h;
|
||||
l.w = w;
|
||||
l.c = classes + ids;
|
||||
l.out_w = l.w;
|
||||
l.out_h = l.h;
|
||||
l.out_c = l.c;
|
||||
l.classes = classes;
|
||||
l.batch = batch;
|
||||
l.extra = ids;
|
||||
l.cost = calloc(1, sizeof(float));
|
||||
l.outputs = h*w*l.c;
|
||||
l.inputs = l.outputs;
|
||||
l.truths = 90*(l.w*l.h+1);
|
||||
l.delta = calloc(batch*l.outputs, sizeof(float));
|
||||
l.output = calloc(batch*l.outputs, sizeof(float));
|
||||
|
||||
l.counts = calloc(90, sizeof(int));
|
||||
l.sums = calloc(90, sizeof(float*));
|
||||
if(ids){
|
||||
int i;
|
||||
for(i = 0; i < 90; ++i){
|
||||
l.sums[i] = calloc(ids, sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
l.forward = forward_iseg_layer;
|
||||
l.backward = backward_iseg_layer;
|
||||
#ifdef GPU
|
||||
l.forward_gpu = forward_iseg_layer_gpu;
|
||||
l.backward_gpu = backward_iseg_layer_gpu;
|
||||
l.output_gpu = cuda_make_array(l.output, batch*l.outputs);
|
||||
l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs);
|
||||
#endif
|
||||
|
||||
fprintf(stderr, "iseg\n");
|
||||
srand(0);
|
||||
|
||||
return l;
|
||||
}
|
||||
|
||||
void resize_iseg_layer(layer *l, int w, int h)
|
||||
{
|
||||
l->w = w;
|
||||
l->h = h;
|
||||
|
||||
l->outputs = h*w*l->c;
|
||||
l->inputs = l->outputs;
|
||||
|
||||
l->output = realloc(l->output, l->batch*l->outputs*sizeof(float));
|
||||
l->delta = realloc(l->delta, l->batch*l->outputs*sizeof(float));
|
||||
|
||||
#ifdef GPU
|
||||
cuda_free(l->delta_gpu);
|
||||
cuda_free(l->output_gpu);
|
||||
|
||||
l->delta_gpu = cuda_make_array(l->delta, l->batch*l->outputs);
|
||||
l->output_gpu = cuda_make_array(l->output, l->batch*l->outputs);
|
||||
#endif
|
||||
}
|
||||
|
||||
void forward_iseg_layer(const layer l, network net)
|
||||
{
|
||||
|
||||
double time = what_time_is_it_now();
|
||||
int i,b,j,k;
|
||||
int ids = l.extra;
|
||||
memcpy(l.output, net.input, l.outputs*l.batch*sizeof(float));
|
||||
memset(l.delta, 0, l.outputs * l.batch * sizeof(float));
|
||||
|
||||
#ifndef GPU
|
||||
for (b = 0; b < l.batch; ++b){
|
||||
int index = b*l.outputs;
|
||||
activate_array(l.output + index, l.classes*l.w*l.h, LOGISTIC);
|
||||
}
|
||||
#endif
|
||||
|
||||
for (b = 0; b < l.batch; ++b){
|
||||
// a priori, each pixel has no class
|
||||
for(i = 0; i < l.classes; ++i){
|
||||
for(k = 0; k < l.w*l.h; ++k){
|
||||
int index = b*l.outputs + i*l.w*l.h + k;
|
||||
l.delta[index] = 0 - l.output[index];
|
||||
}
|
||||
}
|
||||
|
||||
// a priori, embedding should be small magnitude
|
||||
for(i = 0; i < ids; ++i){
|
||||
for(k = 0; k < l.w*l.h; ++k){
|
||||
int index = b*l.outputs + (i+l.classes)*l.w*l.h + k;
|
||||
l.delta[index] = .1 * (0 - l.output[index]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
memset(l.counts, 0, 90*sizeof(float));
|
||||
for(i = 0; i < 90; ++i){
|
||||
l.counts[i] = 0;
|
||||
fill_cpu(ids, 0, l.sums[i], 1);
|
||||
|
||||
int c = net.truth[b*l.truths + i*(l.w*l.h+1)];
|
||||
if(c < 0) break;
|
||||
// add up metric embeddings for each instance
|
||||
for(k = 0; k < l.w*l.h; ++k){
|
||||
int index = b*l.outputs + c*l.w*l.h + k;
|
||||
float v = net.truth[b*l.truths + i*(l.w*l.h + 1) + 1 + k];
|
||||
if(v){
|
||||
l.delta[index] = v - l.output[index];
|
||||
axpy_cpu(ids, 1, l.output + b*l.outputs + l.classes*l.w*l.h + k, l.w*l.h, l.sums[i], 1);
|
||||
++l.counts[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float *mse = calloc(90, sizeof(float));
|
||||
for(i = 0; i < 90; ++i){
|
||||
int c = net.truth[b*l.truths + i*(l.w*l.h+1)];
|
||||
if(c < 0) break;
|
||||
for(k = 0; k < l.w*l.h; ++k){
|
||||
float v = net.truth[b*l.truths + i*(l.w*l.h + 1) + 1 + k];
|
||||
if(v){
|
||||
int z;
|
||||
float sum = 0;
|
||||
for(z = 0; z < ids; ++z){
|
||||
int index = b*l.outputs + (l.classes + z)*l.w*l.h + k;
|
||||
sum += pow(l.sums[i][z]/l.counts[i] - l.output[index], 2);
|
||||
}
|
||||
mse[i] += sum;
|
||||
}
|
||||
}
|
||||
mse[i] /= l.counts[i];
|
||||
}
|
||||
|
||||
// Calculate average embedding
|
||||
for(i = 0; i < 90; ++i){
|
||||
if(!l.counts[i]) continue;
|
||||
scal_cpu(ids, 1.f/l.counts[i], l.sums[i], 1);
|
||||
if(b == 0 && net.gpu_index == 0){
|
||||
printf("%4d, %6.3f, ", l.counts[i], mse[i]);
|
||||
for(j = 0; j < ids/4; ++j){
|
||||
printf("%6.3f,", l.sums[i][j]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
free(mse);
|
||||
|
||||
// Calculate embedding loss
|
||||
for(i = 0; i < 90; ++i){
|
||||
if(!l.counts[i]) continue;
|
||||
for(k = 0; k < l.w*l.h; ++k){
|
||||
float v = net.truth[b*l.truths + i*(l.w*l.h + 1) + 1 + k];
|
||||
if(v){
|
||||
for(j = 0; j < 90; ++j){
|
||||
if(!l.counts[j])continue;
|
||||
int z;
|
||||
for(z = 0; z < ids; ++z){
|
||||
int index = b*l.outputs + (l.classes + z)*l.w*l.h + k;
|
||||
float diff = l.sums[j][z] - l.output[index];
|
||||
if (j == i) l.delta[index] += diff < 0? -.1 : .1;
|
||||
else l.delta[index] += -(diff < 0? -.1 : .1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*(l.cost) = pow(mag_array(l.delta, l.outputs * l.batch), 2);
|
||||
printf("took %lf sec\n", what_time_is_it_now() - time);
|
||||
}
|
||||
|
||||
void backward_iseg_layer(const layer l, network net)
|
||||
{
|
||||
axpy_cpu(l.batch*l.inputs, 1, l.delta, 1, net.delta, 1);
|
||||
}
|
||||
|
||||
#ifdef GPU
|
||||
|
||||
void forward_iseg_layer_gpu(const layer l, network net)
|
||||
{
|
||||
copy_gpu(l.batch*l.inputs, net.input_gpu, 1, l.output_gpu, 1);
|
||||
int b;
|
||||
for (b = 0; b < l.batch; ++b){
|
||||
activate_array_gpu(l.output_gpu + b*l.outputs, l.classes*l.w*l.h, LOGISTIC);
|
||||
//if(l.extra) activate_array_gpu(l.output_gpu + b*l.outputs + l.classes*l.w*l.h, l.extra*l.w*l.h, LOGISTIC);
|
||||
}
|
||||
|
||||
cuda_pull_array(l.output_gpu, net.input, l.batch*l.inputs);
|
||||
forward_iseg_layer(l, net);
|
||||
cuda_push_array(l.delta_gpu, l.delta, l.batch*l.outputs);
|
||||
}
|
||||
|
||||
void backward_iseg_layer_gpu(const layer l, network net)
|
||||
{
|
||||
int b;
|
||||
for (b = 0; b < l.batch; ++b){
|
||||
//if(l.extra) gradient_array_gpu(l.output_gpu + b*l.outputs + l.classes*l.w*l.h, l.extra*l.w*l.h, LOGISTIC, l.delta_gpu + b*l.outputs + l.classes*l.w*l.h);
|
||||
}
|
||||
axpy_gpu(l.batch*l.inputs, 1, l.delta_gpu, 1, net.delta_gpu, 1);
|
||||
}
|
||||
#endif
|
||||
|
19
src/iseg_layer.h
Normal file
19
src/iseg_layer.h
Normal file
@ -0,0 +1,19 @@
|
||||
#ifndef ISEG_LAYER_H
|
||||
#define ISEG_LAYER_H
|
||||
|
||||
#include "darknet.h"
|
||||
#include "layer.h"
|
||||
#include "network.h"
|
||||
|
||||
layer make_iseg_layer(int batch, int w, int h, int classes, int ids);
|
||||
void forward_iseg_layer(const layer l, network net);
|
||||
void backward_iseg_layer(const layer l, network net);
|
||||
void resize_iseg_layer(layer *l, int w, int h);
|
||||
int iseg_num_detections(layer l, float thresh);
|
||||
|
||||
#ifdef GPU
|
||||
void forward_iseg_layer_gpu(const layer l, network net);
|
||||
void backward_iseg_layer_gpu(layer l, network net);
|
||||
#endif
|
||||
|
||||
#endif
|
@ -27,8 +27,8 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s
|
||||
l.w = w;
|
||||
l.c = c;
|
||||
l.pad = padding;
|
||||
l.out_w = (w + 2*padding - size)/stride + 1;
|
||||
l.out_h = (h + 2*padding - size)/stride + 1;
|
||||
l.out_w = (w + padding - size)/stride + 1;
|
||||
l.out_h = (h + padding - size)/stride + 1;
|
||||
l.out_c = c;
|
||||
l.outputs = l.out_h * l.out_w * l.out_c;
|
||||
l.inputs = h*w*c;
|
||||
@ -57,8 +57,8 @@ void resize_maxpool_layer(maxpool_layer *l, int w, int h)
|
||||
l->w = w;
|
||||
l->inputs = h*w*l->c;
|
||||
|
||||
l->out_w = (w + 2*l->pad - l->size)/l->stride + 1;
|
||||
l->out_h = (h + 2*l->pad - l->size)/l->stride + 1;
|
||||
l->out_w = (w + l->pad - l->size)/l->stride + 1;
|
||||
l->out_h = (h + l->pad - l->size)/l->stride + 1;
|
||||
l->outputs = l->out_w * l->out_h * l->c;
|
||||
int output_size = l->outputs * l->batch;
|
||||
|
||||
@ -79,8 +79,8 @@ void resize_maxpool_layer(maxpool_layer *l, int w, int h)
|
||||
void forward_maxpool_layer(const maxpool_layer l, network net)
|
||||
{
|
||||
int b,i,j,k,m,n;
|
||||
int w_offset = -l.pad;
|
||||
int h_offset = -l.pad;
|
||||
int w_offset = -l.pad/l.stride;
|
||||
int h_offset = -l.pad/l.stride;
|
||||
|
||||
int h = l.out_h;
|
||||
int w = l.out_w;
|
||||
|
@ -9,8 +9,8 @@ extern "C" {
|
||||
|
||||
__global__ void forward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, int pad, float *input, float *output, int *indexes)
|
||||
{
|
||||
int h = (in_h + 2*pad - size)/stride + 1;
|
||||
int w = (in_w + 2*pad - size)/stride + 1;
|
||||
int h = (in_h + pad - size)/stride + 1;
|
||||
int w = (in_w + pad - size)/stride + 1;
|
||||
int c = in_c;
|
||||
|
||||
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
@ -24,8 +24,8 @@ __global__ void forward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c
|
||||
id /= c;
|
||||
int b = id;
|
||||
|
||||
int w_offset = -pad;
|
||||
int h_offset = -pad;
|
||||
int w_offset = -pad/2;
|
||||
int h_offset = -pad/2;
|
||||
|
||||
int out_index = j + w*(i + h*(k + c*b));
|
||||
float max = -INFINITY;
|
||||
@ -49,8 +49,8 @@ __global__ void forward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c
|
||||
|
||||
__global__ void backward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, int pad, float *delta, float *prev_delta, int *indexes)
|
||||
{
|
||||
int h = (in_h + 2*pad - size)/stride + 1;
|
||||
int w = (in_w + 2*pad - size)/stride + 1;
|
||||
int h = (in_h + pad - size)/stride + 1;
|
||||
int w = (in_w + pad - size)/stride + 1;
|
||||
int c = in_c;
|
||||
int area = (size-1)/stride;
|
||||
|
||||
@ -66,8 +66,8 @@ __global__ void backward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_
|
||||
id /= in_c;
|
||||
int b = id;
|
||||
|
||||
int w_offset = -pad;
|
||||
int h_offset = -pad;
|
||||
int w_offset = -pad/2;
|
||||
int h_offset = -pad/2;
|
||||
|
||||
float d = 0;
|
||||
int l, m;
|
||||
|
38
src/parser.c
38
src/parser.c
@ -27,6 +27,7 @@
|
||||
#include "parser.h"
|
||||
#include "region_layer.h"
|
||||
#include "yolo_layer.h"
|
||||
#include "iseg_layer.h"
|
||||
#include "reorg_layer.h"
|
||||
#include "rnn_layer.h"
|
||||
#include "route_layer.h"
|
||||
@ -52,6 +53,7 @@ LAYER_TYPE string_to_layer_type(char * type)
|
||||
if (strcmp(type, "[detection]")==0) return DETECTION;
|
||||
if (strcmp(type, "[region]")==0) return REGION;
|
||||
if (strcmp(type, "[yolo]")==0) return YOLO;
|
||||
if (strcmp(type, "[iseg]")==0) return ISEG;
|
||||
if (strcmp(type, "[local]")==0) return LOCAL;
|
||||
if (strcmp(type, "[conv]")==0
|
||||
|| strcmp(type, "[convolutional]")==0) return CONVOLUTIONAL;
|
||||
@ -265,18 +267,19 @@ layer parse_connected(list *options, size_params params)
|
||||
return l;
|
||||
}
|
||||
|
||||
softmax_layer parse_softmax(list *options, size_params params)
|
||||
layer parse_softmax(list *options, size_params params)
|
||||
{
|
||||
int groups = option_find_int_quiet(options, "groups",1);
|
||||
softmax_layer layer = make_softmax_layer(params.batch, params.inputs, groups);
|
||||
layer.temperature = option_find_float_quiet(options, "temperature", 1);
|
||||
layer l = make_softmax_layer(params.batch, params.inputs, groups);
|
||||
l.temperature = option_find_float_quiet(options, "temperature", 1);
|
||||
char *tree_file = option_find_str(options, "tree", 0);
|
||||
if (tree_file) layer.softmax_tree = read_tree(tree_file);
|
||||
layer.w = params.w;
|
||||
layer.h = params.h;
|
||||
layer.c = params.c;
|
||||
layer.spatial = option_find_float_quiet(options, "spatial", 0);
|
||||
return layer;
|
||||
if (tree_file) l.softmax_tree = read_tree(tree_file);
|
||||
l.w = params.w;
|
||||
l.h = params.h;
|
||||
l.c = params.c;
|
||||
l.spatial = option_find_float_quiet(options, "spatial", 0);
|
||||
l.noloss = option_find_int_quiet(options, "noloss", 0);
|
||||
return l;
|
||||
}
|
||||
|
||||
int *parse_yolo_mask(char *a, int *num)
|
||||
@ -338,6 +341,15 @@ layer parse_yolo(list *options, size_params params)
|
||||
return l;
|
||||
}
|
||||
|
||||
layer parse_iseg(list *options, size_params params)
|
||||
{
|
||||
int classes = option_find_int(options, "classes", 20);
|
||||
int ids = option_find_int(options, "ids", 32);
|
||||
layer l = make_iseg_layer(params.batch, params.w, params.h, classes, ids);
|
||||
assert(l.outputs == params.inputs);
|
||||
return l;
|
||||
}
|
||||
|
||||
layer parse_region(list *options, size_params params)
|
||||
{
|
||||
int coords = option_find_int(options, "coords", 4);
|
||||
@ -472,7 +484,7 @@ maxpool_layer parse_maxpool(list *options, size_params params)
|
||||
{
|
||||
int stride = option_find_int(options, "stride",1);
|
||||
int size = option_find_int(options, "size",stride);
|
||||
int padding = option_find_int_quiet(options, "padding", (size-1)/2);
|
||||
int padding = option_find_int_quiet(options, "padding", size-1);
|
||||
|
||||
int batch,h,w,c;
|
||||
h = params.h;
|
||||
@ -791,6 +803,8 @@ network *parse_network_cfg(char *filename)
|
||||
l = parse_region(options, params);
|
||||
}else if(lt == YOLO){
|
||||
l = parse_yolo(options, params);
|
||||
}else if(lt == ISEG){
|
||||
l = parse_iseg(options, params);
|
||||
}else if(lt == DETECTION){
|
||||
l = parse_detection(options, params);
|
||||
}else if(lt == SOFTMAX){
|
||||
@ -829,6 +843,7 @@ network *parse_network_cfg(char *filename)
|
||||
l.stopbackward = option_find_int_quiet(options, "stopbackward", 0);
|
||||
l.dontsave = option_find_int_quiet(options, "dontsave", 0);
|
||||
l.dontload = option_find_int_quiet(options, "dontload", 0);
|
||||
l.numload = option_find_int_quiet(options, "numload", 0);
|
||||
l.dontloadscales = option_find_int_quiet(options, "dontloadscales", 0);
|
||||
l.learning_rate_scale = option_find_float_quiet(options, "learning_rate", 1);
|
||||
l.smooth = option_find_float_quiet(options, "smooth", 0);
|
||||
@ -1152,7 +1167,8 @@ void load_convolutional_weights(layer l, FILE *fp)
|
||||
//load_convolutional_weights_binary(l, fp);
|
||||
//return;
|
||||
}
|
||||
int num = l.nweights;
|
||||
if(l.numload) l.n = l.numload;
|
||||
int num = l.c/l.groups*l.n*l.size*l.size;
|
||||
fread(l.biases, sizeof(float), l.n, fp);
|
||||
if (l.batch_normalize && (!l.dontloadscales)){
|
||||
fread(l.scales, sizeof(float), l.n, fp);
|
||||
|
@ -50,7 +50,7 @@ void forward_softmax_layer(const softmax_layer l, network net)
|
||||
softmax_cpu(net.input, l.inputs/l.groups, l.batch, l.inputs, l.groups, l.inputs/l.groups, 1, l.temperature, l.output);
|
||||
}
|
||||
|
||||
if(net.truth){
|
||||
if(net.truth && !l.noloss){
|
||||
softmax_x_ent_cpu(l.batch*l.inputs, l.output, net.truth, l.delta, l.loss);
|
||||
l.cost[0] = sum_array(l.loss, l.batch*l.inputs);
|
||||
}
|
||||
@ -88,7 +88,7 @@ void forward_softmax_layer_gpu(const softmax_layer l, network net)
|
||||
softmax_gpu(net.input_gpu, l.inputs/l.groups, l.batch, l.inputs, l.groups, l.inputs/l.groups, 1, l.temperature, l.output_gpu);
|
||||
}
|
||||
}
|
||||
if(net.truth){
|
||||
if(net.truth && !l.noloss){
|
||||
softmax_x_ent_gpu(l.batch*l.inputs, l.output_gpu, net.truth_gpu, l.delta_gpu, l.loss_gpu);
|
||||
if(l.softmax_tree){
|
||||
mask_gpu(l.batch*l.inputs, l.delta_gpu, SECRET_NUM, net.truth_gpu, 0);
|
||||
|
Loading…
Reference in New Issue
Block a user