Skip to content

Commit

Permalink
perf: Fix OpenSearch memory problems (#2019)
Browse files Browse the repository at this point in the history
- Using Lucene engine implementation for vector search because of
reasons (opensearch-project/OpenSearch#3545).

  => **This option is only available for OpenSearch >= 2.2**

- Remove word cloud computation and indexing with a multilingual
analyzer (reduce the indexing resources). The better way to compute this
word/term cloud will be using a list of provided tokens or tokenizer
instead of our custom one.

Co-authored-by: Francisco Aranda <francisco@recogn.ai>
  • Loading branch information
frascuchon and Francisco Aranda authored Dec 15, 2022
1 parent 6913df0 commit 5cd2fee
Show file tree
Hide file tree
Showing 18 changed files with 45 additions and 108 deletions.
7 changes: 5 additions & 2 deletions opensearch.compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ services:
soft: -1
hard: -1
ports:
- 9200:9200
- "9200:9200"
networks:
- argilla
volumes:
- opensearch-data:/usr/share/opensearch/data

networks:
argilla:
driver: bridge
volumes:
opensearchdata:
opensearch-data:
6 changes: 2 additions & 4 deletions src/argilla/server/daos/backend/client_adapters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,12 @@ def get(
ca_path: str,
retry_on_timeout: bool = True,
max_retries: int = 5,
opensearch_enable_knn: bool = False,
) -> IClientAdapter:

(
client_class,
support_vector_search,
) = cls._resolve_client_class_with_vector_support(hosts, opensearch_enable_knn)
) = cls._resolve_client_class_with_vector_support(hosts)

return client_class(
index_shards=index_shards,
Expand All @@ -58,15 +57,14 @@ def get(
def _resolve_client_class_with_vector_support(
cls,
hosts: str,
enable_for_opensearch: bool = False,
) -> Tuple[Type, bool]:
version, distribution = cls._fetch_cluster_version_info(hosts)

support_vector_search = True

if distribution == "elasticsearch" and parse("8.5") <= parse(version):
client_class = ElasticsearchClient
elif distribution == "opensearch" and enable_for_opensearch:
elif distribution == "opensearch" and parse("2.2") <= parse(version):
client_class = OpenSearchClient
else:
client_class = OpenSearchClient
Expand Down
14 changes: 3 additions & 11 deletions src/argilla/server/daos/backend/client_adapters/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,7 @@ def configure_index_vectors(

self.set_index_settings(
index=index,
settings={
"index": {
"knn": True,
"knn.algo_param.ef_search": 128, # Ignored when engine=lucene
}
},
settings={"index.knn": False},
)
vector_mappings = {}
for vector_name, vector_dimension in vectors.items():
Expand All @@ -81,12 +76,9 @@ def configure_index_vectors(
"dimension": vector_dimension,
"method": {
"name": "hnsw",
"engine": "lucene",
"space_type": "l2",
"engine": "nmslib",
"parameters": {
"m": 16,
"ef_construction": 128,
},
"parameters": {"m": 2, "ef_construction": 4},
},
}
vector_field = self.query_builder.get_vector_field_name(vector_name)
Expand Down
1 change: 0 additions & 1 deletion src/argilla/server/daos/backend/generic_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def get_instance(cls) -> "GenericElasticEngineBackend":
index_shards=settings.es_records_index_shards,
ssl_verify=settings.elasticsearch_ssl_verify,
ca_path=settings.elasticsearch_ca_path,
opensearch_enable_knn=settings.opensearch_enable_knn,
),
metrics={**ALL_METRICS},
mappings={
Expand Down
75 changes: 10 additions & 65 deletions src/argilla/server/daos/backend/mappings/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List
from typing import List

from argilla.server.settings import settings

EXTENDED_ANALYZER_REF = "extended_analyzer"

MULTILINGUAL_STOP_ANALYZER_REF = "multilingual_stop_analyzer"

DEFAULT_SUPPORTED_LANGUAGES = ["es", "en", "fr", "de"] # TODO: env var configuration


class mappings:
@staticmethod
Expand Down Expand Up @@ -52,31 +46,6 @@ def path_match_keyword_template(
),
}

@staticmethod
def words_text_field():
"""Mappings config for old `word` field. Deprecated"""

default_analyzer = settings.default_es_search_analyzer
exact_analyzer = settings.exact_es_search_analyzer

if default_analyzer == "standard":
default_analyzer = MULTILINGUAL_STOP_ANALYZER_REF

if exact_analyzer == "whitespace":
exact_analyzer = EXTENDED_ANALYZER_REF

return {
"type": "text",
"fielddata": True,
"analyzer": default_analyzer,
"fields": {
"extended": {
"type": "text",
"analyzer": exact_analyzer,
}
},
}

@staticmethod
def text_field():
"""Mappings config for textual field"""
Expand All @@ -87,14 +56,16 @@ def text_field():
"type": "text",
"analyzer": default_analyzer,
"fields": {
"exact": {"type": "text", "analyzer": exact_analyzer},
"wordcloud": {
"exact": {
"type": "text",
"analyzer": MULTILINGUAL_STOP_ANALYZER_REF,
"fielddata": True,
"analyzer": exact_analyzer,
},
# "wordcloud": {
# "type": "text",
# "analyzer": MULTILINGUAL_STOP_ANALYZER_REF,
# "fielddata": True,
# },
},
# TODO(@frascuchon): verify min es version that support meta fields
# "meta": {"experimental": "true"},
}

Expand Down Expand Up @@ -122,44 +93,19 @@ def dynamic_field(cls):
return {"dynamic": True, "type": "object"}


def multilingual_stop_analyzer(supported_langs: List[str] = None) -> Dict[str, Any]:
"""Multilingual stop analyzer"""
from stopwordsiso import stopwords

supported_langs = supported_langs or DEFAULT_SUPPORTED_LANGUAGES
return {
"type": "stop",
"stopwords": [w for w in stopwords(supported_langs)],
}


def extended_analyzer():
"""Extended analyzer (used only in `word` field). Deprecated"""
return {
"type": "custom",
"tokenizer": "whitespace",
"filter": ["lowercase", "asciifolding"],
}


def tasks_common_settings():
"""Common index settings"""
return {
"number_of_shards": settings.es_records_index_shards,
"number_of_replicas": settings.es_records_index_replicas,
"analysis": {
"analyzer": {
MULTILINGUAL_STOP_ANALYZER_REF: multilingual_stop_analyzer(),
EXTENDED_ANALYZER_REF: extended_analyzer(),
}
},
}


def dynamic_metrics_text():
return {
"metrics.*": mappings.path_match_keyword_template(
path="metrics.*", enable_text_search_in_keywords=False
path="metrics.*",
enable_text_search_in_keywords=False,
)
}

Expand Down Expand Up @@ -190,7 +136,6 @@ def tasks_common_mappings():
"dynamic": "strict",
"properties": {
"id": mappings.keyword_field(),
"words": mappings.words_text_field(),
"text": mappings.text_field(),
# TODO(@frascuchon): Enable prediction and annotation
# so we can build extra metrics based on these fields
Expand Down
3 changes: 1 addition & 2 deletions src/argilla/server/daos/backend/mappings/text2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def text2text_mappings():
return {
"_source": mappings.source(
excludes=[
# "words", # Cannot be excluded since comment text_length metric is computed using this source fields
"words", # Cannot be excluded since comment text_length metric is computed using this source fields
"predicted",
"predicted_as",
"predicted_by",
Expand All @@ -32,7 +32,6 @@ def text2text_mappings():
"properties": {
"annotated_as": mappings.keyword_field(),
"predicted_as": mappings.keyword_field(),
"text_predicted": mappings.words_text_field(),
"score": {"type": "float"},
},
}
10 changes: 5 additions & 5 deletions src/argilla/server/services/metrics/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,11 @@ def record_metrics(cls, record: ServiceRecord) -> Dict[str, Any]:
name="Record status distribution",
description="The dataset record status distribution",
),
ServiceBaseMetric(
id="words_cloud",
name="Inputs words cloud",
description="The words cloud for dataset inputs",
),
# ServiceBaseMetric(
# id="words_cloud",
# name="Inputs words cloud",
# description="The words cloud for dataset inputs",
# ),
ServiceBaseMetric(id="metadata", name="Metadata fields stats"),
ServiceBaseMetric(
id="predicted_by",
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/services/tasks/text2text/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def extended_fields(self) -> Dict[str, Any]:
"annotated_by": self.annotated_by,
"predicted_by": self.predicted_by,
"score": self.scores,
"words": self.all_text(),
# "words": self.all_text(),
}


Expand Down
4 changes: 2 additions & 2 deletions src/argilla/server/services/tasks/text2text/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def search(
metrics = TasksFactory.find_task_metrics(
dataset.task,
metric_ids={
"words_cloud",
# "words_cloud",
"predicted_by",
"annotated_by",
"status_distribution",
Expand All @@ -108,7 +108,7 @@ def search(
)

if results.metrics:
results.metrics["words"] = results.metrics["words_cloud"]
results.metrics["words"] = results.metrics.get("words_cloud", {})
results.metrics["status"] = results.metrics["status_distribution"]

return results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def extended_fields(self) -> Dict[str, Any]:
words = self.all_text()
return {
**super().extended_fields(),
"words": words,
# "words": words,
"text": words,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def search(
metrics = TasksFactory.find_task_metrics(
dataset.task,
metric_ids={
"words_cloud",
# "words_cloud",
"predicted_by",
"predicted_as",
"annotated_by",
Expand All @@ -145,7 +145,7 @@ def search(
)

if results.metrics:
results.metrics["words"] = results.metrics["words_cloud"]
results.metrics["words"] = results.metrics.get("words_cloud", {})
results.metrics["status"] = results.metrics["status_distribution"]
results.metrics["predicted"] = results.metrics["error_distribution"]
results.metrics["predicted"].pop("unknown", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def extended_fields(self) -> Dict[str, Any]:
{"mention": mention, "entity": entity.label}
for mention, entity in self.annotated_mentions()
],
"words": self.all_text(),
# "words": self.all_text(),
}

def __init__(self, **data):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def search(
metrics = TasksFactory.find_task_metrics(
dataset.task,
metric_ids={
"words_cloud",
# "words_cloud",
"predicted_by",
"predicted_as",
"annotated_by",
Expand All @@ -128,7 +128,7 @@ def search(
)

if results.metrics:
results.metrics["words"] = results.metrics["words_cloud"]
results.metrics["words"] = results.metrics.get("words_cloud", {})
results.metrics["status"] = results.metrics["status_distribution"]
results.metrics["predicted"] = results.metrics["error_distribution"]
results.metrics["predicted"].pop("unknown", None)
Expand Down
3 changes: 0 additions & 3 deletions src/argilla/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ class ApiSettings(BaseSettings):
elasticsearch_ca_path: Optional[str] = None
cors_origins: List[str] = ["*"]

# TODO: Document this variable
opensearch_enable_knn: bool = False

docs_enabled: bool = True

namespace: str = Field(default=None, regex=r"^[a-z]+$")
Expand Down
2 changes: 1 addition & 1 deletion tests/client/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_load_limits(mocked_client, supported_vector_search):
create_some_data_for_text_classification(
mocked_client,
dataset,
50,
n=50,
with_vectors=supported_vector_search,
)

Expand Down
5 changes: 4 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def telemetry_track_data(mocker):


@pytest.fixture
def mocked_client(monkeypatch, telemetry_track_data) -> SecuredClient:
def mocked_client(
monkeypatch,
telemetry_track_data,
) -> SecuredClient:

with TestClient(app, raise_server_exceptions=False) as _client:
client_ = SecuredClient(_client)
Expand Down
2 changes: 1 addition & 1 deletion tests/server/text2text/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_search_records(mocked_client):
"predicted_by": {"test": 1},
"predicted_text": {},
"status": {"Default": 2},
"words": {"data": 2, "ånother": 1},
"words": {},
}


Expand Down
7 changes: 4 additions & 3 deletions tests/server/text_classification/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def test_create_records_for_text_classification(mocked_client, telemetry_track_d
assert created_dataset.metadata == metadata

response = mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:search", json={}
f"/api/datasets/{dataset}/TextClassification:search",
json={},
)

assert response.status_code == 200
Expand All @@ -178,7 +179,7 @@ def test_create_records_for_text_classification(mocked_client, telemetry_track_d
"predicted_as": {"Mocking": 1},
"predicted_by": {"test": 1},
"status": {"Default": 1},
"words": {"data": 1},
"words": {},
}

telemetry_track_data.assert_called_once()
Expand Down Expand Up @@ -273,7 +274,7 @@ def test_create_records_for_text_classification_vector_search(
"predicted_as": {"Mocking": 3},
"predicted_by": {"test": 3},
"status": {"Default": 3},
"words": {"data": 3},
"words": {},
}

response = mocked_client.post(
Expand Down

0 comments on commit 5cd2fee

Please sign in to comment.