Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowed training with data tagged on more classes than configuration.… #7126

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/coco.c
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ void validate_coco_recall(char *cfgfile, char *weightfile)
replace_image_to_label(path, labelpath);

int num_labels = 0;
box_label *truth = read_boxes(labelpath, &num_labels);
box_label *truth = read_boxes_class(labelpath, &num_labels, classes);
for(k = 0; k < side*side*l.n; ++k){
if(probs[k][0] > thresh){
++proposals;
Expand Down
53 changes: 49 additions & 4 deletions src/data.c
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int
return X;
}


box_label *read_boxes(char *filename, int *n)
{
box_label* boxes = (box_label*)xcalloc(1, sizeof(box_label));
Expand All @@ -214,7 +213,53 @@ box_label *read_boxes(char *filename, int *n)
float x, y, h, w;
int id;
int count = 0;
while (fscanf(file, "%d %f %f %f %f", &id, &x, &y, &w, &h) == 5) {
boxes = (box_label*)xrealloc(boxes, (count + 1) * sizeof(box_label));
boxes[count].track_id = count + img_hash;
//printf(" boxes[count].track_id = %d, count = %d \n", boxes[count].track_id, count);
boxes[count].id = id;
boxes[count].x = x;
boxes[count].y = y;
boxes[count].h = h;
boxes[count].w = w;
boxes[count].left = x - w / 2;
boxes[count].right = x + w / 2;
boxes[count].top = y - h / 2;
boxes[count].bottom = y + h / 2;
++count;
}
fclose(file);
*n = count;
return boxes;
}

box_label *read_boxes_class(char *filename, int *n, int max_num_classes)
{
box_label* boxes = (box_label*)xcalloc(1, sizeof(box_label));
FILE *file = fopen(filename, "r");
if (!file) {
printf("Can't open label file. (This can be normal only if you use MSCOCO): %s \n", filename);
//file_error(filename);
FILE* fw = fopen("bad.list", "a");
fwrite(filename, sizeof(char), strlen(filename), fw);
char *new_line = "\n";
fwrite(new_line, sizeof(char), strlen(new_line), fw);
fclose(fw);

*n = 0;
return boxes;
}
const int max_obj_img = 4000;// 30000;
const int img_hash = (custom_hash(filename) % max_obj_img)*max_obj_img;
//printf(" img_hash = %d, filename = %s; ", img_hash, filename);
float x, y, h, w;
int id;
int count = 0;
while(fscanf(file, "%d %f %f %f %f", &id, &x, &y, &w, &h) == 5){
if (id >= max_num_classes) {
printf("Warn class %d >= %d on %s \n", id, max_num_classes, filename);
continue;
}
boxes = (box_label*)xrealloc(boxes, (count + 1) * sizeof(box_label));
boxes[count].track_id = count + img_hash;
//printf(" boxes[count].track_id = %d, count = %d \n", boxes[count].track_id, count);
Expand Down Expand Up @@ -297,7 +342,7 @@ void fill_truth_swag(char *path, float *truth, int classes, int flip, float dx,
replace_image_to_label(path, labelpath);

int count = 0;
box_label *boxes = read_boxes(labelpath, &count);
box_label *boxes = read_boxes_class(labelpath, &count, classes);
randomize_boxes(boxes, count);
correct_boxes(boxes, count, dx, dy, sx, sy, flip);
float x,y,w,h;
Expand Down Expand Up @@ -331,7 +376,7 @@ void fill_truth_region(char *path, float *truth, int classes, int num_boxes, int
replace_image_to_label(path, labelpath);

int count = 0;
box_label *boxes = read_boxes(labelpath, &count);
box_label *boxes = read_boxes_class(labelpath, &count, classes);
randomize_boxes(boxes, count);
correct_boxes(boxes, count, dx, dy, sx, sy, flip);
float x,y,w,h;
Expand Down Expand Up @@ -376,7 +421,7 @@ int fill_truth_detection(const char *path, int num_boxes, int truth_size, float

int count = 0;
int i;
box_label *boxes = read_boxes(labelpath, &count);
box_label *boxes = read_boxes_class(labelpath, &count, classes);
int min_w_h = 0;
float lowest_w = 1.F / net_w;
float lowest_h = 1.F / net_h;
Expand Down
1 change: 1 addition & 0 deletions src/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h
float aspect, float hue, float saturation, float exposure, int use_mixup, int use_blur, int show_imgs, float label_smooth_eps, int dontuse_opencv, int contrastive);
data load_go(char *filename);

box_label *read_boxes_class(char *filename, int *n, int max_num_classes);
box_label *read_boxes(char *filename, int *n);
data load_cifar10_data(char *filename);
data load_all_cifar10();
Expand Down
8 changes: 5 additions & 3 deletions src/detector.c
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,7 @@ void validate_detector_recall(char *datacfg, char *cfgfile, char *weightfile)

int num_labels = 0;
box_label *truth = read_boxes(labelpath, &num_labels);
printf("CI PASSA--\n");
for (k = 0; k < nboxes; ++k) {
if (dets[k].objectness > thresh) {
++proposals;
Expand Down Expand Up @@ -1082,7 +1083,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
char labelpath[4096];
replace_image_to_label(path, labelpath);
int num_labels = 0;
box_label *truth = read_boxes(labelpath, &num_labels);
box_label *truth = read_boxes_class(labelpath, &num_labels, classes);
int j;
for (j = 0; j < num_labels; ++j) {
truth_classes_count[truth[j].id]++;
Expand All @@ -1098,7 +1099,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
char labelpath_dif[4096];
replace_image_to_label(path_dif, labelpath_dif);

truth_dif = read_boxes(labelpath_dif, &num_labels_dif);
truth_dif = read_boxes_class(labelpath_dif, &num_labels_dif, classes);
}

const int checkpoint_detections_count = detections_count;
Expand Down Expand Up @@ -1469,8 +1470,9 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int
char labelpath[4096];
replace_image_to_label(path, labelpath);


int num_labels = 0;
box_label *truth = read_boxes(labelpath, &num_labels);
box_label *truth = read_boxes_class(labelpath, &num_labels, classes);
//printf(" new path: %s \n", labelpath);
char *buff = (char*)xcalloc(6144, sizeof(char));
for (j = 0; j < num_labels; ++j)
Expand Down
2 changes: 1 addition & 1 deletion src/yolo.c
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ void validate_yolo_recall(char *cfgfile, char *weightfile)
replace_image_to_label(path, labelpath);

int num_labels = 0;
box_label *truth = read_boxes(labelpath, &num_labels);
box_label *truth = read_boxes_class(labelpath, &num_labels, classes);
for(k = 0; k < side*side*l.n; ++k){
if(probs[k][0] > thresh){
++proposals;
Expand Down