Skip to content

Commit

Permalink
Numpy 2+ support with latest deps and refactor dtype handling (#225)
Browse files Browse the repository at this point in the history
Co-authored-by: Altay Sansal <altay.sansal@tgs.com>
  • Loading branch information
tasansal and Altay Sansal authored Nov 20, 2024
1 parent b5885c9 commit fa8291a
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 104 deletions.
141 changes: 75 additions & 66 deletions poetry.lock

Large diffs are not rendered by default.

34 changes: 17 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,27 @@ keywords = ["segy", "seismic", "data", "geophysics"]

[tool.poetry.dependencies]
python = ">=3.9, <3.13"
fsspec = ">=2024.9.0"
numpy = "^1.26.4"
pydantic = "^2.9.0"
pydantic-settings = "^2.4.0"
fsspec = "^2024.10.0"
numpy = "^2.0.0"
pydantic = "^2.9.2"
pydantic-settings = "^2.6.1"
numba = ">=0.59.1, <0.70.0"
pandas = "^2.2.2"
typer = "^0.12.5"
rapidfuzz = "^3.9.7"
gcsfs = { version = ">=2024.9.0.post1", optional = true }
s3fs = { version = ">=2024.9.0", optional = true }
adlfs = { version = ">=2024.7.0", optional = true }
typer = "^0.13.1"
rapidfuzz = "^3.10.1"
gcsfs = {version = "^2024.10.0", optional = true}
s3fs = {version = "^2024.10.0", optional = true}
adlfs = {version = "^2024.7.0", optional = true}
eval-type-backport = { version = "^0.2.0", python = "<3.10" }

[tool.poetry.group.dev.dependencies]
ruff = "^0.6.4"
coverage = { version = "^7.5.3", extras = ["toml"] }
mypy = "^1.11.2"
pytest = "^8.3.2"
pre-commit = "^3.8.0"
pre-commit-hooks = "^4.6.0"
typeguard = "^4.3.0"
ruff = "^0.7.4"
coverage = {version = "^7.6.7", extras = ["toml"]}
mypy = "^1.13.0"
pytest = "^8.3.3"
pre-commit = "^4.0.1"
pre-commit-hooks = "^5.0.0"
typeguard = "^4.4.1"
urllib3 = "^1.26.18" # Workaround for poetry-plugin-export/issues/183
pandas-stubs = "^2.2.2.240807"

Expand Down Expand Up @@ -87,7 +87,7 @@ select = [
"TD", # todos
"PL", # pylint
"FLY", # flynt
"NPY", # numpy
"NPY201", # numpy
]

ignore = [
Expand Down
20 changes: 10 additions & 10 deletions src/segy/schema/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from segy.compat import StrEnum

if TYPE_CHECKING:
from typing import Any


class ScalarType(StrEnum):
"""A class representing scalar data types."""
Expand All @@ -29,18 +34,13 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self._name_}"

@property
def char(self) -> str:
"""Returns the numpy character code for a given data type string."""
# IBM Float
def dtype(self) -> np.dtype[Any]:
"""Return numpy dtype of the format."""
# Special case for IBM 32-bit float
if self.value == "ibm32":
return np.sctype2char("uint32") # noqa: NPY201

# String
if self.name.startswith("STRING"):
return str(self.value)
return np.dtype("uint32")

# Everything Else
return np.sctype2char(str(self.value)) # noqa: NPY201
return np.dtype(self.value)


class TextHeaderEncoding(StrEnum):
Expand Down
4 changes: 2 additions & 2 deletions src/segy/schema/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def offset(self) -> int:

@property
def dtype(self) -> np.dtype[Any]:
"""Converts the byte order and data type of the object into a NumPy dtype."""
return np.dtype(self.format.char)
"""Converts the data type of the object into a NumPy dtype."""
return self.format.dtype


class HeaderSpec(BaseDataType):
Expand Down
4 changes: 2 additions & 2 deletions src/segy/schema/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ class TraceDataSpec(BaseDataType):

@property
def dtype(self) -> np.dtype[Any]:
"""Get numpy dtype."""
dtype = (self.format.char, (self.samples,))
"""Get numpy dtype with # of samples."""
dtype = (self.format.dtype, (self.samples,))
return np.dtype(dtype)


Expand Down
4 changes: 3 additions & 1 deletion src/segy/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def transform(self, data: NDArray[Any]) -> NDArray[Any]:
source_order = get_endianness(data)

if source_order != self.target_order:
data = data.byteswap(inplace=True).newbyteorder(self.target_order.symbol)
data = data.byteswap(inplace=True)
swapped_dtype = data.dtype.newbyteorder(self.target_order.symbol)
data = data.view(swapped_dtype)

return data

Expand Down
3 changes: 1 addition & 2 deletions tests/test_schema_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import numpy as np
import pytest

from segy.schema import Endianness
Expand Down Expand Up @@ -47,7 +46,7 @@ def test_trace_spec(

expected_itemsize = header_spec.dtype.itemsize + data_spec.dtype.itemsize
expected_header_itemsize = header_spec.dtype.itemsize
expected_sample_subtype = (np.dtype(sample_format.char), (samples_per_trace,))
expected_sample_subtype = (sample_format.dtype, (samples_per_trace,))
assert trace_spec.dtype.itemsize == expected_itemsize
assert trace_spec.header.dtype.names == ("h1", "h2")
assert trace_spec.header.dtype.itemsize == expected_header_itemsize
Expand Down
6 changes: 4 additions & 2 deletions tests/test_segy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def test_trace_sample_template(
if mock_segy_factory.sample_format == ScalarType.IBM32:
expected_dtype = np.dtype("float32")
else:
expected_dtype = np.dtype(mock_segy_factory.sample_format.char)
expected_dtype = mock_segy_factory.sample_format.dtype

expected_shape = (num_traces, n_samples)
assert samples.dtype == expected_dtype
Expand Down Expand Up @@ -257,7 +257,9 @@ def test_trace_serialize(
for field_name, values in rand_fields.items():
expected_traces["header"][field_name] = values
if mock_segy_factory.spec.endianness == Endianness.BIG:
expected_traces = expected_traces.byteswap(inplace=True).newbyteorder(">")
expected_traces = expected_traces.byteswap(inplace=True)
expected_dtype = expected_traces.dtype.newbyteorder(">")
expected_traces = expected_traces.view(expected_dtype)

assert trace_bytes == expected_traces.tobytes()

Expand Down
8 changes: 6 additions & 2 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def mock_header_little() -> NDArray[Any]:
@pytest.fixture()
def mock_header_big(mock_header_little: NDArray[Any]) -> NDArray[Any]:
"""Generate a mock structured array to test transforms with big endian."""
return mock_header_little.byteswap().newbyteorder()
data = mock_header_little.byteswap()
swapped_dtype = data.dtype.newbyteorder()
return data.view(swapped_dtype)


@pytest.fixture()
Expand All @@ -62,7 +64,9 @@ def mock_data_little() -> NDArray[Any]:
@pytest.fixture()
def mock_data_big(mock_data_little: NDArray[Any]) -> NDArray[Any]:
"""Generate a mock big endian structured array to test transforms."""
return mock_data_little.byteswap().newbyteorder()
data = mock_data_little.byteswap()
swapped_dtype = data.dtype.newbyteorder()
return data.view(swapped_dtype)


class TestByteSwap:
Expand Down

0 comments on commit fa8291a

Please sign in to comment.