mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Added LSTM sequence detector, and blur data augmentation (for OpenCV only)
This commit is contained in:
2
Makefile
2
Makefile
@ -118,7 +118,7 @@ LDFLAGS+= -L/usr/local/zed/lib -lsl_core -lsl_input -lsl_zed
|
||||
#-lstdc++ -D_GLIBCXX_USE_CXX11_ABI=0
|
||||
endif
|
||||
|
||||
OBJ=image_opencv.o http_stream.o gemm.o utils.o dark_cuda.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o detector.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o demo.o tag.o cifar.o go.o batchnorm_layer.o art.o region_layer.o reorg_layer.o reorg_old_layer.o super.o voxel.o tree.o yolo_layer.o upsample_layer.o lstm_layer.o
|
||||
OBJ=image_opencv.o http_stream.o gemm.o utils.o dark_cuda.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o detector.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o demo.o tag.o cifar.o go.o batchnorm_layer.o art.o region_layer.o reorg_layer.o reorg_old_layer.o super.o voxel.o tree.o yolo_layer.o upsample_layer.o lstm_layer.o conv_lstm_layer.o
|
||||
ifeq ($(GPU), 1)
|
||||
LDFLAGS+= -lstdc++
|
||||
OBJ+=convolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
|
||||
|
@ -14,7 +14,7 @@ More details: http://pjreddie.com/darknet/yolo/
|
||||
* [Requirements (and how to install dependecies)](#requirements)
|
||||
* [Pre-trained models](#pre-trained-models)
|
||||
* [Explanations in issues](https://github.com/AlexeyAB/darknet/issues?q=is%3Aopen+is%3Aissue+label%3AExplanations)
|
||||
* [Yolo v3 in other frameworks (TensorFlow, OpenVINO, OpenCV-dnn, ...)](#yolo-v3-in-other-frameworks)
|
||||
* [Yolo v3 in other frameworks (TensorFlow, PyTorch, OpenVINO, OpenCV-dnn,...)](#yolo-v3-in-other-frameworks)
|
||||
|
||||
0. [Improvements in this repository](#improvements-in-this-repository)
|
||||
1. [How to use](#how-to-use-on-the-command-line)
|
||||
@ -75,9 +75,9 @@ You can get cfg-files by path: `darknet/cfg/`
|
||||
#### Yolo v3 in other frameworks
|
||||
|
||||
* **TensorFlow:** convert `yolov3.weights`/`cfg` files to `yolov3.ckpt`/`pb/meta`: by using [mystic123](https://github.com/mystic123/tensorflow-yolo-v3) or [jinyu121](https://github.com/jinyu121/DW2TF) projects, and [TensorFlow-lite](https://www.tensorflow.org/lite/guide/get_started#2_convert_the_model_format)
|
||||
* **Intel OpenVINO:** (Myriad X / USB Neural Compute Stick / Arria FPGA): read this [manual](https://software.intel.com/en-us/articles/OpenVINO-Using-TensorFlow#converting-a-darknet-yolo-model)
|
||||
* **OpenCV-dnn** is very fast DNN implementation on CPU (x86/ARM-Android), use `yolov3.weights`/`cfg` with: [C++ example](https://github.com/opencv/opencv/blob/8c25a8eb7b10fb50cda323ee6bec68aa1a9ce43c/samples/dnn/object_detection.cpp#L192-L221), [Python example](https://github.com/opencv/opencv/blob/8c25a8eb7b10fb50cda323ee6bec68aa1a9ce43c/samples/dnn/object_detection.py#L129-L150)
|
||||
* **PyTorch > ONNX > CoreML > iOS** how to convert cfg/weights-files to pt-file: [ultralytics/yolov3](https://github.com/ultralytics/yolov3#darknet-conversion)
|
||||
* **Intel OpenVINO 2019 R1:** (Myriad X / USB Neural Compute Stick / Arria FPGA): read this [manual](https://software.intel.com/en-us/articles/OpenVINO-Using-TensorFlow#converting-a-darknet-yolo-model)
|
||||
* **OpenCV-dnn** is a very fast DNN implementation on CPU (x86/ARM-Android), use `yolov3.weights`/`cfg` with: [C++ example](https://github.com/opencv/opencv/blob/8c25a8eb7b10fb50cda323ee6bec68aa1a9ce43c/samples/dnn/object_detection.cpp#L192-L221), [Python example](https://github.com/opencv/opencv/blob/8c25a8eb7b10fb50cda323ee6bec68aa1a9ce43c/samples/dnn/object_detection.py#L129-L150)
|
||||
* **PyTorch > ONNX > CoreML > iOS** how to convert cfg/weights-files to pt-file: [ultralytics/yolov3](https://github.com/ultralytics/yolov3#darknet-conversion) and [iOS App](https://itunes.apple.com/app/id1452689527)
|
||||
|
||||
##### Examples of results
|
||||
|
||||
|
@ -140,6 +140,8 @@
|
||||
<CompileAs>Default</CompileAs>
|
||||
<UndefinePreprocessorDefinitions>NDEBUG</UndefinePreprocessorDefinitions>
|
||||
<MultiProcessorCompilation>true</MultiProcessorCompilation>
|
||||
<AdditionalUsingDirectories>
|
||||
</AdditionalUsingDirectories>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
@ -183,6 +185,7 @@
|
||||
<ClCompile Include="..\..\src\compare.c" />
|
||||
<ClCompile Include="..\..\src\connected_layer.c" />
|
||||
<ClCompile Include="..\..\src\convolutional_layer.c" />
|
||||
<ClCompile Include="..\..\src\conv_lstm_layer.c" />
|
||||
<ClCompile Include="..\..\src\cost_layer.c" />
|
||||
<ClCompile Include="..\..\src\cpu_gemm.c" />
|
||||
<ClCompile Include="..\..\src\crnn_layer.c" />
|
||||
@ -248,6 +251,7 @@
|
||||
<ClInclude Include="..\..\src\col2im.h" />
|
||||
<ClInclude Include="..\..\src\connected_layer.h" />
|
||||
<ClInclude Include="..\..\src\convolutional_layer.h" />
|
||||
<ClInclude Include="..\..\src\conv_lstm_layer.h" />
|
||||
<ClInclude Include="..\..\src\cost_layer.h" />
|
||||
<ClInclude Include="..\..\src\crnn_layer.h" />
|
||||
<ClInclude Include="..\..\src\crop_layer.h" />
|
||||
|
@ -189,6 +189,7 @@
|
||||
<ClCompile Include="..\..\src\compare.c" />
|
||||
<ClCompile Include="..\..\src\connected_layer.c" />
|
||||
<ClCompile Include="..\..\src\convolutional_layer.c" />
|
||||
<ClCompile Include="..\..\src\conv_lstm_layer.c" />
|
||||
<ClCompile Include="..\..\src\cost_layer.c" />
|
||||
<ClCompile Include="..\..\src\cpu_gemm.c" />
|
||||
<ClCompile Include="..\..\src\crnn_layer.c" />
|
||||
@ -254,6 +255,7 @@
|
||||
<ClInclude Include="..\..\src\col2im.h" />
|
||||
<ClInclude Include="..\..\src\connected_layer.h" />
|
||||
<ClInclude Include="..\..\src\convolutional_layer.h" />
|
||||
<ClInclude Include="..\..\src\conv_lstm_layer.h" />
|
||||
<ClInclude Include="..\..\src\cost_layer.h" />
|
||||
<ClInclude Include="..\..\src\crnn_layer.h" />
|
||||
<ClInclude Include="..\..\src\crop_layer.h" />
|
||||
|
@ -33,6 +33,9 @@ darknet.exe partial cfg/yolov3-spp.cfg yolov3-spp.weights yolov3-spp.conv.85 85
|
||||
darknet.exe partial cfg/yolov3-tiny.cfg yolov3-tiny.weights yolov3-tiny.conv.15 15
|
||||
|
||||
|
||||
darknet.exe partial cfg/yolov3-tiny.cfg yolov3-tiny.weights yolov3-tiny.conv.14 14
|
||||
|
||||
|
||||
darknet.exe partial cfg/yolo9000.cfg yolo9000.weights yolo9000.conv.22 22
|
||||
|
||||
|
||||
|
@ -187,6 +187,7 @@
|
||||
<ClCompile Include="..\..\src\compare.c" />
|
||||
<ClCompile Include="..\..\src\connected_layer.c" />
|
||||
<ClCompile Include="..\..\src\convolutional_layer.c" />
|
||||
<ClCompile Include="..\..\src\conv_lstm_layer.c" />
|
||||
<ClCompile Include="..\..\src\cost_layer.c" />
|
||||
<ClCompile Include="..\..\src\cpu_gemm.c" />
|
||||
<ClCompile Include="..\..\src\crnn_layer.c" />
|
||||
@ -254,6 +255,7 @@
|
||||
<ClInclude Include="..\..\src\col2im.h" />
|
||||
<ClInclude Include="..\..\src\connected_layer.h" />
|
||||
<ClInclude Include="..\..\src\convolutional_layer.h" />
|
||||
<ClInclude Include="..\..\src\conv_lstm_layer.h" />
|
||||
<ClInclude Include="..\..\src\cost_layer.h" />
|
||||
<ClInclude Include="..\..\src\crnn_layer.h" />
|
||||
<ClInclude Include="..\..\src\crop_layer.h" />
|
||||
|
@ -173,6 +173,7 @@
|
||||
<ClCompile Include="..\..\src\compare.c" />
|
||||
<ClCompile Include="..\..\src\connected_layer.c" />
|
||||
<ClCompile Include="..\..\src\convolutional_layer.c" />
|
||||
<ClCompile Include="..\..\src\conv_lstm_layer.c" />
|
||||
<ClCompile Include="..\..\src\cost_layer.c" />
|
||||
<ClCompile Include="..\..\src\cpu_gemm.c" />
|
||||
<ClCompile Include="..\..\src\crnn_layer.c" />
|
||||
@ -240,6 +241,7 @@
|
||||
<ClInclude Include="..\..\src\col2im.h" />
|
||||
<ClInclude Include="..\..\src\connected_layer.h" />
|
||||
<ClInclude Include="..\..\src\convolutional_layer.h" />
|
||||
<ClInclude Include="..\..\src\conv_lstm_layer.h" />
|
||||
<ClInclude Include="..\..\src\cost_layer.h" />
|
||||
<ClInclude Include="..\..\src\crnn_layer.h" />
|
||||
<ClInclude Include="..\..\src\crop_layer.h" />
|
||||
|
@ -32,7 +32,6 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define NFRAMES 3
|
||||
#define SECRET_NUM -1234
|
||||
|
||||
#ifdef GPU
|
||||
@ -136,6 +135,7 @@ typedef enum {
|
||||
RNN,
|
||||
GRU,
|
||||
LSTM,
|
||||
CONV_LSTM,
|
||||
CRNN,
|
||||
BATCHNORM,
|
||||
NETWORK,
|
||||
@ -208,6 +208,7 @@ struct layer {
|
||||
int index;
|
||||
int binary;
|
||||
int xnor;
|
||||
int peephole;
|
||||
int use_bin_output;
|
||||
int steps;
|
||||
int hidden;
|
||||
@ -354,6 +355,7 @@ struct layer {
|
||||
float *z_cpu;
|
||||
float *r_cpu;
|
||||
float *h_cpu;
|
||||
float *stored_h_cpu;
|
||||
float * prev_state_cpu;
|
||||
|
||||
float *temp_cpu;
|
||||
@ -369,6 +371,7 @@ struct layer {
|
||||
float *g_cpu;
|
||||
float *o_cpu;
|
||||
float *c_cpu;
|
||||
float *stored_c_cpu;
|
||||
float *dc_cpu;
|
||||
|
||||
float *binary_input;
|
||||
@ -407,10 +410,13 @@ struct layer {
|
||||
struct layer *uh;
|
||||
struct layer *uo;
|
||||
struct layer *wo;
|
||||
struct layer *vo;
|
||||
struct layer *uf;
|
||||
struct layer *wf;
|
||||
struct layer *vf;
|
||||
struct layer *ui;
|
||||
struct layer *wi;
|
||||
struct layer *vi;
|
||||
struct layer *ug;
|
||||
struct layer *wg;
|
||||
|
||||
@ -424,6 +430,7 @@ struct layer {
|
||||
float *z_gpu;
|
||||
float *r_gpu;
|
||||
float *h_gpu;
|
||||
float *stored_h_gpu;
|
||||
|
||||
float *temp_gpu;
|
||||
float *temp2_gpu;
|
||||
@ -432,12 +439,16 @@ struct layer {
|
||||
float *dh_gpu;
|
||||
float *hh_gpu;
|
||||
float *prev_cell_gpu;
|
||||
float *prev_state_gpu;
|
||||
float *last_prev_state_gpu;
|
||||
float *last_prev_cell_gpu;
|
||||
float *cell_gpu;
|
||||
float *f_gpu;
|
||||
float *i_gpu;
|
||||
float *g_gpu;
|
||||
float *o_gpu;
|
||||
float *c_gpu;
|
||||
float *stored_c_gpu;
|
||||
float *dc_gpu;
|
||||
|
||||
// adam
|
||||
@ -451,7 +462,6 @@ struct layer {
|
||||
float * combine_gpu;
|
||||
float * combine_delta_gpu;
|
||||
|
||||
float * prev_state_gpu;
|
||||
float * forgot_state_gpu;
|
||||
float * forgot_delta_gpu;
|
||||
float * state_gpu;
|
||||
@ -571,6 +581,7 @@ typedef struct network {
|
||||
float min_ratio;
|
||||
int center;
|
||||
int flip; // horizontal flip 50% probability augmentaiont for classifier training (default = 1)
|
||||
int blur;
|
||||
float angle;
|
||||
float aspect;
|
||||
float exposure;
|
||||
@ -579,6 +590,8 @@ typedef struct network {
|
||||
int random;
|
||||
int track;
|
||||
int augment_speed;
|
||||
int sequential_subdivisions;
|
||||
int current_subdivision;
|
||||
int try_fix_nan;
|
||||
|
||||
int gpu_index;
|
||||
@ -713,6 +726,7 @@ typedef struct load_args {
|
||||
int show_imgs;
|
||||
float jitter;
|
||||
int flip;
|
||||
int blur;
|
||||
float angle;
|
||||
float aspect;
|
||||
float saturation;
|
||||
|
@ -446,8 +446,8 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
|
||||
l.weights_gpu = cuda_make_array(l.weights, c*n*size*size);
|
||||
l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
|
||||
#ifdef CUDNN_HALF
|
||||
l.weights_gpu16 = cuda_make_array(NULL, c*n*size*size / 2); //cuda_make_array(l.weights, c*n*size*size / 2);
|
||||
l.weight_updates_gpu16 = cuda_make_array(NULL, c*n*size*size / 2); //cuda_make_array(l.weight_updates, c*n*size*size / 2);
|
||||
l.weights_gpu16 = cuda_make_array(NULL, c*n*size*size / 2 + 1); //cuda_make_array(l.weights, c*n*size*size / 2);
|
||||
l.weight_updates_gpu16 = cuda_make_array(NULL, c*n*size*size / 2 + 1); //cuda_make_array(l.weight_updates, c*n*size*size / 2);
|
||||
#endif
|
||||
|
||||
l.biases_gpu = cuda_make_array(l.biases, n);
|
||||
|
@ -85,6 +85,8 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou
|
||||
l.delta_gpu = l.output_layer->delta_gpu;
|
||||
#endif
|
||||
|
||||
l.bflops = l.input_layer->bflops + l.self_layer->bflops + l.output_layer->bflops;
|
||||
|
||||
return l;
|
||||
}
|
||||
|
||||
@ -128,6 +130,16 @@ void resize_crnn_layer(layer *l, int w, int h)
|
||||
#endif
|
||||
}
|
||||
|
||||
void free_state_crnn(layer l)
|
||||
{
|
||||
int i;
|
||||
for (i = 0; i < l.outputs * l.batch; ++i) l.self_layer->output[i] = rand_uniform(-1, 1);
|
||||
|
||||
#ifdef GPU
|
||||
cuda_push_array(l.self_layer->output_gpu, l.self_layer->output, l.outputs * l.batch);
|
||||
#endif // GPU
|
||||
}
|
||||
|
||||
void update_crnn_layer(layer l, int batch, float learning_rate, float momentum, float decay)
|
||||
{
|
||||
update_convolutional_layer(*(l.input_layer), batch, learning_rate, momentum, decay);
|
||||
|
@ -11,6 +11,7 @@ extern "C" {
|
||||
#endif
|
||||
layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int xnor);
|
||||
void resize_crnn_layer(layer *l, int w, int h);
|
||||
void free_state_crnn(layer l);
|
||||
|
||||
void forward_crnn_layer(layer l, network_state state);
|
||||
void backward_crnn_layer(layer l, network_state state);
|
||||
|
28
src/data.c
28
src/data.c
@ -231,6 +231,15 @@ void correct_boxes(box_label *boxes, int n, float dx, float dy, float sx, float
|
||||
boxes[i].h = 999999;
|
||||
continue;
|
||||
}
|
||||
if ((boxes[i].x + boxes[i].w / 2) < 0 || (boxes[i].y + boxes[i].h / 2) < 0 ||
|
||||
(boxes[i].x - boxes[i].w / 2) > 1 || (boxes[i].y - boxes[i].h / 2) > 1)
|
||||
{
|
||||
boxes[i].x = 999999;
|
||||
boxes[i].y = 999999;
|
||||
boxes[i].w = 999999;
|
||||
boxes[i].h = 999999;
|
||||
continue;
|
||||
}
|
||||
boxes[i].left = boxes[i].left * sx - dx;
|
||||
boxes[i].right = boxes[i].right * sx - dx;
|
||||
boxes[i].top = boxes[i].top * sy - dy;
|
||||
@ -378,7 +387,7 @@ void fill_truth_detection(const char *path, int num_boxes, float *truth, int cla
|
||||
continue;
|
||||
}
|
||||
if (x == 999999 || y == 999999) {
|
||||
printf("\n Wrong annotation: x = 0, y = 0 \n");
|
||||
printf("\n Wrong annotation: x = 0, y = 0, < 0 or > 1 \n");
|
||||
sprintf(buff, "echo %s \"Wrong annotation: x = 0 or y = 0\" >> bad_label.list", labelpath);
|
||||
system(buff);
|
||||
++sub;
|
||||
@ -769,7 +778,7 @@ static box float_to_box_stride(float *f, int stride)
|
||||
|
||||
#include "http_stream.h"
|
||||
|
||||
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter,
|
||||
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, float jitter,
|
||||
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs)
|
||||
{
|
||||
c = c ? c : 3;
|
||||
@ -785,7 +794,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
||||
d.X.cols = h*w*c;
|
||||
|
||||
float r1 = 0, r2 = 0, r3 = 0, r4 = 0;
|
||||
float dhue = 0, dsat = 0, dexp = 0, flip = 0;
|
||||
float dhue = 0, dsat = 0, dexp = 0, flip = 0, blur = 0;
|
||||
int augmentation_calculated = 0;
|
||||
|
||||
d.y = make_matrix(n, 5*boxes);
|
||||
@ -819,6 +828,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
||||
dexp = rand_scale(exposure);
|
||||
|
||||
flip = use_flip ? random_gen() % 2 : 0;
|
||||
blur = rand_int(0, 1) ? (use_blur) : 0;
|
||||
}
|
||||
|
||||
int pleft = rand_precalc_random(-dw, dw, r1);
|
||||
@ -835,10 +845,12 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
||||
float dx = ((float)pleft/ow)/sx;
|
||||
float dy = ((float)ptop /oh)/sy;
|
||||
|
||||
image ai = image_data_augmentation(src, w, h, pleft, ptop, swidth, sheight, flip, jitter, dhue, dsat, dexp);
|
||||
d.X.vals[i] = ai.data;
|
||||
fill_truth_detection(filename, boxes, d.y.vals[i], classes, flip, dx, dy, 1. / sx, 1. / sy, w, h);
|
||||
|
||||
fill_truth_detection(filename, boxes, d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy, w, h);
|
||||
image ai = image_data_augmentation(src, w, h, pleft, ptop, swidth, sheight, flip, jitter, dhue, dsat, dexp,
|
||||
blur, boxes, d.y.vals[i]);
|
||||
|
||||
d.X.vals[i] = ai.data;
|
||||
|
||||
if(show_imgs)
|
||||
{
|
||||
@ -869,7 +881,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
|
||||
return d;
|
||||
}
|
||||
#else // OPENCV
|
||||
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter,
|
||||
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, float jitter,
|
||||
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs)
|
||||
{
|
||||
c = c ? c : 3;
|
||||
@ -989,7 +1001,7 @@ void *load_thread(void *ptr)
|
||||
} else if (a.type == REGION_DATA){
|
||||
*a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
|
||||
} else if (a.type == DETECTION_DATA){
|
||||
*a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.c, a.num_boxes, a.classes, a.flip, a.jitter,
|
||||
*a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.c, a.num_boxes, a.classes, a.flip, a.blur, a.jitter,
|
||||
a.hue, a.saturation, a.exposure, a.mini_batch, a.track, a.augment_speed, a.show_imgs);
|
||||
} else if (a.type == SWAG_DATA){
|
||||
*a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);
|
||||
|
@ -86,7 +86,7 @@ void print_letters(float *pred, int n);
|
||||
data load_data_captcha(char **paths, int n, int m, int k, int w, int h);
|
||||
data load_data_captcha_encode(char **paths, int n, int m, int w, int h);
|
||||
data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int h);
|
||||
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter,
|
||||
data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, float jitter,
|
||||
float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs);
|
||||
data load_data_tag(char **paths, int n, int m, int k, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
|
||||
matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure);
|
||||
|
@ -37,6 +37,8 @@ static int demo_ext_output = 0;
|
||||
static long long int frame_id = 0;
|
||||
static int demo_json_port = -1;
|
||||
|
||||
#define NFRAMES 3
|
||||
|
||||
static float* predictions[NFRAMES];
|
||||
static int demo_index = 0;
|
||||
static image images[NFRAMES];
|
||||
|
@ -47,6 +47,15 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
||||
for (k = 0; k < net_map.n; ++k) {
|
||||
free_layer(net_map.layers[k]);
|
||||
}
|
||||
|
||||
char *name_list = option_find_str(options, "names", "data/names.list");
|
||||
int names_size = 0;
|
||||
char **names = get_labels_custom(name_list, &names_size);
|
||||
if (net_map.layers[net_map.n - 1].classes != names_size) {
|
||||
printf(" Error: in the file %s number of names %d that isn't equal to classes=%d in the file %s \n",
|
||||
name_list, names_size, net_map.layers[net_map.n - 1].classes, cfgfile);
|
||||
if (net_map.layers[net_map.n - 1].classes > names_size) getchar();
|
||||
}
|
||||
}
|
||||
|
||||
srand(time(0));
|
||||
@ -119,6 +128,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
||||
args.threads = 64; // 16 or 64
|
||||
|
||||
args.angle = net.angle;
|
||||
args.blur = net.blur;
|
||||
args.exposure = net.exposure;
|
||||
args.saturation = net.saturation;
|
||||
args.hue = net.hue;
|
||||
@ -137,7 +147,8 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
||||
if (net.track) {
|
||||
args.track = net.track;
|
||||
args.augment_speed = net.augment_speed;
|
||||
args.threads = net.subdivisions * ngpus; // 2 * ngpus;
|
||||
if (net.sequential_subdivisions) args.threads = net.sequential_subdivisions * ngpus;
|
||||
else args.threads = net.subdivisions * ngpus; // 2 * ngpus;
|
||||
args.mini_batch = net.batch / net.time_steps;
|
||||
printf("\n Tracking! batch = %d, subdiv = %d, time_steps = %d, mini_batch = %d \n", net.batch, net.subdivisions, net.time_steps, args.mini_batch);
|
||||
}
|
||||
@ -223,7 +234,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
|
||||
calc_map_for_each = fmax(calc_map_for_each, 100);
|
||||
int next_map_calc = iter_map + calc_map_for_each;
|
||||
next_map_calc = fmax(next_map_calc, net.burn_in);
|
||||
next_map_calc = fmax(next_map_calc, 1000);
|
||||
next_map_calc = fmax(next_map_calc, 400);
|
||||
if (calc_map) {
|
||||
printf("\n (next mAP calculation at %d iterations) ", next_map_calc);
|
||||
if (mean_average_precision > 0) printf("\n Last accuracy mAP@0.5 = %2.2f %% ", mean_average_precision * 100);
|
||||
@ -638,7 +649,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
|
||||
char *valid_images = option_find_str(options, "valid", "data/train.txt");
|
||||
char *difficult_valid_images = option_find_str(options, "difficult", NULL);
|
||||
char *name_list = option_find_str(options, "names", "data/names.list");
|
||||
char **names = get_labels(name_list);
|
||||
int names_size = 0;
|
||||
char **names = get_labels_custom(name_list, &names_size); //get_labels(name_list);
|
||||
//char *mapf = option_find_str(options, "map", 0);
|
||||
//int *map = 0;
|
||||
//if (mapf) map = read_map(mapf);
|
||||
@ -650,6 +662,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
|
||||
char *train_images = option_find_str(options, "train", "data/train.txt");
|
||||
valid_images = option_find_str(options, "valid", train_images);
|
||||
net = *existing_net;
|
||||
remember_network_recurrent_state(*existing_net);
|
||||
free_network_recurrent_state(*existing_net);
|
||||
}
|
||||
else {
|
||||
net = parse_network_cfg_custom(cfgfile, 1, 1); // set batch=1
|
||||
@ -660,6 +674,11 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
|
||||
fuse_conv_batchnorm(net);
|
||||
calculate_binary_weights(net);
|
||||
}
|
||||
if (net.layers[net.n - 1].classes != names_size) {
|
||||
printf(" Error: in the file %s number of names %d that isn't equal to classes=%d in the file %s \n",
|
||||
name_list, names_size, net.layers[net.n - 1].classes, cfgfile);
|
||||
getchar();
|
||||
}
|
||||
srand(time(0));
|
||||
printf("\n calculation mAP (mean average precision)...\n");
|
||||
|
||||
@ -1053,6 +1072,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
|
||||
|
||||
if (existing_net) {
|
||||
//set_batch_network(&net, initial_batch);
|
||||
//free_network_recurrent_state(*existing_net);
|
||||
restore_network_recurrent_state(*existing_net);
|
||||
}
|
||||
else {
|
||||
free_network(net);
|
||||
@ -1220,7 +1241,7 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int
|
||||
|
||||
if (show) {
|
||||
#ifdef OPENCV
|
||||
//show_acnhors(number_of_boxes, num_of_clusters, rel_width_height_array, anchors_data, width, height);
|
||||
show_acnhors(number_of_boxes, num_of_clusters, rel_width_height_array, anchors_data, width, height);
|
||||
#endif // OPENCV
|
||||
}
|
||||
free(rel_width_height_array);
|
||||
|
@ -1125,9 +1125,20 @@ void draw_train_loss(mat_cv* img_src, int img_size, float avg_loss, float max_im
|
||||
// ====================================================================
|
||||
// Data augmentation
|
||||
// ====================================================================
|
||||
static box float_to_box_stride(float *f, int stride)
|
||||
{
|
||||
box b = { 0 };
|
||||
b.x = f[0];
|
||||
b.y = f[1 * stride];
|
||||
b.w = f[2 * stride];
|
||||
b.h = f[3 * stride];
|
||||
return b;
|
||||
}
|
||||
|
||||
image image_data_augmentation(mat_cv* mat, int w, int h,
|
||||
int pleft, int ptop, int swidth, int sheight, int flip,
|
||||
float jitter, float dhue, float dsat, float dexp)
|
||||
float jitter, float dhue, float dsat, float dexp,
|
||||
int blur, int num_boxes, float *truth)
|
||||
{
|
||||
image out;
|
||||
try {
|
||||
@ -1192,6 +1203,31 @@ image image_data_augmentation(mat_cv* mat, int w, int h,
|
||||
//cv::imshow(window_name.str(), sized);
|
||||
//cv::waitKey(0);
|
||||
|
||||
if (blur) {
|
||||
cv::Mat dst(sized.size(), sized.type());
|
||||
if(blur == 1) cv::GaussianBlur(sized, dst, cv::Size(31, 31), 0);
|
||||
else cv::GaussianBlur(sized, dst, cv::Size((blur / 2) * 2 + 1, (blur / 2) * 2 + 1), 0);
|
||||
cv::Rect img_rect(0, 0, sized.cols, sized.rows);
|
||||
//std::cout << " blur num_boxes = " << num_boxes << std::endl;
|
||||
|
||||
if (blur == 1) {
|
||||
int t;
|
||||
for (t = 0; t < num_boxes; ++t) {
|
||||
box b = float_to_box_stride(truth + t*(4 + 1), 1);
|
||||
if (!b.x) break;
|
||||
int left = (b.x - b.w / 2.)*sized.cols;
|
||||
int width = b.w*sized.cols;
|
||||
int top = (b.y - b.h / 2.)*sized.rows;
|
||||
int height = b.h*sized.rows;
|
||||
cv::Rect roi(left, top, width, height);
|
||||
roi = roi & img_rect;
|
||||
|
||||
sized(roi).copyTo(dst(roi));
|
||||
}
|
||||
}
|
||||
dst.copyTo(sized);
|
||||
}
|
||||
|
||||
// Mat -> image
|
||||
out = mat_to_image(sized);
|
||||
}
|
||||
|
@ -95,7 +95,8 @@ void draw_train_loss(mat_cv* img, int img_size, float avg_loss, float max_img_lo
|
||||
// Data augmentation
|
||||
image image_data_augmentation(mat_cv* mat, int w, int h,
|
||||
int pleft, int ptop, int swidth, int sheight, int flip,
|
||||
float jitter, float dhue, float dsat, float dexp);
|
||||
float jitter, float dhue, float dsat, float dexp,
|
||||
int blur, int num_boxes, float *truth);
|
||||
|
||||
// Show Anchors
|
||||
void show_acnhors(int number_of_boxes, int num_of_clusters, float *rel_width_height_array, model anchors_data, int width, int height);
|
||||
|
85
src/layer.c
85
src/layer.c
@ -2,22 +2,40 @@
|
||||
#include "dark_cuda.h"
|
||||
#include <stdlib.h>
|
||||
|
||||
void free_sublayer(layer *l)
|
||||
{
|
||||
if (l) {
|
||||
free_layer(*l);
|
||||
free(l);
|
||||
}
|
||||
}
|
||||
|
||||
void free_layer(layer l)
|
||||
{
|
||||
// free layers: input_layer, self_layer, output_layer, ...
|
||||
if (l.type == CONV_LSTM) {
|
||||
if (l.peephole) {
|
||||
free_sublayer(l.vf);
|
||||
free_sublayer(l.vi);
|
||||
free_sublayer(l.vo);
|
||||
}
|
||||
else {
|
||||
free(l.vf);
|
||||
free(l.vi);
|
||||
free(l.vo);
|
||||
}
|
||||
free_sublayer(l.wf);
|
||||
free_sublayer(l.wi);
|
||||
free_sublayer(l.wg);
|
||||
free_sublayer(l.wo);
|
||||
free_sublayer(l.uf);
|
||||
free_sublayer(l.ui);
|
||||
free_sublayer(l.ug);
|
||||
free_sublayer(l.uo);
|
||||
}
|
||||
if (l.type == CRNN) {
|
||||
if (l.input_layer) {
|
||||
free_layer(*l.input_layer);
|
||||
free(l.input_layer);
|
||||
}
|
||||
if (l.self_layer) {
|
||||
free_layer(*l.self_layer);
|
||||
free(l.self_layer);
|
||||
}
|
||||
if (l.output_layer) {
|
||||
free_layer(*l.output_layer);
|
||||
free(l.output_layer);
|
||||
}
|
||||
free_sublayer(l.input_layer);
|
||||
free_sublayer(l.self_layer);
|
||||
free_sublayer(l.output_layer);
|
||||
l.output = NULL;
|
||||
l.delta = NULL;
|
||||
#ifdef GPU
|
||||
@ -83,21 +101,36 @@ void free_layer(layer l)
|
||||
if (l.v) free(l.v);
|
||||
if (l.z_cpu) free(l.z_cpu);
|
||||
if (l.r_cpu) free(l.r_cpu);
|
||||
if (l.h_cpu) free(l.h_cpu);
|
||||
if (l.binary_input) free(l.binary_input);
|
||||
if (l.bin_re_packed_input) free(l.bin_re_packed_input);
|
||||
if (l.t_bit_input) free(l.t_bit_input);
|
||||
if (l.loss) free(l.loss);
|
||||
|
||||
// CONV-LSTM
|
||||
if (l.f_cpu) free(l.f_cpu);
|
||||
if (l.i_cpu) free(l.i_cpu);
|
||||
if (l.g_cpu) free(l.g_cpu);
|
||||
if (l.o_cpu) free(l.o_cpu);
|
||||
if (l.c_cpu) free(l.c_cpu);
|
||||
if (l.h_cpu) free(l.h_cpu);
|
||||
if (l.temp_cpu) free(l.temp_cpu);
|
||||
if (l.temp2_cpu) free(l.temp2_cpu);
|
||||
if (l.temp3_cpu) free(l.temp3_cpu);
|
||||
if (l.dc_cpu) free(l.dc_cpu);
|
||||
if (l.dh_cpu) free(l.dh_cpu);
|
||||
if (l.prev_state_cpu) free(l.prev_state_cpu);
|
||||
if (l.prev_cell_cpu) free(l.prev_cell_cpu);
|
||||
if (l.stored_c_cpu) free(l.stored_c_cpu);
|
||||
if (l.stored_h_cpu) free(l.stored_h_cpu);
|
||||
if (l.cell_cpu) free(l.cell_cpu);
|
||||
|
||||
#ifdef GPU
|
||||
if (l.indexes_gpu) cuda_free((float *)l.indexes_gpu);
|
||||
|
||||
if (l.z_gpu) cuda_free(l.z_gpu);
|
||||
if (l.r_gpu) cuda_free(l.r_gpu);
|
||||
if (l.h_gpu) cuda_free(l.h_gpu);
|
||||
if (l.m_gpu) cuda_free(l.m_gpu);
|
||||
if (l.v_gpu) cuda_free(l.v_gpu);
|
||||
if (l.prev_state_gpu) cuda_free(l.prev_state_gpu);
|
||||
if (l.forgot_state_gpu) cuda_free(l.forgot_state_gpu);
|
||||
if (l.forgot_delta_gpu) cuda_free(l.forgot_delta_gpu);
|
||||
if (l.state_gpu) cuda_free(l.state_gpu);
|
||||
@ -137,5 +170,25 @@ void free_layer(layer l)
|
||||
if (l.rand_gpu) cuda_free(l.rand_gpu);
|
||||
if (l.squared_gpu) cuda_free(l.squared_gpu);
|
||||
if (l.norms_gpu) cuda_free(l.norms_gpu);
|
||||
|
||||
// CONV-LSTM
|
||||
if (l.f_gpu) cuda_free(l.f_gpu);
|
||||
if (l.i_gpu) cuda_free(l.i_gpu);
|
||||
if (l.g_gpu) cuda_free(l.g_gpu);
|
||||
if (l.o_gpu) cuda_free(l.o_gpu);
|
||||
if (l.c_gpu) cuda_free(l.c_gpu);
|
||||
if (l.h_gpu) cuda_free(l.h_gpu);
|
||||
if (l.temp_gpu) cuda_free(l.temp_gpu);
|
||||
if (l.temp2_gpu) cuda_free(l.temp2_gpu);
|
||||
if (l.temp3_gpu) cuda_free(l.temp3_gpu);
|
||||
if (l.dc_gpu) cuda_free(l.dc_gpu);
|
||||
if (l.dh_gpu) cuda_free(l.dh_gpu);
|
||||
if (l.prev_state_gpu) cuda_free(l.prev_state_gpu);
|
||||
if (l.prev_cell_gpu) cuda_free(l.prev_cell_gpu);
|
||||
if (l.stored_c_gpu) cuda_free(l.stored_c_gpu);
|
||||
if (l.stored_h_gpu) cuda_free(l.stored_h_gpu);
|
||||
if (l.last_prev_state_gpu) cuda_free(l.last_prev_state_gpu);
|
||||
if (l.last_prev_cell_gpu) cuda_free(l.last_prev_cell_gpu);
|
||||
if (l.cell_gpu) cuda_free(l.cell_gpu);
|
||||
#endif
|
||||
}
|
||||
|
@ -95,6 +95,7 @@ layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_n
|
||||
|
||||
l.forward = forward_lstm_layer;
|
||||
l.update = update_lstm_layer;
|
||||
l.backward = backward_lstm_layer;
|
||||
|
||||
l.prev_state_cpu = (float*)calloc(batch*outputs, sizeof(float));
|
||||
l.prev_cell_cpu = (float*)calloc(batch*outputs, sizeof(float));
|
||||
|
@ -12,6 +12,7 @@ extern "C" {
|
||||
layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_normalize);
|
||||
|
||||
void forward_lstm_layer(layer l, network_state state);
|
||||
void backward_lstm_layer(layer l, network_state state);
|
||||
void update_lstm_layer(layer l, int batch, float learning_rate, float momentum, float decay);
|
||||
|
||||
#ifdef GPU
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include "gru_layer.h"
|
||||
#include "rnn_layer.h"
|
||||
#include "crnn_layer.h"
|
||||
#include "conv_lstm_layer.h"
|
||||
#include "local_layer.h"
|
||||
#include "convolutional_layer.h"
|
||||
#include "activation_layer.h"
|
||||
@ -315,6 +316,7 @@ float train_network_sgd(network net, data d, int n)
|
||||
float sum = 0;
|
||||
for(i = 0; i < n; ++i){
|
||||
get_random_batch(d, batch, X, y);
|
||||
net.current_subdivision = i;
|
||||
float err = train_network_datum(net, X, y);
|
||||
sum += err;
|
||||
}
|
||||
@ -340,6 +342,7 @@ float train_network_waitkey(network net, data d, int wait_key)
|
||||
float sum = 0;
|
||||
for(i = 0; i < n; ++i){
|
||||
get_next_batch(d, batch, i*batch, X, y);
|
||||
net.current_subdivision = i;
|
||||
float err = train_network_datum(net, X, y);
|
||||
sum += err;
|
||||
if(wait_key) wait_key_cv(5);
|
||||
@ -1111,3 +1114,31 @@ network combine_train_valid_networks(network net_train, network net_map)
|
||||
}
|
||||
return net_combined;
|
||||
}
|
||||
|
||||
void free_network_recurrent_state(network net)
|
||||
{
|
||||
int k;
|
||||
for (k = 0; k < net.n; ++k) {
|
||||
if (net.layers[k].type == CONV_LSTM) free_state_conv_lstm(net.layers[k]);
|
||||
if (net.layers[k].type == CRNN) free_state_crnn(net.layers[k]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void remember_network_recurrent_state(network net)
|
||||
{
|
||||
int k;
|
||||
for (k = 0; k < net.n; ++k) {
|
||||
if (net.layers[k].type == CONV_LSTM) remember_state_conv_lstm(net.layers[k]);
|
||||
//if (net.layers[k].type == CRNN) free_state_crnn(net.layers[k]);
|
||||
}
|
||||
}
|
||||
|
||||
void restore_network_recurrent_state(network net)
|
||||
{
|
||||
int k;
|
||||
for (k = 0; k < net.n; ++k) {
|
||||
if (net.layers[k].type == CONV_LSTM) restore_state_conv_lstm(net.layers[k]);
|
||||
if (net.layers[k].type == CRNN) free_state_crnn(net.layers[k]);
|
||||
}
|
||||
}
|
@ -163,6 +163,9 @@ int get_network_background(network net);
|
||||
//LIB_API void calculate_binary_weights(network net);
|
||||
network combine_train_valid_networks(network net_train, network net_map);
|
||||
void copy_weights_net(network net_train, network *net_map);
|
||||
void free_network_recurrent_state(network net);
|
||||
void remember_network_recurrent_state(network net);
|
||||
void restore_network_recurrent_state(network net);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
@ -171,6 +171,22 @@ void forward_backward_network_gpu(network net, float *x, float *y)
|
||||
cuda_convert_f32_to_f16(l.self_layer->weights_gpu, l.self_layer->nweights, l.self_layer->weights_gpu16);
|
||||
cuda_convert_f32_to_f16(l.output_layer->weights_gpu, l.output_layer->nweights, l.output_layer->weights_gpu16);
|
||||
}
|
||||
else if (l.type == CONV_LSTM && l.wf->weights_gpu && l.wf->weights_gpu16) {
|
||||
assert((l.wf->c * l.wf->n * l.wf->size * l.wf->size) > 0);
|
||||
if (l.peephole) {
|
||||
cuda_convert_f32_to_f16(l.vf->weights_gpu, l.vf->nweights, l.vf->weights_gpu16);
|
||||
cuda_convert_f32_to_f16(l.vi->weights_gpu, l.vi->nweights, l.vi->weights_gpu16);
|
||||
cuda_convert_f32_to_f16(l.vo->weights_gpu, l.vo->nweights, l.vo->weights_gpu16);
|
||||
}
|
||||
cuda_convert_f32_to_f16(l.wf->weights_gpu, l.wf->nweights, l.wf->weights_gpu16);
|
||||
cuda_convert_f32_to_f16(l.wi->weights_gpu, l.wi->nweights, l.wi->weights_gpu16);
|
||||
cuda_convert_f32_to_f16(l.wg->weights_gpu, l.wg->nweights, l.wg->weights_gpu16);
|
||||
cuda_convert_f32_to_f16(l.wo->weights_gpu, l.wo->nweights, l.wo->weights_gpu16);
|
||||
cuda_convert_f32_to_f16(l.uf->weights_gpu, l.uf->nweights, l.uf->weights_gpu16);
|
||||
cuda_convert_f32_to_f16(l.ui->weights_gpu, l.ui->nweights, l.ui->weights_gpu16);
|
||||
cuda_convert_f32_to_f16(l.ug->weights_gpu, l.ug->nweights, l.ug->weights_gpu16);
|
||||
cuda_convert_f32_to_f16(l.uo->weights_gpu, l.uo->nweights, l.uo->weights_gpu16);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
58
src/parser.c
58
src/parser.c
@ -20,6 +20,7 @@
|
||||
#include "list.h"
|
||||
#include "local_layer.h"
|
||||
#include "lstm_layer.h"
|
||||
#include "conv_lstm_layer.h"
|
||||
#include "maxpool_layer.h"
|
||||
#include "normalization_layer.h"
|
||||
#include "option_list.h"
|
||||
@ -61,6 +62,7 @@ LAYER_TYPE string_to_layer_type(char * type)
|
||||
if (strcmp(type, "[crnn]")==0) return CRNN;
|
||||
if (strcmp(type, "[gru]")==0) return GRU;
|
||||
if (strcmp(type, "[lstm]")==0) return LSTM;
|
||||
if (strcmp(type, "[conv_lstm]") == 0) return CONV_LSTM;
|
||||
if (strcmp(type, "[rnn]")==0) return RNN;
|
||||
if (strcmp(type, "[conn]")==0
|
||||
|| strcmp(type, "[connected]")==0) return CONNECTED;
|
||||
@ -239,6 +241,29 @@ layer parse_lstm(list *options, size_params params)
|
||||
return l;
|
||||
}
|
||||
|
||||
layer parse_conv_lstm(list *options, size_params params)
|
||||
{
|
||||
// a ConvLSTM with a larger transitional kernel should be able to capture faster motions
|
||||
int size = option_find_int_quiet(options, "size", 3);
|
||||
int stride = option_find_int_quiet(options, "stride", 1);
|
||||
int pad = option_find_int_quiet(options, "pad", 0);
|
||||
int padding = option_find_int_quiet(options, "padding", 0);
|
||||
if (pad) padding = size / 2;
|
||||
|
||||
int output_filters = option_find_int(options, "output", 1);
|
||||
char *activation_s = option_find_str(options, "activation", "LINEAR");
|
||||
ACTIVATION activation = get_activation(activation_s);
|
||||
int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0);
|
||||
int xnor = option_find_int_quiet(options, "xnor", 0);
|
||||
int peephole = option_find_int_quiet(options, "peephole", 1);
|
||||
|
||||
layer l = make_conv_lstm_layer(params.batch, params.w, params.h, params.c, output_filters, params.time_steps, size, stride, padding, activation, batch_normalize, peephole, xnor);
|
||||
|
||||
l.shortcut = option_find_int_quiet(options, "shortcut", 0);
|
||||
|
||||
return l;
|
||||
}
|
||||
|
||||
connected_layer parse_connected(list *options, size_params params)
|
||||
{
|
||||
int output = option_find_int(options, "output",1);
|
||||
@ -647,6 +672,7 @@ void parse_net_options(list *options, network *net)
|
||||
net->time_steps = option_find_int_quiet(options, "time_steps",1);
|
||||
net->track = option_find_int_quiet(options, "track", 0);
|
||||
net->augment_speed = option_find_int_quiet(options, "augment_speed", 2);
|
||||
net->sequential_subdivisions = option_find_int_quiet(options, "sequential_subdivisions", 0);
|
||||
net->try_fix_nan = option_find_int_quiet(options, "try_fix_nan", 0);
|
||||
net->batch /= subdivs;
|
||||
net->batch *= net->time_steps;
|
||||
@ -666,6 +692,7 @@ void parse_net_options(list *options, network *net)
|
||||
net->max_crop = option_find_int_quiet(options, "max_crop",net->w*2);
|
||||
net->min_crop = option_find_int_quiet(options, "min_crop",net->w);
|
||||
net->flip = option_find_int_quiet(options, "flip", 1);
|
||||
net->blur = option_find_int_quiet(options, "blur", 0);
|
||||
|
||||
net->angle = option_find_float_quiet(options, "angle", 0);
|
||||
net->aspect = option_find_float_quiet(options, "aspect", 1);
|
||||
@ -789,6 +816,8 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
|
||||
l = parse_gru(options, params);
|
||||
}else if(lt == LSTM){
|
||||
l = parse_lstm(options, params);
|
||||
}else if (lt == CONV_LSTM) {
|
||||
l = parse_conv_lstm(options, params);
|
||||
}else if(lt == CRNN){
|
||||
l = parse_crnn(options, params);
|
||||
}else if(lt == CONNECTED){
|
||||
@ -1076,6 +1105,20 @@ void save_weights_upto(network net, char *filename, int cutoff)
|
||||
save_connected_weights(*(l.ui), fp);
|
||||
save_connected_weights(*(l.ug), fp);
|
||||
save_connected_weights(*(l.uo), fp);
|
||||
} if (l.type == CONV_LSTM) {
|
||||
if (l.peephole) {
|
||||
save_convolutional_weights(*(l.vf), fp);
|
||||
save_convolutional_weights(*(l.vi), fp);
|
||||
save_convolutional_weights(*(l.vo), fp);
|
||||
}
|
||||
save_convolutional_weights(*(l.wf), fp);
|
||||
save_convolutional_weights(*(l.wi), fp);
|
||||
save_convolutional_weights(*(l.wg), fp);
|
||||
save_convolutional_weights(*(l.wo), fp);
|
||||
save_convolutional_weights(*(l.uf), fp);
|
||||
save_convolutional_weights(*(l.ui), fp);
|
||||
save_convolutional_weights(*(l.ug), fp);
|
||||
save_convolutional_weights(*(l.uo), fp);
|
||||
} if(l.type == CRNN){
|
||||
save_convolutional_weights(*(l.input_layer), fp);
|
||||
save_convolutional_weights(*(l.self_layer), fp);
|
||||
@ -1298,6 +1341,21 @@ void load_weights_upto(network *net, char *filename, int cutoff)
|
||||
load_connected_weights(*(l.ug), fp, transpose);
|
||||
load_connected_weights(*(l.uo), fp, transpose);
|
||||
}
|
||||
if (l.type == CONV_LSTM) {
|
||||
if (l.peephole) {
|
||||
load_convolutional_weights(*(l.vf), fp);
|
||||
load_convolutional_weights(*(l.vi), fp);
|
||||
load_convolutional_weights(*(l.vo), fp);
|
||||
}
|
||||
load_convolutional_weights(*(l.wf), fp);
|
||||
load_convolutional_weights(*(l.wi), fp);
|
||||
load_convolutional_weights(*(l.wg), fp);
|
||||
load_convolutional_weights(*(l.wo), fp);
|
||||
load_convolutional_weights(*(l.uf), fp);
|
||||
load_convolutional_weights(*(l.ui), fp);
|
||||
load_convolutional_weights(*(l.ug), fp);
|
||||
load_convolutional_weights(*(l.uo), fp);
|
||||
}
|
||||
if(l.type == LOCAL){
|
||||
int locations = l.out_w*l.out_h;
|
||||
int size = l.size*l.size*l.c*l.n*locations;
|
||||
|
@ -22,6 +22,7 @@ extern "C" {
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
#define NFRAMES 3
|
||||
|
||||
//static Detector* detector = NULL;
|
||||
static std::unique_ptr<Detector> detector;
|
||||
|
Reference in New Issue
Block a user