-
Notifications
You must be signed in to change notification settings - Fork 48
/
collect_thresholds.py
84 lines (73 loc) · 3.05 KB
/
collect_thresholds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""Collect the value range of different propertity of ps dataset."""
import argparse
import json
import math
import os
import matplotlib.pyplot as plt
import numpy as np
from data import MarkingPoint
from data.struct import calc_point_squre_dist, direction_diff
from prepare_dataset import generalize_marks
def get_parser():
"""Return argument parser for collecting thresholds."""
parser = argparse.ArgumentParser()
parser.add_argument('--label_directory', required=True,
help="The location of label directory.")
return parser
def collect_thresholds(args):
"""Collect range of value from ground truth to determine threshold."""
distances = []
separator_angles = []
bridge_angles = []
for label_file in os.listdir(args.label_directory):
print(label_file)
with open(os.path.join(args.label_directory, label_file), 'r') as file:
label = json.load(file)
marks = np.array(label['marks'])
slots = np.array(label['slots'])
if slots.size == 0:
continue
if len(marks.shape) < 2:
marks = np.expand_dims(marks, axis=0)
if len(slots.shape) < 2:
slots = np.expand_dims(slots, axis=0)
marks[:, 0:4] -= 300.5
marks = [MarkingPoint(*mark) for mark in generalize_marks(marks)]
for slot in slots:
mark_a = marks[slot[0] - 1]
mark_b = marks[slot[1] - 1]
distances.append(calc_point_squre_dist(mark_a, mark_b))
vector_ab = np.array([mark_b.x - mark_a.x, mark_b.y - mark_a.y])
vector_ab = vector_ab / np.linalg.norm(vector_ab)
ab_bridge_direction = math.atan2(vector_ab[1], vector_ab[0])
ba_bridge_direction = math.atan2(-vector_ab[1], -vector_ab[0])
separator_direction = math.atan2(-vector_ab[0], vector_ab[1])
sangle = direction_diff(separator_direction, mark_a.direction)
if mark_a.shape > 0.5:
separator_angles.append(sangle)
else:
bangle = direction_diff(ab_bridge_direction, mark_a.direction)
if sangle < bangle:
separator_angles.append(sangle)
else:
bridge_angles.append(bangle)
bangle = direction_diff(ba_bridge_direction, mark_b.direction)
if mark_b.shape > 0.5:
bridge_angles.append(bangle)
else:
sangle = direction_diff(separator_direction, mark_b.direction)
if sangle < bangle:
separator_angles.append(sangle)
else:
bridge_angles.append(bangle)
distances = sorted(distances)
separator_angles = sorted(separator_angles)
bridge_angles = sorted(bridge_angles)
plt.figure()
plt.hist(distances, len(distances) // 10)
plt.figure()
plt.hist(separator_angles, len(separator_angles) // 10)
plt.figure()
plt.hist(bridge_angles, len(bridge_angles) // 3)
if __name__ == '__main__':
collect_thresholds(get_parser().parse_args())