Skip to content

Commit

Permalink
Add Level of Detail VTK source loader (#36)
Browse files Browse the repository at this point in the history
* Add Level of Detail VTK source loader

* Cleanup

* codespell

* LOD example improvements

* Improve LOD with global example

* Internal move

* Add time support to VTK source

* Cleanup

* Add test

* Do not test mem share for vtk source
  • Loading branch information
banesullivan authored Sep 20, 2022
1 parent e19eab1 commit 2a765ca
Show file tree
Hide file tree
Showing 6 changed files with 376 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ doctest:

lint:
@echo "Linting with flake8"
flake8 --ignore=E501 pvxarray tests
flake8 --ignore=E501 pvxarray tests examples

format:
@echo "Formatting"
Expand Down
96 changes: 96 additions & 0 deletions examples/level_of_detail.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import pyvista as pv
from trame.app import get_server
from trame.ui.vuetify import SinglePageLayout
from trame.widgets import vtk, vuetify
import xarray as xr

from pvxarray.vtk_source import PyVistaXarraySource

server = get_server()
state, ctrl = server.state, server.controller

state.trame__title = "PyVista Xarray Level of Detail"

# -----------------------------------------------------------------------------
ds = xr.tutorial.load_dataset("air_temperature")
da = ds.air
source = PyVistaXarraySource(da, x="lon", y="lat", time="time", resolution=1.0)

# ds = xr.open_dataset("oisst-avhrr-v02r01.19810901.nc")
# da = ds.err[dict(time=0, zlev=0)]
# source = PyVistaXarraySource(da, x="lon", y="lat", resolution=0.25)

# -----------------------------------------------------------------------------
plotter = pv.Plotter(off_screen=True)
# Requires https://github.com/pyvista/pyvista/pull/3318
plotter.add_mesh(source, name="data_array", show_edges=True)
plotter.view_xy()


@state.change("resolution")
def update_resolution(resolution=25, **kwargs):
source.resolution = resolution / 100.0
source.Update()
ctrl.view_update()


@state.change("time_index")
def update_time_index(time_index=0, **kwargs):
source.time_index = time_index
source.Update()
ctrl.view_update()


# -----------------------------------------------------------------------------
# GUI
# -----------------------------------------------------------------------------

with SinglePageLayout(server) as layout:
layout.icon.click = ctrl.view_reset_camera
layout.title.set_text("PyVista Xarray Level of Detail")

with layout.toolbar:
vuetify.VSpacer()
vuetify.VSlider(
v_model=("time_index", 0),
min=0,
max=len(da.time) - 1,
step=1,
hide_details=True,
label="Time Index",
dense=True,
style="max-width: 300px",
)
vuetify.VSelect(
label="Resolution %",
v_model=("resolution", source.resolution * 100.0),
items=("array_list", [5, 25, 50, 100]),
hide_details=True,
dense=True,
outlined=True,
classes="pt-1 ml-2",
style="max-width: 150px",
)

with layout.content:
with vuetify.VContainer(
fluid=True,
classes="pa-0 fill-height",
):
view = vtk.VtkRemoteView(
plotter.ren_win,
ref="view",
interactive_ratio=1,
)
ctrl.view_update = view.update
ctrl.view_reset_camera = view.reset_camera

# Uncomment following line to hide footer
# layout.footer.hide()

# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------

if __name__ == "__main__":
server.start()
123 changes: 123 additions & 0 deletions examples/level_of_detail_geovista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import geovista as gv
from trame.app import get_server
from trame.ui.vuetify import SinglePageLayout
from trame.widgets import vtk as trame_vtk, vuetify
import xarray as xr

from pvxarray.vtk_source import PyVistaXarraySource

server = get_server()
state, ctrl = server.state, server.controller

state.trame__title = "PyVista Xarray Level of Detail"

# -----------------------------------------------------------------------------
ds = xr.tutorial.load_dataset("air_temperature")
da = ds.air
source = PyVistaXarraySource(da, x="lon", y="lat", time="time", resolution=1.0)

# ds = xr.open_dataset("oisst-avhrr-v02r01.19810901.nc")
# da = ds.err[dict(time=0, zlev=0)]
# source = PyVistaXarraySource(da, x="lon", y="lat", resolution=0.25)

# -----------------------------------------------------------------------------
DS_NAME = "mydata"


def apply():
src = source.apply()
return gv.Transform.from_1d(src.x, src.y, data=src.active_scalars).threshold()


mesh = apply()

plotter = gv.GeoPlotter(off_screen=True)


def _update():
mesh = apply()
# mesh.overwrite(apply())
plotter.remove_actor(DS_NAME)
plotter.add_mesh(mesh, cmap="coolwarm", show_edges=True, name=DS_NAME)
ctrl.view_update()


@state.change("resolution")
def update_resolution(resolution=25, **kwargs):
source.resolution = resolution / 100.0
_update()


@state.change("time_index")
def update_time_index(time_index=0, **kwargs):
source.time_index = time_index
_update()


plotter.add_mesh(mesh, cmap="coolwarm", show_edges=True, name=DS_NAME)
plotter.add_base_layer(texture=gv.blue_marble())
resolution = "10m"
plotter.add_coastlines(resolution=resolution, color="white")
plotter.add_axes()
plotter.add_text(
f"NOAA/NCEI OISST AVHRR ({resolution} Coastlines)",
position="upper_left",
font_size=10,
shadow=True,
)
plotter.view_isometric()


# -----------------------------------------------------------------------------
# GUI
# -----------------------------------------------------------------------------

with SinglePageLayout(server) as layout:
layout.icon.click = ctrl.view_reset_camera
layout.title.set_text("PyVista Xarray Level of Detail")

with layout.toolbar:
vuetify.VSpacer()
vuetify.VSlider(
v_model=("time_index", 0),
min=0,
max=len(da.time) - 1,
step=1,
hide_details=True,
label="Time Index",
dense=True,
style="max-width: 300px",
)
vuetify.VSelect(
label="Resolution %",
v_model=("resolution", source.resolution * 100.0),
items=("array_list", [5, 25, 50, 100]),
hide_details=True,
dense=True,
outlined=True,
classes="pt-1 ml-2",
style="max-width: 150px",
)

with layout.content:
with vuetify.VContainer(
fluid=True,
classes="pa-0 fill-height",
):
view = trame_vtk.VtkRemoteView(
plotter.ren_win,
ref="view",
interactive_ratio=1,
)
ctrl.view_update = view.update
ctrl.view_reset_camera = view.reset_camera

# Uncomment following line to hide footer
# layout.footer.hide()

# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------

if __name__ == "__main__":
server.start()
1 change: 1 addition & 0 deletions pvxarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from pvxarray.errors import DataCopyWarning
from pvxarray.io import pyvista_to_xarray
from pvxarray.report import Report
from pvxarray.vtk_source import PyVistaXarraySource
125 changes: 125 additions & 0 deletions pvxarray/vtk_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from typing import Optional

import numpy as np
import pyvista as pv
from vtk.util.vtkAlgorithm import VTKPythonAlgorithmBase
import xarray as xr


class BaseSource(VTKPythonAlgorithmBase):
def __init__(self, nOutputPorts=1, outputType="vtkTable", **kwargs):
VTKPythonAlgorithmBase.__init__(
self, nInputPorts=0, nOutputPorts=nOutputPorts, outputType=outputType, **kwargs
)

def GetOutput(self, port=0):
output = pv.wrap(self.GetOutputDataObject(port))
if output.active_scalars is None and output.n_arrays:
if len(output.point_data):
output.set_active_scalars(output.point_data.keys()[0])
elif len(output.cell_data):
output.set_active_scalars(output.cell_data.keys()[0])
return output

def apply(self):
self.Update()
return self.GetOutput()

def update(self):
"""Alias for self.Update()"""
return self.Update()

def get_output(self, port=0):
"""Alias for self.GetOutput()"""
return self.GetOutput(port=port)


class PyVistaXarraySource(BaseSource):
def __init__(
self,
data_array: xr.DataArray,
x: Optional[str] = None,
y: Optional[str] = None,
z: Optional[str] = None,
time: Optional[str] = None,
order: Optional[str] = "C",
component: Optional[str] = None,
resolution: float = 1.0,
**kwargs,
):
BaseSource.__init__(
self,
nOutputPorts=1,
outputType="vtkRectilinearGrid",
)
self._data_array = data_array
self._resolution = resolution

self._x = x
self._y = y
self._z = z
self._order = order
self._component = component

self._time = None
self._time_index = 0
if isinstance(time, str):
self._time = time
elif time is not None:
raise TypeError

@property
def data_array(self):
return self._data_array

@property
def resolution(self):
return self._resolution

@resolution.setter
def resolution(self, resolution: int):
self._resolution = resolution
self.Modified()

def resolution_to_sampling_rate(self, data_array):
"""Convert percentage to sampling rate."""
shape = np.array(data_array.shape)
n = np.floor(shape * self._resolution)
rate = np.ceil(shape / n).astype(int)
return np.pad(rate, (0, 3 - len(rate)), mode="constant")

@property
def time_index(self):
return self._time_index

@time_index.setter
def time_index(self, time_index: int):
# TODO: hook into the VTK pipeling to get requested time
self._time_index = time_index
self.Modified()

def RequestData(self, request, inInfo, outInfo):
# Use open data_array handle to fetch data at
# desired Level of Detail
if self._time is not None:
da = self.data_array[{self._time: self.time_index}]
else:
da = self.data_array

rx, ry, rz = self.resolution_to_sampling_rate(da)
if da.ndim == 1:
da = da[::rx]
elif da.ndim == 2:
da = da[::rx, ::ry]
elif da.ndim == 3:
da = da[::rx, ::ry, ::rz]
else:
raise ValueError

mesh = da.pyvista.mesh(
x=self._x, y=self._y, z=self._z, order=self._order, component=self._component
)

pdo = self.GetOutputData(outInfo, 0)
pdo.ShallowCopy(mesh)
return 1
30 changes: 30 additions & 0 deletions tests/test_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
import xarray as xr

from pvxarray.vtk_source import PyVistaXarraySource


def test_vtk_source():
ds = xr.tutorial.load_dataset("air_temperature")

da = ds.air
source = PyVistaXarraySource(da, x="lon", y="lat", time="time", resolution=1.0)

mesh = source.apply()
assert mesh
assert mesh.n_points == 1325
assert "air" in mesh.point_data

assert np.array_equal(mesh["air"], da[dict(time=0)].values.ravel())
# assert np.may_share_memory(mesh["air"], da[dict(time=0)].values.ravel())
assert np.array_equal(mesh.x, da.lon)
assert np.array_equal(mesh.y, da.lat)

source.time_index = 1
mesh = source.apply()
assert np.array_equal(mesh["air"], da[dict(time=1)].values.ravel())
# assert np.may_share_memory(mesh["air"], da[dict(time=1)].values.ravel())

source.resolution = 0.5
mesh = source.apply()
assert mesh.n_points < 1325

0 comments on commit 2a765ca

Please sign in to comment.