Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give NamedArray Generic dimension type #8276

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

headtr1ck
Copy link
Collaborator

This aims at making the dimenion type a generic parameter.
I thought I will start with NamedArray when testing this out because it is much less interconnected.

@headtr1ck
Copy link
Collaborator Author

headtr1ck commented Oct 5, 2023

Currently this only works partially and I cannot figure it out why I get this result:

from xarray.namedarray.core import NamedArray
from dask.array import Array as DaskArray

a = [1, 2, 3]
reveal_type(NamedArray("x", a))
reveal_type(NamedArray([None], a))

b = np.array([1, 2, 3])
reveal_type(NamedArray("a", b))
reveal_type(NamedArray([None], b))

c = DaskArray([1, 2, 3], "c", {})
reveal_type(NamedArray("a", c))
reveal_type(NamedArray([None], c))

prints

test.py:5: note: Revealed type is "xarray.namedarray.core.NamedArray[builtins.str, numpy.ndarray[Any, numpy.dtype[numpy.generic]]]"
test.py:6: note: Revealed type is "xarray.namedarray.core.NamedArray[None, numpy.ndarray[Any, numpy.dtype[numpy.generic]]]"

test.py:9: note: Revealed type is "xarray.namedarray.core.NamedArray[Any, Any]"
test.py:10: note: Revealed type is "xarray.namedarray.core.NamedArray[Any, Any]"

test.py:13: note: Revealed type is "xarray.namedarray.core.NamedArray[builtins.str, numpy.ndarray[Any, numpy.dtype[numpy.generic]]]"
test.py:14: note: Revealed type is "xarray.namedarray.core.NamedArray[None, numpy.ndarray[Any, numpy.dtype[numpy.generic]]]"

So with a python list as data I get the correct result.
If I supply a np.ndarray I only get Any (both for dims and data?)
If I supply another T_DuckArray (in this case a dask array) I get again a np.ndarray?

Something seems wrong with T_DuckArray but I cannot figure it out. Anyone have an idea why this is happening?
And then all arrays basically have the __array__ method, so they are compatible with np.typing.ArrayLike which produces the wrong np.ndarray types...

@andersy005 andersy005 mentioned this pull request Oct 5, 2023
20 tasks
@headtr1ck
Copy link
Collaborator Author

@Illviljan any idea why the T_DuckArray overloads don't work?

@Illviljan
Copy link
Contributor

Yeah, seems to be a combination of all the casts in as_compatible_data and failing to update the T_DuckArray after as_compatible_data has finished.

I wonder if we make it too difficult to ourselves by having all the normalization done in the init? The typing becomes very clean if all that stuff is done before initializing the class, see example below. Dask has a from_array function so there's some precedent to use that kind of workflow.

from __future__ import annotations

from typing import (
    Protocol,
    TypeVar,
    overload,
    runtime_checkable,
    Generic,
    cast,
)


@runtime_checkable
class StrBase(Protocol):
    def e(self: StrBase) -> StrBase:
        ...


class StrNumpy(str):
    def e(self: StrNumpy) -> StrNumpy:
        raise NotImplementedError

    def f(self: StrNumpy) -> StrNumpy:
        raise NotImplementedError


class StrSparse(str):
    def e(self: StrSparse) -> StrSparse:
        raise NotImplementedError

    def f(self: StrSparse) -> StrSparse:
        raise NotImplementedError


T = TypeVar("T", bound=StrBase)


class NamedArray2(Generic[T]):
    a: T

    def __init__(self, a: T):
        self.a = a


@overload
def from_array(x: T) -> NamedArray2[T]:
    ...


@overload
def from_array(x: str | float) -> NamedArray2[StrNumpy]:
    ...


def from_array(x: T | str | float) -> NamedArray2[T] | NamedArray2[StrNumpy]:
    if isinstance(x, StrBase):
        # TODO: cast used because of mypy, pyright does not need it:
        x_ = cast(T, x)
        return NamedArray2(x_)
    else:
        return NamedArray2(StrNumpy(x))


# Test from_array:
from_array_a = from_array("s")
from_array_b = from_array(1)
from_array_d = from_array(StrNumpy(223))
from_array_e = from_array(StrSparse(45))
reveal_type(from_array_a)
reveal_type(from_array_b)
reveal_type(from_array_d)
reveal_type(from_array_e)
(xarray-tests) C:\Users\J.W>mypy G:\Program\Dropbox\Python\xarray_namedarray_tests.py
G:\Program\Dropbox\Python\xarray_namedarray_tests.py:215: note: Revealed type is "xarray_namedarray_tests.NamedArray2[xarray_namedarray_tests.StrNumpy]"
G:\Program\Dropbox\Python\xarray_namedarray_tests.py:216: note: Revealed type is "xarray_namedarray_tests.NamedArray2[xarray_namedarray_tests.StrNumpy]"
G:\Program\Dropbox\Python\xarray_namedarray_tests.py:217: note: Revealed type is "xarray_namedarray_tests.NamedArray2[xarray_namedarray_tests.StrNumpy]"
G:\Program\Dropbox\Python\xarray_namedarray_tests.py:218: note: Revealed type is "xarray_namedarray_tests.NamedArray2[xarray_namedarray_tests.StrSparse]"
Success: no issues found in 1 source file

(xarray-tests) C:\Users\J.W>pyright G:\Program\Dropbox\Python\xarray_namedarray_tests.py
WARNING: there is a new pyright version available (v1.1.280 -> v1.1.330).
Please install the new version or set PYRIGHT_PYTHON_FORCE_VERSION to `latest`

No configuration file found.
No pyproject.toml file found.
stubPath C:\Users\J.W\typings is not a valid directory.
Assuming Python platform Windows
Searching for source files
Found 1 source file
pyright 1.1.280
G:\Program\Dropbox\Python\xarray_namedarray_tests.py
  G:\Program\Dropbox\Python\xarray_namedarray_tests.py:215:13 - information: Type of "from_array_a" is "NamedArray2[StrNumpy]"
  G:\Program\Dropbox\Python\xarray_namedarray_tests.py:216:13 - information: Type of "from_array_b" is "NamedArray2[StrNumpy]"
  G:\Program\Dropbox\Python\xarray_namedarray_tests.py:217:13 - information: Type of "from_array_d" is "NamedArray2[StrNumpy]"
  G:\Program\Dropbox\Python\xarray_namedarray_tests.py:218:13 - information: Type of "from_array_e" is "NamedArray2[StrSparse]"
0 errors, 0 warnings, 4 informations
Completed in 1.062sec

Bunch of testing code:

from __future__ import annotations

from typing import (
    Protocol,
    TypeVar,
    Union,
    overload,
    runtime_checkable,
    Generic,
    cast,
)

T_Str = TypeVar("T_Str", bound=str, covariant=True)


@runtime_checkable
class StrBase(Protocol):
    def e(self: StrBase) -> StrBase:
        ...


class StrNumpy(str):
    def e(self: StrNumpy) -> StrNumpy:
        raise NotImplementedError

    def f(self: StrNumpy) -> StrNumpy:
        raise NotImplementedError


class StrSparse(str):
    def e(self: StrSparse) -> StrSparse:
        raise NotImplementedError

    def f(self: StrSparse) -> StrSparse:
        raise NotImplementedError


T = TypeVar("T", bound=StrBase)
DuckArray = Union[T, StrNumpy]


@overload
def normalize(x: T) -> T:
    ...


@overload
def normalize(x: str) -> StrNumpy:
    ...


@overload
def normalize(x: float) -> StrNumpy:
    ...


def normalize(x: T | str | float) -> T | StrNumpy:
    if isinstance(x, StrBase):
        print(x)
        # TODO: cast used because of mypy, pyright does not need it:
        return cast(T, x)
    else:
        return StrNumpy(x)


a = normalize("s")
b = normalize(1)
# c = normalize(StrBase(23))
d = normalize(StrNumpy(223))
e = normalize(StrSparse(45))

f = StrSparse(45).replace("4", "7")


class NamedArray(Generic[T]):
    a: T | StrNumpy

    def __init__(self, a: T | str | float):
        self.a = normalize(a)


narr_a = NamedArray("s")
narr_b = NamedArray(1)
narr_d = NamedArray(StrNumpy(223))
narr_e = NamedArray(StrSparse(45))


# Test normalize:
reveal_type(a)  # Fails
reveal_type(b)  # Fails
# reveal_type(c)
reveal_type(d)
reveal_type(e)

# Test NamedA generics are passed:
reveal_type(narr_a)
reveal_type(narr_b)
reveal_type(narr_d)
reveal_type(narr_e)

# Test NamedA generics are passed:
narr_a2 = NamedArray(normalize("s"))
narr_b2 = NamedArray(normalize(1))
narr_d2 = NamedArray(normalize(StrNumpy(223)))
narr_e2 = NamedArray(normalize(StrSparse(45)))

reveal_type(narr_a2)
reveal_type(narr_b2)
reveal_type(narr_d2)
reveal_type(narr_e2)

# %%


class NamedArray2(Generic[T]):
    a: T

    def __init__(self, a: T):
        self.a = a


@overload
def normalize2(x: T, func: type[NamedArray2]) -> NamedArray2[T]:
    ...


@overload
def normalize2(x: str, func: type[NamedArray2]) -> NamedArray2[StrNumpy]:
    ...


@overload
def normalize2(x: float, func: Callable[StrNumpy, T]) -> NamedArray2[StrNumpy]:
    ...


def normalize2(
    x: T | str | float, func: type[NamedArray2]
) -> NamedArray2[T] | NamedArray2[StrNumpy]:
    if isinstance(x, StrBase):
        print(x)
        # TODO: cast used because of mypy, pyright does not need it:
        return func(cast(T, x))
    else:
        return func(StrNumpy(x))


# %% Normalize before calling class:
from __future__ import annotations

from typing import (
    Protocol,
    TypeVar,
    overload,
    runtime_checkable,
    Generic,
    cast,
)

T_Str = TypeVar("T_Str", bound=str, covariant=True)


@runtime_checkable
class StrBase(Protocol):
    def e(self: StrBase) -> StrBase:
        ...


class StrNumpy(str):
    def e(self: StrNumpy) -> StrNumpy:
        raise NotImplementedError

    def f(self: StrNumpy) -> StrNumpy:
        raise NotImplementedError


class StrSparse(str):
    def e(self: StrSparse) -> StrSparse:
        raise NotImplementedError

    def f(self: StrSparse) -> StrSparse:
        raise NotImplementedError


T = TypeVar("T", bound=StrBase)


class NamedArray2(Generic[T]):
    a: T

    def __init__(self, a: T):
        self.a = a


@overload
def from_array(x: T) -> NamedArray2[T]:
    ...


@overload
def from_array(x: str | float) -> NamedArray2[StrNumpy]:
    ...


def from_array(x: T | str | float) -> NamedArray2[T] | NamedArray2[StrNumpy]:
    if isinstance(x, StrBase):
        # TODO: cast used because of mypy, pyright does not need it:
        x_ = cast(T, x)
        return NamedArray2(x_)
    else:
        return NamedArray2(StrNumpy(x))


# Test from_array:
from_array_a = from_array("s")
from_array_b = from_array(1)
from_array_d = from_array(StrNumpy(223))
from_array_e = from_array(StrSparse(45))
reveal_type(from_array_a)
reveal_type(from_array_b)
reveal_type(from_array_d)
reveal_type(from_array_e)

# %%
# # Test self.from_array:
# from_array_a = NamedArray2.from_array("s")
# from_array_b = NamedArray2.from_array(1)
# from_array_d = NamedArray2.from_array(StrNumpy(223))
# from_array_e = NamedArray2.from_array(StrSparse(45))
# reveal_type(from_array_a)
# reveal_type(from_array_b)
# reveal_type(from_array_d)
# reveal_type(from_array_e)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic-NamedArray Lightweight version of Variable topic-typing
Projects
Status: In progress
Development

Successfully merging this pull request may close these issues.

3 participants