-
Notifications
You must be signed in to change notification settings - Fork 119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for matryoshka embeddings #490
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR Summary
Here's my summary of the key changes in this PR:
Adds support for matryoshka (variable-length) embeddings across the infinity library with the following major changes:
- Added
dimensions
field to OpenAI embedding input model inpymodels.py
to specify desired embedding length - Modified BatchHandler to truncate embeddings to requested dimension after generation in
batch_handler.py
- Added matryoshka_dim parameter to embedding methods in AsyncEmbeddingEngine and AsyncEngineArray
- Added comprehensive test coverage verifying matryoshka functionality:
- Tests with nomic-embed-text-v1.5 and jina-clip-v2 models
- Validates truncated embeddings maintain semantic similarity
- Verifies correct dimensions in API responses
The implementation enables compatibility with models like OpenAI's text-embedding-3 that support variable-length embeddings while maintaining backward compatibility.
Note: PR is marked WIP and still needs:
- Integration into client
- Implementation for dummy model
- Additional test coverage for edge cases
💡 (2/5) Greptile learns from your feedback when you react with 👍/👎!
7 file(s) reviewed, 14 comment(s)
Edit PR Review Bot Settings | Greptile
@@ -171,9 +171,14 @@ def stop(self): | |||
self.async_run(self.async_engine_array.astop).result() | |||
|
|||
@add_start_docstrings(AsyncEngineArray.embed.__doc__) | |||
def embed(self, *, model: str, sentences: list[str]): | |||
def embed(self, *, model: str, sentences: list[str], matryoshka_dim=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: matryoshka_dim parameter lacks type annotation. Should be Optional[int]
@@ -206,14 +211,24 @@ def classify(self, *, model: str, sentences: list[str], raw_scores: bool = False | |||
) | |||
|
|||
@add_start_docstrings(AsyncEngineArray.image_embed.__doc__) | |||
def image_embed(self, *, model: str, images: list[Union[str, bytes]]): | |||
def image_embed(self, *, model: str, images: list[Union[str, bytes]], matryoshka_dim=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: matryoshka_dim parameter lacks type annotation. Should be Optional[int]
|
||
@add_start_docstrings(AsyncEngineArray.audio_embed.__doc__) | ||
def audio_embed(self, *, model: str, audios: list[Union[str, bytes]]): | ||
def audio_embed(self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: matryoshka_dim parameter lacks type annotation. Should be Optional[int]
@@ -378,13 +393,14 @@ async def classify( | |||
return await self[model].classify(sentences=sentences, raw_scores=raw_scores) | |||
|
|||
async def image_embed( | |||
self, *, model: str, images: list[Union[str, "ImageClassType"]] | |||
self, *, model: str, images: list[Union[str, "ImageClassType"]], matryoshka_dim=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: matryoshka_dim parameter is missing type annotation, should be Optional[int]
@@ -416,13 +432,14 @@ def __getitem__(self, index_or_name: Union[str, int]) -> "AsyncEmbeddingEngine": | |||
) | |||
|
|||
async def audio_embed( | |||
self, *, model: str, audios: list[Union[str, bytes]] | |||
self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: matryoshka_dim parameter is missing type annotation, should be Optional[int]
) | ||
assert engine.capabilities == {"embed"} | ||
async with engine: | ||
embeddings, usage = await engine.embed(sentences=sentences, matryoshka_dim=matryoshka_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: matryoshka_dim parameter should be validated against model's supported dimensions
embeddings = np.array(embeddings) | ||
assert usage == sum([len(s) for s in sentences]) | ||
assert embeddings.shape[0] == len(sentences) | ||
assert embeddings.shape[1] >= 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: redundant assertion since line 408 already checks exact dimension
Codecov ReportAttention: Patch coverage is
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## main #490 +/- ##
==========================================
+ Coverage 79.53% 79.63% +0.10%
==========================================
Files 41 41
Lines 3430 3438 +8
==========================================
+ Hits 2728 2738 +10
+ Misses 702 700 -2 ☔ View full report in Codecov by Sentry. |
I did a quick test like this: from openai import OpenAI
client = OpenAI(
base_url="http://0.0.0.0:7997",
api_key="sk",
)
result = client.embeddings.create(
input=["input","input2"],
model="nomic-ai/nomic-embed-text-v1.5",
dimensions=64
)
assert len(result.data[0].embedding) == 64 |
@@ -54,6 +54,7 @@ class _OpenAIEmbeddingInput(BaseModel): | |||
model: str = "default/not-specified" | |||
encoding_format: EmbeddingEncodingFormat = EmbeddingEncodingFormat.float | |||
user: Optional[str] = None | |||
dimensions: Optional[int] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int should be 0 < x < 8193
, using pydantic v2 conint
LGTM, if you change the OpenAPI spec for the validation of input and add an end-to-end test |
@wirthual
|
Sounds good. Is there an exmaple on how to start a fastapi server within a pytest method without using |
Just add one here: |
Like this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, nevermind.. :)
Related Issue
#476
Checklist
Additional Notes
WIP to add matryoshka embeddings.
Is there a CLAP model which supports matryoshka embedding for testing?
Is there a TinyCLIP model which supoprts matryoshka embedding for testing?
Currently missing:
[ ] Integration into client
[ ] Implementation for dummy model