mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
idk man 🐍 stuff
This commit is contained in:
@@ -82,6 +82,27 @@ void reset_momentum(network net)
|
||||
#endif
|
||||
}
|
||||
|
||||
void reset_network_state(network net, int b)
|
||||
{
|
||||
int i;
|
||||
for (i = 0; i < net.n; ++i) {
|
||||
#ifdef GPU
|
||||
layer l = net.layers[i];
|
||||
if(l.state_gpu){
|
||||
fill_gpu(l.outputs, 0, l.state_gpu + l.outputs*b, 1);
|
||||
}
|
||||
if(l.h_gpu){
|
||||
fill_gpu(l.outputs, 0, l.h_gpu + l.outputs*b, 1);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void reset_rnn(network *net)
|
||||
{
|
||||
reset_network_state(*net, 0);
|
||||
}
|
||||
|
||||
float get_current_rate(network net)
|
||||
{
|
||||
size_t batch_num = get_current_batch(net);
|
||||
@@ -302,6 +323,15 @@ float train_network(network net, data d)
|
||||
return (float)sum/(n*batch);
|
||||
}
|
||||
|
||||
void set_temp_network(network net, float t)
|
||||
{
|
||||
int i;
|
||||
for(i = 0; i < net.n; ++i){
|
||||
net.layers[i].temperature = t;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void set_batch_network(network *net, int b)
|
||||
{
|
||||
net->batch = b;
|
||||
|
||||
@@ -203,7 +203,7 @@ void forward_region_layer(const layer l, network net)
|
||||
int class_index = entry_index(l, b, n, l.coords + 1);
|
||||
int obj_index = entry_index(l, b, n, l.coords);
|
||||
float scale = l.output[obj_index];
|
||||
//l.delta[obj_index] = l.noobject_scale * (0 - l.output[obj_index]);
|
||||
l.delta[obj_index] = l.noobject_scale * (0 - l.output[obj_index]);
|
||||
float p = scale*get_hierarchy_probability(l.output + class_index, l.softmax_tree, class, l.w*l.h);
|
||||
if(p > maxp){
|
||||
maxp = p;
|
||||
@@ -213,8 +213,8 @@ void forward_region_layer(const layer l, network net)
|
||||
int class_index = entry_index(l, b, maxi, l.coords + 1);
|
||||
int obj_index = entry_index(l, b, maxi, l.coords);
|
||||
delta_region_class(l.output, l.delta, class_index, class, l.classes, l.softmax_tree, l.class_scale, l.w*l.h, &avg_cat);
|
||||
//if(l.output[obj_index] < .3) l.delta[obj_index] = l.object_scale * (.3 - l.output[obj_index]);
|
||||
//else l.delta[obj_index] = 0;
|
||||
if(l.output[obj_index] < .3) l.delta[obj_index] = l.object_scale * (.3 - l.output[obj_index]);
|
||||
else l.delta[obj_index] = 0;
|
||||
l.delta[obj_index] = 0;
|
||||
++class_count;
|
||||
onlyclass = 1;
|
||||
|
||||
Reference in New Issue
Block a user