Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Azure OpenAI Embeddings Provider for the Classification feature. #764

Merged
merged 7 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions includes/Classifai/Command/ClassifaiCommand.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
use Classifai\Providers\Watson\APIRequest;
use Classifai\Providers\Watson\Classifier;
use Classifai\Normalizer;
use Classifai\Providers\Azure\Embeddings as AzureEmbeddings;
use Classifai\Providers\OpenAI\Embeddings;
use Classifai\Providers\Watson\NLU;

use function Classifai\Providers\Watson\get_username;
use function Classifai\Providers\Watson\get_password;
Expand Down Expand Up @@ -63,7 +65,7 @@ public function post( $args = [], $opts = [] ) {
$feature = new Classification();
$provider = $feature->get_feature_provider_instance();

if ( Embeddings::ID !== $provider::ID ) {
if ( NLU::ID !== $provider::ID ) {
\WP_CLI::error( 'This command is only available for the IBM Watson Provider' );
}

Expand Down Expand Up @@ -964,8 +966,8 @@ public function embeddings( $args = [], $opts = [] ) {
$feature = new Classification();
$provider = $feature->get_feature_provider_instance();

if ( Embeddings::ID !== $provider::ID ) {
\WP_CLI::error( 'This command is only available for the OpenAI Embeddings feature' );
if ( Embeddings::ID !== $provider::ID && AzureEmbeddings::ID !== $provider::ID ) {
\WP_CLI::error( 'This command is only available for the OpenAI Embeddings and Azure OpenAI Embeddings providers.' );
}

$embeddings = new Embeddings( false );
Expand Down
15 changes: 9 additions & 6 deletions includes/Classifai/Features/Classification.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

use Classifai\Services\LanguageProcessing;
use Classifai\Providers\Watson\NLU;
use Classifai\Providers\OpenAI\Embeddings;
use Classifai\Providers\OpenAI\Embeddings as OpenAIEmbeddings;
use Classifai\Providers\Azure\Embeddings as AzureEmbeddings;
use WP_REST_Server;
use WP_REST_Request;
use WP_Error;
Expand Down Expand Up @@ -39,8 +40,9 @@ public function __construct() {

// Contains just the providers this feature supports.
$this->supported_providers = [
NLU::ID => __( 'IBM Watson NLU', 'classifai' ),
Embeddings::ID => __( 'OpenAI Embeddings', 'classifai' ),
NLU::ID => __( 'IBM Watson NLU', 'classifai' ),
OpenAIEmbeddings::ID => __( 'OpenAI Embeddings', 'classifai' ),
AzureEmbeddings::ID => __( 'Azure OpenAI Embeddings', 'classifai' ),
];
}

Expand Down Expand Up @@ -246,7 +248,8 @@ public function save( int $post_id, array $results, bool $link = true ) {
case NLU::ID:
$results = $provider_instance->link( $post_id, $results, $link );
break;
case Embeddings::ID:
case AzureEmbeddings::ID:
case OpenAIEmbeddings::ID:
$results = $provider_instance->set_terms( $post_id, $results, $link );
break;
}
Expand Down Expand Up @@ -779,7 +782,7 @@ public function add_custom_settings_fields() {
);

// Embeddings only supports existing terms.
if ( isset( $settings['provider'] ) && Embeddings::ID === $settings['provider'] ) {
if ( isset( $settings['provider'] ) && ( OpenAIEmbeddings::ID === $settings['provider'] || AzureEmbeddings::ID === $settings['provider'] ) ) {
unset( $method_options['recommended_terms'] );
$settings['classification_method'] = 'existing_terms';
}
Expand Down Expand Up @@ -876,7 +879,7 @@ public function sanitize_default_feature_settings( array $new_settings ): array
$new_settings['classification_method'] = sanitize_text_field( $new_settings['classification_method'] ?? $settings['classification_method'] );

// Embeddings only supports existing terms.
if ( isset( $new_settings['provider'] ) && Embeddings::ID === $new_settings['provider'] ) {
if ( isset( $new_settings['provider'] ) && ( OpenAIEmbeddings::ID === $new_settings['provider'] || AzureEmbeddings::ID === $new_settings['provider'] ) ) {
$new_settings['classification_method'] = 'existing_terms';
}

Expand Down
4 changes: 2 additions & 2 deletions includes/Classifai/Helpers.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
use Classifai\Features\Classification;
use Classifai\Providers\Provider;
use Classifai\Admin\UserProfile;
use Classifai\Providers\OpenAI\Embeddings;
use Classifai\Providers\Watson\NLU;
use Classifai\Services\Service;
use Classifai\Services\ServicesManager;
use WP_Error;
Expand Down Expand Up @@ -609,7 +609,7 @@ function get_classification_feature_taxonomy( string $classify_by = '' ): string
$taxonomy = $settings[ $classify_by . '_taxonomy' ];
}

if ( Embeddings::ID === $settings['provider'] ) {
if ( NLU::ID !== $settings['provider'] ) {
$taxonomy = $classify_by;
}

Expand Down
Loading
Loading