diff --git a/generative_ai/embedding.py b/generative_ai/embedding.py index 4a004d2c7166..133275b24d9c 100644 --- a/generative_ai/embedding.py +++ b/generative_ai/embedding.py @@ -21,13 +21,15 @@ def embed_text( texts: List[str] = ["banana muffins? ", "banana bread? banana muffins?"], task: str = "RETRIEVAL_DOCUMENT", - model_name: str = "textembedding-gecko@003" + model_name: str = "textembedding-gecko@003", ) -> List[List[float]]: """Embeds texts with a pre-trained, foundational model.""" model = TextEmbeddingModel.from_pretrained(model_name) inputs = [TextEmbeddingInput(text, task) for text in texts] embeddings = model.get_embeddings(inputs) return [embedding.values for embedding in embeddings] + + # [END aiplatform_sdk_embedding] diff --git a/generative_ai/embedding_model_tuning.py b/generative_ai/embedding_model_tuning.py index 21f357cffed1..55f58d179437 100644 --- a/generative_ai/embedding_model_tuning.py +++ b/generative_ai/embedding_model_tuning.py @@ -32,7 +32,7 @@ def tune_embedding_model( train_label_path: str = "gs://embedding-customization-pipeline/dataset/train.tsv", test_label_path: str = "gs://embedding-customization-pipeline/dataset/test.tsv", batch_size: int = 50, - iterations: int = 300 + iterations: int = 300, ) -> pipeline_jobs.PipelineJob: match = re.search(r"(.+)(-autopush|-staging)?-aiplatform.+", api_endpoint) location = match.group(1) if match else "us-central1" @@ -50,7 +50,8 @@ def tune_embedding_model( train_label_path=train_label_path, test_label_path=test_label_path, batch_size=batch_size, - iterations=iterations) + iterations=iterations, + ), ) job.submit() return job @@ -58,6 +59,8 @@ def tune_embedding_model( # [END aiplatform_sdk_embedding] if __name__ == "__main__": - tune_embedding_model(aiplatform_init.global_config.api_endpoint, - aiplatform_init.global_config.project, - aiplatform_init.global_config.staging_bucket) + tune_embedding_model( + aiplatform_init.global_config.api_endpoint, + aiplatform_init.global_config.project, + aiplatform_init.global_config.staging_bucket, + ) diff --git a/generative_ai/embedding_model_tuning_test.py b/generative_ai/embedding_model_tuning_test.py index 09eeea756471..7ddb6154fd45 100644 --- a/generative_ai/embedding_model_tuning_test.py +++ b/generative_ai/embedding_model_tuning_test.py @@ -35,16 +35,19 @@ def dispose(job: pipeline_jobs.PipelineJob) -> None: def test_tune_embedding_model() -> None: credentials, _ = google.auth.default( # Set explicit credentials with Oauth scopes. - scopes=["https://www.googleapis.com/auth/cloud-platform"]) + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) aiplatform.init( api_endpoint="us-central1-aiplatform.googleapis.com:443", project=os.getenv("GOOGLE_CLOUD_PROJECT"), staging_bucket="gs://ucaip-samples-us-central1/training_pipeline_output", - credentials=credentials) + credentials=credentials, + ) job = embedding_model_tuning.tune_embedding_model( aiplatform_init.global_config.api_endpoint, aiplatform_init.global_config.project, - aiplatform_init.global_config.staging_bucket) + aiplatform_init.global_config.staging_bucket, + ) try: assert job.state != "PIPELINE_STATE_FAILED" finally: diff --git a/generative_ai/embedding_preview.py b/generative_ai/embedding_preview.py index d66af5001188..5ea285fb7a6b 100644 --- a/generative_ai/embedding_preview.py +++ b/generative_ai/embedding_preview.py @@ -22,7 +22,7 @@ def embed_text( texts: List[str] = ["banana muffins? ", "banana bread? banana muffins?"], task: str = "QUESTION_ANSWERING", model_name: str = "text-embedding-preview-0409", - dimensionality: Optional[int] = 256 + dimensionality: Optional[int] = 256, ) -> List[List[float]]: """Embeds texts with a pre-trained, foundational model.""" model = TextEmbeddingModel.from_pretrained(model_name) @@ -30,6 +30,8 @@ def embed_text( kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {} embeddings = model.get_embeddings(inputs, **kwargs) return [embedding.values for embedding in embeddings] + + # [END generativeaionvertexai_sdk_embedding] diff --git a/generative_ai/requirements.txt b/generative_ai/requirements.txt index b143b8465fb6..29c26ed4fb87 100644 --- a/generative_ai/requirements.txt +++ b/generative_ai/requirements.txt @@ -1,7 +1,7 @@ pandas==1.3.5; python_version == '3.7' pandas==2.0.1; python_version > '3.7' pillow==9.5.0; python_version < '3.8' -pillow==10.0.1; python_version >= '3.8' +pillow==10.3.0; python_version >= '3.8' google-cloud-aiplatform[pipelines]==1.46.0 google-auth==2.17.3 anthropic[vertex]==0.21.3