Skip to content

Commit

Permalink
Patch pprint to make pytest diffs nicer for big objects (#92)
Browse files Browse the repository at this point in the history
Replaces alexmojaki#1

Closes #73
  • Loading branch information
alexmojaki authored Aug 12, 2024
1 parent b1e6384 commit 2ce5fab
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 8 deletions.
22 changes: 22 additions & 0 deletions dirty_equals/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import io
from abc import ABCMeta
from pprint import PrettyPrinter
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Optional, Protocol, Tuple, TypeVar

from ._utils import Omit
Expand Down Expand Up @@ -131,6 +133,26 @@ def __repr__(self) -> str:
# else return something which explains what's going on.
return self._repr_ne()

def _pprint_format(self, pprinter: PrettyPrinter, stream: io.StringIO, *args: Any, **kwargs: Any) -> None:
# pytest diffs use pprint to format objects, so we patch pprint to call this method
# for DirtyEquals objects. So this method needs to follow the same pattern as __repr__.
# We check that the protected _format method actually exists
# to be safe and to make linters happy.
if self._was_equal and hasattr(pprinter, '_format'):
pprinter._format(self._other, stream, *args, **kwargs)
else:
stream.write(repr(self)) # i.e. self._repr_ne() (for now)


# Patch pprint to call _pprint_format for DirtyEquals objects
# Check that the protected attribute _dispatch exists to be safe and to make linters happy.
# The reason we modify _dispatch rather than _format
# is that pytest sometimes uses a subclass of PrettyPrinter which overrides _format.
if hasattr(PrettyPrinter, '_dispatch'): # pragma: no branch
PrettyPrinter._dispatch[DirtyEquals.__repr__] = lambda pprinter, obj, *args, **kwargs: obj._pprint_format(
pprinter, *args, **kwargs
)


InstanceOrType: 'TypeAlias' = 'Union[DirtyEquals[Any], DirtyEqualsMeta]'

Expand Down
66 changes: 58 additions & 8 deletions tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import platform
import pprint

import packaging.version
import pytest

from dirty_equals import Contains, IsApprox, IsInt, IsNegative, IsOneOf, IsPositive, IsStr
from dirty_equals import Contains, IsApprox, IsInt, IsList, IsNegative, IsOneOf, IsPositive, IsStr
from dirty_equals.version import VERSION


Expand Down Expand Up @@ -39,8 +40,7 @@ def test_value_eq():
v.value

assert 'foo' == v
assert str(v) == "'foo'"
assert repr(v) == "'foo'"
assert repr(v) == str(v) == "'foo'" == pprint.pformat(v)
assert v.value == 'foo'


Expand All @@ -50,8 +50,7 @@ def test_value_ne():
with pytest.raises(AssertionError):
assert 1 == v

assert str(v) == 'IsStr()'
assert repr(v) == 'IsStr()'
assert repr(v) == str(v) == 'IsStr()' == pprint.pformat(v)
with pytest.raises(AttributeError, match='value is not available until __eq__ has been called'):
v.value

Expand Down Expand Up @@ -110,7 +109,7 @@ def test_repr():
],
)
def test_repr_class(v, v_repr):
assert repr(v) == v_repr
assert repr(v) == str(v) == v_repr == pprint.pformat(v)


def test_is_approx_without_init():
Expand All @@ -119,11 +118,62 @@ def test_is_approx_without_init():

def test_ne_repr():
v = IsInt
assert repr(v) == 'IsInt'
assert repr(v) == str(v) == 'IsInt' == pprint.pformat(v)

assert 'x' != v

assert repr(v) == 'IsInt'
assert repr(v) == str(v) == 'IsInt' == pprint.pformat(v)


def test_pprint():
v = [IsList(length=...), 1, [IsList(length=...), 2], 3, IsInt()]
lorem = ['lorem', 'ipsum', 'dolor', 'sit', 'amet'] * 2
with pytest.raises(AssertionError):
assert [lorem, 1, [lorem, 2], 3, '4'] == v

assert repr(v) == (f'[{lorem}, 1, [{lorem}, 2], 3, IsInt()]')
assert pprint.pformat(v) == (
"[['lorem',\n"
" 'ipsum',\n"
" 'dolor',\n"
" 'sit',\n"
" 'amet',\n"
" 'lorem',\n"
" 'ipsum',\n"
" 'dolor',\n"
" 'sit',\n"
" 'amet'],\n"
' 1,\n'
" [['lorem',\n"
" 'ipsum',\n"
" 'dolor',\n"
" 'sit',\n"
" 'amet',\n"
" 'lorem',\n"
" 'ipsum',\n"
" 'dolor',\n"
" 'sit',\n"
" 'amet'],\n"
' 2],\n'
' 3,\n'
' IsInt()]'
)


def test_pprint_not_equal():
v = IsList(*range(30)) # need a big value to trigger pprint
with pytest.raises(AssertionError):
assert [] == v

assert (
pprint.pformat(v)
== (
'IsList(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, '
'15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29)'
)
== repr(v)
== str(v)
)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 2ce5fab

Please sign in to comment.