mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
art, cudnn
This commit is contained in:
9
Makefile
9
Makefile
@ -1,4 +1,5 @@
|
|||||||
GPU=0
|
GPU=0
|
||||||
|
CUDNN=0
|
||||||
OPENCV=0
|
OPENCV=0
|
||||||
DEBUG=0
|
DEBUG=0
|
||||||
|
|
||||||
@ -34,7 +35,13 @@ CFLAGS+= -DGPU
|
|||||||
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
|
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
|
||||||
endif
|
endif
|
||||||
|
|
||||||
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o coco_demo.o tag.o cifar.o yolo_demo.o go.o batchnorm_layer.o
|
ifeq ($(CUDNN), 1)
|
||||||
|
COMMON+= -DCUDNN
|
||||||
|
CFLAGS+= -DCUDNN
|
||||||
|
LDFLAGS+= -lcudnn
|
||||||
|
endif
|
||||||
|
|
||||||
|
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o coco_demo.o tag.o cifar.o yolo_demo.o go.o batchnorm_layer.o art.o
|
||||||
ifeq ($(GPU), 1)
|
ifeq ($(GPU), 1)
|
||||||
LDFLAGS+= -lstdc++
|
LDFLAGS+= -lstdc++
|
||||||
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
|
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
|
||||||
|
76
src/art.c
Normal file
76
src/art.c
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
#include "network.h"
|
||||||
|
#include "utils.h"
|
||||||
|
#include "parser.h"
|
||||||
|
#include "option_list.h"
|
||||||
|
#include "blas.h"
|
||||||
|
#include "classifier.h"
|
||||||
|
#include <sys/time.h>
|
||||||
|
|
||||||
|
#ifdef OPENCV
|
||||||
|
#include "opencv2/highgui/highgui_c.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
void demo_art(char *cfgfile, char *weightfile, int cam_index)
|
||||||
|
{
|
||||||
|
#ifdef OPENCV
|
||||||
|
network net = parse_network_cfg(cfgfile);
|
||||||
|
if(weightfile){
|
||||||
|
load_weights(&net, weightfile);
|
||||||
|
}
|
||||||
|
set_batch_network(&net, 1);
|
||||||
|
|
||||||
|
srand(2222222);
|
||||||
|
CvCapture * cap;
|
||||||
|
|
||||||
|
cap = cvCaptureFromCAM(cam_index);
|
||||||
|
|
||||||
|
char *window = "ArtJudgementBot9000!!!";
|
||||||
|
if(!cap) error("Couldn't connect to webcam.\n");
|
||||||
|
cvNamedWindow(window, CV_WINDOW_NORMAL);
|
||||||
|
cvResizeWindow(window, 512, 512);
|
||||||
|
int i;
|
||||||
|
int idx[] = {37, 401, 434};
|
||||||
|
int n = sizeof(idx)/sizeof(idx[0]);
|
||||||
|
|
||||||
|
while(1){
|
||||||
|
image in = get_image_from_stream(cap);
|
||||||
|
image in_s = resize_image(in, net.w, net.h);
|
||||||
|
show_image(in, window);
|
||||||
|
|
||||||
|
float *p = network_predict(net, in_s.data);
|
||||||
|
|
||||||
|
printf("\033[2J");
|
||||||
|
printf("\033[1;1H");
|
||||||
|
|
||||||
|
float score = 0;
|
||||||
|
for(i = 0; i < n; ++i){
|
||||||
|
float s = p[idx[i]];
|
||||||
|
if (s > score) score = s;
|
||||||
|
}
|
||||||
|
score = score;
|
||||||
|
printf("I APPRECIATE THIS ARTWORK: %10.7f%%\n", score*100);
|
||||||
|
printf("[");
|
||||||
|
int upper = 30;
|
||||||
|
for(i = 0; i < upper; ++i){
|
||||||
|
printf("%s", ((i+.5) < score*upper) ? "\u2588" : " ");
|
||||||
|
}
|
||||||
|
printf("]\n");
|
||||||
|
|
||||||
|
free_image(in_s);
|
||||||
|
free_image(in);
|
||||||
|
|
||||||
|
cvWaitKey(1);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void run_art(int argc, char **argv)
|
||||||
|
{
|
||||||
|
int cam_index = find_int_arg(argc, argv, "-c", 0);
|
||||||
|
char *cfg = argv[2];
|
||||||
|
char *weights = argv[3];
|
||||||
|
demo_art(cfg, weights, cam_index);
|
||||||
|
}
|
||||||
|
|
@ -85,7 +85,6 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
|||||||
|
|
||||||
if(l.xnor){
|
if(l.xnor){
|
||||||
binarize_filters_gpu(l.filters_gpu, l.n, l.c*l.size*l.size, l.binary_filters_gpu);
|
binarize_filters_gpu(l.filters_gpu, l.n, l.c*l.size*l.size, l.binary_filters_gpu);
|
||||||
//binarize_gpu(l.filters_gpu, l.n*l.c*l.size*l.size, l.binary_filters_gpu);
|
|
||||||
swap_binary(&l);
|
swap_binary(&l);
|
||||||
for(i = 0; i < l.batch; ++i){
|
for(i = 0; i < l.batch; ++i){
|
||||||
binarize_input_gpu(state.input + i*l.inputs, l.c, l.h*l.w, l.binary_input_gpu + i*l.inputs);
|
binarize_input_gpu(state.input + i*l.inputs, l.c, l.h*l.w, l.binary_input_gpu + i*l.inputs);
|
||||||
@ -93,13 +92,31 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
|||||||
state.input = l.binary_input_gpu;
|
state.input = l.binary_input_gpu;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef CUDNN
|
||||||
|
float one = 1;
|
||||||
|
cudnnConvolutionForward(cudnn_handle(),
|
||||||
|
&one,
|
||||||
|
l.srcTensorDesc,
|
||||||
|
state.input,
|
||||||
|
l.filterDesc,
|
||||||
|
l.filters_gpu,
|
||||||
|
l.convDesc,
|
||||||
|
l.fw_algo,
|
||||||
|
state.workspace,
|
||||||
|
l.workspace_size,
|
||||||
|
&one,
|
||||||
|
l.dstTensorDesc,
|
||||||
|
l.output_gpu);
|
||||||
|
|
||||||
|
#else
|
||||||
for(i = 0; i < l.batch; ++i){
|
for(i = 0; i < l.batch; ++i){
|
||||||
im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.col_image_gpu);
|
im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace);
|
||||||
float * a = l.filters_gpu;
|
float * a = l.filters_gpu;
|
||||||
float * b = l.col_image_gpu;
|
float * b = state.workspace;
|
||||||
float * c = l.output_gpu;
|
float * c = l.output_gpu;
|
||||||
gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
|
gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
if (l.batch_normalize) {
|
if (l.batch_normalize) {
|
||||||
forward_batchnorm_layer_gpu(l, state);
|
forward_batchnorm_layer_gpu(l, state);
|
||||||
@ -113,7 +130,6 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
|||||||
|
|
||||||
void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
||||||
{
|
{
|
||||||
int i;
|
|
||||||
int m = l.n;
|
int m = l.n;
|
||||||
int n = l.size*l.size*l.c;
|
int n = l.size*l.size*l.c;
|
||||||
int k = convolutional_out_height(l)*
|
int k = convolutional_out_height(l)*
|
||||||
@ -128,26 +144,61 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
|
|||||||
}
|
}
|
||||||
|
|
||||||
if(l.xnor) state.input = l.binary_input_gpu;
|
if(l.xnor) state.input = l.binary_input_gpu;
|
||||||
|
#ifdef CUDNN
|
||||||
|
float one = 1;
|
||||||
|
cudnnConvolutionBackwardFilter(cudnn_handle(),
|
||||||
|
&one,
|
||||||
|
l.srcTensorDesc,
|
||||||
|
state.input,
|
||||||
|
l.ddstTensorDesc,
|
||||||
|
l.delta_gpu,
|
||||||
|
l.convDesc,
|
||||||
|
l.bf_algo,
|
||||||
|
state.workspace,
|
||||||
|
l.workspace_size,
|
||||||
|
&one,
|
||||||
|
l.dfilterDesc,
|
||||||
|
l.filter_updates_gpu);
|
||||||
|
|
||||||
|
if(state.delta){
|
||||||
|
cudnnConvolutionBackwardData(cudnn_handle(),
|
||||||
|
&one,
|
||||||
|
l.filterDesc,
|
||||||
|
l.filters_gpu,
|
||||||
|
l.ddstTensorDesc,
|
||||||
|
l.delta_gpu,
|
||||||
|
l.convDesc,
|
||||||
|
l.bd_algo,
|
||||||
|
state.workspace,
|
||||||
|
l.workspace_size,
|
||||||
|
&one,
|
||||||
|
l.dsrcTensorDesc,
|
||||||
|
state.delta);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
int i;
|
||||||
for(i = 0; i < l.batch; ++i){
|
for(i = 0; i < l.batch; ++i){
|
||||||
float * a = l.delta_gpu;
|
float * a = l.delta_gpu;
|
||||||
float * b = l.col_image_gpu;
|
float * b = state.workspace;
|
||||||
float * c = l.filter_updates_gpu;
|
float * c = l.filter_updates_gpu;
|
||||||
|
|
||||||
im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.col_image_gpu);
|
im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace);
|
||||||
gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
|
gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
|
||||||
|
|
||||||
if(state.delta){
|
if(state.delta){
|
||||||
if(l.binary || l.xnor) swap_binary(&l);
|
if(l.binary || l.xnor) swap_binary(&l);
|
||||||
float * a = l.filters_gpu;
|
float * a = l.filters_gpu;
|
||||||
float * b = l.delta_gpu;
|
float * b = l.delta_gpu;
|
||||||
float * c = l.col_image_gpu;
|
float * c = state.workspace;
|
||||||
|
|
||||||
gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k);
|
gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k);
|
||||||
|
|
||||||
col2im_ongpu(l.col_image_gpu, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta + i*l.c*l.h*l.w);
|
col2im_ongpu(state.workspace, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta + i*l.c*l.h*l.w);
|
||||||
if(l.binary || l.xnor) swap_binary(&l);
|
if(l.binary || l.xnor) swap_binary(&l);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void pull_convolutional_layer(convolutional_layer layer)
|
void pull_convolutional_layer(convolutional_layer layer)
|
||||||
|
@ -88,6 +88,38 @@ image get_convolutional_delta(convolutional_layer l)
|
|||||||
return float_to_image(w,h,c,l.delta);
|
return float_to_image(w,h,c,l.delta);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef CUDNN
|
||||||
|
size_t get_workspace_size(layer l){
|
||||||
|
size_t most = 0;
|
||||||
|
size_t s = 0;
|
||||||
|
cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
|
||||||
|
l.srcTensorDesc,
|
||||||
|
l.filterDesc,
|
||||||
|
l.convDesc,
|
||||||
|
l.dstTensorDesc,
|
||||||
|
l.fw_algo,
|
||||||
|
&s);
|
||||||
|
if (s > most) most = s;
|
||||||
|
cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
|
||||||
|
l.srcTensorDesc,
|
||||||
|
l.ddstTensorDesc,
|
||||||
|
l.convDesc,
|
||||||
|
l.dfilterDesc,
|
||||||
|
l.bf_algo,
|
||||||
|
&s);
|
||||||
|
if (s > most) most = s;
|
||||||
|
cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
|
||||||
|
l.filterDesc,
|
||||||
|
l.ddstTensorDesc,
|
||||||
|
l.convDesc,
|
||||||
|
l.dsrcTensorDesc,
|
||||||
|
l.bd_algo,
|
||||||
|
&s);
|
||||||
|
if (s > most) most = s;
|
||||||
|
return most;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int binary, int xnor)
|
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int binary, int xnor)
|
||||||
{
|
{
|
||||||
int i;
|
int i;
|
||||||
@ -156,7 +188,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
|||||||
l.scales_gpu = cuda_make_array(l.scales, n);
|
l.scales_gpu = cuda_make_array(l.scales, n);
|
||||||
l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
|
l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
|
||||||
|
|
||||||
l.col_image_gpu = cuda_make_array(l.col_image, out_h*out_w*size*size*c);
|
l.workspace_size = out_h*out_w*size*size*c;
|
||||||
l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
|
l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
|
||||||
l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
|
l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
|
||||||
|
|
||||||
@ -182,6 +214,50 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
|
|||||||
l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
|
l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
|
||||||
l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
|
l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
|
||||||
}
|
}
|
||||||
|
#ifdef CUDNN
|
||||||
|
cudnnCreateTensorDescriptor(&l.srcTensorDesc);
|
||||||
|
cudnnCreateTensorDescriptor(&l.dstTensorDesc);
|
||||||
|
cudnnCreateFilterDescriptor(&l.filterDesc);
|
||||||
|
cudnnCreateTensorDescriptor(&l.dsrcTensorDesc);
|
||||||
|
cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
|
||||||
|
cudnnCreateFilterDescriptor(&l.dfilterDesc);
|
||||||
|
cudnnCreateConvolutionDescriptor(&l.convDesc);
|
||||||
|
cudnnSetTensor4dDescriptor(l.dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w);
|
||||||
|
cudnnSetTensor4dDescriptor(l.ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
|
||||||
|
cudnnSetFilter4dDescriptor(l.dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size);
|
||||||
|
|
||||||
|
cudnnSetTensor4dDescriptor(l.srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w);
|
||||||
|
cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
|
||||||
|
cudnnSetFilter4dDescriptor(l.filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size);
|
||||||
|
int padding = l.pad ? l.size/2 : 0;
|
||||||
|
cudnnSetConvolution2dDescriptor(l.convDesc, padding, padding, l.stride, l.stride, 1, 1, CUDNN_CROSS_CORRELATION);
|
||||||
|
cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
|
||||||
|
l.srcTensorDesc,
|
||||||
|
l.filterDesc,
|
||||||
|
l.convDesc,
|
||||||
|
l.dstTensorDesc,
|
||||||
|
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
|
||||||
|
0,
|
||||||
|
&l.fw_algo);
|
||||||
|
cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
|
||||||
|
l.filterDesc,
|
||||||
|
l.ddstTensorDesc,
|
||||||
|
l.convDesc,
|
||||||
|
l.dsrcTensorDesc,
|
||||||
|
CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
|
||||||
|
0,
|
||||||
|
&l.bd_algo);
|
||||||
|
cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
|
||||||
|
l.srcTensorDesc,
|
||||||
|
l.ddstTensorDesc,
|
||||||
|
l.convDesc,
|
||||||
|
l.dfilterDesc,
|
||||||
|
CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
|
||||||
|
0,
|
||||||
|
&l.bf_algo);
|
||||||
|
l.workspace_size = get_workspace_size(l);
|
||||||
|
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
l.activation = activation;
|
l.activation = activation;
|
||||||
|
|
||||||
@ -247,11 +323,9 @@ void resize_convolutional_layer(convolutional_layer *l, int w, int h)
|
|||||||
l->batch*out_h * out_w * l->n*sizeof(float));
|
l->batch*out_h * out_w * l->n*sizeof(float));
|
||||||
|
|
||||||
#ifdef GPU
|
#ifdef GPU
|
||||||
cuda_free(l->col_image_gpu);
|
|
||||||
cuda_free(l->delta_gpu);
|
cuda_free(l->delta_gpu);
|
||||||
cuda_free(l->output_gpu);
|
cuda_free(l->output_gpu);
|
||||||
|
|
||||||
l->col_image_gpu = cuda_make_array(l->col_image, out_h*out_w*l->size*l->size*l->c);
|
|
||||||
l->delta_gpu = cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
|
l->delta_gpu = cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
|
||||||
l->output_gpu = cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
|
l->output_gpu = cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
|
||||||
#endif
|
#endif
|
||||||
@ -299,12 +373,12 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
|
|||||||
|
|
||||||
fill_cpu(l.outputs*l.batch, 0, l.output, 1);
|
fill_cpu(l.outputs*l.batch, 0, l.output, 1);
|
||||||
/*
|
/*
|
||||||
if(l.binary){
|
if(l.binary){
|
||||||
binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
|
binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
|
||||||
binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
|
binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
|
||||||
swap_binary(&l);
|
swap_binary(&l);
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
if(l.binary){
|
if(l.binary){
|
||||||
int m = l.n;
|
int m = l.n;
|
||||||
|
25
src/cuda.c
25
src/cuda.c
@ -46,6 +46,19 @@ dim3 cuda_gridsize(size_t n){
|
|||||||
return d;
|
return d;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef CUDNN
|
||||||
|
cudnnHandle_t cudnn_handle()
|
||||||
|
{
|
||||||
|
static int init = 0;
|
||||||
|
static cudnnHandle_t handle;
|
||||||
|
if(!init) {
|
||||||
|
cudnnCreate(&handle);
|
||||||
|
init = 1;
|
||||||
|
}
|
||||||
|
return handle;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
cublasHandle_t blas_handle()
|
cublasHandle_t blas_handle()
|
||||||
{
|
{
|
||||||
static int init = 0;
|
static int init = 0;
|
||||||
@ -57,7 +70,7 @@ cublasHandle_t blas_handle()
|
|||||||
return handle;
|
return handle;
|
||||||
}
|
}
|
||||||
|
|
||||||
float *cuda_make_array(float *x, int n)
|
float *cuda_make_array(float *x, size_t n)
|
||||||
{
|
{
|
||||||
float *x_gpu;
|
float *x_gpu;
|
||||||
size_t size = sizeof(float)*n;
|
size_t size = sizeof(float)*n;
|
||||||
@ -71,7 +84,7 @@ float *cuda_make_array(float *x, int n)
|
|||||||
return x_gpu;
|
return x_gpu;
|
||||||
}
|
}
|
||||||
|
|
||||||
void cuda_random(float *x_gpu, int n)
|
void cuda_random(float *x_gpu, size_t n)
|
||||||
{
|
{
|
||||||
static curandGenerator_t gen;
|
static curandGenerator_t gen;
|
||||||
static int init = 0;
|
static int init = 0;
|
||||||
@ -84,7 +97,7 @@ void cuda_random(float *x_gpu, int n)
|
|||||||
check_error(cudaPeekAtLastError());
|
check_error(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
float cuda_compare(float *x_gpu, float *x, int n, char *s)
|
float cuda_compare(float *x_gpu, float *x, size_t n, char *s)
|
||||||
{
|
{
|
||||||
float *tmp = calloc(n, sizeof(float));
|
float *tmp = calloc(n, sizeof(float));
|
||||||
cuda_pull_array(x_gpu, tmp, n);
|
cuda_pull_array(x_gpu, tmp, n);
|
||||||
@ -97,7 +110,7 @@ float cuda_compare(float *x_gpu, float *x, int n, char *s)
|
|||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
|
|
||||||
int *cuda_make_int_array(int n)
|
int *cuda_make_int_array(size_t n)
|
||||||
{
|
{
|
||||||
int *x_gpu;
|
int *x_gpu;
|
||||||
size_t size = sizeof(int)*n;
|
size_t size = sizeof(int)*n;
|
||||||
@ -112,14 +125,14 @@ void cuda_free(float *x_gpu)
|
|||||||
check_error(status);
|
check_error(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
void cuda_push_array(float *x_gpu, float *x, int n)
|
void cuda_push_array(float *x_gpu, float *x, size_t n)
|
||||||
{
|
{
|
||||||
size_t size = sizeof(float)*n;
|
size_t size = sizeof(float)*n;
|
||||||
cudaError_t status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
|
cudaError_t status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
|
||||||
check_error(status);
|
check_error(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
void cuda_pull_array(float *x_gpu, float *x, int n)
|
void cuda_pull_array(float *x_gpu, float *x, size_t n)
|
||||||
{
|
{
|
||||||
size_t size = sizeof(float)*n;
|
size_t size = sizeof(float)*n;
|
||||||
cudaError_t status = cudaMemcpy(x, x_gpu, size, cudaMemcpyDeviceToHost);
|
cudaError_t status = cudaMemcpy(x, x_gpu, size, cudaMemcpyDeviceToHost);
|
||||||
|
20
src/cuda.h
20
src/cuda.h
@ -11,16 +11,24 @@ extern int gpu_index;
|
|||||||
#include "curand.h"
|
#include "curand.h"
|
||||||
#include "cublas_v2.h"
|
#include "cublas_v2.h"
|
||||||
|
|
||||||
|
#ifdef CUDNN
|
||||||
|
#include "cudnn.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
void check_error(cudaError_t status);
|
void check_error(cudaError_t status);
|
||||||
cublasHandle_t blas_handle();
|
cublasHandle_t blas_handle();
|
||||||
float *cuda_make_array(float *x, int n);
|
float *cuda_make_array(float *x, size_t n);
|
||||||
int *cuda_make_int_array(int n);
|
int *cuda_make_int_array(size_t n);
|
||||||
void cuda_push_array(float *x_gpu, float *x, int n);
|
void cuda_push_array(float *x_gpu, float *x, size_t n);
|
||||||
void cuda_pull_array(float *x_gpu, float *x, int n);
|
void cuda_pull_array(float *x_gpu, float *x, size_t n);
|
||||||
void cuda_free(float *x_gpu);
|
void cuda_free(float *x_gpu);
|
||||||
void cuda_random(float *x_gpu, int n);
|
void cuda_random(float *x_gpu, size_t n);
|
||||||
float cuda_compare(float *x_gpu, float *x, int n, char *s);
|
float cuda_compare(float *x_gpu, float *x, size_t n, char *s);
|
||||||
dim3 cuda_gridsize(size_t n);
|
dim3 cuda_gridsize(size_t n);
|
||||||
|
|
||||||
|
#ifdef CUDNN
|
||||||
|
cudnnHandle_t cudnn_handle();
|
||||||
|
#endif
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
@ -26,6 +26,7 @@ extern void run_vid_rnn(int argc, char **argv);
|
|||||||
extern void run_tag(int argc, char **argv);
|
extern void run_tag(int argc, char **argv);
|
||||||
extern void run_cifar(int argc, char **argv);
|
extern void run_cifar(int argc, char **argv);
|
||||||
extern void run_go(int argc, char **argv);
|
extern void run_go(int argc, char **argv);
|
||||||
|
extern void run_art(int argc, char **argv);
|
||||||
|
|
||||||
void change_rate(char *filename, float scale, float add)
|
void change_rate(char *filename, float scale, float add)
|
||||||
{
|
{
|
||||||
@ -259,6 +260,8 @@ int main(int argc, char **argv)
|
|||||||
run_coco(argc, argv);
|
run_coco(argc, argv);
|
||||||
} else if (0 == strcmp(argv[1], "classifier")){
|
} else if (0 == strcmp(argv[1], "classifier")){
|
||||||
run_classifier(argc, argv);
|
run_classifier(argc, argv);
|
||||||
|
} else if (0 == strcmp(argv[1], "art")){
|
||||||
|
run_art(argc, argv);
|
||||||
} else if (0 == strcmp(argv[1], "tag")){
|
} else if (0 == strcmp(argv[1], "tag")){
|
||||||
run_tag(argc, argv);
|
run_tag(argc, argv);
|
||||||
} else if (0 == strcmp(argv[1], "compare")){
|
} else if (0 == strcmp(argv[1], "compare")){
|
||||||
|
13
src/layer.h
13
src/layer.h
@ -2,6 +2,7 @@
|
|||||||
#define BASE_LAYER_H
|
#define BASE_LAYER_H
|
||||||
|
|
||||||
#include "activations.h"
|
#include "activations.h"
|
||||||
|
#include "stddef.h"
|
||||||
|
|
||||||
struct layer;
|
struct layer;
|
||||||
typedef struct layer layer;
|
typedef struct layer layer;
|
||||||
@ -157,6 +158,8 @@ struct layer{
|
|||||||
struct layer *input_h_layer;
|
struct layer *input_h_layer;
|
||||||
struct layer *state_h_layer;
|
struct layer *state_h_layer;
|
||||||
|
|
||||||
|
size_t workspace_size;
|
||||||
|
|
||||||
#ifdef GPU
|
#ifdef GPU
|
||||||
float *z_gpu;
|
float *z_gpu;
|
||||||
float *r_gpu;
|
float *r_gpu;
|
||||||
@ -207,6 +210,16 @@ struct layer{
|
|||||||
float * rand_gpu;
|
float * rand_gpu;
|
||||||
float * squared_gpu;
|
float * squared_gpu;
|
||||||
float * norms_gpu;
|
float * norms_gpu;
|
||||||
|
#ifdef CUDNN
|
||||||
|
cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc;
|
||||||
|
cudnnTensorDescriptor_t dsrcTensorDesc, ddstTensorDesc;
|
||||||
|
cudnnFilterDescriptor_t filterDesc;
|
||||||
|
cudnnFilterDescriptor_t dfilterDesc;
|
||||||
|
cudnnConvolutionDescriptor_t convDesc;
|
||||||
|
cudnnConvolutionFwdAlgo_t fw_algo;
|
||||||
|
cudnnConvolutionBwdDataAlgo_t bd_algo;
|
||||||
|
cudnnConvolutionBwdFilterAlgo_t bf_algo;
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ typedef enum {
|
|||||||
} learning_rate_policy;
|
} learning_rate_policy;
|
||||||
|
|
||||||
typedef struct network{
|
typedef struct network{
|
||||||
|
float *workspace;
|
||||||
int n;
|
int n;
|
||||||
int batch;
|
int batch;
|
||||||
int *seen;
|
int *seen;
|
||||||
@ -49,6 +50,7 @@ typedef struct network_state {
|
|||||||
float *truth;
|
float *truth;
|
||||||
float *input;
|
float *input;
|
||||||
float *delta;
|
float *delta;
|
||||||
|
float *workspace;
|
||||||
int train;
|
int train;
|
||||||
int index;
|
int index;
|
||||||
network net;
|
network net;
|
||||||
|
@ -41,6 +41,7 @@ float * get_network_output_gpu(network net);
|
|||||||
|
|
||||||
void forward_network_gpu(network net, network_state state)
|
void forward_network_gpu(network net, network_state state)
|
||||||
{
|
{
|
||||||
|
state.workspace = net.workspace;
|
||||||
int i;
|
int i;
|
||||||
for(i = 0; i < net.n; ++i){
|
for(i = 0; i < net.n; ++i){
|
||||||
state.index = i;
|
state.index = i;
|
||||||
@ -93,6 +94,7 @@ void forward_network_gpu(network net, network_state state)
|
|||||||
|
|
||||||
void backward_network_gpu(network net, network_state state)
|
void backward_network_gpu(network net, network_state state)
|
||||||
{
|
{
|
||||||
|
state.workspace = net.workspace;
|
||||||
int i;
|
int i;
|
||||||
float * original_input = state.input;
|
float * original_input = state.input;
|
||||||
float * original_delta = state.delta;
|
float * original_delta = state.delta;
|
||||||
|
@ -524,6 +524,7 @@ network parse_network_cfg(char *filename)
|
|||||||
params.batch = net.batch;
|
params.batch = net.batch;
|
||||||
params.time_steps = net.time_steps;
|
params.time_steps = net.time_steps;
|
||||||
|
|
||||||
|
size_t workspace_size = 0;
|
||||||
n = n->next;
|
n = n->next;
|
||||||
int count = 0;
|
int count = 0;
|
||||||
free_section(s);
|
free_section(s);
|
||||||
@ -584,6 +585,7 @@ network parse_network_cfg(char *filename)
|
|||||||
l.dontloadscales = option_find_int_quiet(options, "dontloadscales", 0);
|
l.dontloadscales = option_find_int_quiet(options, "dontloadscales", 0);
|
||||||
option_unused(options);
|
option_unused(options);
|
||||||
net.layers[count] = l;
|
net.layers[count] = l;
|
||||||
|
if (l.workspace_size > workspace_size) workspace_size = l.workspace_size;
|
||||||
free_section(s);
|
free_section(s);
|
||||||
n = n->next;
|
n = n->next;
|
||||||
++count;
|
++count;
|
||||||
@ -597,6 +599,11 @@ network parse_network_cfg(char *filename)
|
|||||||
free_list(sections);
|
free_list(sections);
|
||||||
net.outputs = get_network_output_size(net);
|
net.outputs = get_network_output_size(net);
|
||||||
net.output = get_network_output(net);
|
net.output = get_network_output(net);
|
||||||
|
if(workspace_size){
|
||||||
|
#ifdef GPU
|
||||||
|
net.workspace = cuda_make_array(0, (workspace_size-1)/sizeof(float)+1);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
return net;
|
return net;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user