Skip to content

Commit

Permalink
feat: add dejitter method
Browse files Browse the repository at this point in the history
  • Loading branch information
timmahrt committed Nov 4, 2023
1 parent 0f0544e commit 89bb335
Show file tree
Hide file tree
Showing 12 changed files with 247 additions and 56 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

*Praatio uses semantic versioning (Major.Minor.Patch)*

Ver 6.1 (Feb 6, 2023)
- Add `TextgridTier.dejitter()` method for improving consistency between tiers

Ver 6.0 (Feb 4, 2023)
- Refactored 'audio.py' for maintainability (see [UPGRADING.md](https://github.com/timmahrt/praatIO/blob/main/UPGRADING.md) for details)
- Added unit tests for 'audio.py'
Expand Down
54 changes: 54 additions & 0 deletions praatio/data_classes/interval_tier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from praatio.utilities import errors
from praatio.utilities import utils
from praatio.utilities import my_math
from praatio.utilities import constants

from praatio.data_classes import textgrid_tier
Expand Down Expand Up @@ -100,6 +101,23 @@ def _validate(self):
f"({nextEntry.start}, {nextEntry.end}, {nextEntry.label})"
)

@property
def timestamps(self) -> List[float]:
"""All unique timestamps used in this tier"""
tmpTimestamps = [
time
for start, stop, _ in self.entries
for time in [
start,
stop,
]
]

uniqueTimestamps = list(set(tmpTimestamps))
uniqueTimestamps.sort()

return uniqueTimestamps

def crop(
self,
cropStart: float,
Expand Down Expand Up @@ -156,6 +174,42 @@ def crop(

return croppedTier

def dejitter(
self,
referenceTier: textgrid_tier.TextgridTier,
maxDifference: float = 0.001,
) -> textgrid_tier.TextgridTier:
"""
Set timestamps in this tier to be the same as values in the reference tier
Timestamps will only be moved if they are less than maxDifference away from the
reference time.
This can be used to correct minor alignment errors between tiers, as made when
annotating files manually, etc.
Args:
referenceTier: the IntervalTier or PointTier to use as a reference
maxDifference: the maximum amount to allow timestamps to be moved by
Returns:
the modified version of the current tier
"""
referenceTimestamps = referenceTier.timestamps

newEntries = []
for start, stop, label in self.entries:
startCompare = min(referenceTimestamps, key=lambda x: abs(x - start))
stopCompare = min(referenceTimestamps, key=lambda x: abs(x - stop))

if my_math.lessThanOrEqual(abs(start - startCompare), maxDifference):
start = startCompare
if my_math.lessThanOrEqual(abs(stop - stopCompare), maxDifference):
stop = stopCompare
newEntries.append((start, stop, label))

return self.new(entries=newEntries)

def deleteEntry(self, entry: Interval) -> None:
"""Removes an entry from the entries"""
self._entries.pop(self._entries.index(entry))
Expand Down
9 changes: 8 additions & 1 deletion praatio/data_classes/klattgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def __init__(
def crop(self):
raise NotImplementedError

def dejitter(self):
raise NotImplementedError

Check warning on line 142 in praatio/data_classes/klattgrid.py

View check run for this annotation

Codecov / codecov/patch

praatio/data_classes/klattgrid.py#L142

Added line #L142 was not covered by tests

def deleteEntry(self, entry):
raise NotImplementedError

Expand All @@ -153,6 +156,10 @@ def insertEntry(self):
def insertSpace(self):
raise NotImplementedError

@property
def timestamps(self):
raise NotImplementedError

Check warning on line 161 in praatio/data_classes/klattgrid.py

View check run for this annotation

Codecov / codecov/patch

praatio/data_classes/klattgrid.py#L161

Added line #L161 was not covered by tests

def validate(self):
raise NotImplementedError

Expand All @@ -161,7 +168,7 @@ def modifyValues(self, modFunc: Callable[[float], bool]) -> None:
(timestamp, modFunc(float(value))) for timestamp, value in self.entries
]

self.entries = newEntries
self._entries = newEntries

Check warning on line 171 in praatio/data_classes/klattgrid.py

View check run for this annotation

Codecov / codecov/patch

praatio/data_classes/klattgrid.py#L171

Added line #L171 was not covered by tests

def getAsText(self) -> str:
outputList = []
Expand Down
42 changes: 42 additions & 0 deletions praatio/data_classes/point_tier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from praatio.utilities import constants
from praatio.utilities import errors
from praatio.utilities import utils
from praatio.utilities import my_math

from praatio.data_classes import textgrid_tier

Expand Down Expand Up @@ -72,6 +73,16 @@ def __init__(

super(PointTier, self).__init__(name, entries, calculatedMinT, calculatedMaxT)

@property
def timestamps(self) -> List[float]:
"""All unique timestamps used in this tier"""
tmpTimestamps = [time for time, _ in self.entries]

uniqueTimestamps = list(set(tmpTimestamps))
uniqueTimestamps.sort()

return uniqueTimestamps

def crop(
self,
cropStart: float,
Expand Down Expand Up @@ -121,6 +132,37 @@ def deleteEntry(self, entry: Point) -> None:
"""Removes an entry from the entries"""
self._entries.pop(self._entries.index(entry))

def dejitter(
self, referenceTier: textgrid_tier.TextgridTier, maxDifference: float = 0.001
) -> "PointTier":
"""
Set timestamps in this tier to be the same as values in the reference tier
Timestamps will only be moved if they are less than maxDifference away from the
reference time.
This can be used to correct minor alignment errors between tiers, as made when
annotating files manually, etc.
Args:
referenceTier: the IntervalTier or PointTier to use as a reference
maxDifference: the maximum amount to allow timestamps to be moved by
Returns:
the modified version of the current tier
"""
referenceTimestamps = referenceTier.timestamps

newEntries = []
for time, label in self.entries:
timeCompare = min(referenceTimestamps, key=lambda x: abs(x - time))

if my_math.lessThanOrEqual(abs(time - timeCompare), maxDifference):
time = timeCompare
newEntries.append((time, label))

return self.new(entries=newEntries)

def editTimestamps(
self,
offset: float,
Expand Down
13 changes: 13 additions & 0 deletions praatio/data_classes/textgrid_tier.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def __eq__(self, other):
def entries(self):
return tuple(self._entries)

@property
@abstractmethod
def timestamps(self) -> List[float]:
pass

Check warning on line 81 in praatio/data_classes/textgrid_tier.py

View check run for this annotation

Codecov / codecov/patch

praatio/data_classes/textgrid_tier.py#L81

Added line #L81 was not covered by tests

def appendTier(self, tier: "TextgridTier") -> "TextgridTier":
"""Append a tier to the end of this one.
Expand Down Expand Up @@ -209,6 +214,14 @@ def insertEntry(
) -> None: # pragma: no cover
pass

@abstractmethod
def dejitter(
self,
referenceTier: "TextgridTier",
maxDifference: float = 0.001,
) -> "TextgridTier": # pragma: no cover
pass

@abstractmethod
def eraseRegion(
self,
Expand Down
46 changes: 4 additions & 42 deletions praatio/praatio_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ def getValue(myBool) -> Literal["strict", "lax", "truncated"]:
return outputFNList


# TODO: Remove this method in the next major version
# Migrate to using the new Textgridtier.dejitter()
def alignBoundariesAcrossTiers(
tg: textgrid.Textgrid, tierName: str, maxDifference: float = 0.005
) -> textgrid.Textgrid:
Expand Down Expand Up @@ -464,7 +466,7 @@ def alignBoundariesAcrossTiers(
In such a case, choose a smaller maxDifference.
"""
referenceTier = tg.getTier(tierName)
times = _getTimestampsFromTier(referenceTier)
times = referenceTier.timestamps

for time, nextTime in zip(times[1::], times[2::]):
if nextTime - time < maxDifference:
Expand All @@ -480,47 +482,7 @@ def alignBoundariesAcrossTiers(
if tier.name == tierName:
continue

newEntries: list = []
if tier.entryType == constants.Interval:
for start, stop, label in tier.entries:
startCompare = min(times, key=lambda x: abs(x - start))
stopCompare = min(times, key=lambda x: abs(x - stop))

if abs(start - startCompare) <= maxDifference:
start = startCompare
if abs(stop - stopCompare) <= maxDifference:
stop = stopCompare
newEntries.append((start, stop, label))
elif tier.entryType == constants.Point:
for time, label in tier.entries:
timeCompare = min(times, key=lambda x: abs(x - time))

if abs(time - timeCompare) <= maxDifference:
time = timeCompare
newEntries.append((time, label))

tier = tier.new(entries=newEntries)
tier = tier.dejitter(referenceTier, maxDifference)
tg.replaceTier(tier.name, tier)

return tg


def _getTimestampsFromTier(tier: textgrid_tier.TextgridTier) -> List[float]:
"""Get all timestamps used in a tier"""
timestamps = []
if tier.entryType == constants.Interval:
timestamps = [
time
for start, stop, _ in tier.entries
for time in [
start,
stop,
]
]
elif tier.entryType == constants.Point:
timestamps = [time for time, _ in tier.entries]

timestamps = list(set(timestamps))
timestamps.sort()

return timestamps
4 changes: 4 additions & 0 deletions praatio/utilities/my_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def isclose(a: float, b: float, rel_tol: float = 1e-14, abs_tol: float = 0.0) ->
return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)


def lessThanOrEqual(a: float, b: float):
return isclose(a, b) or a < b


def filterTimeSeriesData(
filterFunc: Callable[[List[float], int, bool], List[float]],
featureTimeList: List[list],
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name="praatio",
python_requires=">3.6.0",
version="6.0.1",
version="6.1.0",
author="Tim Mahrt",
author_email="timmahrt@gmail.com",
url="https://github.com/timmahrt/praatIO",
Expand Down
57 changes: 51 additions & 6 deletions tests/test_interval_tier.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import unittest

from praatio import textgrid
from praatio.utilities.constants import Interval, INTERVAL_TIER
from praatio.utilities.constants import Interval, INTERVAL_TIER, Point
from praatio.utilities import errors
from praatio.utilities import constants

from tests.praatio_test_case import PraatioTestCase
from tests import testing_utils


def makeIntervalTier(name="words", intervals=None, minT=0, maxT=5.0):
if intervals is None:
intervals = [Interval(1, 2, "hello"), Interval(3.5, 4.0, "world")]
return textgrid.IntervalTier(name, intervals, minT, maxT)
makeIntervalTier = testing_utils.makeIntervalTier


class TestIntervalTier(PraatioTestCase):
Expand Down Expand Up @@ -352,6 +348,55 @@ def test_crop_drops_overlapping_intervals_if_mode_is_strict_and_rebase_false(
)
self.assertEqual(expectedIntervalTier, sut)

def test_dejitter_when_reference_tier_is_interval_tier(self):
sut = makeIntervalTier(
intervals=[
Interval(0, 0.9, "start will be modified"),
Interval(1, 2.1, "stop will be modified"),
Interval(2.2, 2.5, "will not be modified"),
Interval(2.5, 3.56, "will also not be modified"),
]
)
refInterval = makeIntervalTier(
intervals=[Interval(1, 2.0, "foo"), Interval(2.65, 3.45, "bar")]
)
self.assertSequenceEqual(
[
Interval(0, 1, "start will be modified"),
Interval(1, 2.0, "stop will be modified"),
Interval(2.2, 2.5, "will not be modified"),
Interval(2.5, 3.56, "will also not be modified"),
],
sut.dejitter(refInterval, 0.1)._entries,
)

def test_dejitter_when_reference_tier_is_point_tier(self):
sut = makeIntervalTier(
intervals=[
Interval(0, 0.9, "start will be modified"),
Interval(1, 2.1, "stop will be modified"),
Interval(2.2, 2.5, "will not be modified"),
Interval(2.5, 3.56, "will also not be modified"),
]
)
refInterval = testing_utils.makePointTier(
points=[
Point(1, "foo"),
Point(2.0, "bar"),
Point(2.65, "bizz"),
Point(3.45, "whomp"),
]
)
self.assertSequenceEqual(
[
Interval(0, 1, "start will be modified"),
Interval(1, 2.0, "stop will be modified"),
Interval(2.2, 2.5, "will not be modified"),
Interval(2.5, 3.56, "will also not be modified"),
],
sut.dejitter(refInterval, 0.1)._entries,
)

def test_delete_entry(self):
sut = makeIntervalTier(
intervals=[
Expand Down
Loading

0 comments on commit 89bb335

Please sign in to comment.