This commit is contained in:
Joseph Redmon 2015-06-10 00:11:41 -07:00
parent 7fe80a2bb5
commit cbc9984a17
21 changed files with 7440 additions and 770 deletions

View File

@ -1,5 +1,7 @@
GPU=1
OPENCV=1
DEBUG=0
ARCH= -arch=sm_52
VPATH=./src/
@ -9,9 +11,9 @@ OBJDIR=./obj/
CC=gcc
NVCC=nvcc
OPTS=-Ofast
LDFLAGS=`pkg-config --libs opencv` -lm -pthread -lstdc++
COMMON=`pkg-config --cflags opencv` -I/usr/local/cuda/include/
CFLAGS=-Wall -Wfatal-errors
LDFLAGS= -lm -pthread -lstdc++
COMMON= -I/usr/local/cuda/include/
CFLAGS=-Wall -Wfatal-errors
ifeq ($(DEBUG), 1)
OPTS=-O0 -g
@ -19,9 +21,16 @@ endif
CFLAGS+=$(OPTS)
ifeq ($(OPENCV), 1)
COMMON+= -DOPENCV
CFLAGS+= -DOPENCV
LDFLAGS+= `pkg-config --libs opencv`
COMMON+= `pkg-config --cflags opencv`
endif
ifeq ($(GPU), 1)
COMMON+=-DGPU
CFLAGS+=-DGPU
COMMON+= -DGPU
CFLAGS+= -DGPU
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
endif

View File

@ -1,204 +0,0 @@
[net]
batch=64
subdivisions=4
height=448
width=448
channels=3
learning_rate=0.01
momentum=0.9
decay=0.0005
seen = 0
[crop]
crop_width=448
crop_height=448
flip=0
angle=0
saturation = 2
exposure = 2
[convolutional]
filters=64
size=7
stride=2
pad=1
activation=ramp
[convolutional]
filters=192
size=3
stride=2
pad=1
activation=ramp
[convolutional]
filters=128
size=1
stride=1
pad=1
activation=ramp
[convolutional]
filters=256
size=3
stride=2
pad=1
activation=ramp
[convolutional]
filters=128
size=1
stride=1
pad=1
activation=ramp
[convolutional]
filters=256
size=3
stride=1
pad=1
activation=ramp
[convolutional]
filters=128
size=1
stride=1
pad=1
activation=ramp
[convolutional]
filters=512
size=3
stride=2
pad=1
activation=ramp
[convolutional]
filters=256
size=1
stride=1
pad=1
activation=ramp
[convolutional]
filters=512
size=3
stride=1
pad=1
activation=ramp
[convolutional]
filters=256
size=1
stride=1
pad=1
activation=ramp
[convolutional]
filters=512
size=3
stride=1
pad=1
activation=ramp
[convolutional]
filters=256
size=1
stride=1
pad=1
activation=ramp
[convolutional]
filters=512
size=3
stride=1
pad=1
activation=ramp
[convolutional]
filters=256
size=1
stride=1
pad=1
activation=ramp
[convolutional]
filters=512
size=3
stride=1
pad=1
activation=ramp
[convolutional]
filters=256
size=1
stride=1
pad=1
activation=ramp
[convolutional]
filters=1024
size=3
stride=2
pad=1
activation=ramp
[convolutional]
filters=512
size=1
stride=1
pad=1
activation=ramp
[convolutional]
filters=1024
size=3
stride=1
pad=1
activation=ramp
[convolutional]
size=3
stride=1
pad=1
filters=1024
activation=ramp
[convolutional]
size=3
stride=2
pad=1
filters=1024
activation=ramp
[convolutional]
size=3
stride=1
pad=1
filters=1024
activation=ramp
[convolutional]
size=3
stride=1
pad=1
filters=1024
activation=ramp
[connected]
output=4096
activation=ramp
[dropout]
probability=.5
[connected]
output=1225
activation=logistic
[detection]
classes=20
coords=4
rescore=0
nuisance = 1
background=1

View File

@ -1,6 +1,6 @@
[net]
batch=64
subdivisions=4
batch=1
subdivisions=1
height=448
width=448
channels=3
@ -199,6 +199,7 @@ activation=logistic
[detection]
classes=20
coords=4
rescore=1
nuisance = 0
rescore=0
joint=1
objectness = 0
background=0

BIN
data/dog.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 159 KiB

BIN
data/eagle.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 139 KiB

BIN
data/horses.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

BIN
data/person.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

View File

@ -62,7 +62,9 @@ void decode_captcha(char *cfgfile, char *weightfile)
float *predictions = network_predict(net, X);
image out = float_to_image(300, 57, 1, predictions);
show_image(out, "decoded");
#ifdef OPENCV
cvWaitKey(0);
#endif
free_image(im);
}
}

View File

@ -227,6 +227,15 @@ image get_convolutional_filter(convolutional_layer l, int i)
return float_to_image(w,h,c,l.filters+i*h*w*c);
}
void rgbgr_filters(convolutional_layer l)
{
int i;
for(i = 0; i < l.n; ++i){
image im = get_convolutional_filter(l, i);
if (im.c == 3) rgbgr_image(im);
}
}
image *get_filters(convolutional_layer l)
{
image *filters = calloc(l.n, sizeof(image));

View File

@ -114,9 +114,9 @@ __global__ void levels_image_kernel(float *image, float *rand, int batch, int w,
size_t offset = id * h * w * 3;
image += offset;
float r = image[x + w*(y + h*2)];
float r = image[x + w*(y + h*0)];
float g = image[x + w*(y + h*1)];
float b = image[x + w*(y + h*0)];
float b = image[x + w*(y + h*2)];
float3 rgb = make_float3(r,g,b);
if(train){
float3 hsv = rgb_to_hsv_kernel(rgb);
@ -124,9 +124,9 @@ __global__ void levels_image_kernel(float *image, float *rand, int batch, int w,
hsv.z *= exposure;
rgb = hsv_to_rgb_kernel(hsv);
}
image[x + w*(y + h*2)] = rgb.x*scale + translate;
image[x + w*(y + h*0)] = rgb.x*scale + translate;
image[x + w*(y + h*1)] = rgb.y*scale + translate;
image[x + w*(y + h*0)] = rgb.z*scale + translate;
image[x + w*(y + h*2)] = rgb.z*scale + translate;
}
__global__ void forward_crop_layer_kernel(float *input, float *rand, int size, int c, int h, int w, int crop_height, int crop_width, int train, int flip, float angle, float *output)

View File

@ -73,6 +73,25 @@ void partial(char *cfgfile, char *weightfile, char *outfile, int max)
save_weights(net, outfile);
}
#include "convolutional_layer.h"
void rgbgr_filters(convolutional_layer l);
void rgbgr_net(char *cfgfile, char *weightfile, char *outfile)
{
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
int i;
for(i = 0; i < net.n; ++i){
layer l = net.layers[i];
if(l.type == CONVOLUTIONAL){
rgbgr_filters(l);
break;
}
}
save_weights(net, outfile);
}
void visualize(char *cfgfile, char *weightfile)
{
network net = parse_network_cfg(cfgfile);
@ -80,11 +99,14 @@ void visualize(char *cfgfile, char *weightfile)
load_weights(&net, weightfile);
}
visualize_network(net);
#ifdef OPENCV
cvWaitKey(0);
#endif
}
int main(int argc, char **argv)
{
//test_resize("data/cat.png");
//test_box();
//test_convolutional_layer();
if(argc < 2){
@ -114,6 +136,8 @@ int main(int argc, char **argv)
run_captcha(argc, argv);
} else if (0 == strcmp(argv[1], "change")){
change_rate(argv[2], atof(argv[3]), (argc > 4) ? atof(argv[4]) : 0);
} else if (0 == strcmp(argv[1], "rgbgr")){
rgbgr_net(argv[2], argv[3], argv[4]);
} else if (0 == strcmp(argv[1], "partial")){
partial(argv[2], argv[3], argv[4], atoi(argv[5]));
} else if (0 == strcmp(argv[1], "visualize")){

View File

@ -69,7 +69,7 @@ matrix load_image_paths_gray(char **paths, int n, int w, int h)
X.cols = 0;
for(i = 0; i < n; ++i){
image im = load_image(paths[i], w, h);
image im = load_image(paths[i], w, h, 1);
X.vals[i] = im.data;
X.cols = im.h*im.w*im.c;
}

View File

@ -6,31 +6,24 @@
char *class_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
char *inet_class_names[] = {"bg", "accordion", "airplane", "ant", "antelope", "apple", "armadillo", "artichoke", "axe", "baby bed", "backpack", "bagel", "balance beam", "banana", "band aid", "banjo", "baseball", "basketball", "bathing cap", "beaker", "bear", "bee", "bell pepper", "bench", "bicycle", "binder", "bird", "bookshelf", "bow tie", "bow", "bowl", "brassiere", "burrito", "bus", "butterfly", "camel", "can opener", "car", "cart", "cattle", "cello", "centipede", "chain saw", "chair", "chime", "cocktail shaker", "coffee maker", "computer keyboard", "computer mouse", "corkscrew", "cream", "croquet ball", "crutch", "cucumber", "cup or mug", "diaper", "digital clock", "dishwasher", "dog", "domestic cat", "dragonfly", "drum", "dumbbell", "electric fan", "elephant", "face powder", "fig", "filing cabinet", "flower pot", "flute", "fox", "french horn", "frog", "frying pan", "giant panda", "goldfish", "golf ball", "golfcart", "guacamole", "guitar", "hair dryer", "hair spray", "hamburger", "hammer", "hamster", "harmonica", "harp", "hat with a wide brim", "head cabbage", "helmet", "hippopotamus", "horizontal bar", "horse", "hotdog", "iPod", "isopod", "jellyfish", "koala bear", "ladle", "ladybug", "lamp", "laptop", "lemon", "lion", "lipstick", "lizard", "lobster", "maillot", "maraca", "microphone", "microwave", "milk can", "miniskirt", "monkey", "motorcycle", "mushroom", "nail", "neck brace", "oboe", "orange", "otter", "pencil box", "pencil sharpener", "perfume", "person", "piano", "pineapple", "ping-pong ball", "pitcher", "pizza", "plastic bag", "plate rack", "pomegranate", "popsicle", "porcupine", "power drill", "pretzel", "printer", "puck", "punching bag", "purse", "rabbit", "racket", "ray", "red panda", "refrigerator", "remote control", "rubber eraser", "rugby ball", "ruler", "salt or pepper shaker", "saxophone", "scorpion", "screwdriver", "seal", "sheep", "ski", "skunk", "snail", "snake", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sofa", "spatula", "squirrel", "starfish", "stethoscope", "stove", "strainer", "strawberry", "stretcher", "sunglasses", "swimming trunks", "swine", "syringe", "table", "tape player", "tennis ball", "tick", "tie", "tiger", "toaster", "traffic light", "train", "trombone", "trumpet", "turtle", "tv or monitor", "unicycle", "vacuum", "violin", "volleyball", "waffle iron", "washer", "water bottle", "watercraft", "whale", "wine bottle", "zebra"};
#define AMNT 3
void draw_detection(image im, float *box, int side, char *label)
void draw_detection(image im, float *box, int side, int bg, char *label)
{
int classes = 20;
int elems = 4+classes;
int elems = 4+classes+bg;
int j;
int r, c;
for(r = 0; r < side; ++r){
for(c = 0; c < side; ++c){
j = (r*side + c) * elems;
//printf("%d\n", j);
//printf("Prob: %f\n", box[j]);
j = (r*side + c) * elems + bg;
int class = max_index(box+j, classes);
if(box[j+class] > .05){
//int z;
//for(z = 0; z < classes; ++z) printf("%f %s\n", box[j+z], class_names[z]);
if(box[j+class] > .2){
printf("%f %s\n", box[j+class], class_names[class]);
float red = get_color(0,class,classes);
float green = get_color(1,class,classes);
float blue = get_color(2,class,classes);
//float maxheight = distance_from_edge(r, side);
//float maxwidth = distance_from_edge(c, side);
j += classes;
float x = box[j+0];
float y = box[j+1];
@ -40,7 +33,6 @@ void draw_detection(image im, float *box, int side, char *label)
float h = box[j+3]; //*maxheight;
h = h*h;
w = w*w;
//printf("coords %f %f %f %f\n", x, y, w, h);
int left = (x-w/2)*im.w;
int right = (x+w/2)*im.w;
@ -52,184 +44,9 @@ void draw_detection(image im, float *box, int side, char *label)
}
}
}
//printf("Done\n");
show_image(im, label);
}
void draw_localization(image im, float *box)
{
int classes = 20;
int class;
for(class = 0; class < classes; ++class){
//int z;
//for(z = 0; z < classes; ++z) printf("%f %s\n", box[j+z], class_names[z]);
float red = get_color(0,class,classes);
float green = get_color(1,class,classes);
float blue = get_color(2,class,classes);
int j = class*4;
float x = box[j+0];
float y = box[j+1];
float w = box[j+2]; //*maxheight;
float h = box[j+3]; //*maxwidth;
//printf("coords %f %f %f %f\n", x, y, w, h);
int left = (x-w/2)*im.w;
int right = (x+w/2)*im.w;
int top = (y-h/2)*im.h;
int bot = (y+h/2)*im.h;
draw_box(im, left, top, right, bot, red, green, blue);
}
//printf("Done\n");
}
void train_localization(char *cfgfile, char *weightfile)
{
srand(time(0));
data_seed = time(0);
char *base = basecfg(cfgfile);
printf("%s\n", base);
float avg_loss = -1;
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int imgs = 128;
int classes = 20;
int i = net.seen/imgs;
data train, buffer;
char **paths;
list *plist;
plist = get_paths("/home/pjreddie/data/voc/loc.2012val.txt");
paths = (char **)list_to_array(plist);
pthread_t load_thread = load_data_localization_thread(imgs, paths, plist->size, classes, net.w, net.h, &buffer);
clock_t time;
while(1){
i += 1;
time=clock();
pthread_join(load_thread, 0);
train = buffer;
load_thread = load_data_localization_thread(imgs, paths, plist->size, classes, net.w, net.h, &buffer);
printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
float loss = train_network(net, train);
//TODO
#ifdef GPU
float *out = get_network_output_gpu(net);
#else
float *out = get_network_output(net);
#endif
image im = float_to_image(net.w, net.h, 3, train.X.vals[127]);
image copy = copy_image(im);
draw_localization(copy, &(out[63*80]));
draw_localization(copy, train.y.vals[127]);
show_image(copy, "box");
cvWaitKey(0);
free_image(copy);
net.seen += imgs;
if (avg_loss < 0) avg_loss = loss;
avg_loss = avg_loss*.9 + loss*.1;
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs);
if(i%100==0){
char buff[256];
sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
save_weights(net, buff);
}
free_data(train);
}
}
void train_detection_teststuff(char *cfgfile, char *weightfile)
{
srand(time(0));
data_seed = time(0);
int imgnet = 0;
char *base = basecfg(cfgfile);
printf("%s\n", base);
float avg_loss = -1;
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
detection_layer layer = get_network_detection_layer(net);
net.learning_rate = 0;
net.decay = 0;
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int imgs = 128;
int i = net.seen/imgs;
data train, buffer;
int classes = layer.classes;
int background = layer.background;
int side = sqrt(get_detection_layer_locations(layer));
char **paths;
list *plist;
if (imgnet){
plist = get_paths("/home/pjreddie/data/imagenet/det.train.list");
}else{
plist = get_paths("/home/pjreddie/data/voc/val_2012.txt");
//plist = get_paths("/home/pjreddie/data/voc/no_2007_test.txt");
//plist = get_paths("/home/pjreddie/data/coco/trainval.txt");
//plist = get_paths("/home/pjreddie/data/voc/all2007-2012.txt");
}
paths = (char **)list_to_array(plist);
pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer);
clock_t time;
cost_layer clayer = net.layers[net.n-1];
while(1){
i += 1;
time=clock();
pthread_join(load_thread, 0);
train = buffer;
load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer);
/*
image im = float_to_image(net.w, net.h, 3, train.X.vals[114]);
image copy = copy_image(im);
draw_detection(copy, train.y.vals[114], 7);
free_image(copy);
*/
int z;
int count = 0;
float sx, sy, sw, sh;
sx = sy = sw = sh = 0;
for(z = 0; z < clayer.batch*clayer.inputs; z += 24){
if(clayer.delta[z+20]){
++count;
sx += fabs(clayer.delta[z+20])*64;
sy += fabs(clayer.delta[z+21])*64;
sw += fabs(clayer.delta[z+22])*448;
sh += fabs(clayer.delta[z+23])*448;
}
}
printf("Avg error: %f, %f, %f x %f\n", sx/count, sy/count, sw/count, sh/count);
printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
float loss = train_network(net, train);
net.seen += imgs;
if (avg_loss < 0) avg_loss = loss;
avg_loss = avg_loss*.9 + loss*.1;
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs);
if(i == 100){
//net.learning_rate *= 10;
}
if(i%100==0){
char buff[256];
sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
save_weights(net, buff);
}
free_data(train);
}
}
void train_detection(char *cfgfile, char *weightfile)
{
srand(time(0));
@ -249,7 +66,7 @@ void train_detection(char *cfgfile, char *weightfile)
data train, buffer;
int classes = layer.classes;
int background = layer.background;
int background = (layer.background || layer.objectness);
int side = sqrt(get_detection_layer_locations(layer));
char **paths;
@ -301,7 +118,7 @@ void train_detection(char *cfgfile, char *weightfile)
}
}
void predict_detections(network net, data d, float threshold, int offset, int classes, int nuisance, int background, int num_boxes, int per_box)
void predict_detections(network net, data d, float threshold, int offset, int classes, int objectness, int background, int num_boxes, int per_box)
{
matrix pred = network_predict_data(net, d);
int j, k, class;
@ -311,16 +128,16 @@ void predict_detections(network net, data d, float threshold, int offset, int cl
int index = k/per_box;
int row = index / num_boxes;
int col = index % num_boxes;
if (nuisance) scale = 1.-pred.vals[j][k];
if (objectness) scale = 1.-pred.vals[j][k];
for (class = 0; class < classes; ++class){
int ci = k+classes+background+nuisance;
int ci = k+classes+(background || objectness);
float x = (pred.vals[j][ci + 0] + col)/num_boxes;
float y = (pred.vals[j][ci + 1] + row)/num_boxes;
float w = pred.vals[j][ci + 2]; // distance_from_edge(row, num_boxes);
float h = pred.vals[j][ci + 3]; // distance_from_edge(col, num_boxes);
w = pow(w, 2);
h = pow(h, 2);
float prob = scale*pred.vals[j][k+class+background+nuisance];
float prob = scale*pred.vals[j][k+class+(background || objectness)];
if(prob < threshold) continue;
printf("%d %d %f %f %f %f %f\n", offset + j, class, prob, x, y, w, h);
}
@ -339,19 +156,15 @@ void validate_detection(char *cfgfile, char *weightfile)
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
srand(time(0));
//list *plist = get_paths("/home/pjreddie/data/voc/test_2007.txt");
//list *plist = get_paths("/home/pjreddie/data/voc/val_2012.txt");
list *plist = get_paths("/home/pjreddie/data/voc/test.txt");
//list *plist = get_paths("/home/pjreddie/data/voc/val.expanded.txt");
//list *plist = get_paths("/home/pjreddie/data/voc/train.txt");
char **paths = (char **)list_to_array(plist);
int classes = layer.classes;
int nuisance = layer.nuisance;
int background = (layer.background && !nuisance);
int objectness = layer.objectness;
int background = layer.background;
int num_boxes = sqrt(get_detection_layer_locations(layer));
int per_box = 4+classes+background+nuisance;
int per_box = 4+classes+(background || objectness);
int num_output = num_boxes*num_boxes*per_box;
int m = plist->size;
@ -372,9 +185,7 @@ void validate_detection(char *cfgfile, char *weightfile)
thr[t] = load_data_thread(part, num, 0, 0, num_output, net.w, net.h, &(buf[t]));
}
//clock_t time;
for(i = nthreads; i <= splits; i += nthreads){
//time=clock();
for(t = 0; t < nthreads; ++t){
pthread_join(thr[t], 0);
val[t] = buf[t];
@ -385,223 +196,22 @@ void validate_detection(char *cfgfile, char *weightfile)
thr[t] = load_data_thread(part, num, 0, 0, num_output, net.w, net.h, &(buf[t]));
}
//fprintf(stderr, "%d: Loaded: %lf seconds\n", i, sec(clock()-time));
fprintf(stderr, "%d\n", i);
for(t = 0; t < nthreads; ++t){
predict_detections(net, val[t], .001, (i-nthreads+t)*m/splits, classes, nuisance, background, num_boxes, per_box);
predict_detections(net, val[t], .001, (i-nthreads+t)*m/splits, classes, objectness, background, num_boxes, per_box);
free_data(val[t]);
}
}
fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
}
void do_mask(network net, data d, int offset, int classes, int nuisance, int background, int num_boxes, int per_box)
{
matrix pred = network_predict_data(net, d);
int j, k;
for(j = 0; j < pred.rows; ++j){
printf("%d ", offset + j);
for(k = 0; k < pred.cols; k += per_box){
float scale = 1.-pred.vals[j][k];
printf("%f ", scale);
}
printf("\n");
}
free_matrix(pred);
}
void mask_detection(char *cfgfile, char *weightfile)
{
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
detection_layer layer = get_network_detection_layer(net);
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
srand(time(0));
list *plist = get_paths("/home/pjreddie/data/voc/test_2007.txt");
//list *plist = get_paths("/home/pjreddie/data/voc/val_2012.txt");
//list *plist = get_paths("/home/pjreddie/data/voc/test.txt");
//list *plist = get_paths("/home/pjreddie/data/voc/val.expanded.txt");
//list *plist = get_paths("/home/pjreddie/data/voc/train.txt");
char **paths = (char **)list_to_array(plist);
int classes = layer.classes;
int nuisance = layer.nuisance;
int background = (layer.background && !nuisance);
int num_boxes = sqrt(get_detection_layer_locations(layer));
int per_box = 4+classes+background+nuisance;
int num_output = num_boxes*num_boxes*per_box;
int m = plist->size;
int i = 0;
int splits = 100;
int nthreads = 4;
int t;
data *val = calloc(nthreads, sizeof(data));
data *buf = calloc(nthreads, sizeof(data));
pthread_t *thr = calloc(nthreads, sizeof(data));
for(t = 0; t < nthreads; ++t){
int num = (i+1+t)*m/splits - (i+t)*m/splits;
char **part = paths+((i+t)*m/splits);
thr[t] = load_data_thread(part, num, 0, 0, num_output, net.w, net.h, &(buf[t]));
}
clock_t time;
for(i = nthreads; i <= splits; i += nthreads){
time=clock();
for(t = 0; t < nthreads; ++t){
pthread_join(thr[t], 0);
val[t] = buf[t];
}
for(t = 0; t < nthreads && i < splits; ++t){
int num = (i+1+t)*m/splits - (i+t)*m/splits;
char **part = paths+((i+t)*m/splits);
thr[t] = load_data_thread(part, num, 0, 0, num_output, net.w, net.h, &(buf[t]));
}
fprintf(stderr, "%d: Loaded: %lf seconds\n", i, sec(clock()-time));
for(t = 0; t < nthreads; ++t){
do_mask(net, val[t], (i-nthreads+t)*m/splits, classes, nuisance, background, num_boxes, per_box);
free_data(val[t]);
}
time=clock();
}
}
void validate_detection_post(char *cfgfile, char *weightfile)
{
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
set_batch_network(&net, 1);
network post = parse_network_cfg("cfg/localize.cfg");
load_weights(&post, "/home/pjreddie/imagenet_backup/localize_1000.weights");
set_batch_network(&post, 1);
detection_layer layer = get_network_detection_layer(net);
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
srand(time(0));
//list *plist = get_paths("/home/pjreddie/data/voc/test_2007.txt");
list *plist = get_paths("/home/pjreddie/data/voc/val_2012.txt");
//list *plist = get_paths("/home/pjreddie/data/voc/test.txt");
//list *plist = get_paths("/home/pjreddie/data/voc/val.expanded.txt");
//list *plist = get_paths("/home/pjreddie/data/voc/train.txt");
char **paths = (char **)list_to_array(plist);
int classes = layer.classes;
int nuisance = layer.nuisance;
int background = (layer.background && !nuisance);
int num_boxes = sqrt(get_detection_layer_locations(layer));
int per_box = 4+classes+background+nuisance;
int m = plist->size;
int i = 0;
float threshold = .01;
clock_t time = clock();
for(i = 0; i < m; ++i){
image im = load_image_color(paths[i], 0, 0);
if(i % 100 == 0) {
fprintf(stderr, "%d: Loaded: %lf seconds\n", i, sec(clock()-time));
time = clock();
}
image sized = resize_image(im, net.w, net.h);
float *out = network_predict(net, sized.data);
free_image(sized);
int k, class;
//show_image(im, "original");
int num_output = num_boxes*num_boxes*per_box;
//image cp1 = copy_image(im);
//draw_detection(cp1, out, 7, "before");
for(k = 0; k < num_output; k += per_box){
float *post_out = 0;
float scale = 1.;
int index = k/per_box;
int row = index / num_boxes;
int col = index % num_boxes;
if (nuisance) scale = 1.-out[k];
for (class = 0; class < classes; ++class){
int ci = k+classes+background+nuisance;
float x = (out[ci + 0] + col)/num_boxes;
float y = (out[ci + 1] + row)/num_boxes;
float w = out[ci + 2]; //* distance_from_edge(row, num_boxes);
float h = out[ci + 3]; //* distance_from_edge(col, num_boxes);
w = w*w;
h = h*h;
float prob = scale*out[k+class+background+nuisance];
if (prob >= threshold) {
x *= im.w;
y *= im.h;
w *= im.w;
h *= im.h;
w += 32;
h += 32;
int left = (x - w/2);
int top = (y - h/2);
int right = (x + w/2);
int bot = (y+h/2);
if (left < 0) left = 0;
if (right > im.w) right = im.w;
if (top < 0) top = 0;
if (bot > im.h) bot = im.h;
image crop = crop_image(im, left, top, right-left, bot-top);
image resize = resize_image(crop, post.w, post.h);
if (!post_out){
post_out = network_predict(post, resize.data);
}
/*
draw_localization(resize, post_out);
show_image(resize, "second");
fprintf(stderr, "%s\n", class_names[class]);
cvWaitKey(0);
*/
int index = 4*class;
float px = post_out[index+0];
float py = post_out[index+1];
float pw = post_out[index+2];
float ph = post_out[index+3];
px = (px * crop.w + left) / im.w;
py = (py * crop.h + top) / im.h;
pw = (pw * crop.w) / im.w;
ph = (ph * crop.h) / im.h;
out[ci + 0] = px*num_boxes - col;
out[ci + 1] = py*num_boxes - row;
out[ci + 2] = sqrt(pw);
out[ci + 3] = sqrt(ph);
/*
show_image(crop, "cropped");
cvWaitKey(0);
*/
free_image(crop);
free_image(resize);
printf("%d %d %f %f %f %f %f\n", i, class, prob, px, py, pw, ph);
}
}
}
/*
image cp2 = copy_image(im);
draw_detection(cp2, out, 7, "after");
cvWaitKey(0);
*/
}
}
void test_detection(char *cfgfile, char *weightfile)
{
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
detection_layer layer = get_network_detection_layer(net);
int im_size = 448;
set_batch_network(&net, 1);
srand(2222222);
@ -617,10 +227,12 @@ void test_detection(char *cfgfile, char *weightfile)
time=clock();
float *predictions = network_predict(net, X);
printf("%s: Predicted in %f seconds.\n", filename, sec(clock()-time));
draw_detection(im, predictions, 7, "YOLO#SWAG#BLAZEIT");
draw_detection(im, predictions, 7, layer.background || layer.objectness, "predictions");
free_image(im);
free_image(sized);
#ifdef OPENCV
cvWaitKey(0);
#endif
}
}
@ -635,9 +247,5 @@ void run_detection(int argc, char **argv)
char *weights = (argc > 4) ? argv[4] : 0;
if(0==strcmp(argv[2], "test")) test_detection(cfg, weights);
else if(0==strcmp(argv[2], "train")) train_detection(cfg, weights);
else if(0==strcmp(argv[2], "teststuff")) train_detection_teststuff(cfg, weights);
else if(0==strcmp(argv[2], "trainloc")) train_localization(cfg, weights);
else if(0==strcmp(argv[2], "valid")) validate_detection(cfg, weights);
else if(0==strcmp(argv[2], "mask")) mask_detection(cfg, weights);
else if(0==strcmp(argv[2], "validpost")) validate_detection_post(cfg, weights);
}

View File

@ -10,15 +10,15 @@
int get_detection_layer_locations(detection_layer l)
{
return l.inputs / (l.classes+l.coords+l.rescore+l.background);
return l.inputs / (l.classes+l.coords+l.joint+(l.background || l.objectness));
}
int get_detection_layer_output_size(detection_layer l)
{
return get_detection_layer_locations(l)*(l.background + l.classes + l.coords);
return get_detection_layer_locations(l)*((l.background || l.objectness) + l.classes + l.coords);
}
detection_layer make_detection_layer(int batch, int inputs, int classes, int coords, int rescore, int background, int nuisance)
detection_layer make_detection_layer(int batch, int inputs, int classes, int coords, int joint, int rescore, int background, int objectness)
{
detection_layer l = {0};
l.type = DETECTION;
@ -28,7 +28,8 @@ detection_layer make_detection_layer(int batch, int inputs, int classes, int coo
l.classes = classes;
l.coords = coords;
l.rescore = rescore;
l.nuisance = nuisance;
l.objectness = objectness;
l.joint = joint;
l.cost = calloc(1, sizeof(float));
l.does_cost=1;
l.background = background;
@ -47,28 +48,6 @@ detection_layer make_detection_layer(int batch, int inputs, int classes, int coo
return l;
}
void dark_zone(detection_layer l, int class, int start, network_state state)
{
int index = start+l.background+class;
int size = l.classes+l.coords+l.background;
int location = (index%(7*7*size)) / size ;
int r = location / 7;
int c = location % 7;
int dr, dc;
for(dr = -1; dr <= 1; ++dr){
for(dc = -1; dc <= 1; ++dc){
if(!(dr || dc)) continue;
if((r + dr) > 6 || (r + dr) < 0) continue;
if((c + dc) > 6 || (c + dc) < 0) continue;
int di = (dr*7 + dc) * size;
if(state.truth[index+di]) continue;
l.output[index + di] = 0;
//if(!state.truth[start+di]) continue;
//l.output[start + di] = 1;
}
}
}
typedef struct{
float dx, dy, dw, dh;
} dbox;
@ -258,24 +237,6 @@ void test_box()
wiou = ((1-wiou)*(1-wiou) - iou)/(.00001);
hiou = ((1-hiou)*(1-hiou) - iou)/(.00001);
printf("manual %f %f %f %f\n", xiou, yiou, wiou, hiou);
/*
while(count++ < 300){
dbox d = diou(a, b);
printf("%f %f %f %f\n", a.x, a.y, a.w, a.h);
a.x += .1*d.dx;
a.w += .1*d.dw;
a.y += .1*d.dy;
a.h += .1*d.dh;
printf("inter: %f\n", box_intersection(a, b));
printf("union: %f\n", box_union(a, b));
printf("IOU: %f\n", box_iou(a, b));
if(d.dx==0 && d.dw==0 && d.dy==0 && d.dh==0) {
printf("break!!!\n");
break;
}
}
*/
}
dbox diou(box a, box b)
@ -308,10 +269,10 @@ void forward_detection_layer(const detection_layer l, network_state state)
int locations = get_detection_layer_locations(l);
int i,j;
for(i = 0; i < l.batch*locations; ++i){
int mask = (!state.truth || state.truth[out_i + l.background + l.classes + 2]);
int mask = (!state.truth || state.truth[out_i + (l.background || l.objectness) + l.classes + 2]);
float scale = 1;
if(l.rescore) scale = state.input[in_i++];
else if(l.nuisance){
if(l.joint) scale = state.input[in_i++];
else if(l.objectness){
l.output[out_i++] = 1-state.input[in_i++];
scale = mask;
}
@ -320,7 +281,7 @@ void forward_detection_layer(const detection_layer l, network_state state)
for(j = 0; j < l.classes; ++j){
l.output[out_i++] = scale*state.input[in_i++];
}
if(l.nuisance){
if(l.objectness){
}else if(l.background){
softmax_array(l.output + out_i - l.classes-l.background, l.classes+l.background, l.output + out_i - l.classes-l.background);
@ -337,7 +298,7 @@ void forward_detection_layer(const detection_layer l, network_state state)
int size = get_detection_layer_output_size(l) * l.batch;
memset(l.delta, 0, size * sizeof(float));
for (i = 0; i < l.batch*locations; ++i) {
int classes = l.nuisance+l.classes;
int classes = l.objectness+l.classes;
int offset = i*(classes+l.coords);
for (j = offset; j < offset+classes; ++j) {
*(l.cost) += pow(state.truth[j] - l.output[j], 2);
@ -372,7 +333,7 @@ void forward_detection_layer(const detection_layer l, network_state state)
l.delta[j+1] = 4 * (state.truth[j+1] - l.output[j+1]);
l.delta[j+2] = 4 * (state.truth[j+2] - l.output[j+2]);
l.delta[j+3] = 4 * (state.truth[j+3] - l.output[j+3]);
if(0){
if(l.rescore){
for (j = offset; j < offset+classes; ++j) {
if(state.truth[j]) state.truth[j] = iou;
l.delta[j] = state.truth[j] - l.output[j];
@ -392,21 +353,21 @@ void backward_detection_layer(const detection_layer l, network_state state)
for(i = 0; i < l.batch*locations; ++i){
float scale = 1;
float latent_delta = 0;
if(l.rescore) scale = state.input[in_i++];
else if (l.nuisance) state.delta[in_i++] = -l.delta[out_i++];
if(l.joint) scale = state.input[in_i++];
else if (l.objectness) state.delta[in_i++] = -l.delta[out_i++];
else if (l.background) state.delta[in_i++] = scale*l.delta[out_i++];
for(j = 0; j < l.classes; ++j){
latent_delta += state.input[in_i]*l.delta[out_i];
state.delta[in_i++] = scale*l.delta[out_i++];
}
if (l.nuisance) {
if (l.objectness) {
}else if (l.background) gradient_array(l.output + out_i, l.coords, LOGISTIC, l.delta + out_i);
for(j = 0; j < l.coords; ++j){
state.delta[in_i++] = l.delta[out_i++];
}
if(l.rescore) state.delta[in_i-l.coords-l.classes-l.rescore-l.background] = latent_delta;
if(l.joint) state.delta[in_i-l.coords-l.classes-l.joint] = latent_delta;
}
}

View File

@ -6,7 +6,7 @@
typedef layer detection_layer;
detection_layer make_detection_layer(int batch, int inputs, int classes, int coords, int rescore, int background, int nuisance);
detection_layer make_detection_layer(int batch, int inputs, int classes, int coords, int joint, int rescore, int background, int objectness);
void forward_detection_layer(const detection_layer l, network_state state);
void backward_detection_layer(const detection_layer l, network_state state);
int get_detection_layer_output_size(detection_layer l);

View File

@ -3,6 +3,11 @@
#include <stdio.h>
#include <math.h>
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "stb_image_write.h"
int windows = 0;
float colors[6][3] = { {1,0,1}, {0,0,1},{0,1,1},{0,1,0},{1,1,0},{1,0,0} };
@ -159,10 +164,22 @@ image copy_image(image p)
return copy;
}
void show_image(image p, char *name)
void rgbgr_image(image im)
{
int i;
for(i = 0; i < im.w*im.h; ++i){
float swap = im.data[i];
im.data[i] = im.data[i+im.w*im.h*2];
im.data[i+im.w*im.h*2] = swap;
}
}
#ifdef OPENCV
void show_image_cv(image p, char *name)
{
int x,y,k;
image copy = copy_image(p);
rgbgr_image(copy);
//normalize_image(copy);
char buff[256];
@ -197,8 +214,36 @@ void show_image(image p, char *name)
cvShowImage(buff, disp);
cvReleaseImage(&disp);
}
#endif
void save_image(image p, char *name)
void show_image(image p, char *name)
{
#ifdef OPENCV
show_image_cv(p, name);
#else
fprintf(stderr, "Not compiled with OpenCV, saving to %s.png instead\n", name);
save_image(p, name);
#endif
}
void save_image(image im, char *name)
{
char buff[256];
//sprintf(buff, "%s (%d)", name, windows);
sprintf(buff, "%s.png", name);
unsigned char *data = calloc(im.w*im.h*im.c, sizeof(char));
int i,k;
for(k = 0; k < im.c; ++k){
for(i = 0; i < im.w*im.h; ++i){
data[i*im.c+k] = (unsigned char) (255*im.data[i + k*im.w*im.h]);
}
}
int success = stbi_write_png(buff, im.w, im.h, im.c, data, im.w*im.c);
if(!success) fprintf(stderr, "Failed to write image %s\n", buff);
}
/*
void save_image_cv(image p, char *name)
{
int x,y,k;
image copy = copy_image(p);
@ -221,6 +266,7 @@ void save_image(image p, char *name)
cvSaveImage(buff, disp,0);
cvReleaseImage(&disp);
}
*/
void show_image_layers(image p, char *name)
{
@ -296,26 +342,6 @@ void scale_image(image m, float s)
for(i = 0; i < m.h*m.w*m.c; ++i) m.data[i] *= s;
}
image ipl_to_image(IplImage* src)
{
unsigned char *data = (unsigned char *)src->imageData;
int h = src->height;
int w = src->width;
int c = src->nChannels;
int step = src->widthStep;
image out = make_image(w, h, c);
int i, j, k, count=0;;
for(k= 0; k < c; ++k){
for(i = 0; i < h; ++i){
for(j = 0; j < w; ++j){
out.data[count++] = data[i*step + j*c + k]/255.;
}
}
}
return out;
}
image crop_image(image im, int dx, int dy, int w, int h)
{
image cropped = make_image(w, h, im.c);
@ -355,9 +381,9 @@ void rgb_to_hsv(image im)
float h, s, v;
for(j = 0; j < im.h; ++j){
for(i = 0; i < im.w; ++i){
r = get_pixel(im, i , j, 2);
r = get_pixel(im, i , j, 0);
g = get_pixel(im, i , j, 1);
b = get_pixel(im, i , j, 0);
b = get_pixel(im, i , j, 2);
float max = three_way_max(r,g,b);
float min = three_way_min(r,g,b);
float delta = max - min;
@ -417,9 +443,9 @@ void hsv_to_rgb(image im)
r = v; g = p; b = q;
}
}
set_pixel(im, i, j, 2, r);
set_pixel(im, i, j, 0, r);
set_pixel(im, i, j, 1, g);
set_pixel(im, i, j, 0, b);
set_pixel(im, i, j, 2, b);
}
}
}
@ -429,7 +455,7 @@ image grayscale_image(image im)
assert(im.c == 3);
int i, j, k;
image gray = make_image(im.w, im.h, im.c);
float scale[] = {0.114, 0.587, 0.299};
float scale[] = {0.587, 0.299, 0.114};
for(k = 0; k < im.c; ++k){
for(j = 0; j < im.h; ++j){
for(i = 0; i < im.w; ++i){
@ -497,21 +523,21 @@ void saturate_exposure_image(image im, float sat, float exposure)
}
/*
image saturate_image(image im, float sat)
{
image gray = grayscale_image(im);
image blend = blend_image(im, gray, sat);
free_image(gray);
constrain_image(blend);
return blend;
}
image saturate_image(image im, float sat)
{
image gray = grayscale_image(im);
image blend = blend_image(im, gray, sat);
free_image(gray);
constrain_image(blend);
return blend;
}
image brightness_image(image im, float b)
{
image bright = make_image(im.w, im.h, im.c);
return bright;
}
*/
image brightness_image(image im, float b)
{
image bright = make_image(im.w, im.h, im.c);
return bright;
}
*/
float billinear_interpolate(image im, float x, float y, int c)
{
@ -550,7 +576,7 @@ image resize_image(image im, int w, int h)
void test_resize(char *filename)
{
image im = load_image(filename, 0,0);
image im = load_image(filename, 0,0, 3);
image small = resize_image(im, 65, 63);
image big = resize_image(im, 513, 512);
image crop = crop_image(im, 50, 10, 100, 100);
@ -562,11 +588,12 @@ void test_resize(char *filename)
image sat2 = copy_image(im);
saturate_image(sat2, 2);
exposure_image(sat2, 2);
image sat5 = copy_image(im);
saturate_image(sat5, 2);
exposure_image(sat5, .5);
saturate_image(sat5, 5);
image sat10 = copy_image(im);
saturate_image(sat10, 10);
image exp2 = copy_image(im);
saturate_image(exp2, .5);
@ -580,8 +607,7 @@ void test_resize(char *filename)
show_image(gray, "gray");
show_image(sat2, "sat2");
show_image(sat5, "sat5");
show_image(exp2, "exp2");
show_image(exp5, "exp5");
show_image(sat10, "sat10");
/*
show_image(small, "smaller");
show_image(big, "bigger");
@ -591,44 +617,100 @@ void test_resize(char *filename)
show_image(rot2, "rot2");
show_image(test, "test");
*/
#ifdef OPENCV
cvWaitKey(0);
#endif
}
#ifdef OPENCV
image ipl_to_image(IplImage* src)
{
unsigned char *data = (unsigned char *)src->imageData;
int h = src->height;
int w = src->width;
int c = src->nChannels;
int step = src->widthStep;
image out = make_image(w, h, c);
int i, j, k, count=0;;
for(k= 0; k < c; ++k){
for(i = 0; i < h; ++i){
for(j = 0; j < w; ++j){
out.data[count++] = data[i*step + j*c + k]/255.;
}
}
}
return out;
}
image load_image_cv(char *filename, int channels)
{
IplImage* src = 0;
int flag = -1;
if (channels == 0) flag = -1;
else if (channels == 1) flag = 0;
else if (channels == 3) flag = 1;
else {
fprintf(stderr, "OpenCV can't force load with %d channels\n", channels);
}
if( (src = cvLoadImage(filename, flag)) == 0 )
{
printf("Cannot load file image %s\n", filename);
exit(0);
}
image out = ipl_to_image(src);
cvReleaseImage(&src);
rgbgr_image(out);
return out;
}
#endif
image load_image_stb(char *filename, int channels)
{
int w, h, c;
unsigned char *data = stbi_load(filename, &w, &h, &c, channels);
if (!data) {
printf("Cannot load file image %s\n", filename);
exit(0);
}
if(channels) c = channels;
int i,j,k;
image im = make_image(w, h, c);
for(k = 0; k < c; ++k){
for(j = 0; j < h; ++j){
for(i = 0; i < w; ++i){
int dst_index = i + w*j + w*h*k;
int src_index = k + c*i + c*w*j;
im.data[dst_index] = (float)data[src_index]/255.;
}
}
}
free(data);
return im;
}
image load_image(char *filename, int w, int h, int c)
{
#ifdef OPENCV
image out = load_image_cv(filename, c);
#else
image out = load_image_stb(filename, c);
#endif
if((h && w) && (h != out.h || w != out.w)){
image resized = resize_image(out, w, h);
free_image(out);
out = resized;
}
return out;
}
image load_image_color(char *filename, int w, int h)
{
IplImage* src = 0;
if( (src = cvLoadImage(filename, 1)) == 0 )
{
printf("Cannot load file image %s\n", filename);
exit(0);
}
image out = ipl_to_image(src);
cvReleaseImage(&src);
if((h && w) && (h != out.h || w != out.w)){
//printf("resize\n");
image resized = resize_image(out, w, h);
free_image(out);
out = resized;
}
return out;
}
image load_image(char *filename, int w, int h)
{
IplImage* src = 0;
if( (src = cvLoadImage(filename,-1)) == 0 )
{
printf("Cannot load file image %s\n", filename);
exit(0);
}
image out = ipl_to_image(src);
if((h && w) && (h != out.h || w != out.w)){
image resized = resize_image(out, w, h);
free_image(out);
out = resized;
}
cvReleaseImage(&src);
return out;
return load_image(filename, w, h, 3);
}
image get_image_layer(image m, int l)

View File

@ -2,8 +2,17 @@
#define IMAGE_H
#include <stdlib.h>
#include <stdio.h>
#include <float.h>
#include <string.h>
#include <math.h>
#ifdef OPENCV
#include "opencv2/highgui/highgui_c.h"
#include "opencv2/imgproc/imgproc_c.h"
#endif
typedef struct {
int h;
int w;
@ -26,6 +35,7 @@ void saturate_image(image im, float sat);
void exposure_image(image im, float sat);
void saturate_exposure_image(image im, float sat, float exposure);
void hsv_to_rgb(image im);
void rgbgr_image(image im);
image collapse_image_layers(image source, int border);
image collapse_images_horz(image *ims, int n);
@ -43,11 +53,9 @@ image make_image(int w, int h, int c);
image make_empty_image(int w, int h, int c);
image float_to_image(int w, int h, int c, float *data);
image copy_image(image p);
image load_image(char *filename, int w, int h);
image load_image(char *filename, int w, int h, int c);
image load_image_color(char *filename, int w, int h);
image ipl_to_image(IplImage* src);
float get_pixel(image m, int x, int y, int c);
float get_pixel_extend(image m, int x, int y, int c);
void set_pixel(image m, int x, int y, int c, float val);

View File

@ -44,8 +44,10 @@ typedef struct {
int coords;
int background;
int rescore;
int nuisance;
int objectness;
int does_cost;
int joint;
float probability;
float scale;
int *indexes;

View File

@ -164,10 +164,11 @@ detection_layer parse_detection(list *options, size_params params)
{
int coords = option_find_int(options, "coords", 1);
int classes = option_find_int(options, "classes", 1);
int rescore = option_find_int(options, "rescore", 1);
int nuisance = option_find_int(options, "nuisance", 0);
int rescore = option_find_int(options, "rescore", 0);
int joint = option_find_int(options, "joint", 0);
int objectness = option_find_int(options, "objectness", 0);
int background = option_find_int(options, "background", 1);
detection_layer layer = make_detection_layer(params.batch, params.inputs, classes, coords, rescore, background, nuisance);
detection_layer layer = make_detection_layer(params.batch, params.inputs, classes, coords, joint, rescore, background, objectness);
option_unused(options);
return layer;
}

6437
src/stb_image.h Normal file

File diff suppressed because it is too large Load Diff

730
src/stb_image_write.h Normal file
View File

@ -0,0 +1,730 @@
/* stb_image_write - v0.98 - public domain - http://nothings.org/stb/stb_image_write.h
writes out PNG/BMP/TGA images to C stdio - Sean Barrett 2010
no warranty implied; use at your own risk
Before #including,
#define STB_IMAGE_WRITE_IMPLEMENTATION
in the file that you want to have the implementation.
Will probably not work correctly with strict-aliasing optimizations.
ABOUT:
This header file is a library for writing images to C stdio. It could be
adapted to write to memory or a general streaming interface; let me know.
The PNG output is not optimal; it is 20-50% larger than the file
written by a decent optimizing implementation. This library is designed
for source code compactness and simplicitly, not optimal image file size
or run-time performance.
BUILDING:
You can #define STBIW_ASSERT(x) before the #include to avoid using assert.h.
You can #define STBIW_MALLOC(), STBIW_REALLOC(), and STBIW_FREE() to replace
malloc,realloc,free.
You can define STBIW_MEMMOVE() to replace memmove()
USAGE:
There are four functions, one for each image file format:
int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes);
int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data);
int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data);
int stbi_write_hdr(char const *filename, int w, int h, int comp, const void *data);
Each function returns 0 on failure and non-0 on success.
The functions create an image file defined by the parameters. The image
is a rectangle of pixels stored from left-to-right, top-to-bottom.
Each pixel contains 'comp' channels of data stored interleaved with 8-bits
per channel, in the following order: 1=Y, 2=YA, 3=RGB, 4=RGBA. (Y is
monochrome color.) The rectangle is 'w' pixels wide and 'h' pixels tall.
The *data pointer points to the first byte of the top-left-most pixel.
For PNG, "stride_in_bytes" is the distance in bytes from the first byte of
a row of pixels to the first byte of the next row of pixels.
PNG creates output files with the same number of components as the input.
The BMP format expands Y to RGB in the file format and does not
output alpha.
PNG supports writing rectangles of data even when the bytes storing rows of
data are not consecutive in memory (e.g. sub-rectangles of a larger image),
by supplying the stride between the beginning of adjacent rows. The other
formats do not. (Thus you cannot write a native-format BMP through the BMP
writer, both because it is in BGR order and because it may have padding
at the end of the line.)
HDR expects linear float data. Since the format is always 32-bit rgb(e)
data, alpha (if provided) is discarded, and for monochrome data it is
replicated across all three channels.
CREDITS:
PNG/BMP/TGA
Sean Barrett
HDR
Baldur Karlsson
TGA monochrome:
Jean-Sebastien Guay
misc enhancements:
Tim Kelsey
bugfixes:
github:Chribba
*/
#ifndef INCLUDE_STB_IMAGE_WRITE_H
#define INCLUDE_STB_IMAGE_WRITE_H
#ifdef __cplusplus
extern "C" {
#endif
extern int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes);
extern int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data);
extern int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data);
extern int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data);
#ifdef __cplusplus
}
#endif
#endif//INCLUDE_STB_IMAGE_WRITE_H
#ifdef STB_IMAGE_WRITE_IMPLEMENTATION
#include <stdarg.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <math.h>
#if defined(STBIW_MALLOC) && defined(STBIW_FREE) && defined(STBIW_REALLOC)
// ok
#elif !defined(STBIW_MALLOC) && !defined(STBIW_FREE) && !defined(STBIW_REALLOC)
// ok
#else
#error "Must define all or none of STBIW_MALLOC, STBIW_FREE, and STBIW_REALLOC."
#endif
#ifndef STBIW_MALLOC
#define STBIW_MALLOC(sz) malloc(sz)
#define STBIW_REALLOC(p,sz) realloc(p,sz)
#define STBIW_FREE(p) free(p)
#endif
#ifndef STBIW_MEMMOVE
#define STBIW_MEMMOVE(a,b,sz) memmove(a,b,sz)
#endif
#ifndef STBIW_ASSERT
#include <assert.h>
#define STBIW_ASSERT(x) assert(x)
#endif
typedef unsigned int stbiw_uint32;
typedef int stb_image_write_test[sizeof(stbiw_uint32)==4 ? 1 : -1];
static void writefv(FILE *f, const char *fmt, va_list v)
{
while (*fmt) {
switch (*fmt++) {
case ' ': break;
case '1': { unsigned char x = (unsigned char) va_arg(v, int); fputc(x,f); break; }
case '2': { int x = va_arg(v,int); unsigned char b[2];
b[0] = (unsigned char) x; b[1] = (unsigned char) (x>>8);
fwrite(b,2,1,f); break; }
case '4': { stbiw_uint32 x = va_arg(v,int); unsigned char b[4];
b[0]=(unsigned char)x; b[1]=(unsigned char)(x>>8);
b[2]=(unsigned char)(x>>16); b[3]=(unsigned char)(x>>24);
fwrite(b,4,1,f); break; }
default:
STBIW_ASSERT(0);
return;
}
}
}
static void write3(FILE *f, unsigned char a, unsigned char b, unsigned char c)
{
unsigned char arr[3];
arr[0] = a, arr[1] = b, arr[2] = c;
fwrite(arr, 3, 1, f);
}
static void write_pixels(FILE *f, int rgb_dir, int vdir, int x, int y, int comp, void *data, int write_alpha, int scanline_pad, int expand_mono)
{
unsigned char bg[3] = { 255, 0, 255}, px[3];
stbiw_uint32 zero = 0;
int i,j,k, j_end;
if (y <= 0)
return;
if (vdir < 0)
j_end = -1, j = y-1;
else
j_end = y, j = 0;
for (; j != j_end; j += vdir) {
for (i=0; i < x; ++i) {
unsigned char *d = (unsigned char *) data + (j*x+i)*comp;
if (write_alpha < 0)
fwrite(&d[comp-1], 1, 1, f);
switch (comp) {
case 1: fwrite(d, 1, 1, f);
break;
case 2: if (expand_mono)
write3(f, d[0],d[0],d[0]); // monochrome bmp
else
fwrite(d, 1, 1, f); // monochrome TGA
break;
case 4:
if (!write_alpha) {
// composite against pink background
for (k=0; k < 3; ++k)
px[k] = bg[k] + ((d[k] - bg[k]) * d[3])/255;
write3(f, px[1-rgb_dir],px[1],px[1+rgb_dir]);
break;
}
/* FALLTHROUGH */
case 3:
write3(f, d[1-rgb_dir],d[1],d[1+rgb_dir]);
break;
}
if (write_alpha > 0)
fwrite(&d[comp-1], 1, 1, f);
}
fwrite(&zero,scanline_pad,1,f);
}
}
static int outfile(char const *filename, int rgb_dir, int vdir, int x, int y, int comp, int expand_mono, void *data, int alpha, int pad, const char *fmt, ...)
{
FILE *f;
if (y < 0 || x < 0) return 0;
f = fopen(filename, "wb");
if (f) {
va_list v;
va_start(v, fmt);
writefv(f, fmt, v);
va_end(v);
write_pixels(f,rgb_dir,vdir,x,y,comp,data,alpha,pad,expand_mono);
fclose(f);
}
return f != NULL;
}
int stbi_write_bmp(char const *filename, int x, int y, int comp, const void *data)
{
int pad = (-x*3) & 3;
return outfile(filename,-1,-1,x,y,comp,1,(void *) data,0,pad,
"11 4 22 4" "4 44 22 444444",
'B', 'M', 14+40+(x*3+pad)*y, 0,0, 14+40, // file header
40, x,y, 1,24, 0,0,0,0,0,0); // bitmap header
}
int stbi_write_tga(char const *filename, int x, int y, int comp, const void *data)
{
int has_alpha = (comp == 2 || comp == 4);
int colorbytes = has_alpha ? comp-1 : comp;
int format = colorbytes < 2 ? 3 : 2; // 3 color channels (RGB/RGBA) = 2, 1 color channel (Y/YA) = 3
return outfile(filename, -1,-1, x, y, comp, 0, (void *) data, has_alpha, 0,
"111 221 2222 11", 0,0,format, 0,0,0, 0,0,x,y, (colorbytes+has_alpha)*8, has_alpha*8);
}
// *************************************************************************************************
// Radiance RGBE HDR writer
// by Baldur Karlsson
#define stbiw__max(a, b) ((a) > (b) ? (a) : (b))
void stbiw__linear_to_rgbe(unsigned char *rgbe, float *linear)
{
int exponent;
float maxcomp = stbiw__max(linear[0], stbiw__max(linear[1], linear[2]));
if (maxcomp < 1e-32) {
rgbe[0] = rgbe[1] = rgbe[2] = rgbe[3] = 0;
} else {
float normalize = (float) frexp(maxcomp, &exponent) * 256.0f/maxcomp;
rgbe[0] = (unsigned char)(linear[0] * normalize);
rgbe[1] = (unsigned char)(linear[1] * normalize);
rgbe[2] = (unsigned char)(linear[2] * normalize);
rgbe[3] = (unsigned char)(exponent + 128);
}
}
void stbiw__write_run_data(FILE *f, int length, unsigned char databyte)
{
unsigned char lengthbyte = (unsigned char) (length+128);
STBIW_ASSERT(length+128 <= 255);
fwrite(&lengthbyte, 1, 1, f);
fwrite(&databyte, 1, 1, f);
}
void stbiw__write_dump_data(FILE *f, int length, unsigned char *data)
{
unsigned char lengthbyte = (unsigned char )(length & 0xff);
STBIW_ASSERT(length <= 128); // inconsistent with spec but consistent with official code
fwrite(&lengthbyte, 1, 1, f);
fwrite(data, length, 1, f);
}
void stbiw__write_hdr_scanline(FILE *f, int width, int comp, unsigned char *scratch, const float *scanline)
{
unsigned char scanlineheader[4] = { 2, 2, 0, 0 };
unsigned char rgbe[4];
float linear[3];
int x;
scanlineheader[2] = (width&0xff00)>>8;
scanlineheader[3] = (width&0x00ff);
/* skip RLE for images too small or large */
if (width < 8 || width >= 32768) {
for (x=0; x < width; x++) {
switch (comp) {
case 4: /* fallthrough */
case 3: linear[2] = scanline[x*comp + 2];
linear[1] = scanline[x*comp + 1];
linear[0] = scanline[x*comp + 0];
break;
case 2: /* fallthrough */
case 1: linear[0] = linear[1] = linear[2] = scanline[x*comp + 0];
break;
}
stbiw__linear_to_rgbe(rgbe, linear);
fwrite(rgbe, 4, 1, f);
}
} else {
int c,r;
/* encode into scratch buffer */
for (x=0; x < width; x++) {
switch(comp) {
case 4: /* fallthrough */
case 3: linear[2] = scanline[x*comp + 2];
linear[1] = scanline[x*comp + 1];
linear[0] = scanline[x*comp + 0];
break;
case 2: /* fallthrough */
case 1: linear[0] = linear[1] = linear[2] = scanline[x*comp + 0];
break;
}
stbiw__linear_to_rgbe(rgbe, linear);
scratch[x + width*0] = rgbe[0];
scratch[x + width*1] = rgbe[1];
scratch[x + width*2] = rgbe[2];
scratch[x + width*3] = rgbe[3];
}
fwrite(scanlineheader, 4, 1, f);
/* RLE each component separately */
for (c=0; c < 4; c++) {
unsigned char *comp = &scratch[width*c];
x = 0;
while (x < width) {
// find first run
r = x;
while (r+2 < width) {
if (comp[r] == comp[r+1] && comp[r] == comp[r+2])
break;
++r;
}
if (r+2 >= width)
r = width;
// dump up to first run
while (x < r) {
int len = r-x;
if (len > 128) len = 128;
stbiw__write_dump_data(f, len, &comp[x]);
x += len;
}
// if there's a run, output it
if (r+2 < width) { // same test as what we break out of in search loop, so only true if we break'd
// find next byte after run
while (r < width && comp[r] == comp[x])
++r;
// output run up to r
while (x < r) {
int len = r-x;
if (len > 127) len = 127;
stbiw__write_run_data(f, len, comp[x]);
x += len;
}
}
}
}
}
}
int stbi_write_hdr(char const *filename, int x, int y, int comp, const float *data)
{
int i;
FILE *f;
if (y <= 0 || x <= 0 || data == NULL) return 0;
f = fopen(filename, "wb");
if (f) {
/* Each component is stored separately. Allocate scratch space for full output scanline. */
unsigned char *scratch = (unsigned char *) STBIW_MALLOC(x*4);
fprintf(f, "#?RADIANCE\n# Written by stb_image_write.h\nFORMAT=32-bit_rle_rgbe\n" );
fprintf(f, "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n" , y, x);
for(i=0; i < y; i++)
stbiw__write_hdr_scanline(f, x, comp, scratch, data + comp*i*x);
STBIW_FREE(scratch);
fclose(f);
}
return f != NULL;
}
/////////////////////////////////////////////////////////
// PNG
// stretchy buffer; stbiw__sbpush() == vector<>::push_back() -- stbiw__sbcount() == vector<>::size()
#define stbiw__sbraw(a) ((int *) (a) - 2)
#define stbiw__sbm(a) stbiw__sbraw(a)[0]
#define stbiw__sbn(a) stbiw__sbraw(a)[1]
#define stbiw__sbneedgrow(a,n) ((a)==0 || stbiw__sbn(a)+n >= stbiw__sbm(a))
#define stbiw__sbmaybegrow(a,n) (stbiw__sbneedgrow(a,(n)) ? stbiw__sbgrow(a,n) : 0)
#define stbiw__sbgrow(a,n) stbiw__sbgrowf((void **) &(a), (n), sizeof(*(a)))
#define stbiw__sbpush(a, v) (stbiw__sbmaybegrow(a,1), (a)[stbiw__sbn(a)++] = (v))
#define stbiw__sbcount(a) ((a) ? stbiw__sbn(a) : 0)
#define stbiw__sbfree(a) ((a) ? STBIW_FREE(stbiw__sbraw(a)),0 : 0)
static void *stbiw__sbgrowf(void **arr, int increment, int itemsize)
{
int m = *arr ? 2*stbiw__sbm(*arr)+increment : increment+1;
void *p = STBIW_REALLOC(*arr ? stbiw__sbraw(*arr) : 0, itemsize * m + sizeof(int)*2);
STBIW_ASSERT(p);
if (p) {
if (!*arr) ((int *) p)[1] = 0;
*arr = (void *) ((int *) p + 2);
stbiw__sbm(*arr) = m;
}
return *arr;
}
static unsigned char *stbiw__zlib_flushf(unsigned char *data, unsigned int *bitbuffer, int *bitcount)
{
while (*bitcount >= 8) {
stbiw__sbpush(data, (unsigned char) *bitbuffer);
*bitbuffer >>= 8;
*bitcount -= 8;
}
return data;
}
static int stbiw__zlib_bitrev(int code, int codebits)
{
int res=0;
while (codebits--) {
res = (res << 1) | (code & 1);
code >>= 1;
}
return res;
}
static unsigned int stbiw__zlib_countm(unsigned char *a, unsigned char *b, int limit)
{
int i;
for (i=0; i < limit && i < 258; ++i)
if (a[i] != b[i]) break;
return i;
}
static unsigned int stbiw__zhash(unsigned char *data)
{
stbiw_uint32 hash = data[0] + (data[1] << 8) + (data[2] << 16);
hash ^= hash << 3;
hash += hash >> 5;
hash ^= hash << 4;
hash += hash >> 17;
hash ^= hash << 25;
hash += hash >> 6;
return hash;
}
#define stbiw__zlib_flush() (out = stbiw__zlib_flushf(out, &bitbuf, &bitcount))
#define stbiw__zlib_add(code,codebits) \
(bitbuf |= (code) << bitcount, bitcount += (codebits), stbiw__zlib_flush())
#define stbiw__zlib_huffa(b,c) stbiw__zlib_add(stbiw__zlib_bitrev(b,c),c)
// default huffman tables
#define stbiw__zlib_huff1(n) stbiw__zlib_huffa(0x30 + (n), 8)
#define stbiw__zlib_huff2(n) stbiw__zlib_huffa(0x190 + (n)-144, 9)
#define stbiw__zlib_huff3(n) stbiw__zlib_huffa(0 + (n)-256,7)
#define stbiw__zlib_huff4(n) stbiw__zlib_huffa(0xc0 + (n)-280,8)
#define stbiw__zlib_huff(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : (n) <= 255 ? stbiw__zlib_huff2(n) : (n) <= 279 ? stbiw__zlib_huff3(n) : stbiw__zlib_huff4(n))
#define stbiw__zlib_huffb(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : stbiw__zlib_huff2(n))
#define stbiw__ZHASH 16384
unsigned char * stbi_zlib_compress(unsigned char *data, int data_len, int *out_len, int quality)
{
static unsigned short lengthc[] = { 3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258, 259 };
static unsigned char lengtheb[]= { 0,0,0,0,0,0,0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 };
static unsigned short distc[] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577, 32768 };
static unsigned char disteb[] = { 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13 };
unsigned int bitbuf=0;
int i,j, bitcount=0;
unsigned char *out = NULL;
unsigned char **hash_table[stbiw__ZHASH]; // 64KB on the stack!
if (quality < 5) quality = 5;
stbiw__sbpush(out, 0x78); // DEFLATE 32K window
stbiw__sbpush(out, 0x5e); // FLEVEL = 1
stbiw__zlib_add(1,1); // BFINAL = 1
stbiw__zlib_add(1,2); // BTYPE = 1 -- fixed huffman
for (i=0; i < stbiw__ZHASH; ++i)
hash_table[i] = NULL;
i=0;
while (i < data_len-3) {
// hash next 3 bytes of data to be compressed
int h = stbiw__zhash(data+i)&(stbiw__ZHASH-1), best=3;
unsigned char *bestloc = 0;
unsigned char **hlist = hash_table[h];
int n = stbiw__sbcount(hlist);
for (j=0; j < n; ++j) {
if (hlist[j]-data > i-32768) { // if entry lies within window
int d = stbiw__zlib_countm(hlist[j], data+i, data_len-i);
if (d >= best) best=d,bestloc=hlist[j];
}
}
// when hash table entry is too long, delete half the entries
if (hash_table[h] && stbiw__sbn(hash_table[h]) == 2*quality) {
STBIW_MEMMOVE(hash_table[h], hash_table[h]+quality, sizeof(hash_table[h][0])*quality);
stbiw__sbn(hash_table[h]) = quality;
}
stbiw__sbpush(hash_table[h],data+i);
if (bestloc) {
// "lazy matching" - check match at *next* byte, and if it's better, do cur byte as literal
h = stbiw__zhash(data+i+1)&(stbiw__ZHASH-1);
hlist = hash_table[h];
n = stbiw__sbcount(hlist);
for (j=0; j < n; ++j) {
if (hlist[j]-data > i-32767) {
int e = stbiw__zlib_countm(hlist[j], data+i+1, data_len-i-1);
if (e > best) { // if next match is better, bail on current match
bestloc = NULL;
break;
}
}
}
}
if (bestloc) {
int d = (int) (data+i - bestloc); // distance back
STBIW_ASSERT(d <= 32767 && best <= 258);
for (j=0; best > lengthc[j+1]-1; ++j);
stbiw__zlib_huff(j+257);
if (lengtheb[j]) stbiw__zlib_add(best - lengthc[j], lengtheb[j]);
for (j=0; d > distc[j+1]-1; ++j);
stbiw__zlib_add(stbiw__zlib_bitrev(j,5),5);
if (disteb[j]) stbiw__zlib_add(d - distc[j], disteb[j]);
i += best;
} else {
stbiw__zlib_huffb(data[i]);
++i;
}
}
// write out final bytes
for (;i < data_len; ++i)
stbiw__zlib_huffb(data[i]);
stbiw__zlib_huff(256); // end of block
// pad with 0 bits to byte boundary
while (bitcount)
stbiw__zlib_add(0,1);
for (i=0; i < stbiw__ZHASH; ++i)
(void) stbiw__sbfree(hash_table[i]);
{
// compute adler32 on input
unsigned int i=0, s1=1, s2=0, blocklen = data_len % 5552;
int j=0;
while (j < data_len) {
for (i=0; i < blocklen; ++i) s1 += data[j+i], s2 += s1;
s1 %= 65521, s2 %= 65521;
j += blocklen;
blocklen = 5552;
}
stbiw__sbpush(out, (unsigned char) (s2 >> 8));
stbiw__sbpush(out, (unsigned char) s2);
stbiw__sbpush(out, (unsigned char) (s1 >> 8));
stbiw__sbpush(out, (unsigned char) s1);
}
*out_len = stbiw__sbn(out);
// make returned pointer freeable
STBIW_MEMMOVE(stbiw__sbraw(out), out, *out_len);
return (unsigned char *) stbiw__sbraw(out);
}
unsigned int stbiw__crc32(unsigned char *buffer, int len)
{
static unsigned int crc_table[256];
unsigned int crc = ~0u;
int i,j;
if (crc_table[1] == 0)
for(i=0; i < 256; i++)
for (crc_table[i]=i, j=0; j < 8; ++j)
crc_table[i] = (crc_table[i] >> 1) ^ (crc_table[i] & 1 ? 0xedb88320 : 0);
for (i=0; i < len; ++i)
crc = (crc >> 8) ^ crc_table[buffer[i] ^ (crc & 0xff)];
return ~crc;
}
#define stbiw__wpng4(o,a,b,c,d) ((o)[0]=(unsigned char)(a),(o)[1]=(unsigned char)(b),(o)[2]=(unsigned char)(c),(o)[3]=(unsigned char)(d),(o)+=4)
#define stbiw__wp32(data,v) stbiw__wpng4(data, (v)>>24,(v)>>16,(v)>>8,(v));
#define stbiw__wptag(data,s) stbiw__wpng4(data, s[0],s[1],s[2],s[3])
static void stbiw__wpcrc(unsigned char **data, int len)
{
unsigned int crc = stbiw__crc32(*data - len - 4, len+4);
stbiw__wp32(*data, crc);
}
static unsigned char stbiw__paeth(int a, int b, int c)
{
int p = a + b - c, pa = abs(p-a), pb = abs(p-b), pc = abs(p-c);
if (pa <= pb && pa <= pc) return (unsigned char) a;
if (pb <= pc) return (unsigned char) b;
return (unsigned char) c;
}
unsigned char *stbi_write_png_to_mem(unsigned char *pixels, int stride_bytes, int x, int y, int n, int *out_len)
{
int ctype[5] = { -1, 0, 4, 2, 6 };
unsigned char sig[8] = { 137,80,78,71,13,10,26,10 };
unsigned char *out,*o, *filt, *zlib;
signed char *line_buffer;
int i,j,k,p,zlen;
if (stride_bytes == 0)
stride_bytes = x * n;
filt = (unsigned char *) STBIW_MALLOC((x*n+1) * y); if (!filt) return 0;
line_buffer = (signed char *) STBIW_MALLOC(x * n); if (!line_buffer) { STBIW_FREE(filt); return 0; }
for (j=0; j < y; ++j) {
static int mapping[] = { 0,1,2,3,4 };
static int firstmap[] = { 0,1,0,5,6 };
int *mymap = j ? mapping : firstmap;
int best = 0, bestval = 0x7fffffff;
for (p=0; p < 2; ++p) {
for (k= p?best:0; k < 5; ++k) {
int type = mymap[k],est=0;
unsigned char *z = pixels + stride_bytes*j;
for (i=0; i < n; ++i)
switch (type) {
case 0: line_buffer[i] = z[i]; break;
case 1: line_buffer[i] = z[i]; break;
case 2: line_buffer[i] = z[i] - z[i-stride_bytes]; break;
case 3: line_buffer[i] = z[i] - (z[i-stride_bytes]>>1); break;
case 4: line_buffer[i] = (signed char) (z[i] - stbiw__paeth(0,z[i-stride_bytes],0)); break;
case 5: line_buffer[i] = z[i]; break;
case 6: line_buffer[i] = z[i]; break;
}
for (i=n; i < x*n; ++i) {
switch (type) {
case 0: line_buffer[i] = z[i]; break;
case 1: line_buffer[i] = z[i] - z[i-n]; break;
case 2: line_buffer[i] = z[i] - z[i-stride_bytes]; break;
case 3: line_buffer[i] = z[i] - ((z[i-n] + z[i-stride_bytes])>>1); break;
case 4: line_buffer[i] = z[i] - stbiw__paeth(z[i-n], z[i-stride_bytes], z[i-stride_bytes-n]); break;
case 5: line_buffer[i] = z[i] - (z[i-n]>>1); break;
case 6: line_buffer[i] = z[i] - stbiw__paeth(z[i-n], 0,0); break;
}
}
if (p) break;
for (i=0; i < x*n; ++i)
est += abs((signed char) line_buffer[i]);
if (est < bestval) { bestval = est; best = k; }
}
}
// when we get here, best contains the filter type, and line_buffer contains the data
filt[j*(x*n+1)] = (unsigned char) best;
STBIW_MEMMOVE(filt+j*(x*n+1)+1, line_buffer, x*n);
}
STBIW_FREE(line_buffer);
zlib = stbi_zlib_compress(filt, y*( x*n+1), &zlen, 8); // increase 8 to get smaller but use more memory
STBIW_FREE(filt);
if (!zlib) return 0;
// each tag requires 12 bytes of overhead
out = (unsigned char *) STBIW_MALLOC(8 + 12+13 + 12+zlen + 12);
if (!out) return 0;
*out_len = 8 + 12+13 + 12+zlen + 12;
o=out;
STBIW_MEMMOVE(o,sig,8); o+= 8;
stbiw__wp32(o, 13); // header length
stbiw__wptag(o, "IHDR");
stbiw__wp32(o, x);
stbiw__wp32(o, y);
*o++ = 8;
*o++ = (unsigned char) ctype[n];
*o++ = 0;
*o++ = 0;
*o++ = 0;
stbiw__wpcrc(&o,13);
stbiw__wp32(o, zlen);
stbiw__wptag(o, "IDAT");
STBIW_MEMMOVE(o, zlib, zlen);
o += zlen;
STBIW_FREE(zlib);
stbiw__wpcrc(&o, zlen);
stbiw__wp32(o,0);
stbiw__wptag(o, "IEND");
stbiw__wpcrc(&o,0);
STBIW_ASSERT(o == out + *out_len);
return out;
}
int stbi_write_png(char const *filename, int x, int y, int comp, const void *data, int stride_bytes)
{
FILE *f;
int len;
unsigned char *png = stbi_write_png_to_mem((unsigned char *) data, stride_bytes, x, y, comp, &len);
if (!png) return 0;
f = fopen(filename, "wb");
if (!f) { STBIW_FREE(png); return 0; }
fwrite(png, 1, len, f);
fclose(f);
STBIW_FREE(png);
return 1;
}
#endif // STB_IMAGE_WRITE_IMPLEMENTATION
/* Revision history
0.98 (2015-04-08)
added STBIW_MALLOC, STBIW_ASSERT etc
0.97 (2015-01-18)
fixed HDR asserts, rewrote HDR rle logic
0.96 (2015-01-17)
add HDR output
fix monochrome BMP
0.95 (2014-08-17)
add monochrome TGA output
0.94 (2014-05-31)
rename private functions to avoid conflicts with stb_image.h
0.93 (2014-05-27)
warning fixes
0.92 (2010-08-01)
casts to unsigned char to fix warnings
0.91 (2010-07-17)
first public release
0.90 first internal release
*/