hope i didn't break anything

This commit is contained in:
Joseph Redmon 2016-06-02 15:25:24 -07:00
parent 881d6ee9b6
commit ec3d050a76
17 changed files with 834 additions and 550 deletions

View File

@ -3,7 +3,7 @@ CUDNN=0
OPENCV=0
DEBUG=0
ARCH= --gpu-architecture=compute_20 --gpu-code=compute_20
ARCH= --gpu-architecture=compute_52 --gpu-code=compute_52
VPATH=./src/
EXEC=darknet

View File

@ -14,7 +14,6 @@ power=4
max_batches=500000
[convolutional]
batch_normalize=1
filters=16
size=3
stride=1
@ -26,7 +25,6 @@ size=2
stride=2
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
@ -38,7 +36,6 @@ size=2
stride=2
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
@ -50,7 +47,6 @@ size=2
stride=2
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
@ -62,7 +58,6 @@ size=2
stride=2
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
@ -74,7 +69,6 @@ size=2
stride=2
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
@ -86,7 +80,6 @@ size=2
stride=2
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1

9
cfg/imagenet1k.dataset Normal file
View File

@ -0,0 +1,9 @@
classes=1000
labels = data/inet.labels.list
names = data/shortnames.txt
train = /data/imagenet/imagenet1k.train.list
valid = /data/imagenet/imagenet1k.valid.list
top=5
test = /Users/pjreddie/Documents/sites/selfie/paths.list
backup = /home/pjreddie/backup/

View File

@ -38,7 +38,7 @@ list *read_data_cfg(char *filename)
return options;
}
void train_classifier(char *datacfg, char *cfgfile, char *weightfile)
void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear)
{
data_seed = time(0);
srand(time(0));
@ -49,6 +49,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile)
if(weightfile){
load_weights(&net, weightfile);
}
if(clear) *net.seen = 0;
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int imgs = net.batch;
@ -96,7 +97,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile)
printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
/*
/*
int u;
for(u = 0; u < net.batch; ++u){
image im = float_to_image(net.w, net.h, 3, train.X.vals[u]);
@ -116,7 +117,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile)
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
save_weights(net, buff);
}
if(*net.seen%100 == 0){
if(get_current_batch(net)%100 == 0){
char buff[256];
sprintf(buff, "%s/%s.backup",backup_directory,base);
save_weights(net, buff);
@ -378,8 +379,8 @@ void validate_classifier_single(char *datacfg, char *filename, char *weightfile)
//cvWaitKey(0);
float *pred = network_predict(net, crop.data);
if(resized.data != im.data) free_image(resized);
free_image(im);
free_image(resized);
free_image(crop);
top_k(pred, classes, topk, indexes);
@ -441,7 +442,7 @@ void validate_classifier_multi(char *datacfg, char *filename, char *weightfile)
flip_image(r);
p = network_predict(net, r.data);
axpy_cpu(classes, 1, p, 1, pred, 1);
free_image(r);
if(r.data != im.data) free_image(r);
}
free_image(im);
top_k(pred, classes, topk, indexes);
@ -501,6 +502,46 @@ void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *fi
}
}
void label_classifier(char *datacfg, char *filename, char *weightfile)
{
int i;
network net = parse_network_cfg(filename);
set_batch_network(&net, 1);
if(weightfile){
load_weights(&net, weightfile);
}
srand(time(0));
list *options = read_data_cfg(datacfg);
char *label_list = option_find_str(options, "names", "data/labels.list");
char *test_list = option_find_str(options, "test", "data/train.list");
int classes = option_find_int(options, "classes", 2);
char **labels = get_labels(label_list);
list *plist = get_paths(test_list);
char **paths = (char **)list_to_array(plist);
int m = plist->size;
free_list(plist);
for(i = 0; i < m; ++i){
image im = load_image_color(paths[i], 0, 0);
image resized = resize_min(im, net.w);
image crop = crop_image(resized, (resized.w - net.w)/2, (resized.h - net.h)/2, net.w, net.h);
float *pred = network_predict(net, crop.data);
if(resized.data != im.data) free_image(resized);
free_image(im);
free_image(crop);
int ind = max_index(pred, classes);
printf("%s\n", labels[ind]);
}
}
void test_classifier(char *datacfg, char *cfgfile, char *weightfile, int target_layer)
{
int curr = 0;
@ -649,6 +690,7 @@ void run_classifier(int argc, char **argv)
}
int cam_index = find_int_arg(argc, argv, "-c", 0);
int clear = find_arg(argc, argv, "-clear");
char *data = argv[3];
char *cfg = argv[4];
char *weights = (argc > 5) ? argv[5] : 0;
@ -656,9 +698,10 @@ void run_classifier(int argc, char **argv)
char *layer_s = (argc > 7) ? argv[7]: 0;
int layer = layer_s ? atoi(layer_s) : -1;
if(0==strcmp(argv[2], "predict")) predict_classifier(data, cfg, weights, filename);
else if(0==strcmp(argv[2], "train")) train_classifier(data, cfg, weights);
else if(0==strcmp(argv[2], "train")) train_classifier(data, cfg, weights, clear);
else if(0==strcmp(argv[2], "demo")) demo_classifier(data, cfg, weights, cam_index, filename);
else if(0==strcmp(argv[2], "test")) test_classifier(data, cfg, weights, layer);
else if(0==strcmp(argv[2], "label")) label_classifier(data, cfg, weights);
else if(0==strcmp(argv[2], "valid")) validate_classifier(data, cfg, weights);
else if(0==strcmp(argv[2], "valid10")) validate_classifier_10(data, cfg, weights);
else if(0==strcmp(argv[2], "validmulti")) validate_classifier_multi(data, cfg, weights);

View File

@ -161,6 +161,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
l.filter_updates_gpu);
if(state.delta){
if(l.binary || l.xnor) swap_binary(&l);
cudnnConvolutionBackwardData(cudnn_handle(),
&one,
l.filterDesc,
@ -174,6 +175,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
&one,
l.dsrcTensorDesc,
state.delta);
if(l.binary || l.xnor) swap_binary(&l);
}
#else

View File

@ -88,8 +88,8 @@ image get_convolutional_delta(convolutional_layer l)
return float_to_image(w,h,c,l.delta);
}
#ifdef CUDNN
size_t get_workspace_size(layer l){
#ifdef CUDNN
size_t most = 0;
size_t s = 0;
cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
@ -117,8 +117,10 @@ size_t get_workspace_size(layer l){
&s);
if (s > most) most = s;
return most;
#else
return (size_t)l.out_h*l.out_w*l.size*l.size*l.c*sizeof(float);
#endif
}
#endif
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int binary, int xnor)
{
@ -154,8 +156,6 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
l.outputs = l.out_h * l.out_w * l.out_c;
l.inputs = l.w * l.h * l.c;
l.col_image = calloc(out_h*out_w*size*size*c, sizeof(float));
l.workspace_size = out_h*out_w*size*size*c*sizeof(float);
l.output = calloc(l.batch*out_h * out_w * n, sizeof(float));
l.delta = calloc(l.batch*out_h * out_w * n, sizeof(float));
@ -255,10 +255,9 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
0,
&l.bf_algo);
#endif
#endif
l.workspace_size = get_workspace_size(l);
#endif
#endif
l.activation = activation;
fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n);
@ -315,8 +314,6 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h)
l->outputs = l->out_h * l->out_w * l->out_c;
l->inputs = l->w * l->h * l->c;
l->col_image = realloc(l->col_image,
out_h*out_w*l->size*l->size*l->c*sizeof(float));
l->output = realloc(l->output,
l->batch*out_h * out_w * l->n*sizeof(float));
l->delta = realloc(l->delta,
@ -328,7 +325,43 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h)
l->delta_gpu = cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
l->output_gpu = cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
#ifdef CUDNN
cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
cudnnSetFilter4dDescriptor(l->dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
cudnnSetFilter4dDescriptor(l->filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
int padding = l->pad ? l->size/2 : 0;
cudnnSetConvolution2dDescriptor(l->convDesc, padding, padding, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION);
cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
l->srcTensorDesc,
l->filterDesc,
l->convDesc,
l->dstTensorDesc,
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
0,
&l->fw_algo);
cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
l->filterDesc,
l->ddstTensorDesc,
l->convDesc,
l->dsrcTensorDesc,
CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
0,
&l->bd_algo);
cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
l->srcTensorDesc,
l->ddstTensorDesc,
l->convDesc,
l->dfilterDesc,
CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
0,
&l->bf_algo);
#endif
#endif
l->workspace_size = get_workspace_size(*l);
}
void add_bias(float *output, float *biases, int batch, int n, int size)
@ -386,7 +419,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
int n = out_h*out_w;
char *a = l.cfilters;
float *b = l.col_image;
float *b = state.workspace;
float *c = l.output;
for(i = 0; i < l.batch; ++i){
@ -407,7 +440,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
int n = out_h*out_w;
float *a = l.filters;
float *b = l.col_image;
float *b = state.workspace;
float *c = l.output;
for(i = 0; i < l.batch; ++i){
@ -439,7 +472,7 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)
for(i = 0; i < l.batch; ++i){
float *a = l.delta + i*m*k;
float *b = l.col_image;
float *b = state.workspace;
float *c = l.filter_updates;
float *im = state.input+i*l.c*l.h*l.w;
@ -451,11 +484,11 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)
if(state.delta){
a = l.filters;
b = l.delta + i*m*k;
c = l.col_image;
c = state.workspace;
gemm(1,0,n,k,m,1,a,n,b,k,0,c,k);
col2im_cpu(l.col_image, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta+i*l.c*l.h*l.w);
col2im_cpu(state.workspace, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta+i*l.c*l.h*l.w);
}
}
}

View File

@ -270,6 +270,8 @@ int main(int argc, char **argv)
run_dice(argc, argv);
} else if (0 == strcmp(argv[1], "writing")){
run_writing(argc, argv);
} else if (0 == strcmp(argv[1], "3d")){
composite_3d(argv[2], argv[3], argv[4]);
} else if (0 == strcmp(argv[1], "test")){
test_resize(argv[2]);
} else if (0 == strcmp(argv[1], "captcha")){

View File

@ -271,7 +271,7 @@ void fill_truth_region(char *path, float *truth, int classes, int num_boxes, int
free(boxes);
}
void fill_truth_detection(char *path, float *truth, int classes, int flip, float dx, float dy, float sx, float sy)
void fill_truth_detection(char *path, int num_boxes, float *truth, int classes, int flip, float dx, float dy, float sx, float sy)
{
char *labelpath = find_replace(path, "images", "labels");
labelpath = find_replace(labelpath, "JPEGImages", "labels");
@ -283,7 +283,7 @@ void fill_truth_detection(char *path, float *truth, int classes, int flip, float
box_label *boxes = read_boxes(labelpath, &count);
randomize_boxes(boxes, count);
correct_boxes(boxes, count, dx, dy, sx, sy, flip);
if(count > 17) count = 17;
if(count > num_boxes) count = num_boxes;
float x,y,w,h;
int id;
int i;
@ -297,11 +297,11 @@ void fill_truth_detection(char *path, float *truth, int classes, int flip, float
if (w < .01 || h < .01) continue;
truth[i*5] = id;
truth[i*5+2] = x;
truth[i*5+3] = y;
truth[i*5+4] = w;
truth[i*5+5] = h;
truth[i*5+0] = id;
truth[i*5+1] = x;
truth[i*5+2] = y;
truth[i*5+3] = w;
truth[i*5+4] = h;
}
free(boxes);
}
@ -601,7 +601,7 @@ data load_data_swag(char **paths, int n, int classes, float jitter)
return d;
}
data load_data_detection(int n, int boxes, char **paths, int m, int w, int h, int classes, float jitter)
data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, float jitter)
{
char **random_paths = get_random_paths(paths, n, m);
int i;
@ -643,7 +643,7 @@ data load_data_detection(int n, int boxes, char **paths, int m, int w, int h, in
if(flip) flip_image(sized);
d.X.vals[i] = sized.data;
fill_truth_detection(random_paths[i], d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy);
fill_truth_detection(random_paths[i], boxes, d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy);
free_image(orig);
free_image(cropped);
@ -669,12 +669,12 @@ void *load_thread(void *ptr)
*a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.min, a.max, a.size);
} else if (a.type == STUDY_DATA){
*a.d = load_data_study(a.paths, a.n, a.m, a.labels, a.classes, a.min, a.max, a.size);
} else if (a.type == DETECTION_DATA){
*a.d = load_data_detection(a.n, a.num_boxes, a.paths, a.m, a.classes, a.w, a.h, a.background);
} else if (a.type == WRITING_DATA){
*a.d = load_data_writing(a.paths, a.n, a.m, a.w, a.h, a.out_w, a.out_h);
} else if (a.type == REGION_DATA){
*a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter);
} else if (a.type == DETECTION_DATA){
*a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter);
} else if (a.type == SWAG_DATA){
*a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);
} else if (a.type == COMPARE_DATA){

View File

@ -70,7 +70,7 @@ void print_letters(float *pred, int n);
data load_data_captcha(char **paths, int n, int m, int k, int w, int h);
data load_data_captcha_encode(char **paths, int n, int m, int w, int h);
data load_data(char **paths, int n, int m, char **labels, int k, int w, int h);
data load_data_detection(int n, int boxes, char **paths, int m, int w, int h, int classes, float jitter);
data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, float jitter);
data load_data_tag(char **paths, int n, int m, int k, int min, int max, int size);
data load_data_augment(char **paths, int n, int m, char **labels, int k, int min, int max, int size);
data load_data_study(char **paths, int n, int m, char **labels, int k, int min, int max, int size);

View File

@ -491,6 +491,8 @@ void show_image_cv(image p, const char *name)
int r = j + dy;
int c = i + dx;
float val = 0;
r = constrain_int(r, 0, im.h-1);
c = constrain_int(c, 0, im.w-1);
if (r >= 0 && r < im.h && c >= 0 && c < im.w) {
val = get_pixel(im, c, r, k);
}
@ -501,8 +503,75 @@ void show_image_cv(image p, const char *name)
return cropped;
}
image resize_min(image im, int min)
{
int best_3d_shift_r(image a, image b, int min, int max)
{
if(min == max) return min;
int mid = floor((min + max) / 2.);
image c1 = crop_image(b, 0, mid, b.w, b.h);
image c2 = crop_image(b, 0, mid+1, b.w, b.h);
float d1 = dist_array(c1.data, a.data, a.w*a.h*a.c, 10);
float d2 = dist_array(c2.data, a.data, a.w*a.h*a.c, 10);
free_image(c1);
free_image(c2);
if(d1 < d2) return best_3d_shift_r(a, b, min, mid);
else return best_3d_shift_r(a, b, mid+1, max);
}
int best_3d_shift(image a, image b, int min, int max)
{
int i;
int best = 0;
float best_distance = FLT_MAX;
for(i = min; i <= max; i += 2){
image c = crop_image(b, 0, i, b.w, b.h);
float d = dist_array(c.data, a.data, a.w*a.h*a.c, 100);
if(d < best_distance){
best_distance = d;
best = i;
}
printf("%d %f\n", i, d);
free_image(c);
}
return best;
}
void composite_3d(char *f1, char *f2, char *out)
{
if(!out) out = "out";
image a = load_image(f1, 0,0,0);
image b = load_image(f2, 0,0,0);
int shift = best_3d_shift_r(a, b, -a.h/100, a.h/100);
image c1 = crop_image(b, 10, shift, b.w, b.h);
float d1 = dist_array(c1.data, a.data, a.w*a.h*a.c, 100);
image c2 = crop_image(b, -10, shift, b.w, b.h);
float d2 = dist_array(c2.data, a.data, a.w*a.h*a.c, 100);
if(d2 < d1){
image swap = a;
a = b;
b = swap;
shift = -shift;
printf("swapped, %d\n", shift);
}
else{
printf("%d\n", shift);
}
image c = crop_image(b, 0, shift, a.w, a.h);
int i;
for(i = 0; i < c.w*c.h; ++i){
c.data[i] = a.data[i];
}
#ifdef OPENCV
save_image_jpg(c, out);
#else
save_image(c, out);
#endif
}
image resize_min(image im, int min)
{
int w = im.w;
int h = im.h;
if(w < h){
@ -515,10 +584,10 @@ void show_image_cv(image p, const char *name)
if(w == im.w && h == im.h) return im;
image resized = resize_image(im, w, h);
return resized;
}
}
image random_crop_image(image im, int low, int high, int size)
{
image random_crop_image(image im, int low, int high, int size)
{
int r = rand_int(low, high);
image resized = resize_min(im, r);
int dx = rand_int(0, resized.w - size);
@ -527,21 +596,21 @@ void show_image_cv(image p, const char *name)
if(resized.data != im.data) free_image(resized);
return crop;
}
}
float three_way_max(float a, float b, float c)
{
float three_way_max(float a, float b, float c)
{
return (a > b) ? ( (a > c) ? a : c) : ( (b > c) ? b : c) ;
}
}
float three_way_min(float a, float b, float c)
{
float three_way_min(float a, float b, float c)
{
return (a < b) ? ( (a < c) ? a : c) : ( (b < c) ? b : c) ;
}
}
// http://www.cs.rit.edu/~ncs/color/t_convert.html
void rgb_to_hsv(image im)
{
// http://www.cs.rit.edu/~ncs/color/t_convert.html
void rgb_to_hsv(image im)
{
assert(im.c == 3);
int i, j;
float r, g, b;
@ -574,10 +643,10 @@ void show_image_cv(image p, const char *name)
set_pixel(im, i, j, 2, v);
}
}
}
}
void hsv_to_rgb(image im)
{
void hsv_to_rgb(image im)
{
assert(im.c == 3);
int i, j;
float r, g, b;
@ -615,10 +684,10 @@ void show_image_cv(image p, const char *name)
set_pixel(im, i, j, 2, b);
}
}
}
}
image grayscale_image(image im)
{
image grayscale_image(image im)
{
assert(im.c == 3);
int i, j, k;
image gray = make_image(im.w, im.h, 1);
@ -631,20 +700,20 @@ void show_image_cv(image p, const char *name)
}
}
return gray;
}
}
image threshold_image(image im, float thresh)
{
image threshold_image(image im, float thresh)
{
int i;
image t = make_image(im.w, im.h, im.c);
for(i = 0; i < im.w*im.h*im.c; ++i){
t.data[i] = im.data[i]>thresh ? 1 : 0;
}
return t;
}
}
image blend_image(image fore, image back, float alpha)
{
image blend_image(image fore, image back, float alpha)
{
assert(fore.w == back.w && fore.h == back.h && fore.c == back.c);
image blend = make_image(fore.w, fore.h, fore.c);
int i, j, k;
@ -658,10 +727,10 @@ void show_image_cv(image p, const char *name)
}
}
return blend;
}
}
void scale_image_channel(image im, int c, float v)
{
void scale_image_channel(image im, int c, float v)
{
int i, j;
for(j = 0; j < im.h; ++j){
for(i = 0; i < im.w; ++i){
@ -670,10 +739,10 @@ void show_image_cv(image p, const char *name)
set_pixel(im, i, j, c, pix);
}
}
}
}
image binarize_image(image im)
{
image binarize_image(image im)
{
image c = copy_image(im);
int i;
for(i = 0; i < im.w * im.h * im.c; ++i){
@ -681,34 +750,34 @@ void show_image_cv(image p, const char *name)
else c.data[i] = 0;
}
return c;
}
}
void saturate_image(image im, float sat)
{
void saturate_image(image im, float sat)
{
rgb_to_hsv(im);
scale_image_channel(im, 1, sat);
hsv_to_rgb(im);
constrain_image(im);
}
}
void exposure_image(image im, float sat)
{
void exposure_image(image im, float sat)
{
rgb_to_hsv(im);
scale_image_channel(im, 2, sat);
hsv_to_rgb(im);
constrain_image(im);
}
}
void saturate_exposure_image(image im, float sat, float exposure)
{
void saturate_exposure_image(image im, float sat, float exposure)
{
rgb_to_hsv(im);
scale_image_channel(im, 1, sat);
scale_image_channel(im, 2, exposure);
hsv_to_rgb(im);
constrain_image(im);
}
}
/*
/*
image saturate_image(image im, float sat)
{
image gray = grayscale_image(im);
@ -725,8 +794,8 @@ void show_image_cv(image p, const char *name)
}
*/
float bilinear_interpolate(image im, float x, float y, int c)
{
float bilinear_interpolate(image im, float x, float y, int c)
{
int ix = (int) floorf(x);
int iy = (int) floorf(y);
@ -738,10 +807,10 @@ void show_image_cv(image p, const char *name)
(1-dy) * dx * get_pixel_extend(im, ix+1, iy, c) +
dy * dx * get_pixel_extend(im, ix+1, iy+1, c);
return val;
}
}
image resize_image(image im, int w, int h)
{
image resize_image(image im, int w, int h)
{
image resized = make_image(w, h, im.c);
image part = make_image(w, im.h, im.c);
int r, c, k;
@ -782,12 +851,12 @@ void show_image_cv(image p, const char *name)
free_image(part);
return resized;
}
}
#include "cuda.h"
void test_resize(char *filename)
{
void test_resize(char *filename)
{
image im = load_image(filename, 0,0, 3);
float mag = mag_array(im.data, im.w*im.h*im.c);
printf("L2 Norm: %f\n", mag);
@ -836,11 +905,11 @@ void show_image_cv(image p, const char *name)
#ifdef OPENCV
cvWaitKey(0);
#endif
}
}
#ifdef OPENCV
image ipl_to_image(IplImage* src)
{
image ipl_to_image(IplImage* src)
{
unsigned char *data = (unsigned char *)src->imageData;
int h = src->height;
int w = src->width;
@ -857,10 +926,10 @@ void show_image_cv(image p, const char *name)
}
}
return out;
}
}
image load_image_cv(char *filename, int channels)
{
image load_image_cv(char *filename, int channels)
{
IplImage* src = 0;
int flag = -1;
if (channels == 0) flag = -1;
@ -883,13 +952,13 @@ void show_image_cv(image p, const char *name)
cvReleaseImage(&src);
rgbgr_image(out);
return out;
}
}
#endif
image load_image_stb(char *filename, int channels)
{
image load_image_stb(char *filename, int channels)
{
int w, h, c;
unsigned char *data = stbi_load(filename, &w, &h, &c, channels);
if (!data) {
@ -910,10 +979,10 @@ void show_image_cv(image p, const char *name)
}
free(data);
return im;
}
}
image load_image(char *filename, int w, int h, int c)
{
image load_image(char *filename, int w, int h, int c)
{
#ifdef OPENCV
image out = load_image_cv(filename, c);
#else
@ -926,46 +995,46 @@ void show_image_cv(image p, const char *name)
out = resized;
}
return out;
}
}
image load_image_color(char *filename, int w, int h)
{
image load_image_color(char *filename, int w, int h)
{
return load_image(filename, w, h, 3);
}
}
image get_image_layer(image m, int l)
{
image get_image_layer(image m, int l)
{
image out = make_image(m.w, m.h, 1);
int i;
for(i = 0; i < m.h*m.w; ++i){
out.data[i] = m.data[i+l*m.h*m.w];
}
return out;
}
}
float get_pixel(image m, int x, int y, int c)
{
float get_pixel(image m, int x, int y, int c)
{
assert(x < m.w && y < m.h && c < m.c);
return m.data[c*m.h*m.w + y*m.w + x];
}
float get_pixel_extend(image m, int x, int y, int c)
{
}
float get_pixel_extend(image m, int x, int y, int c)
{
if(x < 0 || x >= m.w || y < 0 || y >= m.h || c < 0 || c >= m.c) return 0;
return get_pixel(m, x, y, c);
}
void set_pixel(image m, int x, int y, int c, float val)
{
}
void set_pixel(image m, int x, int y, int c, float val)
{
assert(x < m.w && y < m.h && c < m.c);
m.data[c*m.h*m.w + y*m.w + x] = val;
}
void add_pixel(image m, int x, int y, int c, float val)
{
}
void add_pixel(image m, int x, int y, int c, float val)
{
assert(x < m.w && y < m.h && c < m.c);
m.data[c*m.h*m.w + y*m.w + x] += val;
}
}
void print_image(image m)
{
void print_image(image m)
{
int i, j, k;
for(i =0 ; i < m.c; ++i){
for(j =0 ; j < m.h; ++j){
@ -979,10 +1048,10 @@ void show_image_cv(image p, const char *name)
printf("\n");
}
printf("\n");
}
}
image collapse_images_vert(image *ims, int n)
{
image collapse_images_vert(image *ims, int n)
{
int color = 1;
int border = 1;
int h,w,c;
@ -1014,10 +1083,10 @@ void show_image_cv(image p, const char *name)
free_image(copy);
}
return filters;
}
}
image collapse_images_horz(image *ims, int n)
{
image collapse_images_horz(image *ims, int n)
{
int color = 1;
int border = 1;
int h,w,c;
@ -1050,18 +1119,18 @@ void show_image_cv(image p, const char *name)
free_image(copy);
}
return filters;
}
}
void show_image_normalized(image im, const char *name)
{
void show_image_normalized(image im, const char *name)
{
image c = copy_image(im);
normalize_image(c);
show_image(c, name);
free_image(c);
}
}
void show_images(image *ims, int n, char *window)
{
void show_images(image *ims, int n, char *window)
{
image m = collapse_images_vert(ims, n);
/*
int w = 448;
@ -1078,9 +1147,9 @@ void show_image_cv(image p, const char *name)
show_image(sized, window);
free_image(sized);
free_image(m);
}
}
void free_image(image m)
{
void free_image(image m)
{
free(m.data);
}
}

View File

@ -44,6 +44,7 @@ void saturate_exposure_image(image im, float sat, float exposure);
void hsv_to_rgb(image im);
void rgbgr_image(image im);
void constrain_image(image im);
void composite_3d(char *f1, char *f2, char *out);
image grayscale_image(image im);
image threshold_image(image im, float thresh);

View File

@ -50,6 +50,7 @@ struct layer{
int h,w,c;
int out_h, out_w, out_c;
int n;
int max_boxes;
int groups;
int size;
int side;

View File

@ -137,6 +137,7 @@ network make_network(int n)
void forward_network(network net, network_state state)
{
state.workspace = net.workspace;
int i;
for(i = 0; i < net.n; ++i){
state.index = i;
@ -400,6 +401,7 @@ int resize_network(network *net, int w, int h)
net->w = w;
net->h = h;
int inputs = 0;
size_t workspace_size = 0;
//fprintf(stderr, "Resizing to %d x %d...", w, h);
//fflush(stderr);
for (i = 0; i < net->n; ++i){
@ -419,12 +421,20 @@ int resize_network(network *net, int w, int h)
}else{
error("Cannot resize this type of layer");
}
if(l.workspace_size > workspace_size) workspace_size = l.workspace_size;
inputs = l.outputs;
net->layers[i] = l;
w = l.out_w;
h = l.out_h;
if(l.type == AVGPOOL) break;
}
#ifdef GPU
cuda_free(net->workspace);
net->workspace = cuda_make_array(0, (workspace_size-1)/sizeof(float)+1);
#else
free(net->workspace);
net->workspace = calloc(1, (workspace_size-1)/sizeof(float)+1);
#endif
//fprintf(stderr, " Done!\n");
return 0;
}

View File

@ -257,6 +257,7 @@ detection_layer parse_detection(list *options, size_params params)
layer.softmax = option_find_int(options, "softmax", 0);
layer.sqrt = option_find_int(options, "sqrt", 0);
layer.max_boxes = option_find_int_quiet(options, "max",30);
layer.coord_scale = option_find_float(options, "coord_scale", 1);
layer.forced = option_find_int(options, "forced", 0);
layer.object_scale = option_find_float(options, "object_scale", 1);
@ -600,8 +601,11 @@ network parse_network_cfg(char *filename)
net.outputs = get_network_output_size(net);
net.output = get_network_output(net);
if(workspace_size){
//printf("%ld\n", workspace_size);
#ifdef GPU
net.workspace = cuda_make_array(0, (workspace_size-1)/sizeof(float)+1);
#else
net.workspace = calloc(1, workspace_size);
#endif
}
return net;

100
src/rnn.c
View File

@ -280,6 +280,104 @@ void test_char_rnn(char *cfgfile, char *weightfile, int num, char *seed, float t
printf("\n");
}
void test_tactic_rnn(char *cfgfile, char *weightfile, int num, char *seed, float temp, int rseed, char *token_file)
{
char **tokens = 0;
if(token_file){
size_t n;
tokens = read_tokens(token_file, &n);
}
srand(rseed);
char *base = basecfg(cfgfile);
fprintf(stderr, "%s\n", base);
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
int inputs = get_network_input_size(net);
int i, j;
for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
int c = 0;
int len = strlen(seed);
float *input = calloc(inputs, sizeof(float));
float *out;
while((c = getc(stdin)) != EOF){
input[c] = 1;
out = network_predict(net, input);
input[c] = 0;
}
for(i = 0; i < num; ++i){
for(j = 0; j < inputs; ++j){
if (out[j] < .0001) out[j] = 0;
}
int next = sample_array(out, inputs);
if(c == '.' && next == '\n') break;
c = next;
print_symbol(c, tokens);
input[c] = 1;
out = network_predict(net, input);
input[c] = 0;
}
printf("\n");
}
void valid_tactic_rnn(char *cfgfile, char *weightfile, char *seed)
{
char *base = basecfg(cfgfile);
fprintf(stderr, "%s\n", base);
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
int inputs = get_network_input_size(net);
int count = 0;
int words = 1;
int c;
int len = strlen(seed);
float *input = calloc(inputs, sizeof(float));
int i;
for(i = 0; i < len; ++i){
c = seed[i];
input[(int)c] = 1;
network_predict(net, input);
input[(int)c] = 0;
}
float sum = 0;
c = getc(stdin);
float log2 = log(2);
int in = 0;
while(c != EOF){
int next = getc(stdin);
if(next == EOF) break;
if(next < 0 || next >= 255) error("Out of range character");
input[c] = 1;
float *out = network_predict(net, input);
input[c] = 0;
if(c == '.' && next == '\n') in = 0;
if(!in) {
if(c == '>' && next == '>'){
in = 1;
++words;
}
c = next;
continue;
}
++count;
sum += log(out[next])/log2;
c = next;
printf("%d %d Perplexity: %4.4f Word Perplexity: %4.4f\n", count, words, pow(2, -sum/count), pow(2, -sum/words));
}
}
void valid_char_rnn(char *cfgfile, char *weightfile, char *seed)
{
char *base = basecfg(cfgfile);
@ -389,6 +487,8 @@ void run_char_rnn(int argc, char **argv)
char *weights = (argc > 4) ? argv[4] : 0;
if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename, clear, tokenized);
else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights, seed);
else if(0==strcmp(argv[2], "validtactic")) valid_tactic_rnn(cfg, weights, seed);
else if(0==strcmp(argv[2], "vec")) vec_char_rnn(cfg, weights, seed);
else if(0==strcmp(argv[2], "generate")) test_char_rnn(cfg, weights, len, seed, temp, rseed, tokens);
else if(0==strcmp(argv[2], "generatetactic")) test_tactic_rnn(cfg, weights, len, seed, temp, rseed, tokens);
}

View File

@ -424,6 +424,13 @@ float variance_array(float *a, int n)
return variance;
}
int constrain_int(int a, int min, int max)
{
if (a < min) return min;
if (a > max) return max;
return a;
}
float constrain(float min, float max, float a)
{
if (a < min) return min;
@ -431,6 +438,14 @@ float constrain(float min, float max, float a)
return a;
}
float dist_array(float *a, float *b, int n, int sub)
{
int i;
float sum = 0;
for(i = 0; i < n; i += sub) sum += pow(a[i]-b[i], 2);
return sqrt(sum);
}
float mse_array(float *a, int n)
{
int i;

View File

@ -36,6 +36,7 @@ void scale_array(float *a, int n, float s);
void translate_array(float *a, int n, float s);
int max_index(float *a, int n);
float constrain(float min, float max, float a);
int constrain_int(int a, int min, int max);
float mse_array(float *a, int n);
float rand_normal();
size_t rand_size_t();
@ -46,6 +47,7 @@ float mean_array(float *a, int n);
void mean_arrays(float **a, int n, int els, float *avg);
float variance_array(float *a, int n);
float mag_array(float *a, int n);
float dist_array(float *a, float *b, int n, int sub);
float **one_hot_encode(float *a, int n, int k);
float sec(clock_t clocks);
int find_int_arg(int argc, char **argv, char *arg, int def);