Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

armory 0.18.1 release candidate #1977

Merged
merged 133 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
133 commits
Select commit Hold shift + click to select a range
6e809a3
Merge remote-tracking branch 'upstream/develop' into develop
christopherwoodall Mar 10, 2023
085a23f
Merge remote-tracking branch 'origin/develop' into develop
christopherwoodall Mar 23, 2023
afd9844
Merge remote-tracking branch 'origin/develop' into develop
christopherwoodall Apr 17, 2023
8140dce
Merge remote-tracking branch 'origin/develop' into develop
christopherwoodall Apr 24, 2023
aa04cb6
Merge remote-tracking branch 'origin/develop' into develop
May 24, 2023
f3285a2
add new metric to compute mAP at different distances from patch
swsuggs Jun 9, 2023
97593ea
load new metric in carla od scenario
swsuggs Jun 9, 2023
cd65243
add missing class to __init__.py
swsuggs Jun 9, 2023
c77d796
check that scenario has y_patch_metadata
swsuggs Jun 9, 2023
3b56ce3
split iou function into modular components for reusability
swsuggs Jun 12, 2023
2a403ca
add generalized iou function
swsuggs Jun 12, 2023
dc75791
use giou for distance based carla metric
swsuggs Jun 14, 2023
e6401ed
remove untimely variable existence check. and update metric name
swsuggs Jun 14, 2023
c473e4d
adding helper functions for more accurate GIoU
swsuggs Jun 16, 2023
02f39d8
update GIoU function with more accurate components
swsuggs Jun 16, 2023
de3ef44
update giou-based mAP function
swsuggs Jun 16, 2023
48caefe
code and comment improvements
swsuggs Jun 21, 2023
24d3129
updated CARLA object detection attacks so that region perturbed in de…
yusong-tan Jun 22, 2023
e6c12c2
expand metric to report map by max giou as well as min
swsuggs Jun 22, 2023
002388f
update scenario with new metric name
swsuggs Jun 22, 2023
b902437
formatting
swsuggs Jun 22, 2023
d81d035
adding option to perform random targeted attacks for CARLA object det…
yusong-tan Jun 24, 2023
bfd58b3
histogram version (not working)
swsuggs Jun 30, 2023
e5e93a8
fix histogram version
swsuggs Jul 3, 2023
0d8e9f8
json formatting
yusong-tan Jul 3, 2023
638b35d
minor correction
swsuggs Jul 5, 2023
68b11a1
additional validation on input list
swsuggs Jul 5, 2023
de93cd5
adding baseline results for carla mot test set
swsuggs Jul 5, 2023
eaef737
Merge pull request #1954 from yusong-tan/carla_attack_update
swsuggs Jul 5, 2023
d9ce743
update art
swsuggs Jul 5, 2023
6666d97
Merge pull request #1956 from swsuggs/art-bump-1.15
Jul 5, 2023
2d079a3
Merge pull request #1955 from swsuggs/carla-mot-test-baselines
deprit Jul 5, 2023
2b2453f
save copy of original config for results json
swsuggs Jul 10, 2023
4ee4623
Merge pull request #1957 from swsuggs/1333-save-original-config
Jul 10, 2023
ad4a737
mute giou output from log
swsuggs Jul 13, 2023
b7cf77c
Merge branch 'develop' of https://github.com/twosixlabs/armory into d…
swsuggs Jul 13, 2023
ff69028
remove unused import
swsuggs Jul 13, 2023
c5989f9
add support for pre-computed fairness majority masks
f4str Jul 13, 2023
c9e5be5
fix docker bugs
f4str Jul 13, 2023
d8ab040
minor update to datasets
yusong-tan Jul 13, 2023
09c881e
Merge branch 'twosixlabs:develop' into develop
Jul 18, 2023
487736c
add support for `.readthedocs.yaml`
Jul 18, 2023
6445f61
update ` PyYAML` requirements
Jul 18, 2023
426b9e5
only install `pyyaml` when using python version > 3.10
Jul 18, 2023
0e7dea9
allow more generic list of inputs to task_meter functions
swsuggs Jul 19, 2023
51432cb
metrics logger add_custom_task function for metrics with unique inputs
swsuggs Jul 19, 2023
c2274cc
formatting
swsuggs Jul 19, 2023
6d45b13
stub for plotting utility
swsuggs Jul 19, 2023
cae1fd9
restructure models dir and add init
swsuggs Jul 19, 2023
3cddebf
find correct cfg path
swsuggs Jul 19, 2023
f1b611f
update json configs
swsuggs Jul 19, 2023
469e889
remove unused dependency
Jul 19, 2023
d4ca344
added CARLA overhead object detection test data
yusong-tan Jul 19, 2023
7e2634c
remove `readthedocs`
Jul 19, 2023
14943dd
remove `bandit`
Jul 19, 2023
b4d629e
update default split
swsuggs Jul 19, 2023
2a8ef75
update url checksum file with url
swsuggs Jul 19, 2023
135a6f8
update url
swsuggs Jul 19, 2023
1cb1e6b
update cached checksum
swsuggs Jul 19, 2023
192d61e
update docs
swsuggs Jul 19, 2023
bf7b758
update pinned dependency
Jul 20, 2023
04903a2
patch `pyyaml` issues in CI until fixed in upstream
Jul 20, 2023
54f6fdb
fix formatting in CI yaml
Jul 20, 2023
1cbe91e
updated CARLA MOT datasets with more accurate green screen coordinate…
yusong-tan Jul 20, 2023
917fbcb
adding directory for majority mask data and helper function for loadi…
swsuggs Jul 20, 2023
dbb4c4d
url checksums
swsuggs Jul 21, 2023
08e7c26
dataset urls
swsuggs Jul 21, 2023
ec5db25
cached checksums
swsuggs Jul 21, 2023
41299c8
updated CARLA MOT test dataset to fix RGB/instance segmentation misal…
yusong-tan Jul 22, 2023
39127d5
resolve merge conflict
yusong-tan Jul 22, 2023
dc1a1d1
Merge branch 'master' into develop
mwartell Jul 24, 2023
f57afa3
Merge branch 'develop' of github.com:twosixlabs/armory into develop
mwartell Jul 24, 2023
075c50f
Merge pull request #1967 from twosixlabs/bring-tag-to-develop-from-ma…
Jul 24, 2023
b72f51c
Merge pull request #1964 from f4str/majority_masks
swsuggs Jul 25, 2023
18f2f66
update CI
Jul 26, 2023
6c30bcd
update references to `readthedocs.com`
Jul 26, 2023
5765baa
Merge pull request #1963 from christopherwoodall/update-read-the-docs…
mwartell Jul 26, 2023
153c8b5
Merge pull request #1966 from yusong-tan/carla_od_test_data
swsuggs Jul 26, 2023
fd70889
standardize threshold keys
swsuggs Jul 26, 2023
7787237
plot functions for visualizing giou output
swsuggs Jul 26, 2023
1574986
add tensorflow specific requirements to conda
Jul 26, 2023
2ecfe23
update cached checksum
swsuggs Jul 26, 2023
a9d8713
descriptive docstring
swsuggs Jul 26, 2023
6a92644
formatting
swsuggs Jul 26, 2023
11ee9bb
updating CARLA attacks with the option to use Adam optimizer
yusong-tan Jul 28, 2023
f53741d
Merge pull request #1969 from yusong-tan/carla_attack_efficiency_update
swsuggs Jul 31, 2023
e4e0fb6
resort imports
Jul 31, 2023
88ef7b1
fix E721
Jul 31, 2023
01d437e
reformat with black
Jul 31, 2023
fc40d51
Merge pull request #1970 from christopherwoodall/formatting
Jul 31, 2023
f191bed
Merge branch 'develop' into fix-yolo-cfg-path
swsuggs Jul 31, 2023
266c127
Merge pull request #1965 from swsuggs/fix-yolo-cfg-path
swsuggs Jul 31, 2023
7e4cd48
od poisoning test triggers
swsuggs Aug 2, 2023
10377fd
add min/max size kwargs to targeted configs
swsuggs Aug 2, 2023
c9ccb09
update targeted configs for adam optimizer
swsuggs Aug 3, 2023
399ac98
fix typo
swsuggs Aug 3, 2023
d6ace3a
format
swsuggs Aug 3, 2023
08552ab
remove np.object references
swsuggs Aug 7, 2023
fc6a6bf
update `yolo` Dockerfile
Aug 7, 2023
ec08e05
Merge pull request #1973 from swsuggs/triggers-and-config-kwargs
swsuggs Aug 7, 2023
feb658f
migrate patch to deepspeech
Aug 7, 2023
1a374a2
Merge pull request #1974 from swsuggs/1971-np-object
swsuggs Aug 7, 2023
ce5943d
update hydra install
Aug 7, 2023
af36875
pin tensorflow version in pyproject.toml
Aug 7, 2023
0d77928
pin tensorflow to version 2.10.0
Aug 7, 2023
266dd03
remove `--upgrade` tag
Aug 7, 2023
0d3f081
Merge pull request #1960 from yusong-tan/carla_update_datasets
swsuggs Aug 7, 2023
f34e4a8
update dataset test for new dataset splits
swsuggs Aug 7, 2023
8402c9c
Adding plot util function to .
jprokos26 Aug 7, 2023
d5b0bf2
Merge pull request #1975 from swsuggs/carla-e2e-test-update
Aug 7, 2023
60c4985
update environment markers
Aug 7, 2023
da30341
format environment markers
Aug 7, 2023
40c729b
loosen version ranges
Aug 7, 2023
fff1323
expand versions once more
Aug 7, 2023
819500f
update `python_version`
Aug 7, 2023
e82aeb3
bump `scikit-learn` version
Aug 7, 2023
cb8749b
update base requirements in `environment.yaml`
Aug 7, 2023
6a5a5f8
fix type checking
swsuggs Aug 8, 2023
90ae5b3
fix type checking
swsuggs Aug 8, 2023
cab8d3a
Add error message for plot utility
swsuggs Aug 8, 2023
7d5b016
formatting
swsuggs Aug 8, 2023
3d82490
pin pytorch to version 1.12 in `pyproject`
Aug 8, 2023
cf2943a
update deepspeech requirements
Aug 8, 2023
2fc4635
merge develop
swsuggs Aug 8, 2023
b3da636
isort
swsuggs Aug 8, 2023
325f489
Merge pull request #1968 from christopherwoodall/update-conda-environ…
swsuggs Aug 8, 2023
2c100e1
update base image
Aug 8, 2023
a450f30
minor update for older pyplot version
swsuggs Aug 8, 2023
4fa3d04
update workflow
Aug 8, 2023
11675f2
add installation to release
Aug 8, 2023
317e6a2
Merge branch 'develop' into update-base-image
Aug 8, 2023
94d44dd
Merge pull request #1946 from swsuggs/distance-based-carla-metric
swsuggs Aug 8, 2023
1e7e415
Merge pull request #1976 from christopherwoodall/update-base-image
mwartell Aug 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions .github/workflows/1-scan-lint-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,6 @@ jobs:
fi


- name: 🦹‍♂️ Scanning with Bandit
run: |
bandit \
-v \
-f txt \
-r ./armory \
-c "pyproject.toml" \
--output /tmp/artifacts/bandit_scan.txt \
|| $( exit 0 ); echo $?


- name: 🖋️ mypy Type Checking
run: |
python3 -m pip install mypy
Expand Down
35 changes: 35 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,41 @@ jobs:
dist/*.whl


release-base-image:
name: Build and Release Base Image
needs: [release-wheel]
runs-on: ubuntu-latest
steps:
- name: 🐍 Setup Python 3.9
uses: actions/setup-python@v4
with:
python-version: 3.9

- name: 📩 Checkout Armory w/ full depth(for tags and SCM)
uses: actions/checkout@v3
with:
fetch-depth: 0

- name: 🌎 Setup Build Environment
run: |
pip install pip>=22.2.2
pip install .
armory configure --use-defaults

- name: ☁️ Login to DockerHub
if: ${{ env.RELEASE_DRY_RUN == 'false' }}
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}

- name: 🔨 Build and Push Base Image
run: |
echo "Building Base Image"
sed -i 's/\r$//' docker/build-base.sh
bash docker/build-base.sh


release-docker:
name: Build and Release Docker Images
needs: [release-wheel]
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ Agency (DARPA).
[python-url]: https://pypi.org/project/armory-testbed
[license-badge]: https://img.shields.io/badge/License-MIT-yellow.svg
[license-url]: https://opensource.org/licenses/MIT
[docs-badge]: https://readthedocs.org/projects/armory/badge/
[docs-url]: https://readthedocs.org/projects/armory/
[docs-badge]: https://github.com/twosixlabs/armory/docs/assets/docs-badge.svg
[docs-url]: https://github.com/twosixlabs/armory/docs
[style-badge]: https://img.shields.io/badge/code%20style-black-000000.svg
[style-url]: https://github.com/ambv/black
10 changes: 9 additions & 1 deletion armory/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@

import armory
from armory import arguments, paths
from armory.cli.tools import log_current_branch, rgb_depth_convert
from armory.cli.tools import (
log_current_branch,
plot_mAP_by_giou_with_patch_cli,
rgb_depth_convert,
)
from armory.configuration import load_global_config, save_config
from armory.eval import Evaluator
import armory.logs
Expand Down Expand Up @@ -722,6 +726,10 @@ def exec(command_args, prog, description):
UTILS_COMMANDS = {
"get-branch": (log_current_branch, "log the current git branch of armory"),
"rgb-convert": (rgb_depth_convert, "converts rgb depth images to another format"),
"plot-mAP-by-giou": (
plot_mAP_by_giou_with_patch_cli,
"Visualize the output of the metric 'object_detection_AP_per_class_by_giou_from_patch.'",
),
}


Expand Down
43 changes: 42 additions & 1 deletion armory/art_experimental/attacks/carla_mot_adversarial_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,45 @@ def __init__(self, estimator, coco_format=False, **kwargs):

super().__init__(estimator=estimator, **kwargs)

# Set`loss.backward(retain_graph=False)` and zero out Adam optimizer gradients before each attack iteration
def _train_step(
self,
images: "torch.Tensor",
target: "torch.Tensor",
mask: Optional["torch.Tensor"] = None,
) -> "torch.Tensor":

self.estimator.model.zero_grad()
# only zero gradients when there is a non-pgd optimizer; pgd optimizer appears to perform better when gradients accumulate
if self._optimizer_string == "Adam":
self._optimizer.zero_grad(set_to_none=True)
loss = self._loss(images, target, mask)
loss.backward(retain_graph=False)

if self._optimizer_string == "pgd":
if self._patch.grad is not None:
gradients = self._patch.grad.sign() * self.learning_rate
else:
raise ValueError("Gradient term in PyTorch model is `None`.")

with torch.no_grad():
self._patch[:] = torch.clamp(
self._patch + gradients,
min=self.estimator.clip_values[0],
max=self.estimator.clip_values[1],
)
else:
self._optimizer.step()

with torch.no_grad():
self._patch[:] = torch.clamp(
self._patch,
min=self.estimator.clip_values[0],
max=self.estimator.clip_values[1],
)

return loss

def create_initial_image(self, size):
"""
Create initial patch based on a user-defined image
Expand Down Expand Up @@ -400,7 +439,7 @@ def generate(self, x, y, y_patch_metadata):
# Use this mask to embed patch into the background in the event of occlusion
self.patch_masks_video = y_patch_metadata[i]["masks"]

# self._patch needs to be re-initialized with the correct shape
# self._patch and optimizer need to be re-initialized
if self.patch_base_image is not None:
self.patch_base = self.create_initial_image(
(patch_width, patch_height),
Expand All @@ -412,6 +451,8 @@ def generate(self, x, y, y_patch_metadata):
self._patch = torch.tensor(
patch_init, requires_grad=True, device=self.estimator.device
)
if self._optimizer_string == "Adam":
self._optimizer = torch.optim.Adam([self._patch], lr=self.learning_rate)

# Perform batch attack by attacking multiple frames from the same video
for batch_i in range(0, x[i].shape[0], self.batch_frame_size):
Expand Down
65 changes: 59 additions & 6 deletions armory/art_experimental/attacks/carla_obj_det_adversarial_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,13 @@ def _train_step(
import torch # lgtm [py/repeated-import]

self.estimator.model.zero_grad()
# only zero gradients when there is a non-pgd optimizer; pgd optimizer appears to perform better when gradients accumulate
if self._optimizer_string == "Adam":
self._optimizer_rgb.zero_grad(set_to_none=True)
if images.shape[-1] == 6:
self._optimizer_depth.zero_grad(set_to_none=True)
loss = self._loss(images, target, mask)
loss.backward(retain_graph=True)
loss.backward(retain_graph=False)

if self._optimizer_string == "pgd":
patch_grads = self._patch.grad
Expand Down Expand Up @@ -159,7 +164,6 @@ def _train_step(
min=self.min_depth,
max=self.max_depth,
)
self.depth_perturbation[:] = perturbed_images - images_depth
else:
images_depth_linear = rgb_depth_to_linear(
images_depth[:, 0, :, :],
Expand All @@ -175,12 +179,49 @@ def _train_step(
perturbed_images = torch.stack(
[depth_r, depth_g, depth_b], dim=1
)
self.depth_perturbation[:] = perturbed_images - images_depth
self.depth_perturbation[:] = perturbed_images - images_depth

else:
raise ValueError(
"Adam optimizer for CARLA Adversarial Patch not supported."
)
self._optimizer_rgb.step()
if images.shape[-1] == 6:
self._optimizer_depth.step()

with torch.no_grad():
self._patch[:] = torch.clamp(
self._patch,
min=self.estimator.clip_values[0],
max=self.estimator.clip_values[1],
)

if images.shape[-1] == 6:
images_depth = torch.permute(images[:, :, :, 3:], (0, 3, 1, 2))
if self.depth_type == "log":
perturbed_images = torch.clamp(
images_depth + self.depth_perturbation,
min=self.min_depth,
max=self.max_depth,
)
else:
images_depth_linear = rgb_depth_to_linear(
images_depth[:, 0, :, :],
images_depth[:, 1, :, :],
images_depth[:, 2, :, :],
)
depth_linear = rgb_depth_to_linear(
self.depth_perturbation[:, 0, :, :],
self.depth_perturbation[:, 1, :, :],
self.depth_perturbation[:, 2, :, :],
)
depth_linear = torch.clamp(
images_depth_linear + depth_linear,
min=self.min_depth,
max=self.max_depth,
)
depth_r, depth_g, depth_b = linear_depth_to_rgb(depth_linear)
perturbed_images = torch.stack(
[depth_r, depth_g, depth_b], dim=1
)
self.depth_perturbation[:] = perturbed_images - images_depth

return loss

Expand Down Expand Up @@ -348,6 +389,9 @@ def _random_overlay(
).to(self.estimator.device)
foreground_mask = torch.permute(foreground_mask, (2, 0, 1))
foreground_mask = torch.unsqueeze(foreground_mask, dim=0)
foreground_mask = ~(
~foreground_mask * image_mask.bool()
) # ensure area perturbed in depth is consistent with area perturbed in RGB

# Adjust green screen brightness
v_avg = (
Expand Down Expand Up @@ -501,6 +545,15 @@ def generate(self, x, y=None, y_patch_metadata=None):
device=self.estimator.device,
)

if self._optimizer_string == "Adam":
self._optimizer_rgb = torch.optim.Adam(
[self._patch], lr=self.learning_rate
)
if x.shape[-1] == 6:
self._optimizer_depth = torch.optim.Adam(
[self.depth_perturbation], lr=self.learning_rate_depth
)

patch, _ = super().generate(np.expand_dims(x[i], axis=0), y=[y_gt])

# Patch image
Expand Down
28 changes: 25 additions & 3 deletions armory/art_experimental/attacks/carla_obj_det_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

from armory.art_experimental.attacks.carla_obj_det_utils import (
linear_depth_to_rgb,
rgb_depth_to_linear,
linear_to_log,
log_to_linear,
rgb_depth_to_linear,
)
from armory.logs import log
from armory.utils.external_repo import ExternalRepoImport
Expand Down Expand Up @@ -450,7 +450,9 @@ def inner_generate( # type: ignore
if self.depth_type == "log":
depth_log = (
self.depth_perturbation
+ np.sign(depth_gradients) * self.learning_rate_depth
+ np.sign(depth_gradients)
* (1 - 2 * int(self.targeted))
* self.learning_rate_depth
)
perturbed_images = np.clip(
images_depth + depth_log, self.min_depth, self.max_depth
Expand All @@ -468,7 +470,10 @@ def inner_generate( # type: ignore
self.depth_perturbation[:, :, :, 2],
).astype("float32")
depth_linear = (
depth_linear + np.sign(grads_linear) * self.learning_rate_depth
depth_linear
+ np.sign(grads_linear)
* (1 - 2 * int(self.targeted))
* self.learning_rate_depth
)

images_depth_linear = rgb_depth_to_linear(
Expand Down Expand Up @@ -770,6 +775,23 @@ def generate(self, x, y=None, y_patch_metadata=None):
) # (1,H,W,3)
self.foreground = np.all(self.binarized_patch_mask == 255, axis=-1)
self.foreground = np.expand_dims(self.foreground, (-1, 0)) # (1,H,W,1)
# ensure area perturbed in depth is consistent with area perturbed in RGB
h, _ = cv2.findHomography(
np.array(
[
[0, 0],
[patch_width - 1, 0],
[patch_width - 1, patch_height - 1],
[0, patch_height - 1],
]
),
gs_coords,
)
rgb_mask = np.ones((patch_height, patch_width, 3), dtype=np.float32)
rgb_mask = cv2.warpPerspective(
rgb_mask, h, (x.shape[2], x.shape[1]), cv2.INTER_CUBIC
)
self.foreground = self.foreground * rgb_mask[:, :, 0:1]

if y is None:
patch = self.inner_generate(
Expand Down
2 changes: 1 addition & 1 deletion armory/art_experimental/attacks/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _is_robust(self, y, y_pred):
return metric_result > self.metric_threshold

def _get_metric_result(self, y, y_pred):
if isinstance(y, np.ndarray) and y.dtype == np.object:
if isinstance(y, np.ndarray) and y.dtype == object:
# convert np object array to list of dicts
metric_result = self.metric_fn([y[0]], y_pred)
else:
Expand Down
23 changes: 23 additions & 0 deletions armory/baseline_models/model_configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path

from armory.data.utils import maybe_download_weights_from_s3
from armory.logs import log

CONFIGS_DIR = Path(__file__).parent


def get_path(filename) -> str:
"""
Get the absolute path of the provided config. Ordering priority is:
1) Check directly for provided filepath
2) Load from `model_configs` directory
3) Attempt to download from s3 as a weights file
"""
filename = Path(filename)
if filename.is_file():
return str(filename)
cfgs_path = CONFIGS_DIR / filename
if cfgs_path.is_file():
return str(cfgs_path)

return maybe_download_weights_from_s3(filename)
7 changes: 4 additions & 3 deletions armory/baseline_models/pytorch/yolov3.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Optional

import torch

from art.estimators.object_detection import PyTorchYolo
from pytorchyolo.utils.loss import compute_loss
from pytorchyolo.models import load_model
from pytorchyolo.utils.loss import compute_loss
import torch

from armory.baseline_models import model_configs

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand All @@ -30,6 +30,7 @@ def get_art_model(
model_kwargs: dict, wrapper_kwargs: dict, weights_path: Optional[str] = None
) -> PyTorchYolo:

model_kwargs["model_path"] = model_configs.get_path(model_kwargs["model_path"])
model = load_model(weights_path=weights_path, **model_kwargs)
model_wrapper = Yolo(model)

Expand Down
Loading