From 893bcf26e40366a988c5d833462903c4ff7d1cb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Ricks?= Date: Thu, 22 Feb 2024 15:30:25 +0100 Subject: [PATCH] Add: Allow to create a binary file with temp_file context manager Extend temp_file context manager to support also creating binary files. --- pontos/testing/__init__.py | 17 ++++++++++++++--- tests/testing/test_testing.py | 9 +++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/pontos/testing/__init__.py b/pontos/testing/__init__.py index 398de367..96ad0c83 100644 --- a/pontos/testing/__init__.py +++ b/pontos/testing/__init__.py @@ -11,7 +11,15 @@ import tempfile from contextlib import contextmanager from pathlib import Path -from typing import Any, AsyncIterator, Awaitable, Generator, Iterable, Optional +from typing import ( + Any, + AsyncIterator, + Awaitable, + Generator, + Iterable, + Optional, + Union, +) from pontos.git._git import exec_git from pontos.helper import add_sys_path, ensure_unload_module, unload_module @@ -136,7 +144,7 @@ def temp_git_repository( @contextmanager def temp_file( - content: Optional[str] = None, + content: Optional[Union[str, bytes]] = None, *, name: str = "test.toml", change_into: bool = False, @@ -166,7 +174,10 @@ def temp_file( with temp_directory(change_into=change_into) as tmp_dir: test_file = tmp_dir / name if content: - test_file.write_text(content, encoding="utf8") + if isinstance(content, bytes): + test_file.write_bytes(content) + else: + test_file.write_text(content, encoding="utf8") else: test_file.touch() diff --git a/tests/testing/test_testing.py b/tests/testing/test_testing.py index 6f618ab5..3e1b6871 100644 --- a/tests/testing/test_testing.py +++ b/tests/testing/test_testing.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later # +import struct import unittest from pathlib import Path @@ -109,6 +110,14 @@ def test_temp_file(self): self.assertFalse(test_file.exists()) + def test_temp_binary_file(self): + data = struct.pack(">if", 42, 2.71828182846) + with temp_file(data) as test_file: + self.assertTrue(test_file.exists()) + self.assertEqual(data, test_file.read_bytes()) + + self.assertFalse(test_file.exists()) + def test_temp_file_without_content(self): with temp_file(name="foo.bar") as test_file: self.assertTrue(test_file.exists())