Skip to content

Commit

Permalink
Merge pull request #52 from harmonydata/search-instruments
Browse files Browse the repository at this point in the history
Search instruments
  • Loading branch information
woodthom2 authored Sep 6, 2024
2 parents d297e63 + 2088b14 commit bf7ebbc
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 6 deletions.
101 changes: 96 additions & 5 deletions src/harmony/matching/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
"""

import statistics
from collections import Counter
import heapq
from collections import Counter, OrderedDict
from typing import List, Callable

import numpy as np
Expand Down Expand Up @@ -155,8 +156,8 @@ def match_instruments_with_catalogue_instruments(
:param catalogue_data: The catalogue data.
:param vectorisation_function: A function to vectorize a text.
:param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector).
:return: Index 0 in the tuple contains the list of instruments that now each contain the best instrument matches from the catalog. Index 1 in the tuple contains a list of closest instrument matches from the catalog for all the instruments.
:return: Index 0 in the tuple contains the list of instruments that now each contain the best instrument matches from the catalog.
Index 1 in the tuple contains a list of closest instrument matches from the catalog for all the instruments.
"""

# Gather all questions
Expand Down Expand Up @@ -206,8 +207,8 @@ def match_questions_with_catalogue_instruments(
) -> List[CatalogueInstrument]:
"""
Match questions with catalogue instruments.
Each question will receive a list of closest instrument matches, and at the end one closest instrument match for
all questions is returned.
Each question from the list will receive the closest instrument match for it.
The closest instrument match for all questions is returned as a result of this function.
:param questions: The questions.
:param catalogue_data: The catalogue data.
Expand Down Expand Up @@ -438,6 +439,96 @@ def match_questions_with_catalogue_instruments(
return top_instruments


def match_query_with_catalogue_instruments(
query: str,
catalogue_data: dict,
vectorisation_function: Callable,
texts_cached_vectors: dict[str, List[float]],
max_results: int = 100,
) -> dict[str, list | dict]:
"""
Match query with catalogue instruments.
:param query: The query.
:param catalogue_data: The catalogue data.
:param vectorisation_function: A function to vectorize a text.
:param texts_cached_vectors: A dictionary of already cached text vectors (text to vector).
:param max_results: The max amount of instruments to return.
:return: A dict containing the list of instruments (up to 100) and the new text vectors.
E.g. {"instruments": [...], "new_text_vectors": {...}}.
"""

response = {"instruments": [], "new_text_vectors": {}}

# Catalogue data
catalogue_instrument_idx_to_catalogue_questions_idx: List[List[int]] = (
catalogue_data["instrument_idx_to_question_idx"]
)
all_catalogue_questions_embeddings_concatenated: np.ndarray = catalogue_data[
"all_embeddings_concatenated"
]
all_catalogue_instruments: List[dict] = catalogue_data["all_instruments"]

# No embeddings = nothing to find
if len(all_catalogue_questions_embeddings_concatenated) == 0:
return response

# Text vectors
text_vectors, new_text_vectors = create_full_text_vectors(
all_questions=[],
query=query,
vectorisation_function=vectorisation_function,
texts_cached_vectors=texts_cached_vectors,
)

# Get an array of dimensions
vectors = np.array([text_vectors[0].vector])

# Get a 2D array of 1 x (number of questions in catalogue)
catalogue_similarities = cosine_similarity(
vectors, all_catalogue_questions_embeddings_concatenated
)

# Get the catalogue questions similarities for the query
catalogue_questions_similarities_for_query = catalogue_similarities[0].tolist()

# Get indexes of top matching questions in the catalogue
# The first index contains the best match
top_catalogue_questions_matches_idxs = [
catalogue_questions_similarities_for_query.index(i)
for i in heapq.nlargest(max_results, catalogue_questions_similarities_for_query)
]

# A dict of matching instruments
# The key is the name of the instrument and the value is the instrument
instrument_matches: OrderedDict[str, Instrument] = OrderedDict()

# Find the matching instruments by looking for the instrument of the top catalogue questions matches indexes
# Loop through indexes of top matched catalogue question
for top_catalogue_question_match_idx in top_catalogue_questions_matches_idxs:
# Loop through instrument index with its question indexes
for catalogue_instrument_idx, catalogue_instrument_questions_idxs in enumerate(
catalogue_instrument_idx_to_catalogue_questions_idx
):
# Check if the index of the top matched catalogue question is in the catalogue instrument's question indexes
if top_catalogue_question_match_idx in catalogue_instrument_questions_idxs:
catalogue_instrument = all_catalogue_instruments[
catalogue_instrument_idx
]

# Add the instrument to the dict if it wasn't already added
instrument_name = catalogue_instrument["instrument_name"]
if instrument_name not in instrument_matches:
instrument_matches[instrument_name] = Instrument.model_validate(
catalogue_instrument
)

response["instruments"] = [x for x in instrument_matches.values()]
response["new_text_vectors"] = new_text_vectors

return response


#
def match_instruments_with_function(
instruments: List[Instrument],
Expand Down
16 changes: 15 additions & 1 deletion src/harmony/schemas/requests/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class Question(BaseModel):
closest_catalogue_question_match: Optional[CatalogueQuestion] = Field(
None, description="The closest question match in the catalogue for the question"
)
seen_in_catalogue_instruments: list[CatalogueInstrument] = Field(
default=None, description="The instruments from the catalogue were the question was seen in"
)
model_config = ConfigDict(
json_schema_extra={
"example": {
Expand Down Expand Up @@ -97,7 +100,7 @@ class Instrument(BaseModel):
description="The ISO 639-2 (alpha-2) encoding of the instrument language")
questions: List[Question] = Field(description="The items inside the instrument")
closest_catalogue_instrument_matches: List[CatalogueInstrument] = Field(
[],
None,
description="The closest instrument matches in the catalogue for the instrument, the first index "
"contains the best match etc"
)
Expand Down Expand Up @@ -203,3 +206,14 @@ class MatchBody(BaseModel):
"model": DEFAULT_MODEL}
}
})


class SearchInstrumentsBody(BaseModel):
parameters: MatchParameters = Field(DEFAULT_MATCH_PARAMETERS, description="Parameters on how to search")
model_config = ConfigDict(
json_schema_extra={
"example": {
"parameters": {"framework": DEFAULT_FRAMEWORK,
"model": DEFAULT_MODEL}
}
})
4 changes: 4 additions & 0 deletions src/harmony/schemas/responses/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class MatchResponse(BaseModel):
)


class SearchInstrumentsResponse(BaseModel):
instruments: List[Instrument] = Field(description="A list of instruments")


class InstrumentList(RootModel):
root: List[Instrument]

Expand Down

0 comments on commit bf7ebbc

Please sign in to comment.