Better partial function

This commit is contained in:
Joseph Redmon 2015-07-20 16:16:26 -07:00
parent 23c08be144
commit 38bd6ae6ba
3 changed files with 8 additions and 3 deletions

View File

@ -36,7 +36,7 @@ void partial(char *cfgfile, char *weightfile, char *outfile, int max)
load_weights_upto(&net, weightfile, max); load_weights_upto(&net, weightfile, max);
} }
net.seen = 0; net.seen = 0;
save_weights(net, outfile); save_weights_upto(net, outfile, max);
} }
#include "convolutional_layer.h" #include "convolutional_layer.h"

View File

@ -500,7 +500,7 @@ list *read_cfg(char *filename)
return sections; return sections;
} }
void save_weights(network net, char *filename) void save_weights_upto(network net, char *filename, int cutoff)
{ {
fprintf(stderr, "Saving weights to %s\n", filename); fprintf(stderr, "Saving weights to %s\n", filename);
FILE *fp = fopen(filename, "w"); FILE *fp = fopen(filename, "w");
@ -512,7 +512,7 @@ void save_weights(network net, char *filename)
fwrite(&net.seen, sizeof(int), 1, fp); fwrite(&net.seen, sizeof(int), 1, fp);
int i; int i;
for(i = 0; i < net.n; ++i){ for(i = 0; i < net.n && i < cutoff; ++i){
layer l = net.layers[i]; layer l = net.layers[i];
if(l.type == CONVOLUTIONAL){ if(l.type == CONVOLUTIONAL){
#ifdef GPU #ifdef GPU
@ -546,6 +546,10 @@ void save_weights(network net, char *filename)
} }
fclose(fp); fclose(fp);
} }
void save_weights(network net, char *filename)
{
save_weights_upto(net, filename, net.n);
}
void load_weights_upto(network *net, char *filename, int cutoff) void load_weights_upto(network *net, char *filename, int cutoff)
{ {

View File

@ -5,6 +5,7 @@
network parse_network_cfg(char *filename); network parse_network_cfg(char *filename);
void save_network(network net, char *filename); void save_network(network net, char *filename);
void save_weights(network net, char *filename); void save_weights(network net, char *filename);
void save_weights_upto(network net, char *filename, int cutoff);
void load_weights(network *net, char *filename); void load_weights(network *net, char *filename);
void load_weights_upto(network *net, char *filename, int cutoff); void load_weights_upto(network *net, char *filename, int cutoff);