From 6c763343d138df659dc6accb1b5d88ee8930bc0e Mon Sep 17 00:00:00 2001 From: rbler1234 Date: Tue, 19 Nov 2024 20:44:28 +0800 Subject: [PATCH] mmscan-devkit v1 --- mmscan/evaluator/metrics/lang_metric.py | 10 ++++++++++ mmscan/evaluator/qa_evaluation.py | 1 + mmscan/evaluator/vg_evaluation.py | 1 + mmscan/utils/box_utils.py | 17 +++++++++++++++++ mmscan/utils/data_io.py | 15 +++++++++++++++ mmscan/utils/lang_utils.py | 14 ++++++++++++++ 6 files changed, 58 insertions(+) diff --git a/mmscan/evaluator/metrics/lang_metric.py b/mmscan/evaluator/metrics/lang_metric.py index 3b66dad..8b8e92e 100644 --- a/mmscan/evaluator/metrics/lang_metric.py +++ b/mmscan/evaluator/metrics/lang_metric.py @@ -119,6 +119,11 @@ def em_evaluation(batch_input): class simcse_evaluator: + """A class for calculating the simcse similarity score. + + Args: + model_path: path to the simcse pretrained model. + """ def __init__(self, model_path, eval_bs=500) -> None: self.eval_bs = eval_bs @@ -208,6 +213,11 @@ def evaluation(self, batch_input): class sbert_evaluator: + """A class for calculating the sbert similarity score. + + Args: + model_path: path to the sbert pretrained model. + """ def __init__(self, model_path, eval_bs=500) -> None: self.eval_bs = eval_bs diff --git a/mmscan/evaluator/qa_evaluation.py b/mmscan/evaluator/qa_evaluation.py index 04637c8..5761c22 100644 --- a/mmscan/evaluator/qa_evaluation.py +++ b/mmscan/evaluator/qa_evaluation.py @@ -57,6 +57,7 @@ def __init__(self, model_config={}, max_length=256, verbose=True) -> None: self.reset() def reset(self): + """Reset the evaluator, clear the buffer and records.""" self.metric_record = {} self.save_results = {} self.save_buffer = [] diff --git a/mmscan/evaluator/vg_evaluation.py b/mmscan/evaluator/vg_evaluation.py index e1d5c3e..90ff78b 100644 --- a/mmscan/evaluator/vg_evaluation.py +++ b/mmscan/evaluator/vg_evaluation.py @@ -45,6 +45,7 @@ def __init__(self, verbose=True) -> None: self.reset() def reset(self): + """Reset the evaluator, clear the buffer and records.""" self.save_buffer = [] self.records = [] self.category_records = {} diff --git a/mmscan/utils/box_utils.py b/mmscan/utils/box_utils.py index 5cae75e..68afb88 100644 --- a/mmscan/utils/box_utils.py +++ b/mmscan/utils/box_utils.py @@ -247,6 +247,14 @@ def euler_iou3d_bbox(center1, size1, rot1, center2, size2, rot2): def box_num(box): + """Return the number of boxes in a grounp. + + Args: + box (list/tuple, tensor): boxes in a grounp. + + Returns: + int : the number + """ if isinstance(box, (list, tuple)): return box[0].shape[0] else: @@ -261,6 +269,15 @@ def index_box(boxes, indices): def to_9dof_box(box): + """Convert a grounp of bounding boxes represented in [center, size, rot] + format to 9 DoF format. + + Args: + box (list/tuple, tensor): boxes in a grounp. + + Returns: + Tensor : 9 DoF format. (num,9) + """ if isinstance(box, (list, tuple)): center, size, rotmat = box euler = matrix_to_euler_angles(rotmat, 'ZXY') diff --git a/mmscan/utils/data_io.py b/mmscan/utils/data_io.py index 8dbc67b..1b75110 100644 --- a/mmscan/utils/data_io.py +++ b/mmscan/utils/data_io.py @@ -168,6 +168,14 @@ def reverse_dict(mapping): self.mp3d_mapping_trans = reverse_dict(self.mp3d_mapping) def forward(self, scan_name): + """map forward the original scan names to the new names. + + Args: + scan_name (str): the original scan name. + Returns: + str: the new name. + """ + if 'matterport3d/' in scan_name: scan_, region_ = ( self.mp3d_mapping[scan_name.split('/')[1]], @@ -182,6 +190,13 @@ def forward(self, scan_name): raise ValueError(f'{scan_name} is not a scan name') def backward(self, scan_name): + """map backward the new names to the original scan names. + + Args: + scan_name (str): the new name. + Returns: + str: the original scan name. + """ if '1mp3d' in scan_name: scene1, scene2, region = scan_name.split('_') return ('matterport3d/' + diff --git a/mmscan/utils/lang_utils.py b/mmscan/utils/lang_utils.py index f99306e..1ea208b 100644 --- a/mmscan/utils/lang_utils.py +++ b/mmscan/utils/lang_utils.py @@ -71,6 +71,13 @@ def clean_answer(data): def normalize_answer(s): + """Help to 'normalize' the answer. + + Args: + s (str): the raw answer. + Returns: + str : the processed sentence. + """ def remove_articles(text): return re.sub(r'\b(a|an|the)\b', ' ', text) @@ -207,6 +214,13 @@ def qa_prompt_define(): def qa_metric_map(eval_type): + """Map the class type to the corresponding Abbrev. + + Args: + eval_type (str): the class name. + Returns: + str : the corresponding Abbrev. + """ if 'Attribute_OO' in eval_type: target = 'OOa' elif 'Space_OO' in eval_type: