Skip to content

Commit

Permalink
mmscan-devkit v1
Browse files Browse the repository at this point in the history
  • Loading branch information
rbler1234 committed Nov 19, 2024
1 parent 772c3aa commit 6c76334
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 0 deletions.
10 changes: 10 additions & 0 deletions mmscan/evaluator/metrics/lang_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ def em_evaluation(batch_input):


class simcse_evaluator:
"""A class for calculating the simcse similarity score.
Args:
model_path: path to the simcse pretrained model.
"""

def __init__(self, model_path, eval_bs=500) -> None:
self.eval_bs = eval_bs
Expand Down Expand Up @@ -208,6 +213,11 @@ def evaluation(self, batch_input):


class sbert_evaluator:
"""A class for calculating the sbert similarity score.
Args:
model_path: path to the sbert pretrained model.
"""

def __init__(self, model_path, eval_bs=500) -> None:
self.eval_bs = eval_bs
Expand Down
1 change: 1 addition & 0 deletions mmscan/evaluator/qa_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, model_config={}, max_length=256, verbose=True) -> None:
self.reset()

def reset(self):
"""Reset the evaluator, clear the buffer and records."""
self.metric_record = {}
self.save_results = {}
self.save_buffer = []
Expand Down
1 change: 1 addition & 0 deletions mmscan/evaluator/vg_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, verbose=True) -> None:
self.reset()

def reset(self):
"""Reset the evaluator, clear the buffer and records."""
self.save_buffer = []
self.records = []
self.category_records = {}
Expand Down
17 changes: 17 additions & 0 deletions mmscan/utils/box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,14 @@ def euler_iou3d_bbox(center1, size1, rot1, center2, size2, rot2):


def box_num(box):
"""Return the number of boxes in a grounp.
Args:
box (list/tuple, tensor): boxes in a grounp.
Returns:
int : the number
"""
if isinstance(box, (list, tuple)):
return box[0].shape[0]
else:
Expand All @@ -261,6 +269,15 @@ def index_box(boxes, indices):


def to_9dof_box(box):
"""Convert a grounp of bounding boxes represented in [center, size, rot]
format to 9 DoF format.
Args:
box (list/tuple, tensor): boxes in a grounp.
Returns:
Tensor : 9 DoF format. (num,9)
"""
if isinstance(box, (list, tuple)):
center, size, rotmat = box
euler = matrix_to_euler_angles(rotmat, 'ZXY')
Expand Down
15 changes: 15 additions & 0 deletions mmscan/utils/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ def reverse_dict(mapping):
self.mp3d_mapping_trans = reverse_dict(self.mp3d_mapping)

def forward(self, scan_name):
"""map forward the original scan names to the new names.
Args:
scan_name (str): the original scan name.
Returns:
str: the new name.
"""

if 'matterport3d/' in scan_name:
scan_, region_ = (
self.mp3d_mapping[scan_name.split('/')[1]],
Expand All @@ -182,6 +190,13 @@ def forward(self, scan_name):
raise ValueError(f'{scan_name} is not a scan name')

def backward(self, scan_name):
"""map backward the new names to the original scan names.
Args:
scan_name (str): the new name.
Returns:
str: the original scan name.
"""
if '1mp3d' in scan_name:
scene1, scene2, region = scan_name.split('_')
return ('matterport3d/' +
Expand Down
14 changes: 14 additions & 0 deletions mmscan/utils/lang_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def clean_answer(data):


def normalize_answer(s):
"""Help to 'normalize' the answer.
Args:
s (str): the raw answer.
Returns:
str : the processed sentence.
"""

def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
Expand Down Expand Up @@ -207,6 +214,13 @@ def qa_prompt_define():


def qa_metric_map(eval_type):
"""Map the class type to the corresponding Abbrev.
Args:
eval_type (str): the class name.
Returns:
str : the corresponding Abbrev.
"""
if 'Attribute_OO' in eval_type:
target = 'OOa'
elif 'Space_OO' in eval_type:
Expand Down

0 comments on commit 6c76334

Please sign in to comment.