-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Level of Detail VTK source loader (#36)
* Add Level of Detail VTK source loader * Cleanup * codespell * LOD example improvements * Improve LOD with global example * Internal move * Add time support to VTK source * Cleanup * Add test * Do not test mem share for vtk source
- Loading branch information
1 parent
e19eab1
commit 2a765ca
Showing
6 changed files
with
376 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import pyvista as pv | ||
from trame.app import get_server | ||
from trame.ui.vuetify import SinglePageLayout | ||
from trame.widgets import vtk, vuetify | ||
import xarray as xr | ||
|
||
from pvxarray.vtk_source import PyVistaXarraySource | ||
|
||
server = get_server() | ||
state, ctrl = server.state, server.controller | ||
|
||
state.trame__title = "PyVista Xarray Level of Detail" | ||
|
||
# ----------------------------------------------------------------------------- | ||
ds = xr.tutorial.load_dataset("air_temperature") | ||
da = ds.air | ||
source = PyVistaXarraySource(da, x="lon", y="lat", time="time", resolution=1.0) | ||
|
||
# ds = xr.open_dataset("oisst-avhrr-v02r01.19810901.nc") | ||
# da = ds.err[dict(time=0, zlev=0)] | ||
# source = PyVistaXarraySource(da, x="lon", y="lat", resolution=0.25) | ||
|
||
# ----------------------------------------------------------------------------- | ||
plotter = pv.Plotter(off_screen=True) | ||
# Requires https://github.com/pyvista/pyvista/pull/3318 | ||
plotter.add_mesh(source, name="data_array", show_edges=True) | ||
plotter.view_xy() | ||
|
||
|
||
@state.change("resolution") | ||
def update_resolution(resolution=25, **kwargs): | ||
source.resolution = resolution / 100.0 | ||
source.Update() | ||
ctrl.view_update() | ||
|
||
|
||
@state.change("time_index") | ||
def update_time_index(time_index=0, **kwargs): | ||
source.time_index = time_index | ||
source.Update() | ||
ctrl.view_update() | ||
|
||
|
||
# ----------------------------------------------------------------------------- | ||
# GUI | ||
# ----------------------------------------------------------------------------- | ||
|
||
with SinglePageLayout(server) as layout: | ||
layout.icon.click = ctrl.view_reset_camera | ||
layout.title.set_text("PyVista Xarray Level of Detail") | ||
|
||
with layout.toolbar: | ||
vuetify.VSpacer() | ||
vuetify.VSlider( | ||
v_model=("time_index", 0), | ||
min=0, | ||
max=len(da.time) - 1, | ||
step=1, | ||
hide_details=True, | ||
label="Time Index", | ||
dense=True, | ||
style="max-width: 300px", | ||
) | ||
vuetify.VSelect( | ||
label="Resolution %", | ||
v_model=("resolution", source.resolution * 100.0), | ||
items=("array_list", [5, 25, 50, 100]), | ||
hide_details=True, | ||
dense=True, | ||
outlined=True, | ||
classes="pt-1 ml-2", | ||
style="max-width: 150px", | ||
) | ||
|
||
with layout.content: | ||
with vuetify.VContainer( | ||
fluid=True, | ||
classes="pa-0 fill-height", | ||
): | ||
view = vtk.VtkRemoteView( | ||
plotter.ren_win, | ||
ref="view", | ||
interactive_ratio=1, | ||
) | ||
ctrl.view_update = view.update | ||
ctrl.view_reset_camera = view.reset_camera | ||
|
||
# Uncomment following line to hide footer | ||
# layout.footer.hide() | ||
|
||
# ----------------------------------------------------------------------------- | ||
# Main | ||
# ----------------------------------------------------------------------------- | ||
|
||
if __name__ == "__main__": | ||
server.start() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import geovista as gv | ||
from trame.app import get_server | ||
from trame.ui.vuetify import SinglePageLayout | ||
from trame.widgets import vtk as trame_vtk, vuetify | ||
import xarray as xr | ||
|
||
from pvxarray.vtk_source import PyVistaXarraySource | ||
|
||
server = get_server() | ||
state, ctrl = server.state, server.controller | ||
|
||
state.trame__title = "PyVista Xarray Level of Detail" | ||
|
||
# ----------------------------------------------------------------------------- | ||
ds = xr.tutorial.load_dataset("air_temperature") | ||
da = ds.air | ||
source = PyVistaXarraySource(da, x="lon", y="lat", time="time", resolution=1.0) | ||
|
||
# ds = xr.open_dataset("oisst-avhrr-v02r01.19810901.nc") | ||
# da = ds.err[dict(time=0, zlev=0)] | ||
# source = PyVistaXarraySource(da, x="lon", y="lat", resolution=0.25) | ||
|
||
# ----------------------------------------------------------------------------- | ||
DS_NAME = "mydata" | ||
|
||
|
||
def apply(): | ||
src = source.apply() | ||
return gv.Transform.from_1d(src.x, src.y, data=src.active_scalars).threshold() | ||
|
||
|
||
mesh = apply() | ||
|
||
plotter = gv.GeoPlotter(off_screen=True) | ||
|
||
|
||
def _update(): | ||
mesh = apply() | ||
# mesh.overwrite(apply()) | ||
plotter.remove_actor(DS_NAME) | ||
plotter.add_mesh(mesh, cmap="coolwarm", show_edges=True, name=DS_NAME) | ||
ctrl.view_update() | ||
|
||
|
||
@state.change("resolution") | ||
def update_resolution(resolution=25, **kwargs): | ||
source.resolution = resolution / 100.0 | ||
_update() | ||
|
||
|
||
@state.change("time_index") | ||
def update_time_index(time_index=0, **kwargs): | ||
source.time_index = time_index | ||
_update() | ||
|
||
|
||
plotter.add_mesh(mesh, cmap="coolwarm", show_edges=True, name=DS_NAME) | ||
plotter.add_base_layer(texture=gv.blue_marble()) | ||
resolution = "10m" | ||
plotter.add_coastlines(resolution=resolution, color="white") | ||
plotter.add_axes() | ||
plotter.add_text( | ||
f"NOAA/NCEI OISST AVHRR ({resolution} Coastlines)", | ||
position="upper_left", | ||
font_size=10, | ||
shadow=True, | ||
) | ||
plotter.view_isometric() | ||
|
||
|
||
# ----------------------------------------------------------------------------- | ||
# GUI | ||
# ----------------------------------------------------------------------------- | ||
|
||
with SinglePageLayout(server) as layout: | ||
layout.icon.click = ctrl.view_reset_camera | ||
layout.title.set_text("PyVista Xarray Level of Detail") | ||
|
||
with layout.toolbar: | ||
vuetify.VSpacer() | ||
vuetify.VSlider( | ||
v_model=("time_index", 0), | ||
min=0, | ||
max=len(da.time) - 1, | ||
step=1, | ||
hide_details=True, | ||
label="Time Index", | ||
dense=True, | ||
style="max-width: 300px", | ||
) | ||
vuetify.VSelect( | ||
label="Resolution %", | ||
v_model=("resolution", source.resolution * 100.0), | ||
items=("array_list", [5, 25, 50, 100]), | ||
hide_details=True, | ||
dense=True, | ||
outlined=True, | ||
classes="pt-1 ml-2", | ||
style="max-width: 150px", | ||
) | ||
|
||
with layout.content: | ||
with vuetify.VContainer( | ||
fluid=True, | ||
classes="pa-0 fill-height", | ||
): | ||
view = trame_vtk.VtkRemoteView( | ||
plotter.ren_win, | ||
ref="view", | ||
interactive_ratio=1, | ||
) | ||
ctrl.view_update = view.update | ||
ctrl.view_reset_camera = view.reset_camera | ||
|
||
# Uncomment following line to hide footer | ||
# layout.footer.hide() | ||
|
||
# ----------------------------------------------------------------------------- | ||
# Main | ||
# ----------------------------------------------------------------------------- | ||
|
||
if __name__ == "__main__": | ||
server.start() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import pyvista as pv | ||
from vtk.util.vtkAlgorithm import VTKPythonAlgorithmBase | ||
import xarray as xr | ||
|
||
|
||
class BaseSource(VTKPythonAlgorithmBase): | ||
def __init__(self, nOutputPorts=1, outputType="vtkTable", **kwargs): | ||
VTKPythonAlgorithmBase.__init__( | ||
self, nInputPorts=0, nOutputPorts=nOutputPorts, outputType=outputType, **kwargs | ||
) | ||
|
||
def GetOutput(self, port=0): | ||
output = pv.wrap(self.GetOutputDataObject(port)) | ||
if output.active_scalars is None and output.n_arrays: | ||
if len(output.point_data): | ||
output.set_active_scalars(output.point_data.keys()[0]) | ||
elif len(output.cell_data): | ||
output.set_active_scalars(output.cell_data.keys()[0]) | ||
return output | ||
|
||
def apply(self): | ||
self.Update() | ||
return self.GetOutput() | ||
|
||
def update(self): | ||
"""Alias for self.Update()""" | ||
return self.Update() | ||
|
||
def get_output(self, port=0): | ||
"""Alias for self.GetOutput()""" | ||
return self.GetOutput(port=port) | ||
|
||
|
||
class PyVistaXarraySource(BaseSource): | ||
def __init__( | ||
self, | ||
data_array: xr.DataArray, | ||
x: Optional[str] = None, | ||
y: Optional[str] = None, | ||
z: Optional[str] = None, | ||
time: Optional[str] = None, | ||
order: Optional[str] = "C", | ||
component: Optional[str] = None, | ||
resolution: float = 1.0, | ||
**kwargs, | ||
): | ||
BaseSource.__init__( | ||
self, | ||
nOutputPorts=1, | ||
outputType="vtkRectilinearGrid", | ||
) | ||
self._data_array = data_array | ||
self._resolution = resolution | ||
|
||
self._x = x | ||
self._y = y | ||
self._z = z | ||
self._order = order | ||
self._component = component | ||
|
||
self._time = None | ||
self._time_index = 0 | ||
if isinstance(time, str): | ||
self._time = time | ||
elif time is not None: | ||
raise TypeError | ||
|
||
@property | ||
def data_array(self): | ||
return self._data_array | ||
|
||
@property | ||
def resolution(self): | ||
return self._resolution | ||
|
||
@resolution.setter | ||
def resolution(self, resolution: int): | ||
self._resolution = resolution | ||
self.Modified() | ||
|
||
def resolution_to_sampling_rate(self, data_array): | ||
"""Convert percentage to sampling rate.""" | ||
shape = np.array(data_array.shape) | ||
n = np.floor(shape * self._resolution) | ||
rate = np.ceil(shape / n).astype(int) | ||
return np.pad(rate, (0, 3 - len(rate)), mode="constant") | ||
|
||
@property | ||
def time_index(self): | ||
return self._time_index | ||
|
||
@time_index.setter | ||
def time_index(self, time_index: int): | ||
# TODO: hook into the VTK pipeling to get requested time | ||
self._time_index = time_index | ||
self.Modified() | ||
|
||
def RequestData(self, request, inInfo, outInfo): | ||
# Use open data_array handle to fetch data at | ||
# desired Level of Detail | ||
if self._time is not None: | ||
da = self.data_array[{self._time: self.time_index}] | ||
else: | ||
da = self.data_array | ||
|
||
rx, ry, rz = self.resolution_to_sampling_rate(da) | ||
if da.ndim == 1: | ||
da = da[::rx] | ||
elif da.ndim == 2: | ||
da = da[::rx, ::ry] | ||
elif da.ndim == 3: | ||
da = da[::rx, ::ry, ::rz] | ||
else: | ||
raise ValueError | ||
|
||
mesh = da.pyvista.mesh( | ||
x=self._x, y=self._y, z=self._z, order=self._order, component=self._component | ||
) | ||
|
||
pdo = self.GetOutputData(outInfo, 0) | ||
pdo.ShallowCopy(mesh) | ||
return 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import numpy as np | ||
import xarray as xr | ||
|
||
from pvxarray.vtk_source import PyVistaXarraySource | ||
|
||
|
||
def test_vtk_source(): | ||
ds = xr.tutorial.load_dataset("air_temperature") | ||
|
||
da = ds.air | ||
source = PyVistaXarraySource(da, x="lon", y="lat", time="time", resolution=1.0) | ||
|
||
mesh = source.apply() | ||
assert mesh | ||
assert mesh.n_points == 1325 | ||
assert "air" in mesh.point_data | ||
|
||
assert np.array_equal(mesh["air"], da[dict(time=0)].values.ravel()) | ||
# assert np.may_share_memory(mesh["air"], da[dict(time=0)].values.ravel()) | ||
assert np.array_equal(mesh.x, da.lon) | ||
assert np.array_equal(mesh.y, da.lat) | ||
|
||
source.time_index = 1 | ||
mesh = source.apply() | ||
assert np.array_equal(mesh["air"], da[dict(time=1)].values.ravel()) | ||
# assert np.may_share_memory(mesh["air"], da[dict(time=1)].values.ravel()) | ||
|
||
source.resolution = 0.5 | ||
mesh = source.apply() | ||
assert mesh.n_points < 1325 |