This commit is contained in:
Joseph Redmon 2016-05-12 13:36:11 -07:00
parent 9942d48412
commit 054e2b1954
5 changed files with 145 additions and 39 deletions

View File

@ -1,5 +1,5 @@
GPU=1
OPENCV=1
GPU=0
OPENCV=0
DEBUG=0
ARCH= --gpu-architecture=compute_20 --gpu-code=compute_20
@ -34,7 +34,7 @@ CFLAGS+= -DGPU
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
endif
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo2.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o coco_demo.o tag.o cifar.o yolo_demo.o go.o batchnorm_layer.o
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o coco_demo.o tag.o cifar.o yolo_demo.o go.o batchnorm_layer.o
ifeq ($(GPU), 1)
LDFLAGS+= -lstdc++
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o

View File

@ -65,6 +65,8 @@ float get_current_rate(network net)
return net.learning_rate * pow(net.gamma, batch_num);
case POLY:
return net.learning_rate * pow(1 - (float)batch_num / net.max_batches, net.power);
case RANDOM:
return net.learning_rate * pow(rand_uniform(0,1), net.power);
case SIG:
return net.learning_rate * (1./(1.+exp(net.gamma*(batch_num - net.step))));
default:

View File

@ -7,7 +7,7 @@
#include "data.h"
typedef enum {
CONSTANT, STEP, EXP, POLY, STEPS, SIG
CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM
} learning_rate_policy;
typedef struct network{

View File

@ -432,6 +432,7 @@ route_layer parse_route(list *options, size_params params, network net)
learning_rate_policy get_policy(char *s)
{
if (strcmp(s, "random")==0) return RANDOM;
if (strcmp(s, "poly")==0) return POLY;
if (strcmp(s, "constant")==0) return CONSTANT;
if (strcmp(s, "step")==0) return STEP;
@ -497,7 +498,7 @@ void parse_net_options(list *options, network *net)
} else if (net->policy == SIG){
net->gamma = option_find_float(options, "gamma", 1);
net->step = option_find_int(options, "step", 1);
} else if (net->policy == POLY){
} else if (net->policy == POLY || net->policy == RANDOM){
net->power = option_find_float(options, "power", 1);
}
net->max_batches = option_find_int(options, "max_batches", 0);

171
src/rnn.c
View File

@ -13,6 +13,76 @@ typedef struct {
float *y;
} float_pair;
int *read_tokenized_data(char *filename, size_t *read)
{
size_t size = 512;
size_t count = 0;
FILE *fp = fopen(filename, "r");
int *d = calloc(size, sizeof(int));
int n, one;
one = fscanf(fp, "%d", &n);
while(one == 1){
++count;
if(count > size){
size = size*2;
d = realloc(d, size*sizeof(int));
}
d[count-1] = n;
one = fscanf(fp, "%d", &n);
}
fclose(fp);
d = realloc(d, count*sizeof(int));
*read = count;
return d;
}
char **read_tokens(char *filename, size_t *read)
{
size_t size = 512;
size_t count = 0;
FILE *fp = fopen(filename, "r");
char **d = calloc(size, sizeof(char *));
char *line;
while((line=fgetl(fp)) != 0){
++count;
if(count > size){
size = size*2;
d = realloc(d, size*sizeof(char *));
}
d[count-1] = line;
}
fclose(fp);
d = realloc(d, count*sizeof(char *));
*read = count;
return d;
}
float_pair get_rnn_token_data(int *tokens, size_t *offsets, int characters, size_t len, int batch, int steps)
{
float *x = calloc(batch * steps * characters, sizeof(float));
float *y = calloc(batch * steps * characters, sizeof(float));
int i,j;
for(i = 0; i < batch; ++i){
for(j = 0; j < steps; ++j){
int curr = tokens[(offsets[i])%len];
int next = tokens[(offsets[i] + 1)%len];
x[(j*batch + i)*characters + curr] = 1;
y[(j*batch + i)*characters + next] = 1;
offsets[i] = (offsets[i] + 1) % len;
if(curr >= characters || curr < 0 || next >= characters || next < 0){
error("Bad char");
}
}
}
float_pair p;
p.x = x;
p.y = y;
return p;
}
float_pair get_rnn_data(unsigned char *text, size_t *offsets, int characters, size_t len, int batch, int steps)
{
float *x = calloc(batch * steps * characters, sizeof(float));
@ -47,8 +117,8 @@ void reset_rnn_state(network net, int b)
{
int i;
for (i = 0; i < net.n; ++i) {
layer l = net.layers[i];
#ifdef GPU
layer l = net.layers[i];
if(l.state_gpu){
fill_ongpu(l.outputs, 0, l.state_gpu + l.outputs*b, 1);
}
@ -56,19 +126,26 @@ void reset_rnn_state(network net, int b)
}
}
void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear)
void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized)
{
srand(time(0));
data_seed = time(0);
FILE *fp = fopen(filename, "rb");
unsigned char *text = 0;
int *tokens = 0;
size_t size;
if(tokenized){
tokens = read_tokenized_data(filename, &size);
} else {
FILE *fp = fopen(filename, "rb");
fseek(fp, 0, SEEK_END);
size_t size = ftell(fp);
fseek(fp, 0, SEEK_SET);
fseek(fp, 0, SEEK_END);
size = ftell(fp);
fseek(fp, 0, SEEK_SET);
unsigned char *text = calloc(size+1, sizeof(char));
fread(text, 1, size, fp);
fclose(fp);
text = calloc(size+1, sizeof(char));
fread(text, 1, size, fp);
fclose(fp);
}
char *backup_directory = "/home/pjreddie/backup/";
char *base = basecfg(cfgfile);
@ -97,7 +174,12 @@ void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear)
while(get_current_batch(net) < net.max_batches){
i += 1;
time=clock();
float_pair p = get_rnn_data(text, offsets, inputs, size, streams, steps);
float_pair p;
if(tokenized){
p = get_rnn_token_data(tokens, offsets, inputs, size, streams, steps);
}else{
p = get_rnn_data(text, offsets, inputs, size, streams, steps);
}
float loss = train_network_datum(net, p.x, p.y) / (batch);
free(p.x);
@ -133,8 +215,22 @@ void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear)
save_weights(net, buff);
}
void test_char_rnn(char *cfgfile, char *weightfile, int num, char *seed, float temp, int rseed)
void print_symbol(int n, char **tokens){
if(tokens){
printf("%s ", tokens[n]);
} else {
printf("%c", n);
}
}
void test_char_rnn(char *cfgfile, char *weightfile, int num, char *seed, float temp, int rseed, char *token_file)
{
char **tokens = 0;
if(token_file){
size_t n;
tokens = read_tokens(token_file, &n);
}
srand(rseed);
char *base = basecfg(cfgfile);
fprintf(stderr, "%s\n", base);
@ -147,38 +243,39 @@ void test_char_rnn(char *cfgfile, char *weightfile, int num, char *seed, float t
int i, j;
for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
unsigned char c;
int c = 0;
int len = strlen(seed);
float *input = calloc(inputs, sizeof(float));
/*
fill_cpu(inputs, 0, input, 1);
for(i = 0; i < 10; ++i){
network_predict(net, input);
}
fill_cpu(inputs, 0, input, 1);
*/
/*
fill_cpu(inputs, 0, input, 1);
for(i = 0; i < 10; ++i){
network_predict(net, input);
}
fill_cpu(inputs, 0, input, 1);
*/
for(i = 0; i < len-1; ++i){
c = seed[i];
input[(int)c] = 1;
input[c] = 1;
network_predict(net, input);
input[(int)c] = 0;
printf("%c", c);
input[c] = 0;
print_symbol(c, tokens);
}
c = seed[len-1];
if(len) c = seed[len-1];
print_symbol(c, tokens);
for(i = 0; i < num; ++i){
printf("%c", c);
input[(int)c] = 1;
input[c] = 1;
float *out = network_predict(net, input);
input[(int)c] = 0;
input[c] = 0;
for(j = 32; j < 127; ++j){
//printf("%d %c %f\n",j, j, out[j]);
}
for(j = 0; j < inputs; ++j){
//if (out[j] < .0001) out[j] = 0;
if (out[j] < .0001) out[j] = 0;
}
c = sample_array(out, inputs);
print_symbol(c, tokens);
}
printf("\n");
}
@ -195,6 +292,7 @@ void valid_char_rnn(char *cfgfile, char *weightfile, char *seed)
int inputs = get_network_input_size(net);
int count = 0;
int words = 1;
int c;
int len = strlen(seed);
float *input = calloc(inputs, sizeof(float));
@ -213,12 +311,13 @@ void valid_char_rnn(char *cfgfile, char *weightfile, char *seed)
if(next == EOF) break;
if(next < 0 || next >= 255) error("Out of range character");
++count;
if(next == ' ' || next == '\n' || next == '\t') ++words;
input[c] = 1;
float *out = network_predict(net, input);
input[c] = 0;
sum += log(out[next])/log2;
c = next;
printf("%d Perplexity: %f\n", count, pow(2, -sum/count));
printf("%d Perplexity: %4.4f Word Perplexity: %4.4f\n", count, pow(2, -sum/count), pow(2, -sum/words));
}
}
@ -254,13 +353,15 @@ void vec_char_rnn(char *cfgfile, char *weightfile, char *seed)
network_predict(net, input);
input[(int)c] = 0;
}
c = ' ';
input[(int)c] = 1;
network_predict(net, input);
input[(int)c] = 0;
c = ' ';
input[(int)c] = 1;
network_predict(net, input);
input[(int)c] = 0;
layer l = net.layers[0];
#ifdef GPU
cuda_pull_array(l.output_gpu, l.output, l.outputs);
#endif
printf("%s", line);
for(i = 0; i < l.outputs; ++i){
printf(",%g", l.output[i]);
@ -281,11 +382,13 @@ void run_char_rnn(int argc, char **argv)
float temp = find_float_arg(argc, argv, "-temp", .7);
int rseed = find_int_arg(argc, argv, "-srand", time(0));
int clear = find_arg(argc, argv, "-clear");
int tokenized = find_arg(argc, argv, "-tokenized");
char *tokens = find_char_arg(argc, argv, "-tokens", 0);
char *cfg = argv[3];
char *weights = (argc > 4) ? argv[4] : 0;
if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename, clear);
if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename, clear, tokenized);
else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights, seed);
else if(0==strcmp(argv[2], "vec")) vec_char_rnn(cfg, weights, seed);
else if(0==strcmp(argv[2], "generate")) test_char_rnn(cfg, weights, len, seed, temp, rseed);
else if(0==strcmp(argv[2], "generate")) test_char_rnn(cfg, weights, len, seed, temp, rseed, tokens);
}