This commit is contained in:
Joseph Redmon
2015-05-08 10:33:47 -07:00
parent e7688a05a1
commit 0cbfa46461
7 changed files with 86 additions and 10 deletions

View File

@ -4,7 +4,6 @@
#include "image.h"
#include "data.h"
#include "utils.h"
#include "params.h"
#include "crop_layer.h"
#include "connected_layer.h"
@ -16,6 +15,7 @@
#include "normalization_layer.h"
#include "softmax_layer.h"
#include "dropout_layer.h"
#include "route_layer.h"
char *get_layer_string(LAYER_TYPE a)
{
@ -40,6 +40,8 @@ char *get_layer_string(LAYER_TYPE a)
return "crop";
case COST:
return "cost";
case ROUTE:
return "route";
default:
break;
}
@ -99,6 +101,9 @@ void forward_network(network net, network_state state)
else if(net.types[i] == DROPOUT){
forward_dropout_layer(*(dropout_layer *)net.layers[i], state);
}
else if(net.types[i] == ROUTE){
forward_route_layer(*(route_layer *)net.layers[i], net);
}
state.input = get_network_output_layer(net, i);
}
}
@ -143,6 +148,8 @@ float *get_network_output_layer(network net, int i)
return ((crop_layer *)net.layers[i]) -> output;
} else if(net.types[i] == NORMALIZATION){
return ((normalization_layer *)net.layers[i]) -> output;
} else if(net.types[i] == ROUTE){
return ((route_layer *)net.layers[i]) -> output;
}
return 0;
}
@ -177,6 +184,8 @@ float *get_network_delta_layer(network net, int i)
} else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
return layer.delta;
} else if(net.types[i] == ROUTE){
return ((route_layer *)net.layers[i]) -> delta;
}
return 0;
}
@ -247,10 +256,12 @@ void backward_network(network net, network_state state)
else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
backward_connected_layer(layer, state);
}
else if(net.types[i] == COST){
} else if(net.types[i] == COST){
cost_layer layer = *(cost_layer *)net.layers[i];
backward_cost_layer(layer, state);
} else if(net.types[i] == ROUTE){
route_layer layer = *(route_layer *)net.layers[i];
backward_route_layer(layer, net);
}
}
}
@ -369,6 +380,10 @@ void set_batch_network(network *net, int b)
crop_layer *layer = (crop_layer *)net->layers[i];
layer->batch = b;
}
else if(net->types[i] == ROUTE){
route_layer *layer = (route_layer *)net->layers[i];
layer->batch = b;
}
}
}
@ -445,12 +460,17 @@ int get_network_output_size_layer(network net, int i)
softmax_layer layer = *(softmax_layer *)net.layers[i];
return layer.inputs;
}
else if(net.types[i] == ROUTE){
route_layer layer = *(route_layer *)net.layers[i];
return layer.outputs;
}
fprintf(stderr, "Can't find output size\n");
return 0;
}
int resize_network(network net, int h, int w, int c)
{
fprintf(stderr, "Might be broken, careful!!");
int i;
for (i = 0; i < net.n; ++i){
if(net.types[i] == CONVOLUTIONAL){
@ -540,6 +560,10 @@ image get_network_image_layer(network net, int i)
crop_layer layer = *(crop_layer *)net.layers[i];
return get_crop_image(layer);
}
else if(net.types[i] == ROUTE){
route_layer layer = *(route_layer *)net.layers[i];
return get_network_image_layer(net, layer.input_layers[0]);
}
return make_empty_image(0,0,0);
}