Skip to content

Commit

Permalink
feat: add __len__ and __iter__ methods (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
timmahrt committed Dec 10, 2023
1 parent 5ff7c78 commit 9e01fdc
Show file tree
Hide file tree
Showing 23 changed files with 7,960 additions and 7,788 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[flake8]
max-line-length = 100
ignore = E203
ignore = E203, W503
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

*Praatio uses semantic versioning (Major.Minor.Patch)*

Ver 6.2 (Dec 10, 2023)
- Add `__len__` and `__iter__` methods to Textgrid and TextgridTier
- Fix behavior of `__eq__` method in TextgridTier

Ver 6.1 (Nov 4, 2023)
- Add `TextgridTier.dejitter()` method for improving consistency between tiers

Expand Down
683 changes: 341 additions & 342 deletions docs/praatio/audio.html

Large diffs are not rendered by default.

4,132 changes: 2,059 additions & 2,073 deletions docs/praatio/data_classes/interval_tier.html

Large diffs are not rendered by default.

2,291 changes: 1,143 additions & 1,148 deletions docs/praatio/data_classes/point_tier.html

Large diffs are not rendered by default.

2,996 changes: 1,505 additions & 1,491 deletions docs/praatio/data_classes/textgrid.html

Large diffs are not rendered by default.

1,304 changes: 658 additions & 646 deletions docs/praatio/data_classes/textgrid_tier.html

Large diffs are not rendered by default.

628 changes: 310 additions & 318 deletions docs/praatio/klattgrid.html

Large diffs are not rendered by default.

1,992 changes: 994 additions & 998 deletions docs/praatio/pitch_and_intensity.html

Large diffs are not rendered by default.

1,476 changes: 735 additions & 741 deletions docs/praatio/praatio_scripts.html

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions praatio/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def __init__(self, params: List):
if self.nchannels != 1:
raise (
errors.ArgumentError(
"Only audio with a single channel can be loaded. Your file was #{self.nchannels}."
"Only audio with a single channel can be loaded. "
"Your file was #{self.nchannels}."
)
)

Expand Down Expand Up @@ -280,7 +281,6 @@ def getFrames(self, startTime: float = None, endTime: float = None) -> bytes:
return readFramesAtTime(self.audiofile, startTime, endTime)

def getSamples(self, startTime: float, endTime: float) -> Tuple[int, ...]:

frames = self.getFrames(startTime, endTime)
audioFrameList = convertFromBytes(frames, self.sampleWidth)

Expand Down
5 changes: 0 additions & 5 deletions praatio/data_classes/interval_tier.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def _calculateMinAndMaxTime(entries: Sequence[Interval], minT=None, maxT=None):


class IntervalTier(textgrid_tier.TextgridTier):

tierType = INTERVAL_TIER
entryType = Interval

Expand Down Expand Up @@ -260,7 +259,6 @@ def editTimestamps(

newEntryList = []
for interval in self.entries:

newStart = offset + interval.start
newEnd = offset + interval.end

Expand Down Expand Up @@ -338,7 +336,6 @@ def eraseRegion(
# right edges
# if categorical, it doesn't make it into the list at all
if collisionMode == constants.EraseCollision.TRUNCATE:

# Check left edge
if matchList[0].start < start:
newEntry = Interval(matchList[0].start, start, matchList[0].label)
Expand All @@ -350,7 +347,6 @@ def eraseRegion(
newTier.insertEntry(newEntry)

if doShrink is True:

diff = end - start
newEntryList = []
for interval in newTier.entries:
Expand Down Expand Up @@ -714,7 +710,6 @@ def morph(
newEntryList = []
allIntervals = [self.entries, targetTier.entries]
for sourceInterval, targetInterval in utils.safeZip(allIntervals, True):

# sourceInterval.start - lastFromEnd -> was this interval and the
# last one adjacent?
newStart = sourceInterval.start + cumulativeAdjustAmount
Expand Down
2 changes: 0 additions & 2 deletions praatio/data_classes/point_tier.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def _calculateMinAndMaxTime(entries: Sequence[Point], minT=None, maxT=None):


class PointTier(textgrid_tier.TextgridTier):

tierType = POINT_TIER
entryType = Point

Expand Down Expand Up @@ -186,7 +185,6 @@ def editTimestamps(

newEntries: List[Point] = []
for timestamp, label in self.entries:

newTimestamp = timestamp + offset
utils.checkIsUndershoot(newTimestamp, self.minTimestamp, errorReporter)
utils.checkIsOvershoot(newTimestamp, self.maxTimestamp, errorReporter)
Expand Down
7 changes: 7 additions & 0 deletions praatio/data_classes/textgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def __init__(self, minTimestamp: float = None, maxTimestamp: float = None):
self.minTimestamp: float = minTimestamp # type: ignore[assignment]
self.maxTimestamp: float = maxTimestamp # type: ignore[assignment]

def __len__(self):
return len(self._tierDict)

def __iter__(self):
for entry in self.tiers:
yield entry

def __eq__(self, other):
if not isinstance(other, Textgrid):
return False
Expand Down
12 changes: 9 additions & 3 deletions praatio/data_classes/textgrid_tier.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


class TextgridTier(ABC):

tierType: str
entryType: Union[Type[constants.Point], Type[constants.Interval]]

Expand All @@ -48,15 +47,22 @@ def __init__(
self.maxTimestamp = maxT
self.errorReporter = utils.getErrorReporter(errorMode)

def __len__(self):
return len(self._entries)

def __iter__(self):
for entry in self.entries:
yield entry

def __eq__(self, other):
if type(self) != type(other):
if not isinstance(self, type(other)):
return False

isEqual = True
isEqual &= self.name == other.name
isEqual &= math.isclose(self.minTimestamp, other.minTimestamp)
isEqual &= math.isclose(self.maxTimestamp, other.maxTimestamp)
isEqual &= len(self.entries) == len(self.entries)
isEqual &= len(self.entries) == len(other.entries)

# TODO: Intervals and Points now use isclose, so we can simplify this
# logic (selfEntry == otherEntry); however, this will break
Expand Down
6 changes: 0 additions & 6 deletions praatio/klattgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@


def openKlattgrid(fnFullPath: str) -> Klattgrid:

try:
with io.open(fnFullPath, "r", encoding="utf-16") as fd:
data = fd.read()
Expand Down Expand Up @@ -106,7 +105,6 @@ def resynthesize(
doCascade: bool = True,
scriptFN: Optional[str] = None,
) -> None:

if doCascade:
method = "Cascade"
else:
Expand All @@ -120,7 +118,6 @@ def resynthesize(


def _openNormalKlattgrid(data: str) -> Klattgrid:

kg = Klattgrid()

# Toss header
Expand Down Expand Up @@ -149,12 +146,10 @@ def _openNormalKlattgrid(data: str) -> Klattgrid:
"delta_formants",
"frication_formants",
]:

kct = _proccessContainerTierInput(sectionData, name)
kg.addTier(kct)

else:

# Process entries if this tier has any
entries = _buildEntries(sectionTuple)
tier = KlattPointTier(name, entries, minT, maxT)
Expand Down Expand Up @@ -269,7 +264,6 @@ def _buildEntries(sectionTuple):


def _processSectionData(sectionData: str) -> List[Tuple[float, float]]:

sectionData += "\n"

startI = 0
Expand Down
3 changes: 0 additions & 3 deletions praatio/pitch_and_intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def _extractPIPiecewise(

firstTime = not os.path.exists(outputFN)
if firstTime or forceRegenerate is True:

utils.makeDir(tmpOutputPath)
splitAudioList = praatio_scripts.splitAudioOnTier(
inputFN, tgFN, tierName, tmpOutputPath, False
Expand Down Expand Up @@ -129,7 +128,6 @@ def _extractPIFile(

firstTime = not os.path.exists(outputFN)
if firstTime or forceRegenerate is True:

# The praat script uses append mode, so we need to clear any prior
# result
if os.path.exists(outputFN):
Expand Down Expand Up @@ -183,7 +181,6 @@ def extractIntensity(

firstTime = not os.path.exists(outputFN)
if firstTime or forceRegenerate is True:

# The praat script uses append mode, so we need to clear any prior
# result
if os.path.exists(outputFN):
Expand Down
3 changes: 0 additions & 3 deletions praatio/praatio_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def audioSplice(

# Ensure all time points involved in splicing fall on zero crossings
if alignToZeroCrossing is True:

# Cut the splice segment to zero crossings
spliceDuration = spliceSegment.duration
spliceZeroStart = spliceSegment.findNearestZeroCrossing(0)
Expand Down Expand Up @@ -178,7 +177,6 @@ def spellCheckEntries(

mispelledEntries = []
for start, end, label in tier.entries:

# Remove punctuation
for char in punctuationList:
label = label.replace(char, "")
Expand Down Expand Up @@ -260,7 +258,6 @@ def splitTierEntries(

# Or insert new entries into existing target tier
else:

for entry in newEntries:
targetTier.insertEntry(entry, constants.IntervalCollision.ERROR)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name="praatio",
python_requires=">3.6.0",
version="6.1.0",
version="6.2.0",
author="Tim Mahrt",
author_email="timmahrt@gmail.com",
url="https://github.com/timmahrt/praatIO",
Expand Down
1 change: 0 additions & 1 deletion tests/test_data_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def test_point_process_io(self):
)

def test_point_process_io_long_vs_short(self):

shortFN = join(self.dataRoot, "bobby.PointProcess")
longFN = join(self.dataRoot, "bobby_longfile.PointProcess")

Expand Down
73 changes: 73 additions & 0 deletions tests/test_interval_tier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,82 @@
from tests import testing_utils

makeIntervalTier = testing_utils.makeIntervalTier
makePointTier = testing_utils.makePointTier


class TestIntervalTier(PraatioTestCase):
def test__eq__(self):
sut = makeIntervalTier(name="foo", intervals=[], minT=1.0, maxT=4.0)
intervalTier = makeIntervalTier(name="foo", intervals=[], minT=1.0, maxT=4.0)
pointTier = makePointTier()
interval1 = Interval(1.0, 2.0, "hello")
interval2 = Interval(2.0, 3.0, "world")

# must be the same type
self.assertEqual(sut, intervalTier)
self.assertNotEqual(sut, pointTier)

# must have the same entries
sut.insertEntry(interval1)
self.assertNotEqual(sut, intervalTier)

# just having the same number of entries is not enough
intervalTier.insertEntry(interval2)
self.assertNotEqual(sut, intervalTier)

sut.insertEntry(interval2)
intervalTier.insertEntry(interval1)
self.assertEqual(sut, intervalTier)

# must have the same name
intervalTier.name = "bar"
self.assertNotEqual(sut, intervalTier)
intervalTier.name = "foo"
self.assertEqual(sut, intervalTier)

# must have the same min/max timestamps
intervalTier.minTimestamp = 0.5
self.assertNotEqual(sut, intervalTier)

intervalTier.minTimestamp = 1
intervalTier.maxTimestamp = 5
self.assertNotEqual(sut, intervalTier)

sut.maxTimestamp = 5
self.assertEqual(sut, intervalTier)

def test__len__returns_the_number_of_intervals_in_the_interval_tier(self):
interval1 = Interval(1.0, 2.0, "hello")
interval2 = Interval(2.0, 3.0, "world")

sut = makeIntervalTier(intervals=[])

self.assertEqual(len(sut), 0)

sut.insertEntry(interval1)
self.assertEqual(len(sut), 1)

sut.insertEntry(interval2)
self.assertEqual(len(sut), 2)

sut.deleteEntry(interval1)
self.assertEqual(len(sut), 1)

sut.deleteEntry(interval2)
self.assertEqual(len(sut), 0)

def test__iter__iterates_through_intervals_in_the_interval_tier(self):
interval1 = Interval(1.0, 2.0, "hello")
interval2 = Interval(2.0, 3.0, "world")

sut = makeIntervalTier(intervals=[interval1, interval2])

seenIntervals = []
for interval in sut:
seenIntervals.append(interval)

self.assertEqual(seenIntervals, [interval1, interval2])

def test_inequivalence_with_non_interval_tiers(self):
sut = makeIntervalTier()
self.assertNotEqual(sut, 55)
Expand Down
Loading

0 comments on commit 9e01fdc

Please sign in to comment.