Skip to content

Commit

Permalink
Add generative Anthropic module integration (#1149)
Browse files Browse the repository at this point in the history
* Add anthropic generative config and test

* Fix expected response field names

---------

Co-authored-by: Tommy Smith <tommy@weaviate.io>
  • Loading branch information
cdpierse and tsmith023 authored Jul 2, 2024
1 parent 89af73d commit 92a5f76
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 11 deletions.
20 changes: 20 additions & 0 deletions test/collection/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,26 @@ def test_config_with_vectorizer_and_properties(
}
},
),
(
Configure.Generative.anthropic(
model="model",
max_tokens=100,
stop_sequences=["stop"],
temperature=0.5,
top_k=10,
top_p=0.5,
),
{
"generative-anthropic": {
"model": "model",
"maxTokens": 100,
"stopSequences": ["stop"],
"temperature": 0.5,
"topK": 10,
"topP": 0.5,
}
},
),
]


Expand Down
69 changes: 58 additions & 11 deletions weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,23 @@
_NamedVectors,
_NamedVectorsUpdate,
)
from weaviate.collections.classes.config_vector_index import VectorIndexType as VectorIndexTypeAlias
from weaviate.collections.classes.config_vector_index import (
_QuantizerConfigCreate,
_VectorIndexConfigCreate,
_VectorIndexConfigDynamicCreate,
_VectorIndexConfigDynamicUpdate,
_VectorIndexConfigHNSWCreate,
_VectorIndexConfigFlatCreate,
_VectorIndexConfigHNSWUpdate,
_VectorIndexConfigFlatUpdate,
_VectorIndexConfigDynamicCreate,
_VectorIndexConfigHNSWCreate,
_VectorIndexConfigHNSWUpdate,
_VectorIndexConfigSkipCreate,
_VectorIndexConfigUpdate,
VectorIndexType as VectorIndexTypeAlias,
)
from weaviate.collections.classes.config_vectorizers import (
_Vectorizer,
_VectorizerConfigCreate,
CohereModel,
Vectorizers as VectorizersAlias,
VectorDistances as VectorDistancesAlias,
)
from weaviate.collections.classes.config_vectorizers import CohereModel
from weaviate.collections.classes.config_vectorizers import VectorDistances as VectorDistancesAlias
from weaviate.collections.classes.config_vectorizers import Vectorizers as VectorizersAlias
from weaviate.collections.classes.config_vectorizers import _Vectorizer, _VectorizerConfigCreate
from weaviate.exceptions import WeaviateInvalidInputError
from weaviate.util import _capitalize_first_letter
from weaviate.warnings import _Warnings
Expand Down Expand Up @@ -161,9 +158,12 @@ class GenerativeSearches(str, Enum):
Weaviate module backed by PaLM generative models.
`AWS`
Weaviate module backed by AWS Bedrock generative models.
`ANTHROPIC`
Weaviate module backed by Anthropic generative models.
"""

AWS = "generative-aws"
ANTHROPIC = "generative-anthropic"
ANYSCALE = "generative-anyscale"
COHERE = "generative-cohere"
MISTRAL = "generative-mistral"
Expand Down Expand Up @@ -494,6 +494,18 @@ class _GenerativeAWSConfig(_GenerativeConfigCreate):
endpoint: Optional[str]


class _GenerativeAnthropicConfig(_GenerativeConfigCreate):
generative: GenerativeSearches = Field(
default=GenerativeSearches.ANTHROPIC, frozen=True, exclude=True
)
model: Optional[str]
maxTokens: Optional[int]
stopSequences: Optional[List[str]]
temperature: Optional[float]
topK: Optional[int]
topP: Optional[float]


class _RerankerConfigCreate(_ConfigCreateModel):
reranker: Rerankers

Expand Down Expand Up @@ -767,6 +779,41 @@ def aws(
endpoint=endpoint,
)

@staticmethod
def anthropic(
model: Optional[str] = None,
max_tokens: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> _GenerativeConfigCreate:
"""
Create a `_GenerativeAnthropicConfig` object for use when performing AI generation using the `generative-anthropic` module.
Arguments:
`model`
The model to use. Defaults to `None`, which uses the server-defined default
`max_tokens`
The maximum number of tokens to generate. Defaults to `None`, which uses the server-defined default
`stop_sequences`
The stop sequences to use. Defaults to `None`, which uses the server-defined default
`temperature`
The temperature to use. Defaults to `None`, which uses the server-defined default
`top_k`
The top K to use. Defaults to `None`, which uses the server-defined default
`top_p`
The top P to use. Defaults to `None`, which uses the server-defined default
"""
return _GenerativeAnthropicConfig(
model=model,
maxTokens=max_tokens,
stopSequences=stop_sequences,
temperature=temperature,
topK=top_k,
topP=top_p,
)


class _Reranker:
"""Use this factory class to create the correct object for the `reranker_config` argument in the `collections.create()` method.
Expand Down

0 comments on commit 92a5f76

Please sign in to comment.