darknet/include/darknet.h
2017-06-06 17:23:35 -07:00

572 lines
10 KiB
C

#ifndef DARKNET_API
#define DARKNET_API
#include <stdlib.h>
#include <pthread.h>
extern int gpu_index;
#ifdef GPU
#define BLOCK 512
#include "cuda_runtime.h"
#include "curand.h"
#include "cublas_v2.h"
#ifdef CUDNN
#include "cudnn.h"
#endif
#endif
#ifndef __cplusplus
#ifdef OPENCV
#include "opencv2/highgui/highgui_c.h"
#include "opencv2/imgproc/imgproc_c.h"
#include "opencv2/core/version.hpp"
#if CV_MAJOR_VERSION == 3
#include "opencv2/videoio/videoio_c.h"
#endif
#endif
#endif
typedef struct{
int *leaf;
int n;
int *parent;
int *child;
int *group;
char **name;
int groups;
int *group_size;
int *group_offset;
} tree;
typedef enum{
LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN
}ACTIVATION;
typedef enum {
CONVOLUTIONAL,
DECONVOLUTIONAL,
CONNECTED,
MAXPOOL,
SOFTMAX,
DETECTION,
DROPOUT,
CROP,
ROUTE,
COST,
NORMALIZATION,
AVGPOOL,
LOCAL,
SHORTCUT,
ACTIVE,
RNN,
GRU,
LSTM,
CRNN,
BATCHNORM,
NETWORK,
XNOR,
REGION,
REORG,
BLANK
} LAYER_TYPE;
typedef enum{
SSE, MASKED, L1, SMOOTH
} COST_TYPE;
struct network;
typedef struct network network;
struct layer;
typedef struct layer layer;
struct layer{
LAYER_TYPE type;
ACTIVATION activation;
COST_TYPE cost_type;
void (*forward) (struct layer, struct network);
void (*backward) (struct layer, struct network);
void (*update) (struct layer, int, float, float, float);
void (*forward_gpu) (struct layer, struct network);
void (*backward_gpu) (struct layer, struct network);
void (*update_gpu) (struct layer, int, float, float, float);
int batch_normalize;
int shortcut;
int batch;
int forced;
int flipped;
int inputs;
int outputs;
int nweights;
int nbiases;
int extra;
int truths;
int h,w,c;
int out_h, out_w, out_c;
int n;
int max_boxes;
int groups;
int size;
int side;
int stride;
int reverse;
int flatten;
int spatial;
int pad;
int sqrt;
int flip;
int index;
int binary;
int xnor;
int steps;
int hidden;
int truth;
float smooth;
float dot;
float angle;
float jitter;
float saturation;
float exposure;
float shift;
float ratio;
float learning_rate_scale;
int softmax;
int classes;
int coords;
int background;
int rescore;
int objectness;
int does_cost;
int joint;
int noadjust;
int reorg;
int log;
int adam;
float B1;
float B2;
float eps;
int t;
float alpha;
float beta;
float kappa;
float coord_scale;
float object_scale;
float noobject_scale;
float class_scale;
int bias_match;
int random;
float thresh;
int classfix;
int absolute;
int onlyforward;
int stopbackward;
int dontload;
int dontloadscales;
float temperature;
float probability;
float scale;
char * cweights;
int * indexes;
int * input_layers;
int * input_sizes;
int * map;
float * rand;
float * cost;
float * state;
float * prev_state;
float * forgot_state;
float * forgot_delta;
float * state_delta;
float * concat;
float * concat_delta;
float * binary_weights;
float * biases;
float * bias_updates;
float * scales;
float * scale_updates;
float * weights;
float * weight_updates;
float * delta;
float * output;
float * squared;
float * norms;
float * spatial_mean;
float * mean;
float * variance;
float * mean_delta;
float * variance_delta;
float * rolling_mean;
float * rolling_variance;
float * x;
float * x_norm;
float * m;
float * v;
float * bias_m;
float * bias_v;
float * scale_m;
float * scale_v;
float * z_cpu;
float * r_cpu;
float * h_cpu;
float * binary_input;
struct layer *input_layer;
struct layer *self_layer;
struct layer *output_layer;
struct layer *input_gate_layer;
struct layer *state_gate_layer;
struct layer *input_save_layer;
struct layer *state_save_layer;
struct layer *input_state_layer;
struct layer *state_state_layer;
struct layer *input_z_layer;
struct layer *state_z_layer;
struct layer *input_r_layer;
struct layer *state_r_layer;
struct layer *input_h_layer;
struct layer *state_h_layer;
struct layer *wz;
struct layer *uz;
struct layer *wr;
struct layer *ur;
struct layer *wh;
struct layer *uh;
struct layer *uo;
struct layer *wo;
struct layer *uf;
struct layer *wf;
struct layer *ui;
struct layer *wi;
struct layer *ug;
struct layer *wg;
tree *softmax_tree;
size_t workspace_size;
#ifdef GPU
int *indexes_gpu;
float *z_gpu;
float *r_gpu;
float *h_gpu;
float *temp_gpu;
float *temp2_gpu;
float *temp3_gpu;
float *dh_gpu;
float *hh_gpu;
float *prev_cell_gpu;
float *cell_gpu;
float *f_gpu;
float *i_gpu;
float *g_gpu;
float *o_gpu;
float *c_gpu;
float *dc_gpu;
float *m_gpu;
float *v_gpu;
float *bias_m_gpu;
float *scale_m_gpu;
float *bias_v_gpu;
float *scale_v_gpu;
float * prev_state_gpu;
float * forgot_state_gpu;
float * forgot_delta_gpu;
float * state_gpu;
float * state_delta_gpu;
float * gate_gpu;
float * gate_delta_gpu;
float * save_gpu;
float * save_delta_gpu;
float * concat_gpu;
float * concat_delta_gpu;
float *binary_input_gpu;
float *binary_weights_gpu;
float * mean_gpu;
float * variance_gpu;
float * rolling_mean_gpu;
float * rolling_variance_gpu;
float * variance_delta_gpu;
float * mean_delta_gpu;
float * x_gpu;
float * x_norm_gpu;
float * weights_gpu;
float * weight_updates_gpu;
float * biases_gpu;
float * bias_updates_gpu;
float * scales_gpu;
float * scale_updates_gpu;
float * output_gpu;
float * delta_gpu;
float * rand_gpu;
float * squared_gpu;
float * norms_gpu;
#ifdef CUDNN
cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc;
cudnnTensorDescriptor_t dsrcTensorDesc, ddstTensorDesc;
cudnnTensorDescriptor_t normTensorDesc;
cudnnFilterDescriptor_t weightDesc;
cudnnFilterDescriptor_t dweightDesc;
cudnnConvolutionDescriptor_t convDesc;
cudnnConvolutionFwdAlgo_t fw_algo;
cudnnConvolutionBwdDataAlgo_t bd_algo;
cudnnConvolutionBwdFilterAlgo_t bf_algo;
#endif
#endif
};
void free_layer(layer);
typedef enum {
CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM
} learning_rate_policy;
typedef struct network{
int n;
int batch;
int *seen;
float epoch;
int subdivisions;
float momentum;
float decay;
layer *layers;
float *output;
learning_rate_policy policy;
float learning_rate;
float gamma;
float scale;
float power;
int time_steps;
int step;
int max_batches;
float *scales;
int *steps;
int num_steps;
int burn_in;
int adam;
float B1;
float B2;
float eps;
int inputs;
int outputs;
int truths;
int notruth;
int h, w, c;
int max_crop;
int min_crop;
int center;
float angle;
float aspect;
float exposure;
float saturation;
float hue;
int gpu_index;
tree *hierarchy;
float *input;
float *truth;
float *delta;
float *workspace;
int train;
int index;
float *cost;
#ifdef GPU
float *input_gpu;
float *truth_gpu;
float *delta_gpu;
float *output_gpu;
#endif
} network;
typedef struct {
int w;
int h;
float scale;
float rad;
float dx;
float dy;
float aspect;
} augment_args;
typedef struct {
int h;
int w;
int c;
float *data;
} image;
typedef struct{
float x, y, w, h;
} box;
typedef struct matrix{
int rows, cols;
float **vals;
} matrix;
typedef struct{
int w, h;
matrix X;
matrix y;
int shallow;
int *num_boxes;
box **boxes;
} data;
typedef enum {
CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA, COMPARE_DATA, WRITING_DATA, SWAG_DATA, TAG_DATA, OLD_CLASSIFICATION_DATA, STUDY_DATA, DET_DATA, SUPER_DATA, LETTERBOX_DATA, REGRESSION_DATA, SEGMENTATION_DATA
} data_type;
typedef struct load_args{
int threads;
char **paths;
char *path;
int n;
int m;
char **labels;
int h;
int w;
int out_w;
int out_h;
int nh;
int nw;
int num_boxes;
int min, max, size;
int classes;
int background;
int scale;
int center;
float jitter;
float angle;
float aspect;
float saturation;
float exposure;
float hue;
data *d;
image *im;
image *resized;
data_type type;
tree *hierarchy;
} load_args;
typedef struct{
int id;
float x,y,w,h;
float left, right, top, bottom;
} box_label;
network load_network(char *cfg, char *weights, int clear);
load_args get_base_args(network net);
void free_data(data d);
typedef struct node{
void *val;
struct node *next;
struct node *prev;
} node;
typedef struct list{
int size;
node *front;
node *back;
} list;
pthread_t load_data(load_args args);
list *read_data_cfg(char *filename);
list *read_cfg(char *filename);
#include "activation_layer.h"
#include "activations.h"
#include "avgpool_layer.h"
#include "batchnorm_layer.h"
#include "blas.h"
#include "box.h"
#include "classifier.h"
#include "col2im.h"
#include "connected_layer.h"
#include "convolutional_layer.h"
#include "cost_layer.h"
#include "crnn_layer.h"
#include "crop_layer.h"
#include "cuda.h"
#include "data.h"
#include "deconvolutional_layer.h"
#include "demo.h"
#include "detection_layer.h"
#include "dropout_layer.h"
#include "gemm.h"
#include "gru_layer.h"
#include "lstm_layer.h"
#include "im2col.h"
#include "image.h"
#include "layer.h"
#include "list.h"
#include "local_layer.h"
#include "matrix.h"
#include "maxpool_layer.h"
#include "network.h"
#include "normalization_layer.h"
#include "option_list.h"
#include "parser.h"
#include "region_layer.h"
#include "reorg_layer.h"
#include "rnn_layer.h"
#include "route_layer.h"
#include "shortcut_layer.h"
#include "softmax_layer.h"
#include "stb_image.h"
#include "stb_image_write.h"
#include "tree.h"
#include "utils.h"
#endif