mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
Merge pull request #1724 from eon129/master
Classifier now is working for CPU/GPU
This commit is contained in:
22
src/blas.c
22
src/blas.c
@ -244,6 +244,28 @@ void l1_cpu(int n, float *pred, float *truth, float *delta, float *error)
|
||||
}
|
||||
}
|
||||
|
||||
void softmax_x_ent_cpu(int n, float *pred, float *truth, float *delta, float *error)
|
||||
{
|
||||
int i;
|
||||
for(i = 0; i < n; ++i){
|
||||
float t = truth[i];
|
||||
float p = pred[i];
|
||||
error[i] = (t) ? -log(p) : 0;
|
||||
delta[i] = t-p;
|
||||
}
|
||||
}
|
||||
|
||||
void logistic_x_ent_cpu(int n, float *pred, float *truth, float *delta, float *error)
|
||||
{
|
||||
int i;
|
||||
for(i = 0; i < n; ++i){
|
||||
float t = truth[i];
|
||||
float p = pred[i];
|
||||
error[i] = -t*log(p) - (1-t)*log(1-p);
|
||||
delta[i] = t-p;
|
||||
}
|
||||
}
|
||||
|
||||
void l2_cpu(int n, float *pred, float *truth, float *delta, float *error)
|
||||
{
|
||||
int i;
|
||||
|
@ -37,9 +37,12 @@ void weighted_sum_cpu(float *a, float *b, float *s, int num, float *c);
|
||||
|
||||
void softmax(float *input, int n, float temp, float *output, int stride);
|
||||
void upsample_cpu(float *in, int w, int h, int c, int batch, int stride, int forward, float scale, float *out);
|
||||
void softmax_cpu(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output);
|
||||
void softmax_x_ent_cpu(int n, float *pred, float *truth, float *delta, float *error);
|
||||
|
||||
#ifdef GPU
|
||||
#include "cuda.h"
|
||||
#include "tree.h"
|
||||
|
||||
void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY);
|
||||
void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY);
|
||||
@ -47,6 +50,7 @@ void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY);
|
||||
void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY);
|
||||
void scal_ongpu(int N, float ALPHA, float * X, int INCX);
|
||||
void supp_ongpu(int N, float ALPHA, float * X, int INCX);
|
||||
void mask_gpu_new_api(int N, float * X, float mask_num, float * mask, float val);
|
||||
void mask_ongpu(int N, float * X, float mask_num, float * mask);
|
||||
void const_ongpu(int N, float ALPHA, float *X, int INCX);
|
||||
void pow_ongpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY);
|
||||
@ -71,6 +75,7 @@ void scale_bias_gpu(float *output, float *biases, int batch, int n, int size);
|
||||
void add_bias_gpu(float *output, float *biases, int batch, int n, int size);
|
||||
void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size);
|
||||
|
||||
void softmax_x_ent_gpu(int n, float *pred, float *truth, float *delta, float *error);
|
||||
void smooth_l1_gpu(int n, float *pred, float *truth, float *delta, float *error);
|
||||
void l2_gpu(int n, float *pred, float *truth, float *delta, float *error);
|
||||
void weighted_delta_gpu(float *a, float *b, float *s, float *da, float *db, float *ds, int num, float *dc);
|
||||
@ -79,6 +84,7 @@ void mult_add_into_gpu(int num, float *a, float *b, float *c);
|
||||
|
||||
void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out);
|
||||
|
||||
void softmax_gpu_new_api(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output);
|
||||
void softmax_gpu(float *input, int n, int offset, int groups, float temp, float *output);
|
||||
void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t);
|
||||
void adam_update_gpu(float *w, float *d, float *m, float *v, float B1, float B2, float eps, float decay, float rate, int n, int batch, int t);
|
||||
@ -87,5 +93,7 @@ void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, fl
|
||||
|
||||
void upsample_gpu(float *in, int w, int h, int c, int batch, int stride, int forward, float scale, float *out);
|
||||
|
||||
void softmax_tree_gpu(float *input, int spatial, int batch, int stride, float temp, float *output, tree hier);
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
@ -7,6 +7,7 @@ extern "C" {
|
||||
#include "blas.h"
|
||||
#include "cuda.h"
|
||||
#include "utils.h"
|
||||
#include "tree.h"
|
||||
}
|
||||
|
||||
__global__ void scale_bias_kernel(float *output, float *biases, int n, int size)
|
||||
@ -419,7 +420,13 @@ __global__ void fill_kernel(int N, float ALPHA, float *X, int INCX)
|
||||
if(i < N) X[i*INCX] = ALPHA;
|
||||
}
|
||||
|
||||
__global__ void mask_kernel(int n, float *x, float mask_num, float *mask)
|
||||
__global__ void mask_kernel_new_api(int n, float *x, float mask_num, float *mask, float val)
|
||||
{
|
||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if (i < n && mask[i] == mask_num) x[i] = val;
|
||||
}
|
||||
|
||||
__global__ void mask_kernel(int n, float *x, float mask_num, float *mask)
|
||||
{
|
||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if(i < n && mask[i] == mask_num) x[i] = mask_num;
|
||||
@ -592,6 +599,12 @@ extern "C" void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void mask_gpu_new_api(int N, float * X, float mask_num, float * mask, float val)
|
||||
{
|
||||
mask_kernel_new_api <<<cuda_gridsize(N), BLOCK >>>(N, X, mask_num, mask, val);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
extern "C" void mask_ongpu(int N, float * X, float mask_num, float * mask)
|
||||
{
|
||||
mask_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, X, mask_num, mask);
|
||||
@ -687,6 +700,23 @@ extern "C" void smooth_l1_gpu(int n, float *pred, float *truth, float *delta, fl
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
__global__ void softmax_x_ent_kernel(int n, float *pred, float *truth, float *delta, float *error)
|
||||
{
|
||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if (i < n) {
|
||||
float t = truth[i];
|
||||
float p = pred[i];
|
||||
error[i] = (t) ? -log(p) : 0;
|
||||
delta[i] = t - p;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void softmax_x_ent_gpu(int n, float *pred, float *truth, float *delta, float *error)
|
||||
{
|
||||
softmax_x_ent_kernel << <cuda_gridsize(n), BLOCK >> >(n, pred, truth, delta, error);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
__global__ void l2_kernel(int n, float *pred, float *truth, float *delta, float *error)
|
||||
{
|
||||
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
@ -784,6 +814,40 @@ extern "C" void softmax_gpu(float *input, int n, int offset, int groups, float t
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
__device__ void softmax_device_new_api(float *input, int n, float temp, int stride, float *output)
|
||||
{
|
||||
int i;
|
||||
float sum = 0;
|
||||
float largest = -INFINITY;
|
||||
for (i = 0; i < n; ++i) {
|
||||
int val = input[i*stride];
|
||||
largest = (val>largest) ? val : largest;
|
||||
}
|
||||
for (i = 0; i < n; ++i) {
|
||||
float e = expf(input[i*stride] / temp - largest / temp);
|
||||
sum += e;
|
||||
output[i*stride] = e;
|
||||
}
|
||||
for (i = 0; i < n; ++i) {
|
||||
output[i*stride] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void softmax_kernel_new_api(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output)
|
||||
{
|
||||
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if (id >= batch*groups) return;
|
||||
int b = id / groups;
|
||||
int g = id % groups;
|
||||
softmax_device_new_api(input + b*batch_offset + g*group_offset, n, temp, stride, output + b*batch_offset + g*group_offset);
|
||||
}
|
||||
|
||||
extern "C" void softmax_gpu_new_api(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output)
|
||||
{
|
||||
softmax_kernel_new_api << <cuda_gridsize(batch*groups), BLOCK >> >(input, n, batch, batch_offset, groups, group_offset, stride, temp, output);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
__global__ void upsample_kernel(size_t N, float *x, int w, int h, int c, int batch, int stride, int forward, float scale, float *out)
|
||||
{
|
||||
@ -814,4 +878,36 @@ extern "C" void upsample_gpu(float *in, int w, int h, int c, int batch, int stri
|
||||
size_t size = w*h*c*batch*stride*stride;
|
||||
upsample_kernel << <cuda_gridsize(size), BLOCK >> >(size, in, w, h, c, batch, stride, forward, scale, out);
|
||||
check_error(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
__global__ void softmax_tree_kernel(float *input, int spatial, int batch, int stride, float temp, float *output, int groups, int *group_size, int *group_offset)
|
||||
{
|
||||
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
||||
if (id >= spatial*batch*groups) return;
|
||||
int s = id % spatial;
|
||||
id = id / spatial;
|
||||
int g = id % groups;
|
||||
int b = id / groups;
|
||||
int goff = group_offset[g] * spatial;
|
||||
int boff = b*stride;
|
||||
softmax_device_new_api(input + goff + boff + s, group_size[g], temp, spatial, output + goff + boff + s);
|
||||
}
|
||||
|
||||
extern "C" void softmax_tree_gpu(float *input, int spatial, int batch, int stride, float temp, float *output, tree hier)
|
||||
{
|
||||
int *tree_groups_size = cuda_make_int_array_new_api(hier.group_size, hier.groups);
|
||||
int *tree_groups_offset = cuda_make_int_array_new_api(hier.group_offset, hier.groups);
|
||||
/*
|
||||
static int *tree_groups_size = 0;
|
||||
static int *tree_groups_offset = 0;
|
||||
if(!tree_groups_size){
|
||||
tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups);
|
||||
tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups);
|
||||
}
|
||||
*/
|
||||
int num = spatial*batch*hier.groups;
|
||||
softmax_tree_kernel <<<cuda_gridsize(num), BLOCK >>>(input, spatial, batch, stride, temp, output, hier.groups, tree_groups_size, tree_groups_offset);
|
||||
check_error(cudaPeekAtLastError());
|
||||
cuda_free((float *)tree_groups_size);
|
||||
cuda_free((float *)tree_groups_offset);
|
||||
}
|
14
src/cuda.c
14
src/cuda.c
@ -162,6 +162,20 @@ int *cuda_make_int_array(size_t n)
|
||||
return x_gpu;
|
||||
}
|
||||
|
||||
int *cuda_make_int_array_new_api(int *x, size_t n)
|
||||
{
|
||||
int *x_gpu;
|
||||
size_t size = sizeof(int)*n;
|
||||
cudaError_t status = cudaMalloc((void **)&x_gpu, size);
|
||||
check_error(status);
|
||||
if (x) {
|
||||
status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
|
||||
check_error(status);
|
||||
}
|
||||
if (!x_gpu) error("Cuda malloc failed\n");
|
||||
return x_gpu;
|
||||
}
|
||||
|
||||
void cuda_free(float *x_gpu)
|
||||
{
|
||||
//cudaStreamSynchronize(get_cuda_stream());
|
||||
|
@ -40,6 +40,7 @@ extern "C" {
|
||||
cublasHandle_t blas_handle();
|
||||
float *cuda_make_array(float *x, size_t n);
|
||||
int *cuda_make_int_array(size_t n);
|
||||
int *cuda_make_int_array_new_api(int *x, size_t n);
|
||||
void cuda_push_array(float *x_gpu, float *x, size_t n);
|
||||
YOLODLL_API void cuda_pull_array(float *x_gpu, float *x, size_t n);
|
||||
YOLODLL_API void cuda_set_device(int n);
|
||||
|
@ -83,6 +83,7 @@ struct layer{
|
||||
int side;
|
||||
int stride;
|
||||
int reverse;
|
||||
int spatial;
|
||||
int pad;
|
||||
int sqrt;
|
||||
int flip;
|
||||
@ -100,6 +101,7 @@ struct layer{
|
||||
float shift;
|
||||
float ratio;
|
||||
int focal_loss;
|
||||
int noloss;
|
||||
int softmax;
|
||||
int classes;
|
||||
int coords;
|
||||
@ -198,6 +200,7 @@ struct layer{
|
||||
int * input_sizes;
|
||||
float * delta;
|
||||
float * output;
|
||||
float * loss;
|
||||
float * squared;
|
||||
float * norms;
|
||||
|
||||
@ -289,6 +292,7 @@ struct layer{
|
||||
float * scale_updates_gpu;
|
||||
|
||||
float * output_gpu;
|
||||
float * loss_gpu;
|
||||
float * delta_gpu;
|
||||
float * rand_gpu;
|
||||
float * squared_gpu;
|
||||
|
17
src/parser.c
17
src/parser.c
@ -233,12 +233,17 @@ connected_layer parse_connected(list *options, size_params params)
|
||||
|
||||
softmax_layer parse_softmax(list *options, size_params params)
|
||||
{
|
||||
int groups = option_find_int_quiet(options, "groups",1);
|
||||
softmax_layer layer = make_softmax_layer(params.batch, params.inputs, groups);
|
||||
layer.temperature = option_find_float_quiet(options, "temperature", 1);
|
||||
char *tree_file = option_find_str(options, "tree", 0);
|
||||
if (tree_file) layer.softmax_tree = read_tree(tree_file);
|
||||
return layer;
|
||||
int groups = option_find_int_quiet(options, "groups", 1);
|
||||
softmax_layer layer = make_softmax_layer(params.batch, params.inputs, groups);
|
||||
layer.temperature = option_find_float_quiet(options, "temperature", 1);
|
||||
char *tree_file = option_find_str(options, "tree", 0);
|
||||
if (tree_file) layer.softmax_tree = read_tree(tree_file);
|
||||
layer.w = params.w;
|
||||
layer.h = params.h;
|
||||
layer.c = params.c;
|
||||
layer.spatial = option_find_float_quiet(options, "spatial", 0);
|
||||
layer.noloss = option_find_int_quiet(options, "noloss", 0);
|
||||
return layer;
|
||||
}
|
||||
|
||||
int *parse_yolo_mask(char *a, int *num)
|
||||
|
@ -1,12 +1,31 @@
|
||||
#include "softmax_layer.h"
|
||||
#include "blas.h"
|
||||
#include "cuda.h"
|
||||
#include "utils.h"
|
||||
#include "blas.h"
|
||||
|
||||
#include <float.h>
|
||||
#include <math.h>
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
|
||||
#define SECRET_NUM -1234
|
||||
|
||||
void softmax_tree(float *input, int batch, int inputs, float temp, tree *hierarchy, float *output)
|
||||
{
|
||||
int b;
|
||||
for (b = 0; b < batch; ++b) {
|
||||
int i;
|
||||
int count = 0;
|
||||
for (i = 0; i < hierarchy->groups; ++i) {
|
||||
int group_size = hierarchy->group_size[i];
|
||||
softmax(input + b*inputs + count, group_size, temp, output + b*inputs + count, 1);
|
||||
count += group_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
softmax_layer make_softmax_layer(int batch, int inputs, int groups)
|
||||
{
|
||||
assert(inputs%groups == 0);
|
||||
@ -17,8 +36,10 @@ softmax_layer make_softmax_layer(int batch, int inputs, int groups)
|
||||
l.groups = groups;
|
||||
l.inputs = inputs;
|
||||
l.outputs = inputs;
|
||||
l.loss = calloc(inputs*batch, sizeof(float));
|
||||
l.output = calloc(inputs*batch, sizeof(float));
|
||||
l.delta = calloc(inputs*batch, sizeof(float));
|
||||
l.cost = calloc(1, sizeof(float));
|
||||
|
||||
l.forward = forward_softmax_layer;
|
||||
l.backward = backward_softmax_layer;
|
||||
@ -27,45 +48,35 @@ softmax_layer make_softmax_layer(int batch, int inputs, int groups)
|
||||
l.backward_gpu = backward_softmax_layer_gpu;
|
||||
|
||||
l.output_gpu = cuda_make_array(l.output, inputs*batch);
|
||||
l.loss_gpu = cuda_make_array(l.loss, inputs*batch);
|
||||
l.delta_gpu = cuda_make_array(l.delta, inputs*batch);
|
||||
#endif
|
||||
return l;
|
||||
}
|
||||
|
||||
void softmax_tree(float *input, int batch, int inputs, float temp, tree *hierarchy, float *output)
|
||||
void forward_softmax_layer(const softmax_layer l, network_state net)
|
||||
{
|
||||
int b;
|
||||
for(b = 0; b < batch; ++b){
|
||||
if(l.softmax_tree){
|
||||
int i;
|
||||
int count = 0;
|
||||
for(i = 0; i < hierarchy->groups; ++i){
|
||||
int group_size = hierarchy->group_size[i];
|
||||
softmax(input+b*inputs + count, group_size, temp, output+b*inputs + count, 1);
|
||||
for (i = 0; i < l.softmax_tree->groups; ++i) {
|
||||
int group_size = l.softmax_tree->group_size[i];
|
||||
softmax_cpu(net.input + count, group_size, l.batch, l.inputs, 1, 0, 1, l.temperature, l.output + count);
|
||||
count += group_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void forward_softmax_layer(const softmax_layer l, network_state state)
|
||||
{
|
||||
int b;
|
||||
int inputs = l.inputs / l.groups;
|
||||
int batch = l.batch * l.groups;
|
||||
if(l.softmax_tree){
|
||||
softmax_tree(state.input, batch, inputs, l.temperature, l.softmax_tree, l.output);
|
||||
} else {
|
||||
for(b = 0; b < batch; ++b){
|
||||
softmax(state.input+b*inputs, inputs, l.temperature, l.output+b*inputs, 1);
|
||||
}
|
||||
softmax_cpu(net.input, l.inputs/l.groups, l.batch, l.inputs, l.groups, l.inputs/l.groups, 1, l.temperature, l.output);
|
||||
}
|
||||
|
||||
if(net.truth && !l.noloss){
|
||||
softmax_x_ent_cpu(l.batch*l.inputs, l.output, net.truth, l.delta, l.loss);
|
||||
l.cost[0] = sum_array(l.loss, l.batch*l.inputs);
|
||||
}
|
||||
}
|
||||
|
||||
void backward_softmax_layer(const softmax_layer l, network_state state)
|
||||
void backward_softmax_layer(const softmax_layer l, network_state net)
|
||||
{
|
||||
int i;
|
||||
for(i = 0; i < l.inputs*l.batch; ++i){
|
||||
state.delta[i] += l.delta[i];
|
||||
}
|
||||
axpy_cpu(l.inputs*l.batch, 1, l.delta, 1, net.delta, 1);
|
||||
}
|
||||
|
||||
#ifdef GPU
|
||||
@ -75,26 +86,40 @@ void pull_softmax_layer_output(const softmax_layer layer)
|
||||
cuda_pull_array(layer.output_gpu, layer.output, layer.inputs*layer.batch);
|
||||
}
|
||||
|
||||
void forward_softmax_layer_gpu(const softmax_layer l, network_state state)
|
||||
void forward_softmax_layer_gpu(const softmax_layer l, network_state net)
|
||||
{
|
||||
int inputs = l.inputs / l.groups;
|
||||
int batch = l.batch * l.groups;
|
||||
if(l.softmax_tree){
|
||||
int i;
|
||||
int count = 0;
|
||||
for (i = 0; i < l.softmax_tree->groups; ++i) {
|
||||
int group_size = l.softmax_tree->group_size[i];
|
||||
softmax_gpu(state.input+count, group_size, inputs, batch, l.temperature, l.output_gpu + count);
|
||||
count += group_size;
|
||||
}
|
||||
softmax_tree_gpu(net.input, 1, l.batch, l.inputs, l.temperature, l.output_gpu, *l.softmax_tree);
|
||||
/*
|
||||
int i;
|
||||
int count = 0;
|
||||
for (i = 0; i < l.softmax_tree->groups; ++i) {
|
||||
int group_size = l.softmax_tree->group_size[i];
|
||||
softmax_gpu(net.input_gpu + count, group_size, l.batch, l.inputs, 1, 0, 1, l.temperature, l.output_gpu + count);
|
||||
count += group_size;
|
||||
}
|
||||
*/
|
||||
} else {
|
||||
softmax_gpu(state.input, inputs, inputs, batch, l.temperature, l.output_gpu);
|
||||
if(l.spatial){
|
||||
softmax_gpu_new_api(net.input, l.c, l.batch*l.c, l.inputs/l.c, l.w*l.h, 1, l.w*l.h, 1, l.output_gpu);
|
||||
}else{
|
||||
softmax_gpu_new_api(net.input, l.inputs/l.groups, l.batch, l.inputs, l.groups, l.inputs/l.groups, 1, l.temperature, l.output_gpu);
|
||||
}
|
||||
}
|
||||
if(net.truth && !l.noloss){
|
||||
softmax_x_ent_gpu(l.batch*l.inputs, l.output_gpu, net.truth, l.delta_gpu, l.loss_gpu);
|
||||
if(l.softmax_tree){
|
||||
mask_gpu_new_api(l.batch*l.inputs, l.delta_gpu, SECRET_NUM, net.truth, 0);
|
||||
mask_gpu_new_api(l.batch*l.inputs, l.loss_gpu, SECRET_NUM, net.truth, 0);
|
||||
}
|
||||
cuda_pull_array(l.loss_gpu, l.loss, l.batch*l.inputs);
|
||||
l.cost[0] = sum_array(l.loss, l.batch*l.inputs);
|
||||
}
|
||||
}
|
||||
|
||||
void backward_softmax_layer_gpu(const softmax_layer layer, network_state state)
|
||||
void backward_softmax_layer_gpu(const softmax_layer layer, network_state net)
|
||||
{
|
||||
axpy_ongpu(layer.batch*layer.inputs, 1, layer.delta_gpu, 1, state.delta, 1);
|
||||
axpy_ongpu(layer.batch*layer.inputs, 1, layer.delta_gpu, 1, net.delta, 1);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user