Skip to content

Commit

Permalink
Precision-Recall Curve feature update (ultralytics#1206)
Browse files Browse the repository at this point in the history
* Precision-Recall Curve feature update

* sentinel value update
  • Loading branch information
glenn-jocher authored Oct 25, 2020
1 parent d51f727 commit 04982f9
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,10 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-re
p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score

# AP from recall-precision curve
py.append(np.interp(px, recall[:, 0], precision[:, 0])) # precision at mAP@0.5
for j in range(tp.shape[1]):
ap[ci, j] = compute_ap(recall[:, j], precision[:, j])
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
if j == 0:
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5

# Compute F1 score (harmonic mean of precision and recall)
f1 = 2 * p * r / (p + r + 1e-16)
Expand All @@ -304,7 +305,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-re
py = np.stack(py, axis=1)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(px, py, linewidth=0.5, color='grey') # plot(recall, precision)
ax.plot(px, py.mean(1), linewidth=2, color='blue', label='all classes')
ax.plot(px, py.mean(1), linewidth=2, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_xlim(0, 1)
Expand All @@ -327,22 +328,22 @@ def compute_ap(recall, precision):
"""

# Append sentinel values to beginning and end
mrec = np.concatenate(([0.], recall, [min(recall[-1] + 1E-3, 1.)]))
mpre = np.concatenate(([0.], precision, [0.]))
mrec = recall # np.concatenate(([0.], recall, [recall[-1] + 1E-3]))
mpre = precision # np.concatenate(([0.], precision, [0.]))

# Compute the precision envelope
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))

# Integrate area under curve
method = 'interp' # methods: 'continuous', 'interp'
if method == 'interp':
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
x = np.linspace(0, 1, 1001) # 101-point interp (COCO)
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
else: # 'continuous'
i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve

return ap
return ap, mpre, mrec


def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-9):
Expand Down Expand Up @@ -1259,7 +1260,7 @@ def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_


def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
# from utils.general import *; plot_results()
# from utils.general import *; plot_results(save_dir='runs/exp0')
# Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training
fig, ax = plt.subplots(2, 5, figsize=(12, 6))
ax = ax.ravel()
Expand All @@ -1273,6 +1274,7 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
os.system(c)
else:
files = glob.glob(str(Path(save_dir) / 'results*.txt')) + glob.glob('../../Downloads/results*.txt')
assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
for fi, f in enumerate(files):
try:
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
Expand Down

0 comments on commit 04982f9

Please sign in to comment.