Skip to content

Commit

Permalink
feat(Python): Add Pyodide Mesh support
Browse files Browse the repository at this point in the history
  • Loading branch information
thewtex committed Apr 19, 2023
1 parent f8d0fa3 commit fc9f404
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 27 deletions.
2 changes: 1 addition & 1 deletion packages/core/python/itkwasm/itkwasm/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _default_direction() -> ArrayLike:
@dataclass
class Image:
imageType: Union[ImageType, Dict] = field(default_factory=ImageType)
name: str = 'image'
name: str = 'Image'
origin: Sequence[float] = field(default_factory=list)
spacing: Sequence[float] = field(default_factory=list)
direction: ArrayLike = field(default_factory=_default_direction)
Expand Down
4 changes: 2 additions & 2 deletions packages/core/python/itkwasm/itkwasm/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

@dataclass
class MeshType:
dimension: int = 2
dimension: int = 3

pointComponentType: Union[IntTypes, FloatTypes] = FloatTypes.Float32
pointPixelComponentType: Union[IntTypes, FloatTypes] = FloatTypes.Float32
Expand All @@ -30,7 +30,7 @@ class MeshType:
class Mesh:
meshType: Union[MeshType, Dict] = field(default_factory=MeshType)

name: str = 'mesh'
name: str = 'Mesh'

numberOfPoints: int = 0
points: Optional[ArrayLike] = None
Expand Down
34 changes: 21 additions & 13 deletions packages/core/python/itkwasm/itkwasm/pointset.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,38 @@
from dataclasses import dataclass
from dataclasses import dataclass, field

from typing import Optional
from typing import Optional, Union, Dict

try:
from numpy.typing import ArrayLike
except ImportError:
from numpy import ndarray as ArrayLike

from .float_types import FloatTypes
from .int_types import IntTypes
from .pixel_types import PixelTypes

@dataclass
class PointSetType:
dimension: int
dimension: int = 3

pointComponentType: str
pointPixelComponentType: str
pointPixelType: str
pointPixelComponents: int
pointComponentType: Union[IntTypes, FloatTypes] = FloatTypes.Float32
pointPixelComponentType: Union[IntTypes, FloatTypes] = FloatTypes.Float32
pointPixelType: PixelTypes = PixelTypes.Scalar
pointPixelComponents: int = 1


@dataclass
class PointSet:
pointSetType: PointSetType
pointSetType: Union[PointSetType, Dict] = field(default_factory=PointSetType)

name: str = 'PointSet'

name: str
numberOfPoints: int = 0
points: Optional[ArrayLike] = None

numberOfPoints: int
points: Optional[ArrayLike]
numberOfPointPixels: int = 0
pointData: Optional[ArrayLike] = None

numberOfPointPixels: int
pointData: Optional[ArrayLike]
def __post_init__(self):
if isinstance(self.pointSetType, dict):
self.pointSetType = PointSetType(**self.pointSetType)
47 changes: 41 additions & 6 deletions packages/core/python/itkwasm/itkwasm/pyodide.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def to_py(js_proxy):
dimension = image_type.dimension
component_type = image_type.componentType
image_dict['direction'] = _to_numpy_array(str(FloatTypes.Float64), image_dict['direction']).reshape((dimension, dimension))
image_dict['data'] = _to_numpy_array(component_type, image_dict['data']).reshape((dimension, dimension))
if image_dict['data'] is not None:
image_dict['data'] = _to_numpy_array(component_type, image_dict['data']).reshape((dimension, dimension))
return Image(**image_dict)
elif hasattr(js_proxy, "pointSetType"):
point_set_dict = js_proxy.to_py()
Expand All @@ -78,9 +79,29 @@ def to_py(js_proxy):
dimension = point_set_type.dimension
point_component_type = point_set_type.pointComponentType
point_pixel_component_type = point_set_type.pointPixelComponentType
point_set_dict['points'] = _to_numpy_array(point_component_type, point_set_dict['points']).reshape((-1, dimension))
point_set_dict['pointData'] = _to_numpy_array(point_pixel_component_type, point_set_dict['pointData'])
if point_set_dict['points'] is not None:
point_set_dict['points'] = _to_numpy_array(point_component_type, point_set_dict['points']).reshape((-1, dimension))
if point_set_dict['pointData'] is not None:
point_set_dict['pointData'] = _to_numpy_array(point_pixel_component_type, point_set_dict['pointData'])
return PointSet(**point_set_dict)
elif hasattr(js_proxy, "meshType"):
mesh_dict = js_proxy.to_py()
mesh_type = MeshType(**mesh_dict['meshType'])
mesh_dict['meshType'] = mesh_type
dimension = mesh_type.dimension
point_component_type = mesh_type.pointComponentType
point_pixel_component_type = mesh_type.pointPixelComponentType
cell_component_type = mesh_type.cellComponentType
cell_pixel_component_type = mesh_type.cellPixelComponentType
if mesh_dict['points'] is not None:
mesh_dict['points'] = _to_numpy_array(point_component_type, mesh_dict['points']).reshape((-1, dimension))
if mesh_dict['pointData'] is not None:
mesh_dict['pointData'] = _to_numpy_array(point_pixel_component_type, mesh_dict['pointData'])
if mesh_dict['cells'] is not None:
mesh_dict['cells'] = _to_numpy_array(cell_component_type, mesh_dict['cells'])
if mesh_dict['cellData'] is not None:
mesh_dict['cellData'] = _to_numpy_array(cell_pixel_component_type, mesh_dict['cellData'])
return Mesh(**mesh_dict)
return js_proxy.to_py()

def to_js(py):
Expand All @@ -90,12 +111,26 @@ def to_js(py):
image_dict = asdict(py)
print('to_js image dict', image_dict['imageType'])
image_dict['direction'] = image_dict['direction'].ravel()
image_dict['data'] = image_dict['data'].ravel()
if image_dict['data'] is not None:
image_dict['data'] = image_dict['data'].ravel()
return pyodide.ffi.to_js(image_dict, dict_converter=js.Object.fromEntries)
elif isinstance(py, PointSet):
point_set_dict = asdict(py)
point_set_dict['points'] = point_set_dict['points'].ravel()
point_set_dict['pointData'] = point_set_dict['pointData'].ravel()
if point_set_dict['points'] is not None:
point_set_dict['points'] = point_set_dict['points'].ravel()
if point_set_dict['pointData'] is not None:
point_set_dict['pointData'] = point_set_dict['pointData'].ravel()
return pyodide.ffi.to_js(point_set_dict, dict_converter=js.Object.fromEntries)
elif isinstance(py, Mesh):
mesh_dict = asdict(py)
if mesh_dict['points'] is not None:
mesh_dict['points'] = mesh_dict['points'].ravel()
if mesh_dict['pointData'] is not None:
mesh_dict['pointData'] = mesh_dict['pointData'].ravel()
if mesh_dict['cells'] is not None:
mesh_dict['cells'] = mesh_dict['cells'].ravel()
if mesh_dict['cellData'] is not None:
mesh_dict['cellData'] = mesh_dict['cellData'].ravel()
return pyodide.ffi.to_js(mesh_dict, dict_converter=js.Object.fromEntries)

return py
4 changes: 2 additions & 2 deletions packages/core/python/itkwasm/test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ def test_image_defaults():
assert image.imageType.pixelType == 'Scalar'
assert image.imageType.components == 1

assert image.name == "image"
assert image.name == "Image"
assert image.origin[0] == 0.0
assert image.origin[1] == 0.0
assert image.spacing[0] == 1.0
assert image.spacing[1] == 1.0
assert np.array_equal(image.direction, np.eye(2).astype(np.float32))
assert np.array_equal(image.direction, np.eye(2).astype(np.float64))

assert image.size[0] == 1
assert image.size[1] == 1
Expand Down
41 changes: 38 additions & 3 deletions packages/core/python/itkwasm/test/test_pyodide.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def test_image_conversion(selenium, package_wheel):
assert image.imageType.pixelType == 'Scalar'
assert image.imageType.components == 1

assert image.name == "image"
assert image.name == "Image"
assert image.origin[0] == 0.0
assert image.origin[1] == 0.0
assert image.spacing[0] == 1.0
Expand All @@ -45,7 +45,7 @@ async def test_image_conversion(selenium, package_wheel):
assert image_py.imageType.pixelType == 'Scalar'
assert image_py.imageType.components == 1

assert image_py.name == "image"
assert image_py.name == "Image"
assert image_py.origin[0] == 0.0
assert image_py.origin[1] == 0.0
assert image_py.spacing[0] == 1.0
Expand Down Expand Up @@ -91,4 +91,39 @@ async def test_point_set_conversion(selenium, package_wheel):
assert point_set.numberOfPoints == point_set_py.numberOfPoints
assert np.array_equal(point_set.points, point_set_py.points)
assert point_set.numberOfPointPixels == point_set_py.numberOfPointPixels
assert np.array_equal(point_set.pointData, point_set_py.pointData)
assert np.array_equal(point_set.pointData, point_set_py.pointData)

@run_in_pyodide(packages=['micropip', 'numpy'])
async def test_mesh_conversion(selenium, package_wheel):
import micropip
await micropip.install(package_wheel)

from itkwasm import Mesh, MeshType, FloatTypes
from itkwasm.pyodide import to_js, to_py
import numpy as np

n_points = 5
dimension = 3

mesh_type = MeshType()

points = np.random.random((n_points, dimension)).astype(np.float32)
point_data = np.random.random((n_points,)).astype(np.float32)

mesh = Mesh(mesh_type, points=points, numberOfPoints=n_points, pointData=point_data, numberOfPointPixels=n_points)

mesh_js = to_js(mesh)
mesh_py = to_py(mesh_js)

mesh_type_py = mesh_py.meshType
assert mesh_type.dimension == mesh_type_py.dimension
assert mesh_type.pointComponentType == mesh_type_py.pointComponentType
assert mesh_type.pointPixelComponentType == mesh_type_py.pointPixelComponentType
assert mesh_type.pointPixelType == mesh_type_py.pointPixelType
assert mesh_type.pointPixelComponents == mesh_type_py.pointPixelComponents

assert mesh.name == mesh_py.name
assert mesh.numberOfPoints == mesh_py.numberOfPoints
assert np.array_equal(mesh.points, mesh_py.points)
assert mesh.numberOfPointPixels == mesh_py.numberOfPointPixels
assert np.array_equal(mesh.pointData, mesh_py.pointData)

0 comments on commit fc9f404

Please sign in to comment.