diff --git a/src/coco.c b/src/coco.c index 8ad13834b01..a226edf46ec 100644 --- a/src/coco.c +++ b/src/coco.c @@ -294,7 +294,7 @@ void validate_coco_recall(char *cfgfile, char *weightfile) replace_image_to_label(path, labelpath); int num_labels = 0; - box_label *truth = read_boxes(labelpath, &num_labels); + box_label *truth = read_boxes_class(labelpath, &num_labels, classes); for(k = 0; k < side*side*l.n; ++k){ if(probs[k][0] > thresh){ ++proposals; diff --git a/src/data.c b/src/data.c index 70e1b09b2f4..903eb29c060 100644 --- a/src/data.c +++ b/src/data.c @@ -191,7 +191,6 @@ matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int return X; } - box_label *read_boxes(char *filename, int *n) { box_label* boxes = (box_label*)xcalloc(1, sizeof(box_label)); @@ -214,7 +213,53 @@ box_label *read_boxes(char *filename, int *n) float x, y, h, w; int id; int count = 0; + while (fscanf(file, "%d %f %f %f %f", &id, &x, &y, &w, &h) == 5) { + boxes = (box_label*)xrealloc(boxes, (count + 1) * sizeof(box_label)); + boxes[count].track_id = count + img_hash; + //printf(" boxes[count].track_id = %d, count = %d \n", boxes[count].track_id, count); + boxes[count].id = id; + boxes[count].x = x; + boxes[count].y = y; + boxes[count].h = h; + boxes[count].w = w; + boxes[count].left = x - w / 2; + boxes[count].right = x + w / 2; + boxes[count].top = y - h / 2; + boxes[count].bottom = y + h / 2; + ++count; + } + fclose(file); + *n = count; + return boxes; +} + +box_label *read_boxes_class(char *filename, int *n, int max_num_classes) +{ + box_label* boxes = (box_label*)xcalloc(1, sizeof(box_label)); + FILE *file = fopen(filename, "r"); + if (!file) { + printf("Can't open label file. (This can be normal only if you use MSCOCO): %s \n", filename); + //file_error(filename); + FILE* fw = fopen("bad.list", "a"); + fwrite(filename, sizeof(char), strlen(filename), fw); + char *new_line = "\n"; + fwrite(new_line, sizeof(char), strlen(new_line), fw); + fclose(fw); + + *n = 0; + return boxes; + } + const int max_obj_img = 4000;// 30000; + const int img_hash = (custom_hash(filename) % max_obj_img)*max_obj_img; + //printf(" img_hash = %d, filename = %s; ", img_hash, filename); + float x, y, h, w; + int id; + int count = 0; while(fscanf(file, "%d %f %f %f %f", &id, &x, &y, &w, &h) == 5){ + if (id >= max_num_classes) { + printf("Warn class %d >= %d on %s \n", id, max_num_classes, filename); + continue; + } boxes = (box_label*)xrealloc(boxes, (count + 1) * sizeof(box_label)); boxes[count].track_id = count + img_hash; //printf(" boxes[count].track_id = %d, count = %d \n", boxes[count].track_id, count); @@ -297,7 +342,7 @@ void fill_truth_swag(char *path, float *truth, int classes, int flip, float dx, replace_image_to_label(path, labelpath); int count = 0; - box_label *boxes = read_boxes(labelpath, &count); + box_label *boxes = read_boxes_class(labelpath, &count, classes); randomize_boxes(boxes, count); correct_boxes(boxes, count, dx, dy, sx, sy, flip); float x,y,w,h; @@ -331,7 +376,7 @@ void fill_truth_region(char *path, float *truth, int classes, int num_boxes, int replace_image_to_label(path, labelpath); int count = 0; - box_label *boxes = read_boxes(labelpath, &count); + box_label *boxes = read_boxes_class(labelpath, &count, classes); randomize_boxes(boxes, count); correct_boxes(boxes, count, dx, dy, sx, sy, flip); float x,y,w,h; @@ -376,7 +421,7 @@ int fill_truth_detection(const char *path, int num_boxes, int truth_size, float int count = 0; int i; - box_label *boxes = read_boxes(labelpath, &count); + box_label *boxes = read_boxes_class(labelpath, &count, classes); int min_w_h = 0; float lowest_w = 1.F / net_w; float lowest_h = 1.F / net_h; diff --git a/src/data.h b/src/data.h index 9f12343a0a9..2f4a6ba4941 100644 --- a/src/data.h +++ b/src/data.h @@ -95,6 +95,7 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h float aspect, float hue, float saturation, float exposure, int use_mixup, int use_blur, int show_imgs, float label_smooth_eps, int dontuse_opencv, int contrastive); data load_go(char *filename); +box_label *read_boxes_class(char *filename, int *n, int max_num_classes); box_label *read_boxes(char *filename, int *n); data load_cifar10_data(char *filename); data load_all_cifar10(); diff --git a/src/detector.c b/src/detector.c index 0fc36142904..6193d8f6d7b 100644 --- a/src/detector.c +++ b/src/detector.c @@ -886,6 +886,7 @@ void validate_detector_recall(char *datacfg, char *cfgfile, char *weightfile) int num_labels = 0; box_label *truth = read_boxes(labelpath, &num_labels); + printf("CI PASSA--\n"); for (k = 0; k < nboxes; ++k) { if (dets[k].objectness > thresh) { ++proposals; @@ -1082,7 +1083,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa char labelpath[4096]; replace_image_to_label(path, labelpath); int num_labels = 0; - box_label *truth = read_boxes(labelpath, &num_labels); + box_label *truth = read_boxes_class(labelpath, &num_labels, classes); int j; for (j = 0; j < num_labels; ++j) { truth_classes_count[truth[j].id]++; @@ -1098,7 +1099,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa char labelpath_dif[4096]; replace_image_to_label(path_dif, labelpath_dif); - truth_dif = read_boxes(labelpath_dif, &num_labels_dif); + truth_dif = read_boxes_class(labelpath_dif, &num_labels_dif, classes); } const int checkpoint_detections_count = detections_count; @@ -1469,8 +1470,9 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int char labelpath[4096]; replace_image_to_label(path, labelpath); + int num_labels = 0; - box_label *truth = read_boxes(labelpath, &num_labels); + box_label *truth = read_boxes_class(labelpath, &num_labels, classes); //printf(" new path: %s \n", labelpath); char *buff = (char*)xcalloc(6144, sizeof(char)); for (j = 0; j < num_labels; ++j) diff --git a/src/yolo.c b/src/yolo.c index ef68acabc51..6a91b2e5fd6 100644 --- a/src/yolo.c +++ b/src/yolo.c @@ -254,7 +254,7 @@ void validate_yolo_recall(char *cfgfile, char *weightfile) replace_image_to_label(path, labelpath); int num_labels = 0; - box_label *truth = read_boxes(labelpath, &num_labels); + box_label *truth = read_boxes_class(labelpath, &num_labels, classes); for(k = 0; k < side*side*l.n; ++k){ if(probs[k][0] > thresh){ ++proposals;