for dan, anyone else don't use, 🗑️ 🔥

This commit is contained in:
Joseph Redmon 2018-03-14 15:42:17 -07:00
parent b40bbdc7b2
commit 0b64cb4dd3
16 changed files with 305 additions and 69 deletions

View File

@ -97,5 +97,5 @@ results:
.PHONY: clean
clean:
rm -rf $(OBJS) $(SLIB) $(ALIB) $(EXEC) $(EXECOBJ)
rm -rf $(OBJS) $(SLIB) $(ALIB) $(EXEC) $(EXECOBJ) $(OBJDIR)/*

View File

@ -1,6 +1,7 @@
classes=21842
train = /data/imagenet/imagenet22k.train.list
valid = /data/imagenet/imagenet22k.valid.list
#valid = /data/imagenet/imagenet1k.valid.list
backup = /home/pjreddie/backup/
labels = data/imagenet.labels.list
names = data/imagenet.shortnames.list

View File

@ -47,6 +47,8 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus,
int tag = option_find_int_quiet(options, "tag", 0);
char *label_list = option_find_str(options, "labels", "data/labels.list");
char *train_list = option_find_str(options, "train", "data/train.list");
char *tree = option_find_str(options, "tree", 0);
if (tree) net->hierarchy = read_tree(tree);
int classes = option_find_int(options, "classes", 2);
char **labels;

View File

@ -188,6 +188,25 @@ void partial(char *cfgfile, char *weightfile, char *outfile, int max)
save_weights_upto(net, outfile, max);
}
void print_weights(char *cfgfile, char *weightfile, int n)
{
gpu_index = -1;
network *net = load_network(cfgfile, weightfile, 1);
layer l = net->layers[n];
int i, j;
//printf("[");
for(i = 0; i < l.n; ++i){
//printf("[");
for(j = 0; j < l.size*l.size*l.c; ++j){
//if(j > 0) printf(",");
printf("%g ", l.weights[i*l.size*l.size*l.c + j]);
}
printf("\n");
//printf("]%s\n", (i == l.n-1)?"":",");
}
//printf("]");
}
void rescale_net(char *cfgfile, char *weightfile, char *outfile)
{
gpu_index = -1;
@ -467,6 +486,8 @@ int main(int argc, char **argv)
oneoff(argv[2], argv[3], argv[4]);
} else if (0 == strcmp(argv[1], "oneoff2")){
oneoff2(argv[2], argv[3], argv[4], atoi(argv[5]));
} else if (0 == strcmp(argv[1], "print")){
print_weights(argv[2], argv[3], atoi(argv[4]));
} else if (0 == strcmp(argv[1], "partial")){
partial(argv[2], argv[3], argv[4], atoi(argv[5]));
} else if (0 == strcmp(argv[1], "average")){

View File

@ -395,7 +395,7 @@ void slerp(float *start, float *end, float s, int n, float *out)
scale_array(out, n, 1./mag);
}
image random_unit_vector_image(w, h, c)
image random_unit_vector_image(int w, int h, int c)
{
image im = make_image(w, h, c);
int i;
@ -480,13 +480,7 @@ void test_dcgan(char *cfgfile, char *weightfile)
char *input = buff;
int i, imlayer = 0;
for (i = 0; i < net->n; ++i) {
if (net->layers[i].out_c == 3) {
imlayer = i;
printf("%d\n", i);
break;
}
}
imlayer = net->n-1;
while(1){
image im = make_image(net->w, net->h, net->c);
@ -494,8 +488,8 @@ void test_dcgan(char *cfgfile, char *weightfile)
for(i = 0; i < im.w*im.h*im.c; ++i){
im.data[i] = rand_normal();
}
float mag = mag_array(im.data, im.w*im.h*im.c);
scale_array(im.data, im.w*im.h*im.c, 1./mag);
//float mag = mag_array(im.data, im.w*im.h*im.c);
//scale_array(im.data, im.w*im.h*im.c, 1./mag);
float *X = im.data;
time=clock();
@ -514,6 +508,173 @@ void test_dcgan(char *cfgfile, char *weightfile)
}
}
void set_network_alpha_beta(network *net, float alpha, float beta)
{
int i;
for(i = 0; i < net->n; ++i){
if(net->layers[i].type == SHORTCUT){
net->layers[i].alpha = alpha;
net->layers[i].beta = beta;
}
}
}
void train_prog(char *cfg, char *weight, char *acfg, char *aweight, int clear, int display, char *train_images, int maxbatch)
{
#ifdef GPU
char *backup_directory = "/home/pjreddie/backup/";
srand(time(0));
char *base = basecfg(cfg);
char *abase = basecfg(acfg);
printf("%s\n", base);
network *gnet = load_network(cfg, weight, clear);
network *anet = load_network(acfg, aweight, clear);
int i, j, k;
layer imlayer = gnet->layers[gnet->n-1];
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", gnet->learning_rate, gnet->momentum, gnet->decay);
int imgs = gnet->batch*gnet->subdivisions;
i = *gnet->seen/imgs;
data train, buffer;
list *plist = get_paths(train_images);
char **paths = (char **)list_to_array(plist);
load_args args= get_base_args(anet);
args.paths = paths;
args.n = imgs;
args.m = plist->size;
args.d = &buffer;
args.type = CLASSIFICATION_DATA;
args.threads=16;
args.classes = 1;
char *ls[2] = {"imagenet", "zzzzzzzz"};
args.labels = ls;
pthread_t load_thread = load_data_in_thread(args);
clock_t time;
gnet->train = 1;
anet->train = 1;
int x_size = gnet->inputs*gnet->batch;
int y_size = gnet->truths*gnet->batch;
float *imerror = cuda_make_array(0, y_size);
float aloss_avg = -1;
if (maxbatch == 0) maxbatch = gnet->max_batches;
while (get_current_batch(gnet) < maxbatch) {
{
int cb = get_current_batch(gnet);
float alpha = (float) cb / (maxbatch/2);
if(alpha > 1) alpha = 1;
float beta = 1 - alpha;
printf("%f %f\n", alpha, beta);
set_network_alpha_beta(gnet, alpha, beta);
set_network_alpha_beta(anet, beta, alpha);
}
i += 1;
time=clock();
pthread_join(load_thread, 0);
train = buffer;
load_thread = load_data_in_thread(args);
printf("Loaded: %lf seconds\n", sec(clock()-time));
data gen = copy_data(train);
for (j = 0; j < imgs; ++j) {
train.y.vals[j][0] = 1;
gen.y.vals[j][0] = 0;
}
time=clock();
for (j = 0; j < gnet->subdivisions; ++j) {
get_next_batch(train, gnet->batch, j*gnet->batch, gnet->truth, 0);
int z;
for(z = 0; z < x_size; ++z){
gnet->input[z] = rand_normal();
}
/*
for(z = 0; z < gnet->batch; ++z){
float mag = mag_array(gnet->input + z*gnet->inputs, gnet->inputs);
scale_array(gnet->input + z*gnet->inputs, gnet->inputs, 1./mag);
}
*/
*gnet->seen += gnet->batch;
forward_network(gnet);
fill_gpu(imlayer.outputs*imlayer.batch, 0, imerror, 1);
fill_cpu(anet->truths*anet->batch, 1, anet->truth, 1);
copy_cpu(anet->inputs*anet->batch, imlayer.output, 1, anet->input, 1);
anet->delta_gpu = imerror;
forward_network(anet);
backward_network(anet);
float genaloss = *anet->cost / anet->batch;
scal_gpu(imlayer.outputs*imlayer.batch, 1, imerror, 1);
scal_gpu(imlayer.outputs*imlayer.batch, 0, gnet->layers[gnet->n-1].delta_gpu, 1);
axpy_gpu(imlayer.outputs*imlayer.batch, 1, imerror, 1, gnet->layers[gnet->n-1].delta_gpu, 1);
backward_network(gnet);
for(k = 0; k < gnet->batch; ++k){
int index = j*gnet->batch + k;
copy_cpu(gnet->outputs, gnet->output + k*gnet->outputs, 1, gen.X.vals[index], 1);
}
}
harmless_update_network_gpu(anet);
data merge = concat_data(train, gen);
float aloss = train_network(anet, merge);
#ifdef OPENCV
if(display){
image im = float_to_image(anet->w, anet->h, anet->c, gen.X.vals[0]);
image im2 = float_to_image(anet->w, anet->h, anet->c, train.X.vals[0]);
show_image(im, "gen");
show_image(im2, "train");
save_image(im, "gen");
save_image(im2, "train");
cvWaitKey(1);
}
#endif
update_network_gpu(gnet);
free_data(merge);
free_data(train);
free_data(gen);
if (aloss_avg < 0) aloss_avg = aloss;
aloss_avg = aloss_avg*.9 + aloss*.1;
printf("%d: adv: %f | adv_avg: %f, %f rate, %lf seconds, %d images\n", i, aloss, aloss_avg, get_current_rate(gnet), sec(clock()-time), i*imgs);
if(i%10000==0){
char buff[256];
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
save_weights(gnet, buff);
sprintf(buff, "%s/%s_%d.weights", backup_directory, abase, i);
save_weights(anet, buff);
}
if(i%1000==0){
char buff[256];
sprintf(buff, "%s/%s.backup", backup_directory, base);
save_weights(gnet, buff);
sprintf(buff, "%s/%s.backup", backup_directory, abase);
save_weights(anet, buff);
}
}
char buff[256];
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
save_weights(gnet, buff);
#endif
}
void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear, int display, char *train_images, int maxbatch)
{
@ -668,7 +829,7 @@ void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear,
show_image(im2, "train");
save_image(im, "gen");
save_image(im2, "train");
cvWaitKey(50);
cvWaitKey(1);
}
#endif
@ -850,7 +1011,7 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle
image im2 = float_to_image(anet->w, anet->h, anet->c, train.X.vals[0]);
show_image(im, "gen");
show_image(im2, "train");
cvWaitKey(50);
cvWaitKey(1);
}
#endif
free_data(merge);
@ -1217,6 +1378,7 @@ void run_lsd(int argc, char **argv)
//else if(0==strcmp(argv[2], "traincolor")) train_colorizer(cfg, weights, acfg, aweights, clear);
//else if(0==strcmp(argv[2], "train3")) train_lsd3(argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], clear);
if(0==strcmp(argv[2], "traingan")) train_dcgan(cfg, weights, acfg, aweights, clear, display, file, batches);
else if(0==strcmp(argv[2], "trainprog")) train_prog(cfg, weights, acfg, aweights, clear, display, file, batches);
else if(0==strcmp(argv[2], "traincolor")) train_colorizer(cfg, weights, acfg, aweights, clear, display);
else if(0==strcmp(argv[2], "gan")) test_dcgan(cfg, weights);
else if(0==strcmp(argv[2], "inter")) inter_dcgan(cfg, weights);

View File

@ -51,6 +51,7 @@ typedef struct{
int *group_size;
int *group_offset;
} tree;
tree *read_tree(char *filename);
typedef enum{
LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN
@ -189,13 +190,17 @@ struct layer{
float class_scale;
int bias_match;
int random;
float ignore_thresh;
float truth_thresh;
float thresh;
float focus;
int classfix;
int absolute;
int onlyforward;
int stopbackward;
int dontload;
int dontsave;
int dontloadscales;
float temperature;
@ -790,5 +795,6 @@ void normalize_array(float *a, int n);
int *read_intlist(char *s, int *n, int d);
size_t rand_size_t();
float rand_normal();
float rand_uniform(float min, float max);
#endif

View File

@ -65,7 +65,7 @@ void weighted_delta_cpu(float *a, float *b, float *s, float *da, float *db, floa
}
}
void shortcut_cpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
void shortcut_cpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float s1, float s2, float *out)
{
int stride = w1/w2;
int sample = w2/w1;
@ -84,7 +84,7 @@ void shortcut_cpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2,
for(i = 0; i < minw; ++i){
int out_index = i*sample + w2*(j*sample + h2*(k + c2*b));
int add_index = i*stride + w1*(j*stride + h1*(k + c1*b));
out[out_index] += add[add_index];
out[out_index] = s1*out[out_index] + s2*add[add_index];
}
}
}
@ -331,7 +331,7 @@ void softmax_cpu(float *input, int n, int batch, int batch_offset, int groups, i
}
}
void upsample_cpu(float *in, int w, int h, int c, int batch, int stride, int forward, float *out)
void upsample_cpu(float *in, int w, int h, int c, int batch, int stride, int forward, float scale, float *out)
{
int i, j, k, b;
for(b = 0; b < batch; ++b){
@ -340,8 +340,8 @@ void upsample_cpu(float *in, int w, int h, int c, int batch, int stride, int for
for(i = 0; i < w*stride; ++i){
int in_index = b*w*h*c + k*w*h + (j/stride)*w + i/stride;
int out_index = b*w*h*c + k*w*h + j*w + i;
if(forward) out[out_index] = in[in_index];
else in[in_index] += out[out_index];
if(forward) out[out_index] = scale*in[in_index];
else in[in_index] += scale*out[out_index];
}
}
}

View File

@ -20,7 +20,7 @@ void pow_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY);
void mul_cpu(int N, float *X, int INCX, float *Y, int INCY);
int test_gpu_blas();
void shortcut_cpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out);
void shortcut_cpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float s1, float s2, float *out);
void mean_cpu(float *x, int batch, int filters, int spatial, float *mean);
void variance_cpu(float *x, float *mean, int batch, int filters, int spatial, float *variance);
@ -42,7 +42,7 @@ void weighted_delta_cpu(float *a, float *b, float *s, float *da, float *db, floa
void softmax(float *input, int n, float temp, int stride, float *output);
void softmax_cpu(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output);
void upsample_cpu(float *in, int w, int h, int c, int batch, int stride, int forward, float *out);
void upsample_cpu(float *in, int w, int h, int c, int batch, int stride, int forward, float scale, float *out);
#ifdef GPU
#include "cuda.h"
@ -72,7 +72,7 @@ void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *varianc
void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance);
void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean);
void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out);
void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float s1, float s2, float *out);
void scale_bias_gpu(float *output, float *biases, int batch, int n, int size);
void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates);
void scale_bias_gpu(float *output, float *biases, int batch, int n, int size);
@ -99,7 +99,7 @@ void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rat
void flatten_gpu(float *x, int spatial, int layers, int batch, int forward, float *out);
void softmax_tree(float *input, int spatial, int batch, int stride, float temp, float *output, tree hier);
void upsample_gpu(float *in, int w, int h, int c, int batch, int stride, int forward, float *out);
void upsample_gpu(float *in, int w, int h, int c, int batch, int stride, int forward, float scale, float *out);
#endif
#endif

View File

@ -708,7 +708,7 @@ extern "C" void fill_gpu(int N, float ALPHA, float * X, int INCX)
check_error(cudaPeekAtLastError());
}
__global__ void shortcut_kernel(int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
__global__ void shortcut_kernel(int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float s1, float s2, float *out)
{
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (id >= size) return;
@ -722,10 +722,11 @@ __global__ void shortcut_kernel(int size, int minw, int minh, int minc, int stri
int out_index = i*sample + w2*(j*sample + h2*(k + c2*b));
int add_index = i*stride + w1*(j*stride + h1*(k + c1*b));
out[out_index] += add[add_index];
out[out_index] = s1*out[out_index] + s2*add[add_index];
//out[out_index] += add[add_index];
}
extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float s1, float s2, float *out)
{
int minw = (w1 < w2) ? w1 : w2;
int minh = (h1 < h2) ? h1 : h2;
@ -739,7 +740,7 @@ extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int
if(sample < 1) sample = 1;
int size = batch * minw * minh * minc;
shortcut_kernel<<<cuda_gridsize(size), BLOCK>>>(size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out);
shortcut_kernel<<<cuda_gridsize(size), BLOCK>>>(size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, s1, s2, out);
check_error(cudaPeekAtLastError());
}
@ -1003,7 +1004,7 @@ extern "C" void softmax_gpu(float *input, int n, int batch, int batch_offset, in
}
__global__ void upsample_kernel(size_t N, float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
__global__ void upsample_kernel(size_t N, float *x, int w, int h, int c, int batch, int stride, int forward, float scale, float *out)
{
size_t i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(i >= N) return;
@ -1023,12 +1024,12 @@ __global__ void upsample_kernel(size_t N, float *x, int w, int h, int c, int bat
int in_index = b*w*h*c + in_c*w*h + in_h*w + in_w;
if(forward) out[out_index] += x[in_index];
else atomicAdd(x+in_index, out[out_index]);
if(forward) out[out_index] += scale * x[in_index];
else atomicAdd(x+in_index, scale * out[out_index]);
}
extern "C" void upsample_gpu(float *in, int w, int h, int c, int batch, int stride, int forward, float *out)
extern "C" void upsample_gpu(float *in, int w, int h, int c, int batch, int stride, int forward, float scale, float *out)
{
size_t size = w*h*c*batch*stride*stride;
upsample_kernel<<<cuda_gridsize(size), BLOCK>>>(size, in, w, h, c, batch, stride, forward, out);
upsample_kernel<<<cuda_gridsize(size), BLOCK>>>(size, in, w, h, c, batch, stride, forward, scale, out);
check_error(cudaPeekAtLastError());
}

View File

@ -556,9 +556,18 @@ matrix load_regression_labels_paths(char **paths, int n, int k)
char labelpath[4096];
find_replace(paths[i], "images", "labels", labelpath);
find_replace(labelpath, "JPEGImages", "labels", labelpath);
find_replace(labelpath, ".jpg", ".txt", labelpath);
find_replace(labelpath, ".BMP", ".txt", labelpath);
find_replace(labelpath, ".JPEG", ".txt", labelpath);
find_replace(labelpath, ".JPG", ".txt", labelpath);
find_replace(labelpath, ".JPeG", ".txt", labelpath);
find_replace(labelpath, ".Jpeg", ".txt", labelpath);
find_replace(labelpath, ".PNG", ".txt", labelpath);
find_replace(labelpath, ".TIF", ".txt", labelpath);
find_replace(labelpath, ".bmp", ".txt", labelpath);
find_replace(labelpath, ".jpeg", ".txt", labelpath);
find_replace(labelpath, ".jpg", ".txt", labelpath);
find_replace(labelpath, ".png", ".txt", labelpath);
find_replace(labelpath, ".tif", ".txt", labelpath);
FILE *file = fopen(labelpath, "r");
for(j = 0; j < k; ++j){

View File

@ -313,7 +313,8 @@ layer parse_region(list *options, size_params params)
l.jitter = option_find_float(options, "jitter", .2);
l.rescore = option_find_int_quiet(options, "rescore",0);
l.thresh = option_find_float(options, "thresh", .5);
l.ignore_thresh = option_find_float(options, "ignore_thresh", .5);
l.truth_thresh = option_find_float(options, "truth_thresh", 1);
l.classfix = option_find_int_quiet(options, "classfix", 0);
l.absolute = option_find_int_quiet(options, "absolute", 0);
l.random = option_find_int_quiet(options, "random", 0);
@ -324,6 +325,7 @@ layer parse_region(list *options, size_params params)
l.mask_scale = option_find_float_quiet(options, "mask_scale", 1);
l.class_scale = option_find_float(options, "class_scale", 1);
l.bias_match = option_find_int_quiet(options, "bias_match",0);
l.focus = option_find_float_quiet(options, "focus", 0);
char *tree_file = option_find_str(options, "tree", 0);
if (tree_file) l.softmax_tree = read_tree(tree_file);
@ -494,6 +496,8 @@ layer parse_shortcut(list *options, size_params params, network *net)
char *activation_s = option_find_str(options, "activation", "linear");
ACTIVATION activation = get_activation(activation_s);
s.activation = activation;
s.alpha = option_find_float_quiet(options, "alpha", 1);
s.beta = option_find_float_quiet(options, "beta", 1);
return s;
}
@ -536,6 +540,7 @@ layer parse_upsample(list *options, size_params params, network *net)
int stride = option_find_int(options, "stride",2);
layer l = make_upsample_layer(params.batch, params.w, params.h, params.c, stride);
l.scale = option_find_float_quiet(options, "scale", 1);
return l;
}
@ -778,6 +783,7 @@ network *parse_network_cfg(char *filename)
l.truth = option_find_int_quiet(options, "truth", 0);
l.onlyforward = option_find_int_quiet(options, "onlyforward", 0);
l.stopbackward = option_find_int_quiet(options, "stopbackward", 0);
l.dontsave = option_find_int_quiet(options, "dontsave", 0);
l.dontload = option_find_int_quiet(options, "dontload", 0);
l.dontloadscales = option_find_int_quiet(options, "dontloadscales", 0);
l.learning_rate_scale = option_find_float_quiet(options, "learning_rate", 1);
@ -961,6 +967,7 @@ void save_weights_upto(network *net, char *filename, int cutoff)
int i;
for(i = 0; i < net->n && i < cutoff; ++i){
layer l = net->layers[i];
if (l.dontsave) continue;
if(l.type == CONVOLUTIONAL || l.type == DECONVOLUTIONAL){
save_convolutional_weights(l, fp);
} if(l.type == CONNECTED){

View File

@ -124,7 +124,7 @@ void delta_region_mask(float *truth, float *x, int n, int index, float *delta, i
}
void delta_region_class(float *output, float *delta, int index, int class, int classes, tree *hier, float scale, int stride, float *avg_cat, int tag)
void delta_region_class(float *output, float *delta, int index, int class, int classes, tree *hier, float scale, int stride, float *avg_cat, int tag, float focus)
{
int i, n;
if(hier){
@ -140,15 +140,30 @@ void delta_region_class(float *output, float *delta, int index, int class, int c
class = hier->parent[class];
}
*avg_cat += pred;
if(avg_cat) *avg_cat += pred;
} else {
if (delta[index] && tag){
delta[index + stride*class] = scale * (1 - output[index + stride*class]);
if(focus){
float y = -1;
float p = output[index + stride*class];
float lg = p > .0000000001 ? log(p) : -10;
delta[index + stride*class] = y * pow(1-p, focus) * (focus*p*lg + p - 1);
}else{
delta[index + stride*class] = scale * (1 - output[index + stride*class]);
if(avg_cat) *avg_cat += output[index + stride*class];
}
return;
}
for(n = 0; n < classes; ++n){
delta[index + stride*n] = scale * (((n == class)?1 : 0) - output[index + stride*n]);
if(n == class) *avg_cat += output[index + stride*n];
if(focus){
float y = (n == class) ? -1 : 1;
float p = (n == class) ? output[index + stride*n] : 1 - output[index + stride*n];
float lg = p > .0000000001 ? log(p) : -10;
delta[index + stride*n] = y * pow(1-p, focus) * (focus*p*lg + p - 1);
}else{
delta[index + stride*n] = scale * (((n == class)?1 : 0) - output[index + stride*n]);
}
if(n == class && avg_cat) *avg_cat += output[index + stride*n];
}
}
}
@ -204,6 +219,7 @@ void forward_region_layer(const layer l, network net)
if(!net.train) return;
float avg_iou = 0;
float recall = 0;
float recall75 = 0;
float avg_cat = 0;
float avg_obj = 0;
float avg_anyobj = 0;
@ -233,7 +249,7 @@ void forward_region_layer(const layer l, network net)
}
int class_index = entry_index(l, b, maxi, l.coords + 1);
int obj_index = entry_index(l, b, maxi, l.coords);
delta_region_class(l.output, l.delta, class_index, class, l.classes, l.softmax_tree, l.class_scale, l.w*l.h, &avg_cat, !l.softmax);
delta_region_class(l.output, l.delta, class_index, class, l.classes, l.softmax_tree, l.class_scale, l.w*l.h, &avg_cat, !l.softmax, l.focus);
if(l.output[obj_index] < .3) l.delta[obj_index] = l.object_scale * (.3 - l.output[obj_index]);
else l.delta[obj_index] = 0;
l.delta[obj_index] = 0;
@ -250,32 +266,44 @@ void forward_region_layer(const layer l, network net)
int box_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, 0);
box pred = get_region_box(l.output, l.biases, l.mask[n], box_index, i, j, l.w, l.h, net.w, net.h, l.w*l.h);
float best_iou = 0;
int best_t = 0;
for(t = 0; t < l.max_boxes; ++t){
box truth = float_to_box(net.truth + t*(l.coords + 1) + b*l.truths, 1);
if(!truth.x) break;
float iou = box_iou(pred, truth);
if (iou > best_iou) {
best_iou = iou;
best_t = t;
}
}
int obj_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, l.coords);
avg_anyobj += l.output[obj_index];
l.delta[obj_index] = l.noobject_scale * (0 - l.output[obj_index]);
if(l.background) l.delta[obj_index] = l.noobject_scale * (1 - l.output[obj_index]);
if (best_iou > l.thresh) {
if (best_iou > l.ignore_thresh) {
l.delta[obj_index] = 0;
}
if (best_iou > l.truth_thresh) {
l.delta[obj_index] = l.object_scale * (1 - l.output[obj_index]);
/*
if(*(net.seen) < 12800){
box truth = {0};
truth.x = (i + .5)/l.w;
truth.y = (j + .5)/l.h;
truth.w = l.biases[2*l.mask[n]]/net.w;
truth.h = l.biases[2*l.mask[n]+1]/net.h;
delta_region_box(truth, l.output, l.biases, l.mask[n], box_index, i, j, l.w, l.h, net.w, net.h, l.delta, .01, l.w*l.h);
int class = net.truth[best_t*(l.coords + 1) + b*l.truths + l.coords];
if (l.map) class = l.map[class];
int class_index = entry_index(l, b, n*l.w*l.h + j*l.w + i, l.coords + 1);
delta_region_class(l.output, l.delta, class_index, class, l.classes, l.softmax_tree, l.class_scale, l.w*l.h, 0, !l.softmax, l.focus);
box truth = float_to_box(net.truth + best_t*(l.coords + 1) + b*l.truths, 1);
delta_region_box(truth, l.output, l.biases, l.mask[n], box_index, i, j, l.w, l.h, net.w, net.h, l.delta, l.coord_scale*(2-truth.w*truth.h), l.w*l.h);
}
*/
/*
if(*(net.seen) < 12800){
box truth = {0};
truth.x = (i + .5)/l.w;
truth.y = (j + .5)/l.h;
truth.w = l.biases[2*l.mask[n]]/net.w;
truth.h = l.biases[2*l.mask[n]+1]/net.h;
delta_region_box(truth, l.output, l.biases, l.mask[n], box_index, i, j, l.w, l.h, net.w, net.h, l.delta, .01, l.w*l.h);
}
*/
}
}
}
@ -309,12 +337,13 @@ void forward_region_layer(const layer l, network net)
//printf("%d %d\n", best_n, mask_n);
if(mask_n >= 0){
int box_index = entry_index(l, b, mask_n*l.w*l.h + j*l.w + i, 0);
float iou = delta_region_box(truth, l.output, l.biases, best_n, box_index, i, j, l.w, l.h, net.w, net.h, l.delta, l.coord_scale * (2 - truth.w*truth.h), l.w*l.h);
float iou = delta_region_box(truth, l.output, l.biases, best_n, box_index, i, j, l.w, l.h, net.w, net.h, l.delta, l.coord_scale*(2-truth.w*truth.h), l.w*l.h);
if(l.coords > 4){
int mask_index = entry_index(l, b, mask_n*l.w*l.h + j*l.w + i, 4);
delta_region_mask(net.truth + t*(l.coords + 1) + b*l.truths + 5, l.output, l.coords - 4, mask_index, l.delta, l.w*l.h, l.mask_scale);
}
if(iou > .5) recall += 1;
if(iou > .75) recall75 += 1;
avg_iou += iou;
//l.delta[best_index + 4] = iou - l.output[best_index + 4];
@ -331,7 +360,7 @@ void forward_region_layer(const layer l, network net)
int class = net.truth[t*(l.coords + 1) + b*l.truths + l.coords];
if (l.map) class = l.map[class];
int class_index = entry_index(l, b, mask_n*l.w*l.h + j*l.w + i, l.coords + 1);
delta_region_class(l.output, l.delta, class_index, class, l.classes, l.softmax_tree, l.class_scale, l.w*l.h, &avg_cat, !l.softmax);
delta_region_class(l.output, l.delta, class_index, class, l.classes, l.softmax_tree, l.class_scale, l.w*l.h, &avg_cat, !l.softmax, l.focus);
++count;
++class_count;
}
@ -339,7 +368,7 @@ void forward_region_layer(const layer l, network net)
}
//printf("\n");
*(l.cost) = pow(mag_array(l.delta, l.outputs * l.batch), 2);
printf("Region %d Avg IOU: %f, Class: %f, Obj: %f, No Obj: %f, Avg Recall: %f, count: %d\n", net.index, avg_iou/count, avg_cat/class_count, avg_obj/count, avg_anyobj/(l.w*l.h*l.n*l.batch), recall/count, count);
printf("Region %d Avg IOU: %f, Class: %f, Obj: %f, No Obj: %f, .5R: %f, .75R: %f, count: %d\n", net.index, avg_iou/count, avg_cat/class_count, avg_obj/count, avg_anyobj/(l.w*l.h*l.n*l.batch), recall/count, recall75/count, count);
}
void backward_region_layer(const layer l, network net)
@ -576,7 +605,7 @@ void backward_region_layer_gpu(const layer l, network net)
for (b = 0; b < l.batch; ++b){
for(n = 0; n < l.n; ++n){
int index = entry_index(l, b, n*l.w*l.h, 0);
gradient_array_gpu(l.output_gpu + index, 2*l.w*l.h, LOGISTIC, l.delta_gpu + index);
//gradient_array_gpu(l.output_gpu + index, 2*l.w*l.h, LOGISTIC, l.delta_gpu + index);
if(l.coords > 4){
index = entry_index(l, b, n*l.w*l.h, 4);
gradient_array_gpu(l.output_gpu + index, (l.coords - 4)*l.w*l.h, LOGISTIC, l.delta_gpu + index);

View File

@ -62,29 +62,29 @@ void resize_shortcut_layer(layer *l, int w, int h)
void forward_shortcut_layer(const layer l, network net)
{
copy_cpu(l.outputs*l.batch, net.input, 1, l.output, 1);
shortcut_cpu(l.batch, l.w, l.h, l.c, net.layers[l.index].output, l.out_w, l.out_h, l.out_c, l.output);
shortcut_cpu(l.batch, l.w, l.h, l.c, net.layers[l.index].output, l.out_w, l.out_h, l.out_c, l.alpha, l.beta, l.output);
activate_array(l.output, l.outputs*l.batch, l.activation);
}
void backward_shortcut_layer(const layer l, network net)
{
gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);
axpy_cpu(l.outputs*l.batch, 1, l.delta, 1, net.delta, 1);
shortcut_cpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta, l.w, l.h, l.c, net.layers[l.index].delta);
axpy_cpu(l.outputs*l.batch, l.alpha, l.delta, 1, net.delta, 1);
shortcut_cpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta, l.w, l.h, l.c, 1, l.beta, net.layers[l.index].delta);
}
#ifdef GPU
void forward_shortcut_layer_gpu(const layer l, network net)
{
copy_gpu(l.outputs*l.batch, net.input_gpu, 1, l.output_gpu, 1);
shortcut_gpu(l.batch, l.w, l.h, l.c, net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.output_gpu);
shortcut_gpu(l.batch, l.w, l.h, l.c, net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.alpha, l.beta, l.output_gpu);
activate_array_gpu(l.output_gpu, l.outputs*l.batch, l.activation);
}
void backward_shortcut_layer_gpu(const layer l, network net)
{
gradient_array_gpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);
axpy_gpu(l.outputs*l.batch, 1, l.delta_gpu, 1, net.delta_gpu, 1);
shortcut_gpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta_gpu, l.w, l.h, l.c, net.layers[l.index].delta_gpu);
axpy_gpu(l.outputs*l.batch, l.alpha, l.delta_gpu, 1, net.delta_gpu, 1);
shortcut_gpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta_gpu, l.w, l.h, l.c, 1, l.beta, net.layers[l.index].delta_gpu);
}
#endif

View File

@ -2,7 +2,6 @@
#define TREE_H
#include "darknet.h"
tree *read_tree(char *filename);
int hierarchy_top_prediction(float *predictions, tree *hier, float thresh, int stride);
float get_hierarchy_probability(float *x, tree *hier, int c, int stride);

View File

@ -69,18 +69,18 @@ void forward_upsample_layer(const layer l, network net)
{
fill_cpu(l.outputs*l.batch, 0, l.output, 1);
if(l.reverse){
upsample_cpu(l.output, l.out_w, l.out_h, l.c, l.batch, l.stride, 0, net.input);
upsample_cpu(l.output, l.out_w, l.out_h, l.c, l.batch, l.stride, 0, l.scale, net.input);
}else{
upsample_cpu(net.input, l.w, l.h, l.c, l.batch, l.stride, 1, l.output);
upsample_cpu(net.input, l.w, l.h, l.c, l.batch, l.stride, 1, l.scale, l.output);
}
}
void backward_upsample_layer(const layer l, network net)
{
if(l.reverse){
upsample_cpu(l.delta, l.out_w, l.out_h, l.c, l.batch, l.stride, 1, net.delta);
upsample_cpu(l.delta, l.out_w, l.out_h, l.c, l.batch, l.stride, 1, l.scale, net.delta);
}else{
upsample_cpu(net.delta, l.w, l.h, l.c, l.batch, l.stride, 0, l.delta);
upsample_cpu(net.delta, l.w, l.h, l.c, l.batch, l.stride, 0, l.scale, l.delta);
}
}
@ -89,18 +89,18 @@ void forward_upsample_layer_gpu(const layer l, network net)
{
fill_gpu(l.outputs*l.batch, 0, l.output_gpu, 1);
if(l.reverse){
upsample_gpu(l.output_gpu, l.out_w, l.out_h, l.c, l.batch, l.stride, 0, net.input_gpu);
upsample_gpu(l.output_gpu, l.out_w, l.out_h, l.c, l.batch, l.stride, 0, l.scale, net.input_gpu);
}else{
upsample_gpu(net.input_gpu, l.w, l.h, l.c, l.batch, l.stride, 1, l.output_gpu);
upsample_gpu(net.input_gpu, l.w, l.h, l.c, l.batch, l.stride, 1, l.scale, l.output_gpu);
}
}
void backward_upsample_layer_gpu(const layer l, network net)
{
if(l.reverse){
upsample_gpu(l.delta_gpu, l.out_w, l.out_h, l.c, l.batch, l.stride, 1, net.delta_gpu);
upsample_gpu(l.delta_gpu, l.out_w, l.out_h, l.c, l.batch, l.stride, 1, l.scale, net.delta_gpu);
}else{
upsample_gpu(net.delta_gpu, l.w, l.h, l.c, l.batch, l.stride, 0, l.delta_gpu);
upsample_gpu(net.delta_gpu, l.w, l.h, l.c, l.batch, l.stride, 0, l.scale, l.delta_gpu);
}
}
#endif

View File

@ -40,7 +40,6 @@ float *parse_fields(char *line, int n);
void translate_array(float *a, int n, float s);
float constrain(float min, float max, float a);
int constrain_int(int a, int min, int max);
float rand_uniform(float min, float max);
float rand_scale(float s);
int rand_int(int min, int max);
void mean_arrays(float **a, int n, int els, float *avg);