mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
opencv is hell. this is why we can't have nice things.
This commit is contained in:
parent
38802ef56a
commit
179ed8ec76
@ -1,6 +1,10 @@
|
||||
[net]
|
||||
batch=64
|
||||
subdivisions=8
|
||||
# Testing
|
||||
batch=1
|
||||
subdivisions=1
|
||||
# Training
|
||||
# batch=64
|
||||
# subdivisions=8
|
||||
height=416
|
||||
width=416
|
||||
channels=3
|
||||
|
@ -1,6 +1,10 @@
|
||||
[net]
|
||||
batch=64
|
||||
subdivisions=8
|
||||
# Testing
|
||||
batch=1
|
||||
subdivisions=1
|
||||
# Training
|
||||
# batch=64
|
||||
# subdivisions=8
|
||||
height=416
|
||||
width=416
|
||||
channels=3
|
||||
|
@ -6,12 +6,6 @@
|
||||
#include "classifier.h"
|
||||
#include <sys/time.h>
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
image get_image_from_stream(CvCapture *cap);
|
||||
#endif
|
||||
|
||||
|
||||
void demo_art(char *cfgfile, char *weightfile, int cam_index)
|
||||
{
|
||||
#ifdef OPENCV
|
||||
|
@ -4,10 +4,6 @@
|
||||
#include "option_list.h"
|
||||
#include "blas.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#endif
|
||||
|
||||
void train_cifar(char *cfgfile, char *weightfile)
|
||||
{
|
||||
srand(time(0));
|
||||
|
@ -8,11 +8,6 @@
|
||||
#include "cuda.h"
|
||||
#include <sys/time.h>
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
image get_image_from_stream(CvCapture *cap);
|
||||
#endif
|
||||
|
||||
float *get_regression_values(char **labels, int n)
|
||||
{
|
||||
float *v = calloc(n, sizeof(float));
|
||||
|
@ -8,10 +8,6 @@
|
||||
#include "box.h"
|
||||
#include "demo.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#endif
|
||||
|
||||
char *coco_classes[] = {"person","bicycle","car","motorcycle","airplane","bus","train","truck","boat","traffic light","fire hydrant","stop sign","parking meter","bench","bird","cat","dog","horse","sheep","cow","elephant","bear","zebra","giraffe","backpack","umbrella","handbag","tie","suitcase","frisbee","skis","snowboard","sports ball","kite","baseball bat","baseball glove","skateboard","surfboard","tennis racket","bottle","wine glass","cup","fork","knife","spoon","bowl","banana","apple","sandwich","orange","broccoli","carrot","hot dog","pizza","donut","cake","chair","couch","potted plant","bed","dining table","toilet","tv","laptop","mouse","remote","keyboard","cell phone","microwave","oven","toaster","sink","refrigerator","book","clock","vase","scissors","teddy bear","hair drier","toothbrush"};
|
||||
|
||||
int coco_ids[] = {1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90};
|
||||
|
@ -97,6 +97,12 @@ connected_layer make_connected_layer(int batch, int inputs, int outputs, ACTIVAT
|
||||
|
||||
l.x_gpu = cuda_make_array(l.output, l.batch*outputs);
|
||||
l.x_norm_gpu = cuda_make_array(l.output, l.batch*outputs);
|
||||
#ifdef CUDNN
|
||||
cudnnCreateTensorDescriptor(&l.normTensorDesc);
|
||||
cudnnCreateTensorDescriptor(&l.dstTensorDesc);
|
||||
cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
|
||||
cudnnSetTensor4dDescriptor(l.normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l.out_c, 1, 1);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
l.activation = activation;
|
||||
@ -213,11 +219,11 @@ void statistics_connected_layer(layer l)
|
||||
printf("Scales ");
|
||||
print_statistics(l.scales, l.outputs);
|
||||
/*
|
||||
printf("Rolling Mean ");
|
||||
print_statistics(l.rolling_mean, l.outputs);
|
||||
printf("Rolling Variance ");
|
||||
print_statistics(l.rolling_variance, l.outputs);
|
||||
*/
|
||||
printf("Rolling Mean ");
|
||||
print_statistics(l.rolling_mean, l.outputs);
|
||||
printf("Rolling Variance ");
|
||||
print_statistics(l.rolling_variance, l.outputs);
|
||||
*/
|
||||
}
|
||||
printf("Biases ");
|
||||
print_statistics(l.biases, l.outputs);
|
||||
|
@ -8,10 +8,6 @@
|
||||
#include "blas.h"
|
||||
#include "connected_layer.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#endif
|
||||
|
||||
extern void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename, int top);
|
||||
extern void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh, float hier_thresh);
|
||||
extern void run_voxel(int argc, char **argv);
|
||||
|
@ -1017,7 +1017,7 @@ void get_next_batch(data d, int n, int offset, float *X, float *y)
|
||||
for(j = 0; j < n; ++j){
|
||||
int index = offset + j;
|
||||
memcpy(X+j*d.X.cols, d.X.vals[index], d.X.cols*sizeof(float));
|
||||
memcpy(y+j*d.y.cols, d.y.vals[index], d.y.cols*sizeof(float));
|
||||
if(y) memcpy(y+j*d.y.cols, d.y.vals[index], d.y.cols*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -12,9 +12,6 @@
|
||||
#define FRAMES 3
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#include "opencv2/imgproc/imgproc_c.h"
|
||||
image get_image_from_stream(CvCapture *cap);
|
||||
|
||||
static char **demo_names;
|
||||
static image **demo_alphabet;
|
||||
|
@ -8,9 +8,6 @@
|
||||
#include "option_list.h"
|
||||
#include "blas.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#endif
|
||||
static int coco_ids[] = {1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90};
|
||||
|
||||
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
|
||||
|
10
src/go.c
10
src/go.c
@ -6,10 +6,6 @@
|
||||
#include "data.h"
|
||||
#include <unistd.h>
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#endif
|
||||
|
||||
int inverted = 1;
|
||||
int noi = 1;
|
||||
static const int nind = 2;
|
||||
@ -125,7 +121,7 @@ data random_go_moves(moves m, int n)
|
||||
}
|
||||
|
||||
|
||||
void train_go(char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
|
||||
void train_go(char *cfgfile, char *weightfile, char *filename, int *gpus, int ngpus, int clear)
|
||||
{
|
||||
int i;
|
||||
float avg_loss = -1;
|
||||
@ -150,7 +146,7 @@ void train_go(char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
|
||||
char *backup_directory = "/home/pjreddie/backup/";
|
||||
|
||||
char buff[256];
|
||||
moves m = load_go_moves("/home/pjreddie/backup/go.train");
|
||||
moves m = load_go_moves(filename);
|
||||
//moves m = load_go_moves("games.txt");
|
||||
|
||||
int N = m.n;
|
||||
@ -909,7 +905,7 @@ void run_go(int argc, char **argv)
|
||||
char *c2 = (argc > 5) ? argv[5] : 0;
|
||||
char *w2 = (argc > 6) ? argv[6] : 0;
|
||||
int multi = find_arg(argc, argv, "-multi");
|
||||
if(0==strcmp(argv[2], "train")) train_go(cfg, weights, gpus, ngpus, clear);
|
||||
if(0==strcmp(argv[2], "train")) train_go(cfg, weights, c2, gpus, ngpus, clear);
|
||||
else if(0==strcmp(argv[2], "valid")) valid_go(cfg, weights, multi, c2);
|
||||
else if(0==strcmp(argv[2], "self")) self_go(cfg, weights, c2, w2, multi);
|
||||
else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
|
||||
|
@ -70,6 +70,15 @@ layer make_gru_layer(int batch, int inputs, int outputs, int steps, int batch_no
|
||||
*(l.state_h_layer) = make_connected_layer(batch*steps, outputs, outputs, LINEAR, batch_normalize);
|
||||
l.state_h_layer->batch = batch;
|
||||
|
||||
#ifdef CUDNN
|
||||
cudnnSetTensor4dDescriptor(l.input_z_layer->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch, l.input_z_layer->out_c, l.input_z_layer->out_h, l.input_z_layer->out_w);
|
||||
cudnnSetTensor4dDescriptor(l.input_h_layer->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch, l.input_h_layer->out_c, l.input_h_layer->out_h, l.input_h_layer->out_w);
|
||||
cudnnSetTensor4dDescriptor(l.input_r_layer->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch, l.input_r_layer->out_c, l.input_r_layer->out_h, l.input_r_layer->out_w);
|
||||
cudnnSetTensor4dDescriptor(l.state_z_layer->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch, l.state_z_layer->out_c, l.state_z_layer->out_h, l.state_z_layer->out_w);
|
||||
cudnnSetTensor4dDescriptor(l.state_h_layer->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch, l.state_h_layer->out_c, l.state_h_layer->out_h, l.state_h_layer->out_w);
|
||||
cudnnSetTensor4dDescriptor(l.state_r_layer->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch, l.state_r_layer->out_c, l.state_r_layer->out_h, l.state_r_layer->out_w);
|
||||
#endif
|
||||
|
||||
l.batch_normalize = batch_normalize;
|
||||
|
||||
|
||||
|
@ -10,12 +10,6 @@
|
||||
#define STB_IMAGE_WRITE_IMPLEMENTATION
|
||||
#include "stb_image_write.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#include "opencv2/imgproc/imgproc_c.h"
|
||||
#endif
|
||||
|
||||
|
||||
int windows = 0;
|
||||
|
||||
float colors[6][3] = { {1,0,1}, {0,0,1},{0,1,1},{0,1,0},{1,1,0},{1,0,0} };
|
||||
|
18
src/image.h
18
src/image.h
@ -8,6 +8,17 @@
|
||||
#include <math.h>
|
||||
#include "box.h"
|
||||
|
||||
#ifndef __cplusplus
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#include "opencv2/imgproc/imgproc_c.h"
|
||||
#include "opencv2/core/version.hpp"
|
||||
#if CV_MAJOR_VERSION == 3
|
||||
#include "opencv2/videoio/videoio_c.h"
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
int h;
|
||||
int w;
|
||||
@ -15,6 +26,13 @@ typedef struct {
|
||||
float *data;
|
||||
} image;
|
||||
|
||||
#ifndef __cplusplus
|
||||
#ifdef OPENCV
|
||||
image get_image_from_stream(CvCapture *cap);
|
||||
image ipl_to_image(IplImage* src);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
float get_color(int c, int x, int max);
|
||||
void flip_image(image a);
|
||||
void draw_box(image a, int x1, int y1, int x2, int y2, float r, float g, float b);
|
||||
|
287
src/lsd.c
287
src/lsd.c
@ -4,10 +4,6 @@
|
||||
#include "parser.h"
|
||||
#include "blas.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#endif
|
||||
|
||||
void train_lsd3(char *fcfg, char *fweight, char *gcfg, char *gweight, char *acfg, char *aweight, int clear)
|
||||
{
|
||||
#ifdef GPU
|
||||
@ -75,9 +71,7 @@ void train_lsd3(char *fcfg, char *fweight, char *gcfg, char *gweight, char *acfg
|
||||
float *y = calloc(y_size, sizeof(float));
|
||||
|
||||
float *ones = cuda_make_array(0, anet.batch);
|
||||
float *zeros = cuda_make_array(0, anet.batch);
|
||||
fill_ongpu(anet.batch, .99, ones, 1);
|
||||
fill_ongpu(anet.batch, .01, zeros, 1);
|
||||
fill_ongpu(anet.batch, .9, ones, 1);
|
||||
|
||||
network_state astate = {0};
|
||||
astate.index = 0;
|
||||
@ -145,7 +139,7 @@ void train_lsd3(char *fcfg, char *fweight, char *gcfg, char *gweight, char *acfg
|
||||
float *delta = imlayer.delta_gpu;
|
||||
fill_ongpu(x_size, 0, delta, 1);
|
||||
scal_ongpu(x_size, 100, astate.delta, 1);
|
||||
scal_ongpu(x_size, .00001, fstate.delta, 1);
|
||||
scal_ongpu(x_size, .001, fstate.delta, 1);
|
||||
axpy_ongpu(x_size, 1, fstate.delta, 1, delta, 1);
|
||||
axpy_ongpu(x_size, 1, astate.delta, 1, delta, 1);
|
||||
|
||||
@ -165,7 +159,8 @@ void train_lsd3(char *fcfg, char *fweight, char *gcfg, char *gweight, char *acfg
|
||||
for(k = 0; k < gnet.batch; ++k){
|
||||
int index = j*gnet.batch + k;
|
||||
copy_cpu(imlayer.outputs, imlayer.output + k*imlayer.outputs, 1, generated.X.vals[index], 1);
|
||||
generated.y.vals[index][0] = .01;
|
||||
generated.y.vals[index][0] = .1;
|
||||
style.y.vals[index][0] = .9;
|
||||
}
|
||||
}
|
||||
|
||||
@ -346,7 +341,7 @@ void train_pix2pix(char *cfg, char *weight, char *acfg, char *aweight, int clear
|
||||
|
||||
backward_network_gpu(net, gstate);
|
||||
|
||||
scal_ongpu(imlayer.outputs, 100, imerror, 1);
|
||||
scal_ongpu(imlayer.outputs, 1000, imerror, 1);
|
||||
|
||||
printf("realness %f\n", cuda_mag_array(imerror, imlayer.outputs));
|
||||
printf("features %f\n", cuda_mag_array(net.layers[net.n-1].delta_gpu, imlayer.outputs));
|
||||
@ -399,6 +394,217 @@ void train_pix2pix(char *cfg, char *weight, char *acfg, char *aweight, int clear
|
||||
#endif
|
||||
}
|
||||
|
||||
void test_dcgan(char *cfgfile, char *weightfile)
|
||||
{
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
load_weights(&net, weightfile);
|
||||
}
|
||||
set_batch_network(&net, 1);
|
||||
srand(2222222);
|
||||
|
||||
clock_t time;
|
||||
char buff[256];
|
||||
char *input = buff;
|
||||
int i, imlayer = 0;
|
||||
|
||||
for (i = 0; i < net.n; ++i) {
|
||||
if (net.layers[i].out_c == 3) {
|
||||
imlayer = i;
|
||||
printf("%d\n", i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
while(1){
|
||||
image im = make_image(net.w, net.h, net.c);
|
||||
int i;
|
||||
for(i = 0; i < im.w*im.h*im.c; ++i){
|
||||
im.data[i] = rand_normal();
|
||||
}
|
||||
|
||||
float *X = im.data;
|
||||
time=clock();
|
||||
network_predict(net, X);
|
||||
image out = get_network_image_layer(net, imlayer);
|
||||
//yuv_to_rgb(out);
|
||||
constrain_image(out);
|
||||
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
|
||||
show_image(out, "out");
|
||||
save_image(out, "out");
|
||||
#ifdef OPENCV
|
||||
cvWaitKey(0);
|
||||
#endif
|
||||
|
||||
free_image(im);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void train_dcgan(char *cfg, char *weight, char *acfg, char *aweight, int clear)
|
||||
{
|
||||
#ifdef GPU
|
||||
//char *train_images = "/home/pjreddie/data/coco/train1.txt";
|
||||
//char *train_images = "/home/pjreddie/data/coco/trainvalno5k.txt";
|
||||
char *train_images = "/home/pjreddie/data/imagenet/imagenet1k.train.list";
|
||||
char *backup_directory = "/home/pjreddie/backup/";
|
||||
srand(time(0));
|
||||
char *base = basecfg(cfg);
|
||||
char *abase = basecfg(acfg);
|
||||
printf("%s\n", base);
|
||||
network net = load_network(cfg, weight, clear);
|
||||
network anet = load_network(acfg, aweight, clear);
|
||||
|
||||
int i, j, k;
|
||||
layer imlayer = {0};
|
||||
for (i = 0; i < net.n; ++i) {
|
||||
if (net.layers[i].out_c == 3) {
|
||||
imlayer = net.layers[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
|
||||
int imgs = net.batch*net.subdivisions;
|
||||
i = *net.seen/imgs;
|
||||
data train, buffer;
|
||||
|
||||
|
||||
list *plist = get_paths(train_images);
|
||||
//int N = plist->size;
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
|
||||
load_args args= get_base_args(anet);
|
||||
args.paths = paths;
|
||||
args.n = imgs;
|
||||
args.m = plist->size;
|
||||
args.d = &buffer;
|
||||
args.type = CLASSIFICATION_DATA;
|
||||
args.classes = 2;
|
||||
char *ls[2] = {"imagenet", "zzzzzzzz"};
|
||||
args.labels = ls;
|
||||
|
||||
pthread_t load_thread = load_data_in_thread(args);
|
||||
clock_t time;
|
||||
|
||||
network_state gstate = {0};
|
||||
gstate.index = 0;
|
||||
gstate.net = net;
|
||||
int x_size = get_network_input_size(net)*net.batch;
|
||||
int y_size = get_network_output_size(net)*net.batch;
|
||||
gstate.input = cuda_make_array(0, x_size);
|
||||
gstate.truth = cuda_make_array(0, y_size);
|
||||
gstate.train = 1;
|
||||
float *input = calloc(x_size, sizeof(float));
|
||||
float *y = calloc(y_size, sizeof(float));
|
||||
float *imerror = cuda_make_array(0, y_size);
|
||||
|
||||
network_state astate = {0};
|
||||
astate.index = 0;
|
||||
astate.net = anet;
|
||||
int ay_size = get_network_output_size(anet)*anet.batch;
|
||||
astate.input = 0;
|
||||
astate.truth = 0;
|
||||
astate.delta = 0;
|
||||
astate.train = 1;
|
||||
|
||||
float *ones_gpu = cuda_make_array(0, ay_size);
|
||||
fill_ongpu(ay_size, .1, ones_gpu, 1);
|
||||
fill_ongpu(ay_size/2, .9, ones_gpu, 2);
|
||||
|
||||
float aloss_avg = -1;
|
||||
|
||||
//data generated = copy_data(train);
|
||||
|
||||
while (get_current_batch(net) < net.max_batches) {
|
||||
i += 1;
|
||||
time=clock();
|
||||
pthread_join(load_thread, 0);
|
||||
train = buffer;
|
||||
load_thread = load_data_in_thread(args);
|
||||
|
||||
printf("Loaded: %lf seconds\n", sec(clock()-time));
|
||||
|
||||
data gen = copy_data(train);
|
||||
for(j = 0; j < imgs; ++j){
|
||||
train.y.vals[j][0] = .9;
|
||||
train.y.vals[j][1] = .1;
|
||||
gen.y.vals[j][0] = .1;
|
||||
gen.y.vals[j][1] = .9;
|
||||
}
|
||||
time=clock();
|
||||
|
||||
for(j = 0; j < net.subdivisions; ++j){
|
||||
get_next_batch(train, net.batch, j*net.batch, y, 0);
|
||||
int z;
|
||||
for(z = 0; z < x_size; ++z){
|
||||
input[z] = rand_normal();
|
||||
}
|
||||
|
||||
cuda_push_array(gstate.input, input, x_size);
|
||||
cuda_push_array(gstate.truth, y, y_size);
|
||||
*net.seen += net.batch;
|
||||
forward_network_gpu(net, gstate);
|
||||
|
||||
fill_ongpu(imlayer.outputs*imlayer.batch, 0, imerror, 1);
|
||||
astate.input = imlayer.output_gpu;
|
||||
astate.delta = imerror;
|
||||
astate.truth = ones_gpu;
|
||||
forward_network_gpu(anet, astate);
|
||||
backward_network_gpu(anet, astate);
|
||||
|
||||
scal_ongpu(imlayer.outputs*imlayer.batch, 1, imerror, 1);
|
||||
scal_ongpu(imlayer.outputs*imlayer.batch, .001, net.layers[net.n-1].delta_gpu, 1);
|
||||
|
||||
printf("realness %f\n", cuda_mag_array(imerror, imlayer.outputs*imlayer.batch));
|
||||
printf("features %f\n", cuda_mag_array(net.layers[net.n-1].delta_gpu, imlayer.outputs*imlayer.batch));
|
||||
|
||||
axpy_ongpu(imlayer.outputs*imlayer.batch, 1, imerror, 1, net.layers[net.n-1].delta_gpu, 1);
|
||||
|
||||
backward_network_gpu(net, gstate);
|
||||
|
||||
cuda_pull_array(imlayer.output_gpu, imlayer.output, x_size);
|
||||
for(k = 0; k < net.batch; ++k){
|
||||
int index = j*net.batch + k;
|
||||
copy_cpu(imlayer.outputs, imlayer.output + k*imlayer.outputs, 1, gen.X.vals[index], 1);
|
||||
gen.y.vals[index][0] = .1;
|
||||
}
|
||||
}
|
||||
harmless_update_network_gpu(anet);
|
||||
|
||||
data merge = concat_data(train, gen);
|
||||
randomize_data(merge);
|
||||
float aloss = train_network(anet, merge);
|
||||
|
||||
update_network_gpu(net);
|
||||
free_data(merge);
|
||||
free_data(train);
|
||||
free_data(gen);
|
||||
if (aloss_avg < 0) aloss_avg = aloss;
|
||||
aloss_avg = aloss_avg*.9 + aloss*.1;
|
||||
|
||||
printf("%d: adv: %f | adv_avg: %f, %f rate, %lf seconds, %d images\n", i, aloss, aloss_avg, get_current_rate(net), sec(clock()-time), i*imgs);
|
||||
if(i%1000==0){
|
||||
char buff[256];
|
||||
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
|
||||
save_weights(net, buff);
|
||||
sprintf(buff, "%s/%s_%d.weights", backup_directory, abase, i);
|
||||
save_weights(anet, buff);
|
||||
}
|
||||
if(i%100==0){
|
||||
char buff[256];
|
||||
sprintf(buff, "%s/%s.backup", backup_directory, base);
|
||||
save_weights(net, buff);
|
||||
sprintf(buff, "%s/%s.backup", backup_directory, abase);
|
||||
save_weights(anet, buff);
|
||||
}
|
||||
}
|
||||
char buff[256];
|
||||
sprintf(buff, "%s/%s_final.weights", backup_directory, base);
|
||||
save_weights(net, buff);
|
||||
#endif
|
||||
}
|
||||
|
||||
void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int clear)
|
||||
{
|
||||
#ifdef GPU
|
||||
@ -432,25 +638,15 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle
|
||||
//int N = plist->size;
|
||||
char **paths = (char **)list_to_array(plist);
|
||||
|
||||
load_args args = {0};
|
||||
args.w = net.w;
|
||||
args.h = net.h;
|
||||
load_args args= get_base_args(net);
|
||||
args.paths = paths;
|
||||
args.n = imgs;
|
||||
args.m = plist->size;
|
||||
args.d = &buffer;
|
||||
|
||||
args.min = net.min_crop;
|
||||
args.max = net.max_crop;
|
||||
args.angle = net.angle;
|
||||
args.aspect = net.aspect;
|
||||
args.exposure = net.exposure;
|
||||
args.saturation = net.saturation;
|
||||
args.hue = net.hue;
|
||||
args.size = net.w;
|
||||
args.type = CLASSIFICATION_DATA;
|
||||
args.classes = 1;
|
||||
char *ls[1] = {"imagenet"};
|
||||
args.classes = 2;
|
||||
char *ls[2] = {"imagenet", "zzzzzzz"};
|
||||
args.labels = ls;
|
||||
|
||||
pthread_t load_thread = load_data_in_thread(args);
|
||||
@ -478,9 +674,10 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle
|
||||
astate.delta = 0;
|
||||
astate.train = 1;
|
||||
|
||||
float *imerror = cuda_make_array(0, imlayer.outputs);
|
||||
float *imerror = cuda_make_array(0, imlayer.outputs*imlayer.batch);
|
||||
float *ones_gpu = cuda_make_array(0, ay_size);
|
||||
fill_ongpu(ay_size, .99, ones_gpu, 1);
|
||||
fill_ongpu(ay_size, .1, ones_gpu, 1);
|
||||
fill_ongpu(ay_size/2, .9, ones_gpu, 2);
|
||||
|
||||
float aloss_avg = -1;
|
||||
float gloss_avg = -1;
|
||||
@ -500,17 +697,17 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle
|
||||
for(j = 0; j < imgs; ++j){
|
||||
image gim = float_to_image(net.w, net.h, net.c, gray.X.vals[j]);
|
||||
grayscale_image_3c(gim);
|
||||
train.y.vals[j][0] = .99;
|
||||
|
||||
image yim = float_to_image(net.w, net.h, net.c, train.X.vals[j]);
|
||||
//rgb_to_yuv(yim);
|
||||
train.y.vals[j][0] = .9;
|
||||
train.y.vals[j][1] = .1;
|
||||
gray.y.vals[j][0] = .1;
|
||||
gray.y.vals[j][1] = .9;
|
||||
}
|
||||
time=clock();
|
||||
float gloss = 0;
|
||||
|
||||
for(j = 0; j < net.subdivisions; ++j){
|
||||
get_next_batch(train, net.batch, j*net.batch, pixs, y);
|
||||
get_next_batch(gray, net.batch, j*net.batch, graypixs, y);
|
||||
get_next_batch(train, net.batch, j*net.batch, pixs, 0);
|
||||
get_next_batch(gray, net.batch, j*net.batch, graypixs, 0);
|
||||
cuda_push_array(gstate.input, graypixs, x_size);
|
||||
cuda_push_array(gstate.truth, pixs, x_size);
|
||||
/*
|
||||
@ -523,23 +720,24 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle
|
||||
*net.seen += net.batch;
|
||||
forward_network_gpu(net, gstate);
|
||||
|
||||
fill_ongpu(imlayer.outputs, 0, imerror, 1);
|
||||
fill_ongpu(imlayer.outputs*imlayer.batch, 0, imerror, 1);
|
||||
astate.input = imlayer.output_gpu;
|
||||
astate.delta = imerror;
|
||||
astate.truth = ones_gpu;
|
||||
forward_network_gpu(anet, astate);
|
||||
backward_network_gpu(anet, astate);
|
||||
|
||||
scal_ongpu(imlayer.outputs, .1, net.layers[net.n-1].delta_gpu, 1);
|
||||
scal_ongpu(imlayer.outputs*imlayer.batch, 1./1000., net.layers[net.n-1].delta_gpu, 1);
|
||||
|
||||
scal_ongpu(imlayer.outputs*imlayer.batch, 1, imerror, 1);
|
||||
|
||||
printf("realness %f\n", cuda_mag_array(imerror, imlayer.outputs*imlayer.batch));
|
||||
printf("features %f\n", cuda_mag_array(net.layers[net.n-1].delta_gpu, imlayer.outputs*imlayer.batch));
|
||||
|
||||
axpy_ongpu(imlayer.outputs*imlayer.batch, 1, imerror, 1, net.layers[net.n-1].delta_gpu, 1);
|
||||
|
||||
backward_network_gpu(net, gstate);
|
||||
|
||||
scal_ongpu(imlayer.outputs, 100, imerror, 1);
|
||||
|
||||
printf("realness %f\n", cuda_mag_array(imerror, imlayer.outputs));
|
||||
printf("features %f\n", cuda_mag_array(net.layers[net.n-1].delta_gpu, imlayer.outputs));
|
||||
|
||||
axpy_ongpu(imlayer.outputs, 1, imerror, 1, imlayer.delta_gpu, 1);
|
||||
|
||||
gloss += get_network_cost(net) /(net.subdivisions*net.batch);
|
||||
|
||||
@ -547,7 +745,6 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle
|
||||
for(k = 0; k < net.batch; ++k){
|
||||
int index = j*net.batch + k;
|
||||
copy_cpu(imlayer.outputs, imlayer.output + k*imlayer.outputs, 1, gray.X.vals[index], 1);
|
||||
gray.y.vals[index][0] = .01;
|
||||
}
|
||||
}
|
||||
harmless_update_network_gpu(anet);
|
||||
@ -557,7 +754,6 @@ void train_colorizer(char *cfg, char *weight, char *acfg, char *aweight, int cle
|
||||
float aloss = train_network(anet, merge);
|
||||
|
||||
update_network_gpu(net);
|
||||
update_network_gpu(anet);
|
||||
free_data(merge);
|
||||
free_data(train);
|
||||
free_data(gray);
|
||||
@ -840,7 +1036,7 @@ void train_lsd(char *cfgfile, char *weightfile, int clear)
|
||||
save_weights(net, buff);
|
||||
}
|
||||
|
||||
void test_lsd(char *cfgfile, char *weightfile, char *filename)
|
||||
void test_lsd(char *cfgfile, char *weightfile, char *filename, int gray)
|
||||
{
|
||||
network net = parse_network_cfg(cfgfile);
|
||||
if(weightfile){
|
||||
@ -875,7 +1071,7 @@ void test_lsd(char *cfgfile, char *weightfile, char *filename)
|
||||
image im = load_image_color(input, 0, 0);
|
||||
image resized = resize_min(im, net.w);
|
||||
image crop = crop_image(resized, (resized.w - net.w)/2, (resized.h - net.h)/2, net.w, net.h);
|
||||
//grayscale_image_3c(crop);
|
||||
if(gray) grayscale_image_3c(crop);
|
||||
|
||||
float *X = crop.data;
|
||||
time=clock();
|
||||
@ -916,8 +1112,11 @@ void run_lsd(int argc, char **argv)
|
||||
if(0==strcmp(argv[2], "train")) train_lsd(cfg, weights, clear);
|
||||
else if(0==strcmp(argv[2], "train2")) train_lsd2(cfg, weights, acfg, aweights, clear);
|
||||
else if(0==strcmp(argv[2], "traincolor")) train_colorizer(cfg, weights, acfg, aweights, clear);
|
||||
else if(0==strcmp(argv[2], "traingan")) train_dcgan(cfg, weights, acfg, aweights, clear);
|
||||
else if(0==strcmp(argv[2], "gan")) test_dcgan(cfg, weights);
|
||||
else if(0==strcmp(argv[2], "train3")) train_lsd3(argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], clear);
|
||||
else if(0==strcmp(argv[2], "test")) test_lsd(cfg, weights, filename);
|
||||
else if(0==strcmp(argv[2], "test")) test_lsd(cfg, weights, filename, 0);
|
||||
else if(0==strcmp(argv[2], "color")) test_lsd(cfg, weights, filename, 1);
|
||||
/*
|
||||
else if(0==strcmp(argv[2], "valid")) validate_lsd(cfg, weights);
|
||||
*/
|
||||
|
@ -8,7 +8,6 @@ extern "C" {
|
||||
#include <assert.h>
|
||||
|
||||
#include "network.h"
|
||||
#include "image.h"
|
||||
#include "data.h"
|
||||
#include "utils.h"
|
||||
#include "parser.h"
|
||||
|
@ -1,13 +1,8 @@
|
||||
|
||||
#include "network.h"
|
||||
#include "parser.h"
|
||||
#include "blas.h"
|
||||
#include "utils.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#endif
|
||||
|
||||
// ./darknet nightmare cfg/extractor.recon.cfg ~/trained/yolo-coco.conv frame6.png -reconstruct -iters 500 -i 3 -lambda .1 -rate .01 -smooth 2
|
||||
|
||||
float abs_mean(float *x, int n)
|
||||
|
@ -7,11 +7,6 @@
|
||||
#include "cuda.h"
|
||||
#include <sys/time.h>
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
image get_image_from_stream(CvCapture *cap);
|
||||
#endif
|
||||
|
||||
void train_regressor(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
|
||||
{
|
||||
int i;
|
||||
@ -185,7 +180,6 @@ void demo_regressor(char *datacfg, char *cfgfile, char *weightfile, int cam_inde
|
||||
cvNamedWindow("Regressor", CV_WINDOW_NORMAL);
|
||||
cvResizeWindow("Regressor", 512, 512);
|
||||
float fps = 0;
|
||||
int i;
|
||||
|
||||
while(1){
|
||||
struct timeval tval_before, tval_after, tval_result;
|
||||
|
@ -4,10 +4,6 @@
|
||||
#include "blas.h"
|
||||
#include "parser.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
float *x;
|
||||
float *y;
|
||||
|
@ -5,7 +5,6 @@
|
||||
#include "blas.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
image get_image_from_stream(CvCapture *cap);
|
||||
image ipl_to_image(IplImage* src);
|
||||
|
||||
|
@ -3,10 +3,6 @@
|
||||
#include "utils.h"
|
||||
#include "parser.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#endif
|
||||
|
||||
void train_super(char *cfgfile, char *weightfile, int clear)
|
||||
{
|
||||
char *train_images = "/data/imagenet/imagenet1k.train.list";
|
||||
|
@ -5,10 +5,6 @@
|
||||
#include "parser.h"
|
||||
#include "box.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#endif
|
||||
|
||||
void train_swag(char *cfgfile, char *weightfile)
|
||||
{
|
||||
char *train_images = "data/voc.0712.trainval";
|
||||
|
@ -2,10 +2,6 @@
|
||||
#include "utils.h"
|
||||
#include "parser.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#endif
|
||||
|
||||
void train_tag(char *cfgfile, char *weightfile, int clear)
|
||||
{
|
||||
srand(time(0));
|
||||
|
@ -3,11 +3,6 @@
|
||||
#include "utils.h"
|
||||
#include "parser.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
image get_image_from_stream(CvCapture *cap);
|
||||
#endif
|
||||
|
||||
void extract_voxel(char *lfile, char *rfile, char *prefix)
|
||||
{
|
||||
#ifdef OPENCV
|
||||
|
@ -2,10 +2,6 @@
|
||||
#include "utils.h"
|
||||
#include "parser.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#endif
|
||||
|
||||
void train_writing(char *cfgfile, char *weightfile)
|
||||
{
|
||||
char *backup_directory = "/home/pjreddie/backup/";
|
||||
|
@ -6,10 +6,6 @@
|
||||
#include "box.h"
|
||||
#include "demo.h"
|
||||
|
||||
#ifdef OPENCV
|
||||
#include "opencv2/highgui/highgui_c.h"
|
||||
#endif
|
||||
|
||||
char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
|
||||
|
||||
void train_yolo(char *cfgfile, char *weightfile)
|
||||
|
Loading…
Reference in New Issue
Block a user