Skip to content

Commit

Permalink
MetaCAT fixes and upgrades (#495)
Browse files Browse the repository at this point in the history
* MetaCAT fixes and upgrades

Pushing for 3 updates:
1) Removed the check and update for labels with zero data, as this was causing issues during evaluation
2) Resolved an issue where the confusion matrix couldn't be calculated when testing on a single class with an F1 score of 1, as it expected the original number of training classes (3)
3) Updated the attention mask creation to dynamically use the actual pad_idx value instead of assuming it to be 0

* Pushing type fix

* Pushing for type fix

* Fixing type issues

* Pushing change

* Pushing update w/o try except block

For the issue where the confusion matrix couldn't be calculated when testing on a single class with an F1 score of 1, as it expected the original number of training classes (3), pushing an optimized version w/o the try except block
  • Loading branch information
shubham-s-agarwal authored and mart-r committed Oct 14, 2024
1 parent 38a60e3 commit b054bdb
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 20 deletions.
17 changes: 0 additions & 17 deletions medcat/utils/meta_cat/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,23 +180,6 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
category_value2id = {}

category_values = set([x[2] for x in data])
# Ensuring that each label has data and checking for class imbalance

label_data = {key: 0 for key in category_value2id}
for i in range(len(data)):
if data[i][2] in category_value2id:
label_data[data[i][2]] = label_data[data[i][2]] + 1

# If a label has no data, changing the mapping
if 0 in label_data.values():
category_value2id_: Dict = {}
keys_ls = [key for key, value in category_value2id.items() if value != 0]
for k in keys_ls:
category_value2id_[k] = len(category_value2id_)

logger.warning("Labels found with 0 data; updates made\nFinal label encoding mapping: %s",category_value2id_)
category_value2id = category_value2id_

for c in category_values:
if c not in category_value2id:
category_value2id[c] = len(category_value2id)
Expand Down
12 changes: 9 additions & 3 deletions medcat/utils/meta_cat/ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class label of the data

x = torch.tensor(x, dtype=torch.long).to(device)
# cpos = torch.tensor(cpos, dtype=torch.long).to(device)
attention_masks = (x != 0).type(torch.int)
attention_masks = (x != pad_id).type(torch.int)
return x, cpos, attention_masks, y


Expand Down Expand Up @@ -412,10 +412,16 @@ def eval_model(model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: T
precision, recall, f1, support = precision_recall_fscore_support(y_eval, predictions, average=score_average)

labels = [name for (name, _) in sorted(config.general['category_value2id'].items(), key=lambda x: x[1])]
labels_present_: set = set(predictions)
labels_present: List[str] = [str(x) for x in labels_present_]

if len(labels) != len(labels_present):
logger.warning(
"The evaluation dataset does not contain all the labels, some labels are missing. Performance displayed for labels found...")
confusion = pd.DataFrame(
data=confusion_matrix(y_eval, predictions, ),
columns=["true " + label for label in labels],
index=["predicted " + label for label in labels],
columns=["true " + label for label in labels_present],
index=["predicted " + label for label in labels_present],
)

examples: Dict = {'FP': {}, 'FN': {}, 'TP': {}}
Expand Down

0 comments on commit b054bdb

Please sign in to comment.