Skip to content

Commit

Permalink
Merge pull request #20 from mohwald/enhance/pytorch
Browse files Browse the repository at this point in the history
Feature: add YOLOv8 inference and finetunning
  • Loading branch information
kshitijrajsharma authored Oct 22, 2024
2 parents d0e8476 + e14199f commit ac462ab
Show file tree
Hide file tree
Showing 15 changed files with 506 additions and 28 deletions.
7 changes: 7 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ RUN pip install --global-option=build_ext --global-option="-I/usr/include/gdal"
COPY docker/ramp/docker-requirements.txt docker-requirements.txt
RUN pip install -r docker-requirements.txt

# Install ultralytics for YOLO, FastSAM, etc. together with pytorch and other dependencies
# For exact pytorch+cuda versions, see https://pytorch.org/get-started/previous-versions/
RUN pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
RUN pip install ultralytics==8.1.6

# pip install solaris -- try with tmp-free build
# COPY docker/ramp/solaris /tmp/solaris
COPY docker/solaris/solaris /tmp/solaris/solaris
Expand Down Expand Up @@ -56,3 +61,5 @@ RUN unzip checkpoint.tf.zip -d ramp-code/ramp

# Copy test_app.py
COPY test_app.py ./test_app.py
COPY test_yolo.py ./test_yolo.py
COPY Package_Test.ipynb ./Package_Test.ipynb
2 changes: 1 addition & 1 deletion hot_fair_utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .georeferencing import georeference
from .inference import predict
from .inference import predict, evaluate
from .postprocessing import polygonize, vectorize
from .preprocessing import preprocess
from .training import train
Expand Down
1 change: 1 addition & 0 deletions hot_fair_utilities/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .predict import predict
from .evaluate import evaluate
55 changes: 55 additions & 0 deletions hot_fair_utilities/inference/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Patched from ramp-code.scripts.calculate_accuracy.iou created for ramp project by carolyn.johnston@dev.global

from pathlib import Path
import geopandas as gpd

from ramp.utils.eval_utils import get_iou_accuracy_metrics


def evaluate(test_path, truth_path, filter_area_m2=None, iou_threshold=0.5, verbose=False):
"""
Calculate precision/recall/F1-score based on intersection-over-union accuracy evaluation protocol defined by RAMP.
The predicted masks will be georeferenced with EPSG:3857 as CRS
Args:
test_path: Path where the weights of the model can be found.
truth_path: Path of the directory where the images are stored.
filter_area_m2: Minimum area of buildings to analyze in m^2.
iou_threshold: (float, 0<threshold<1) above which value of IoU of a detection is considered to be accurate
verbose: Bool, more statistics are printed when turned on.
Example::
evaluate(
"data/prediction.geojson",
"data/labels.geojson"
)
"""

test_path, truth_path = Path(test_path), Path(truth_path)
truth_df, test_df = gpd.read_file(str(truth_path)), gpd.read_file(str(test_path))
metrics = get_iou_accuracy_metrics(test_df, truth_df, filter_area_m2, iou_threshold)

n_detections = metrics['n_detections']
n_truth = metrics["n_truth"]
n_truepos = metrics['true_pos']
n_falsepos = n_detections - n_truepos
n_falseneg = n_truth - n_truepos
agg_precision = n_truepos / n_detections
agg_recall = n_truepos / n_truth
agg_f1 = 2 * n_truepos / (n_truth + n_detections)

if verbose:
print(f"Detections: {n_detections}")
print(f"Truth buildings: {n_truth}")
print(f"True positives: {n_truepos}")
print(f"False positives: {n_falsepos}")
print(f"False negatives: {n_falseneg}")
print(f"Precision IoU@p: {agg_precision}")
print(f"Recall IoU@p: {agg_recall}")
print(f"F1 IoU@p: {agg_f1}")

return {
"precision": agg_precision,
"recall": agg_recall,
"f1": agg_f1,
}
71 changes: 46 additions & 25 deletions hot_fair_utilities/inference/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@

# Third party imports
import numpy as np
import torch
from tensorflow import keras
from ultralytics import YOLO

from ..georeferencing import georeference
from ..utils import remove_files
from .utils import open_images, save_mask
from .utils import open_images, save_mask, initialize_model

BATCH_SIZE = 8
IMAGE_SIZE = 256
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


def predict(
checkpoint_path: str, input_path: str, prediction_path: str, confidence: float = 0.5
checkpoint_path: str, input_path: str, prediction_path: str, confidence: float = 0.5, remove_images=True
) -> None:
"""Predict building footprints for aerial images given a model checkpoint.
Expand All @@ -32,6 +34,7 @@ def predict(
input_path: Path of the directory where the images are stored.
prediction_path: Path of the directory where the predicted images will go.
confidence: Threshold probability for filtering out low-confidence predictions.
remove_images: Bool indicating whether delete prediction images after they were georeferenced.
Example::
Expand All @@ -43,39 +46,57 @@ def predict(
"""
start = time.time()
print(f"Using : {checkpoint_path}")
model = keras.models.load_model(checkpoint_path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = initialize_model(checkpoint_path, device=device)
print(f"It took {round(time.time()-start)} sec to load model")
start = time.time()

os.makedirs(prediction_path, exist_ok=True)
image_paths = glob(f"{input_path}/*.png")

for i in range((len(image_paths) + BATCH_SIZE - 1) // BATCH_SIZE):
image_batch = image_paths[BATCH_SIZE * i : BATCH_SIZE * (i + 1)]
images = open_images(image_batch)
images = images.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 3)

preds = model.predict(images)
preds = np.argmax(preds, axis=-1)
preds = np.expand_dims(preds, axis=-1)
preds = np.where(
preds > confidence, 1, 0
) # Filter out low confidence predictions

for idx, path in enumerate(image_batch):
save_mask(
preds[idx],
str(f"{prediction_path}/{Path(path).stem}.png"),
)
image_paths = glob(f"{input_path}/*.png") + glob(f"{input_path}/*.tif")

if isinstance(model, keras.Model):
for i in range((len(image_paths) + BATCH_SIZE - 1) // BATCH_SIZE):
image_batch = image_paths[BATCH_SIZE * i : BATCH_SIZE * (i + 1)]
images = open_images(image_batch)
images = images.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 3)

preds = model.predict(images)
preds = np.argmax(preds, axis=-1)
preds = np.expand_dims(preds, axis=-1)
preds = np.where(
preds > confidence, 1, 0
) # Filter out low confidence predictions

for idx, path in enumerate(image_batch):
save_mask(
preds[idx],
str(f"{prediction_path}/{Path(path).stem}.png"),
)
elif isinstance(model, YOLO):
for idx in range(0, len(image_paths), BATCH_SIZE):
batch = image_paths[idx:idx + BATCH_SIZE]
for i, r in enumerate(model(batch, stream=True, conf=confidence, verbose=False)):
if r.masks is None:
preds = np.zeros((IMAGE_SIZE, IMAGE_SIZE,), dtype=np.float32)
else:
preds = r.masks.data.max(dim=0)[0] # dim=0 means to take only footprint
preds = torch.where(preds > confidence, torch.tensor(1), torch.tensor(0))
preds = preds.detach().cpu().numpy()
save_mask(preds, str(f"{prediction_path}/{Path(batch[i]).stem}.png"))
else:
raise RuntimeError("Loaded model is not supported")

print(
f"It took {round(time.time()-start)} sec to predict with {confidence} Confidence Threshold"
)
keras.backend.clear_session()
if isinstance(model, keras.Model):
keras.backend.clear_session()
del model
start = time.time()

georeference(prediction_path, prediction_path, is_mask=True)
print(f"It took {round(time.time()-start)} sec to georeference")

remove_files(f"{prediction_path}/*.xml")
remove_files(f"{prediction_path}/*.png")
if remove_images:
remove_files(f"{prediction_path}/*.xml")
remove_files(f"{prediction_path}/*.png")
17 changes: 17 additions & 0 deletions hot_fair_utilities/inference/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import List

import numpy as np
import torch
from PIL import Image
from tensorflow import keras
from ultralytics import YOLO


IMAGE_SIZE = 256

Expand All @@ -25,3 +28,17 @@ def save_mask(mask: np.ndarray, filename: str) -> None:
reshaped_mask = mask.reshape((IMAGE_SIZE, IMAGE_SIZE)) * 255
result = Image.fromarray(reshaped_mask.astype(np.uint8))
result.save(filename)


def initialize_model(path, device=None):
"""Loads either keras or yolo model."""
if not isinstance(path, str): # probably loaded model
return path

if path.endswith('.pt'): # YOLO
if not device:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = YOLO(path).to(device)
else:
model = keras.models.load_model(path)
return model
Empty file.
47 changes: 47 additions & 0 deletions hot_fair_utilities/model/yolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import torch.nn as nn
import ultralytics

from ultralytics.utils import RANK


#
# Binary cross entropy with p_c
#

class YOLOSegWithPosWeight(ultralytics.YOLO):

def train(self, trainer=None, pc=1.0, **kwargs):
return super().train(trainer, **{**kwargs, "pose": pc}) # Hide pc inside pose (pose est loss weight arg)

@property
def task_map(self):
map = super().task_map
map['segment']['model'] = SegmentationModelWithPosWeight
map['segment']['trainer'] = SegmentationTrainerWithPosWeight
return map


class SegmentationTrainerWithPosWeight(ultralytics.models.yolo.segment.train.SegmentationTrainer):

def get_model(self, cfg=None, weights=None, verbose=True):
"""Return a YOLO segmentation model."""
model = SegmentationModelWithPosWeight(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model


class SegmentationModelWithPosWeight(ultralytics.models.yolo.segment.train.SegmentationModel):

def init_criterion(self):
return v8SegmentationLossWithPosWeight(model=self)


class v8SegmentationLossWithPosWeight(ultralytics.utils.loss.v8SegmentationLoss):

def __init__(self, model):
super().__init__(model)
pc = model.args.pose # hidden in pose arg (used in different task)
pos_weight = torch.full((model.nc,), pc).to(self.device)
self.bce = nn.BCEWithLogitsLoss(reduction="none", pos_weight=pos_weight)
1 change: 1 addition & 0 deletions hot_fair_utilities/postprocessing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def tiles_from_directory(dir_path):
"""
for path in glob(f"{dir_path}/*"):
_, *tile_info = re.split("-", Path(path).stem)
tile_info[-1] = tile_info[-1].replace(".mask", "") # resolve OAM-x-y-z.mask.tif
x, y, z = map(int, tile_info)
tile = mercantile.Tile(x=x, y=y, z=z)
yield tile, path
Expand Down
8 changes: 8 additions & 0 deletions hot_fair_utilities/preprocessing/multimasks_from_polygons.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tqdm import tqdm



def get_rasterio_shape_and_transform(image_path):
# get the image shape and the affine transform to pass into df_to_px_mask.
with rio.open(image_path) as rio_dset:
Expand All @@ -41,10 +42,12 @@ def multimasks_from_polygons(
Real-world width (in meters)= Pixel width×Resolution (meters per pixel)
Args:
in_poly_dir (str): Path to directory containing geojson files.
in_chip_dir (str): Path to directory containing image chip files with names matching geojson files.
out_mask_dir (str): Path to directory containing output SDT masks.
input_contact_spacing (int, optional): Pixels that are closer to two different polygons than contact_spacing will be labeled with the contact mask.
input_boundary_width (int, optional): Width in pixel of boundary inner buffer around building footprints
Expand All @@ -66,6 +69,7 @@ def multimasks_from_polygons(
# construct the output mask file names from the chip file names.
# these will have the same base filenames as the chip files,
# with a mask.tif extension in place of the .tif extension.

mask_paths = [
construct_mask_filepath(out_mask_dir, chip_path) for chip_path in chip_paths
]
Expand Down Expand Up @@ -98,6 +102,7 @@ def multimasks_from_polygons(

if crs_is_metric(gdf):
meters = True

boundary_width = min(reference_im.res) * input_boundary_width
contact_spacing = min(reference_im.res) * input_contact_spacing

Expand All @@ -112,6 +117,7 @@ def multimasks_from_polygons(
gdf_poly = gdf.explode(ignore_index=True)

# multi_mask is a one-hot, channels-last encoded mask

onehot_multi_mask = df_to_px_mask(
df=gdf_poly,
out_file=mask_path,
Expand All @@ -126,6 +132,7 @@ def multimasks_from_polygons(
meters=meters,
)


# convert onehot_multi_mask to a sparse encoded mask
# of shape (1,H,W) for compatibility with rasterio writer
sparse_multi_mask = multimask_to_sparse_multimask(onehot_multi_mask)
Expand All @@ -135,6 +142,7 @@ def multimasks_from_polygons(
with rio.open(chip_path, "r") as src:
meta = src.meta.copy()
meta.update(count=sparse_multi_mask.shape[0])

meta.update(dtype="uint8")
meta.update(nodata=None)
with rio.open(mask_path, "w", **meta) as dst:
Expand Down
5 changes: 4 additions & 1 deletion hot_fair_utilities/preprocessing/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .fix_labels import fix_labels
from .multimasks_from_polygons import multimasks_from_polygons
from .reproject_labels import reproject_labels_to_epsg3857
from .multimasks_from_polygons import multimasks_from_polygons


def preprocess(
Expand Down Expand Up @@ -43,6 +44,7 @@ def preprocess(
If rasterize=False, rasterize_options will be ignored.
georeference_images: Whether to georeference the OAM images.
multimasks: Whether to additionally output multimask labels.
input_contact_spacing (int, optional): Pixels that are closer to two different polygons than contact_spacing will be labeled with the contact mask.
input_boundary_width (int, optional): Width in pixel of boundary inner buffer around building footprints
Expand Down Expand Up @@ -96,6 +98,7 @@ def preprocess(
os.remove(f"{output_path}/labels_epsg3857.geojson")

if multimasks:

assert os.path.isdir(
f"{output_path}/chips"
), "Chips do not exist. Set georeference_images=True."
Expand All @@ -105,4 +108,4 @@ def preprocess(
f"{output_path}/multimasks",
input_contact_spacing=input_contact_spacing,
input_boundary_width=input_boundary_width,
)
)
Loading

0 comments on commit ac462ab

Please sign in to comment.