Skip to content

Commit

Permalink
feat: add __hash__ and __eq__ to Index (#1809)
Browse files Browse the repository at this point in the history
  • Loading branch information
waketzheng authored Dec 14, 2024
1 parent e5cadb5 commit 313ee76
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ source = ["tortoise"]
[tool.coverage.report]
show_missing = true

[tool.ruff]
line-length = 100
[tool.ruff.lint]
ignore = ["E501"]

Expand Down
34 changes: 34 additions & 0 deletions tests/fields/test_db_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,40 @@
from tortoise import fields
from tortoise.contrib import test
from tortoise.exceptions import ConfigurationError
from tortoise.indexes import Index


class CustomIndex(Index):
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
self._foo = ""


class TestIndexHashEqual(test.TestCase):
def test_index_eq(self):
assert Index(fields=("id",)) == Index(fields=("id",))
assert CustomIndex(fields=("id",)) == CustomIndex(fields=("id",))
assert Index(fields=("id", "name")) == Index(fields=["id", "name"])

assert Index(fields=("id", "name")) != Index(fields=("name", "id"))
assert Index(fields=("id",)) != Index(fields=("name",))
assert CustomIndex(fields=("id",)) != Index(fields=("id",))

def test_index_hash(self):
assert hash(Index(fields=("id",))) == hash(Index(fields=("id",)))
assert hash(Index(fields=("id", "name"))) == hash(Index(fields=["id", "name"]))
assert hash(CustomIndex(fields=("id", "name"))) == hash(CustomIndex(fields=["id", "name"]))

assert hash(Index(fields=("id", "name"))) != hash(Index(fields=["name", "id"]))
assert hash(Index(fields=("id",))) != hash(Index(fields=("name",)))

indexes = {Index(fields=("id",))}
indexes.add(Index(fields=("id",)))
assert len(indexes) == 1
indexes.add(CustomIndex(fields=("id",)))
assert len(indexes) == 2
indexes.add(Index(fields=("name",)))
assert len(indexes) == 3


class TestIndexAlias(test.TestCase):
Expand Down
8 changes: 7 additions & 1 deletion tortoise/indexes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Optional, Tuple, Type

from pypika.terms import Term, ValueWrapper

Expand Down Expand Up @@ -59,6 +59,12 @@ def get_sql(
def index_name(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]") -> str:
return self.name or schema_generator._generate_index_name("idx", model, self.fields)

def __hash__(self) -> int:
return hash((tuple(self.fields), self.name, tuple(self.expressions)))

def __eq__(self, other: Any) -> bool:
return type(self) is type(other) and self.__dict__ == other.__dict__


class PartialIndex(Index):
def __init__(
Expand Down

0 comments on commit 313ee76

Please sign in to comment.