Skip to content

Commit

Permalink
fix: __eq__ returns false if two tiers are different
Browse files Browse the repository at this point in the history
  • Loading branch information
timmahrt committed Dec 10, 2023
1 parent 9ebcfad commit 4d579d4
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 8 deletions.
4 changes: 2 additions & 2 deletions praatio/data_classes/textgrid_tier.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def __iter__(self):
yield entry

def __eq__(self, other):
if isinstance(self, type(other)):
if not isinstance(self, type(other)):
return False

isEqual = True
isEqual &= self.name == other.name
isEqual &= math.isclose(self.minTimestamp, other.minTimestamp)
isEqual &= math.isclose(self.maxTimestamp, other.maxTimestamp)
isEqual &= len(self.entries) == len(self.entries)
isEqual &= len(self.entries) == len(other.entries)

# TODO: Intervals and Points now use isclose, so we can simplify this
# logic (selfEntry == otherEntry); however, this will break
Expand Down
45 changes: 43 additions & 2 deletions tests/test_interval_tier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,51 @@
from tests import testing_utils

makeIntervalTier = testing_utils.makeIntervalTier
makePointTier = testing_utils.makePointTier


class TestIntervalTier(PraatioTestCase):
def test_len_returns_the_number_of_intervals_in_the_interval_tier(self):
def test__eq__(self):
sut = makeIntervalTier(name="foo", intervals=[], minT=1.0, maxT=4.0)
intervalTier = makeIntervalTier(name="foo", intervals=[], minT=1.0, maxT=4.0)
pointTier = makePointTier()
interval1 = Interval(1.0, 2.0, "hello")
interval2 = Interval(2.0, 3.0, "world")

# must be the same type
self.assertEqual(sut, intervalTier)
self.assertNotEqual(sut, pointTier)

# must have the same entries
sut.insertEntry(interval1)
self.assertNotEqual(sut, intervalTier)

# just having the same number of entries is not enough
intervalTier.insertEntry(interval2)
self.assertNotEqual(sut, intervalTier)

sut.insertEntry(interval2)
intervalTier.insertEntry(interval1)
self.assertEqual(sut, intervalTier)

# must have the same name
intervalTier.name = "bar"
self.assertNotEqual(sut, intervalTier)
intervalTier.name = "foo"
self.assertEqual(sut, intervalTier)

# must have the same min/max timestamps
intervalTier.minTimestamp = 0.5
self.assertNotEqual(sut, intervalTier)

intervalTier.minTimestamp = 1
intervalTier.maxTimestamp = 5
self.assertNotEqual(sut, intervalTier)

sut.maxTimestamp = 5
self.assertEqual(sut, intervalTier)

def test__len__returns_the_number_of_intervals_in_the_interval_tier(self):
interval1 = Interval(1.0, 2.0, "hello")
interval2 = Interval(2.0, 3.0, "world")

Expand All @@ -32,7 +73,7 @@ def test_len_returns_the_number_of_intervals_in_the_interval_tier(self):
sut.deleteEntry(interval2)
self.assertEqual(len(sut), 0)

def test_iter_iterates_through_points_in_the_point_tier(self):
def test__iter__iterates_through_points_in_the_point_tier(self):
interval1 = Interval(1.0, 2.0, "hello")
interval2 = Interval(2.0, 3.0, "world")

Expand Down
45 changes: 43 additions & 2 deletions tests/test_point_tier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,51 @@
from tests import testing_utils

makePointTier = testing_utils.makePointTier
makeIntervalTier = testing_utils.makeIntervalTier


class TestPointTier(PraatioTestCase):
def test_len_returns_the_number_of_points_in_the_point_tier(self):
def test__eq__(self):
sut = makePointTier(name="foo", points=[], minT=1, maxT=4)
pointTier = makePointTier(name="foo", points=[], minT=1, maxT=4)
intervalTier = makeIntervalTier()
point1 = Point(1.0, "hello")
point2 = Point(3.0, "world")

# must be the same type
self.assertEqual(sut, pointTier)
self.assertNotEqual(sut, intervalTier)

# must have the same entries
sut.insertEntry(point1)
self.assertNotEqual(sut, pointTier)

# just having the same number of entries is not enough
pointTier.insertEntry(point2)
self.assertNotEqual(sut, pointTier)

sut.insertEntry(point2)
pointTier.insertEntry(point1)
self.assertEqual(sut, pointTier)

# must have the same name
pointTier.name = "bar"
self.assertNotEqual(sut, pointTier)
pointTier.name = "foo"
self.assertEqual(sut, pointTier)

# must have the same min/max timestamps
pointTier.minTimestamp = 0.5
self.assertNotEqual(sut, pointTier)

pointTier.minTimestamp = 1
pointTier.maxTimestamp = 5
self.assertNotEqual(sut, pointTier)

sut.maxTimestamp = 5
self.assertEqual(sut, pointTier)

def test__len__returns_the_number_of_points_in_the_point_tier(self):
point1 = Point(1, "hello")
point2 = Point(3.5, "world")

Expand All @@ -32,7 +73,7 @@ def test_len_returns_the_number_of_points_in_the_point_tier(self):
sut.deleteEntry(point2)
self.assertEqual(len(sut), 0)

def test_iter_iterates_through_points_in_the_point_tier(self):
def test__iter__iterates_through_points_in_the_point_tier(self):
point1 = Point(1, "hello")
point2 = Point(3.5, "world")

Expand Down
4 changes: 2 additions & 2 deletions tests/test_textgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def makePointTier(name="pitch_values", points=None, minT=0, maxT=5.0):


class TestTextgrid(PraatioTestCase):
def test_len_returns_the_number_of_tiers_in_the_textgrid(self):
def test__len__returns_the_number_of_tiers_in_the_textgrid(self):
tier1 = makeIntervalTier()
tier2 = makePointTier()

Expand All @@ -41,7 +41,7 @@ def test_len_returns_the_number_of_tiers_in_the_textgrid(self):
sut.removeTier(tier2.name)
self.assertEqual(len(sut), 0)

def test_iter_iterates_through_tiers(self):
def test__iter__iterates_through_tiers(self):
tier1 = makeIntervalTier()
tier2 = makePointTier()

Expand Down

0 comments on commit 4d579d4

Please sign in to comment.