Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ci: Simplify Python code with ruff rules SIM #5833

Merged
merged 5 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion e2e/pipelines/test_standard_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def test_summarization_pipeline():
output = pipeline.run(query=query, params={"Retriever": {"top_k": 1}})
answers = output["answers"]
assert len(answers) == 1
assert "The Eiffel Tower is one of the world's tallest structures." == answers[0]["answer"].strip()
assert answers[0]["answer"].strip() == "The Eiffel Tower is one of the world's tallest structures."


def test_summarization_pipeline_one_summary():
Expand Down
4 changes: 2 additions & 2 deletions e2e/preview/components/test_gpt35_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_gpt35_generator_run(generator_class, model_name):
assert "Paris" in results["replies"][0]
assert len(results["metadata"]) == 1
assert model_name in results["metadata"][0]["model"]
assert "stop" == results["metadata"][0]["finish_reason"]
assert results["metadata"][0]["finish_reason"] == "stop"


@pytest.mark.skipif(
Expand Down Expand Up @@ -54,6 +54,6 @@ def __call__(self, chunk):

assert len(results["metadata"]) == 1
assert model_name in results["metadata"][0]["model"]
assert "stop" == results["metadata"][0]["finish_reason"]
assert results["metadata"][0]["finish_reason"] == "stop"

assert callback.responses == results["replies"][0]
8 changes: 4 additions & 4 deletions e2e/preview/components/test_whisper_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ def test_whisper_local_transcriber(preview_samples_path):
docs = output["documents"]
assert len(docs) == 3

assert "this is the content of the document." == docs[0].text.strip().lower()
assert docs[0].text.strip().lower() == "this is the content of the document."
assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"]

assert "the context for this answer is here." == docs[1].text.strip().lower()
assert docs[1].text.strip().lower() == "the context for this answer is here."
assert (
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute())
== docs[1].metadata["audio_file"]
)

assert "answer." == docs[2].text.strip().lower()
assert "<<binary stream>>" == docs[2].metadata["audio_file"]
assert docs[2].text.strip().lower() == "answer."
assert docs[2].metadata["audio_file"] == "<<binary stream>>"
8 changes: 4 additions & 4 deletions e2e/preview/components/test_whisper_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def test_whisper_remote_transcriber(preview_samples_path):
docs = output["documents"]
assert len(docs) == 3

assert "this is the content of the document." == docs[0].text.strip().lower()
assert docs[0].text.strip().lower() == "this is the content of the document."
assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"]

assert "the context for this answer is here." == docs[1].text.strip().lower()
assert docs[1].text.strip().lower() == "the context for this answer is here."
assert (
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute())
== docs[1].metadata["audio_file"]
)

assert "answer." == docs[2].text.strip().lower()
assert "<<binary stream>>" == docs[2].metadata["audio_file"]
assert docs[2].text.strip().lower() == "answer."
assert docs[2].metadata["audio_file"] == "<<binary stream>>"
27 changes: 14 additions & 13 deletions haystack-linter/haystack_linter/linting.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,13 @@ def leave_functiondef(self, node: nodes.FunctionDef) -> None:
self._function_stack.pop()

def visit_call(self, node: nodes.Call) -> None:
if isinstance(node.func, nodes.Attribute) and isinstance(node.func.expr, nodes.Name):
if node.func.expr.name == "logging" and node.func.attrname in [
"debug",
"info",
"warning",
"error",
"critical",
"exception",
]:
self.add_message("no-direct-logging", args=node.func.attrname, node=node)
if (
isinstance(node.func, nodes.Attribute)
and isinstance(node.func.expr, nodes.Name)
and node.func.expr.name == "logging"
and node.func.attrname in ["debug", "info", "warning", "error", "critical", "exception"]
):
self.add_message("no-direct-logging", args=node.func.attrname, node=node)


class NoLoggingConfigurationChecker(BaseChecker):
Expand All @@ -71,9 +68,13 @@ def leave_functiondef(self, node: nodes.FunctionDef) -> None:
self._function_stack.pop()

def visit_call(self, node: nodes.Call) -> None:
if isinstance(node.func, nodes.Attribute) and isinstance(node.func.expr, nodes.Name):
if node.func.expr.name == "logging" and node.func.attrname in ["basicConfig"]:
self.add_message("no-logging-basicconfig", node=node)
if (
isinstance(node.func, nodes.Attribute)
and isinstance(node.func.expr, nodes.Name)
and node.func.expr.name == "logging"
and node.func.attrname in ["basicConfig"]
):
self.add_message("no-logging-basicconfig", node=node)


def register(linter: "PyLinter") -> None:
Expand Down
2 changes: 1 addition & 1 deletion haystack/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def run(
You can only pass parameters to tools that are pipelines, but not nodes.
"""
try:
if not self.hash == self.last_hash:
if self.hash != self.last_hash:
self.last_hash = self.hash
send_event(event_name="Agent", event_properties={"llm.agent_hash": self.hash})
except Exception as exc:
Expand Down
7 changes: 4 additions & 3 deletions haystack/document_stores/elasticsearch/es8.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,10 @@ def _init_elastic_client(
return client

def _index_exists(self, index_name: str, headers: Optional[Dict[str, str]] = None) -> bool:
if logger.isEnabledFor(logging.DEBUG):
if self.client.options(headers=headers).indices.exists_alias(name=index_name):
logger.debug("Index name %s is an alias.", index_name)
if logger.isEnabledFor(logging.DEBUG) and self.client.options(headers=headers).indices.exists_alias(
name=index_name
):
logger.debug("Index name %s is an alias.", index_name)

return self.client.options(headers=headers).indices.exists(index=index_name)

Expand Down
5 changes: 2 additions & 3 deletions haystack/document_stores/es_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,8 @@ def elasticsearch_index_to_document_store(
content = record["_source"].pop(original_content_field, "")
if content:
meta = {}
if original_name_field is not None:
if original_name_field in record["_source"]:
meta["name"] = record["_source"].pop(original_name_field)
if original_name_field is not None and original_name_field in record["_source"]:
meta["name"] = record["_source"].pop(original_name_field)
# Only add selected metadata fields
if included_metadata_fields is not None:
for metadata_field in included_metadata_fields:
Expand Down
5 changes: 2 additions & 3 deletions haystack/document_stores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,8 @@ def get_all_documents_generator(
return_embedding = self.return_embedding

for doc in documents:
if return_embedding:
if doc.meta and doc.meta.get("vector_id") is not None:
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
if return_embedding and doc.meta and doc.meta.get("vector_id") is not None:
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
yield doc

def get_documents_by_id(
Expand Down
7 changes: 3 additions & 4 deletions haystack/document_stores/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,9 @@ def write_documents(
self.index_type in ["ivf", "ivf_pq"]
and not index.startswith(".")
and not self._ivf_model_exists(index=index)
):
if self.get_embedding_count(index=index, headers=headers) >= self.ivf_train_size:
train_docs = self.get_all_documents(index=index, return_embedding=True, headers=headers)
self._train_ivf_index(index=index, documents=train_docs, headers=headers)
) and self.get_embedding_count(index=index, headers=headers) >= self.ivf_train_size:
train_docs = self.get_all_documents(index=index, return_embedding=True, headers=headers)
self._train_ivf_index(index=index, documents=train_docs, headers=headers)

def _embed_documents(self, documents: List[Document], retriever: DenseRetriever) -> np.ndarray:
"""
Expand Down
2 changes: 1 addition & 1 deletion haystack/document_stores/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def write_documents(
documents=document_objects, index=index, duplicate_documents=duplicate_documents
)
if document_objects:
add_vectors = False if document_objects[0].embedding is None else True
add_vectors = document_objects[0].embedding is not None
# If these are not labels, we need to find the correct value for `doc_type` metadata field
if not labels:
type_metadata = DOCUMENT_WITH_EMBEDDING if add_vectors else DOCUMENT_WITHOUT_EMBEDDING
Expand Down
5 changes: 2 additions & 3 deletions haystack/document_stores/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,9 +1620,8 @@ def delete_index(self, index: str):
self._index_delete(index)

def _index_exists(self, index_name: str, headers: Optional[Dict[str, str]] = None) -> bool:
if logger.isEnabledFor(logging.DEBUG):
if self.client.indices.exists_alias(name=index_name):
logger.debug("Index name %s is an alias.", index_name)
if logger.isEnabledFor(logging.DEBUG) and self.client.indices.exists_alias(name=index_name):
logger.debug("Index name %s is an alias.", index_name)

return self.client.indices.exists(index=index_name, headers=headers)

Expand Down
35 changes: 16 additions & 19 deletions haystack/document_stores/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ def eval_data_from_json(
logger.warning("No title information found for documents in QA file: %s", filename)

for squad_document in data["data"]:
if max_docs:
if len(docs) > max_docs:
break
if max_docs and len(docs) > max_docs:
break
# Extracting paragraphs and their labels from a SQuAD document dict
cur_docs, cur_labels, cur_problematic_ids = _extract_docs_and_labels_from_dict(
squad_document, preprocessor, open_domain
Expand Down Expand Up @@ -84,9 +83,8 @@ def eval_data_from_jsonl(

with open(filename, "r", encoding="utf-8") as file:
for document in file:
if max_docs:
if len(docs) > max_docs:
break
if max_docs and len(docs) > max_docs:
break
# Extracting paragraphs and their labels from a SQuAD document dict
squad_document = json.loads(document)
cur_docs, cur_labels, cur_problematic_ids = _extract_docs_and_labels_from_dict(
Expand All @@ -96,19 +94,18 @@ def eval_data_from_jsonl(
labels.extend(cur_labels)
problematic_ids.extend(cur_problematic_ids)

if batch_size is not None:
if len(docs) >= batch_size:
if len(problematic_ids) > 0:
logger.warning(
"Could not convert an answer for %s questions.\n"
"There were conversion errors for question ids: %s",
len(problematic_ids),
problematic_ids,
)
yield docs, labels
docs = []
labels = []
problematic_ids = []
if batch_size is not None and len(docs) >= batch_size:
if len(problematic_ids) > 0:
logger.warning(
"Could not convert an answer for %s questions.\n"
"There were conversion errors for question ids: %s",
len(problematic_ids),
problematic_ids,
)
yield docs, labels
docs = []
labels = []
problematic_ids = []

yield docs, labels

Expand Down
38 changes: 21 additions & 17 deletions haystack/document_stores/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,10 +661,9 @@ def write_documents(
if isinstance(v, dict):
json_fields.append(k)
v = json.dumps(v)
elif isinstance(v, list):
if len(v) > 0 and isinstance(v[0], dict):
json_fields.append(k)
v = [json.dumps(item) for item in v]
elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], dict):
json_fields.append(k)
v = [json.dumps(item) for item in v]
_doc[k] = v
_doc.pop("meta")

Expand Down Expand Up @@ -734,9 +733,8 @@ def update_document_meta(
# Weaviate requires dates to be in RFC3339 format
date_fields = self._get_date_properties(index)
for date_field in date_fields:
if date_field in meta:
if isinstance(meta[date_field], str):
meta[date_field] = convert_date_to_rfc3339(str(meta[date_field]))
if date_field in meta and isinstance(meta[date_field], str):
meta[date_field] = convert_date_to_rfc3339(str(meta[date_field]))

self.weaviate_client.data_object.update(meta, class_name=index, uuid=id)

Expand Down Expand Up @@ -771,10 +769,8 @@ def get_document_count(
else:
result = self.weaviate_client.query.aggregate(index).with_meta_count().do()

if "data" in result:
if "Aggregate" in result.get("data"):
if result.get("data").get("Aggregate").get(index):
doc_count = result.get("data").get("Aggregate").get(index)[0]["meta"]["count"]
if "data" in result and "Aggregate" in result.get("data") and result.get("data").get("Aggregate").get(index):
doc_count = result.get("data").get("Aggregate").get(index)[0]["meta"]["count"]

return doc_count

Expand Down Expand Up @@ -1153,9 +1149,13 @@ def query(
query_output = self.weaviate_client.query.raw(gql_query)

results = []
if query_output and "data" in query_output and "Get" in query_output.get("data"):
if query_output.get("data").get("Get").get(index):
results = query_output.get("data").get("Get").get(index)
if (
query_output
and "data" in query_output
and "Get" in query_output.get("data")
and query_output.get("data").get("Get").get(index)
):
results = query_output.get("data").get("Get").get(index)

# We retrieve the JSON properties from the schema and convert them back to the Python dicts
json_properties = self._get_json_properties(index=index)
Expand Down Expand Up @@ -1421,9 +1421,13 @@ def query_by_embedding(
)

results = []
if query_output and "data" in query_output and "Get" in query_output.get("data"):
if query_output.get("data").get("Get").get(index):
results = query_output.get("data").get("Get").get(index)
if (
query_output
and "data" in query_output
and "Get" in query_output.get("data")
and query_output.get("data").get("Get").get(index)
):
results = query_output.get("data").get("Get").get(index)

# We retrieve the JSON properties from the schema and convert them back to the Python dicts
json_properties = self._get_json_properties(index=index)
Expand Down
10 changes: 6 additions & 4 deletions haystack/modeling/data_handler/data_silo.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,12 @@ def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[Lis
if dicts is None:
dicts = list(self.processor.file_to_dicts(filename)) # type: ignore
# shuffle list of dicts here if we later want to have a random dev set split from train set
if str(self.processor.train_filename) in str(filename):
if not self.processor.dev_filename:
if self.processor.dev_split > 0.0:
random.shuffle(dicts)
if (
str(self.processor.train_filename) in str(filename)
and not self.processor.dev_filename
and self.processor.dev_split > 0.0
):
random.shuffle(dicts)

num_dicts = len(dicts)
datasets = []
Expand Down
5 changes: 2 additions & 3 deletions haystack/modeling/data_handler/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,8 @@ def dataset_from_dicts(
dataset, tensor_names, baskets = self._create_dataset(baskets)

# Logging
if indices:
if 0 in indices:
self._log_samples(n_samples=1, baskets=baskets)
if indices and 0 in indices:
self._log_samples(n_samples=1, baskets=baskets)

# During inference we need to keep the information contained in baskets.
if return_baskets:
Expand Down
15 changes: 9 additions & 6 deletions haystack/modeling/evaluation/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,15 @@ def log_results(
logger.info("\n _________ %s _________", head["task_name"])
for metric_name, metric_val in head.items():
# log with experiment tracking framework (e.g. Mlflow)
if logging:
if not metric_name in ["preds", "labels"] and not metric_name.startswith("_"):
if isinstance(metric_val, numbers.Number):
tracker.track_metrics(
metrics={f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val}, step=steps
)
if (
logging
and not metric_name in ["preds", "labels"]
and not metric_name.startswith("_")
and isinstance(metric_val, numbers.Number)
):
tracker.track_metrics(
metrics={f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val}, step=steps
)
# print via standard python logger
if print:
if metric_name == "report":
Expand Down
Loading