Skip to content

Commit

Permalink
add get_var_for_nextsim
Browse files Browse the repository at this point in the history
  • Loading branch information
akorosov committed Jan 13, 2022
1 parent 67d610d commit 6c9f982
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 14 deletions.
11 changes: 7 additions & 4 deletions geodataset/custom_geodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,25 @@ def get_lonlat_arrays(self):

class JaxaAmsr2IceConc(CustomDatasetRead):
pattern = re.compile(r'Arc_\d{8}_res3.125_pyres.nc')
_filename_suffix = '_res3.125_pyres.nc'
lonlat_names = 'longitude', 'latitude'
projection = pyproj.Proj(3411)
grid_mapping_variable = 'absent'


class MooringsNextsim(CustomDatasetRead):
_filename_prefix = 'Moorings'
pattern = re.compile(r'Moorings.*.nc')
projection = pyproj.Proj(
'+proj=stere +a=6378273.0 +b=6356889.448910593 '
'+lon_0=-45.0 +lat_0=90.0 +lat_ts=60.0')
grid_mapping_variable = 'Polar_Stereographic_Grid'


class MooringsArcMfc(CustomDatasetRead):
_filename_prefix = 'Moorings'
pattern = re.compile(r'Moorings.*.nc')
projection = pyproj.Proj(
'+proj=stere +a=6378273.0 +b=6378273.0 '
'+lon_0=-45.0 +lat_0=90.0 +lat_ts=90.0')
grid_mapping_variable = 'absent'


class NerscSarProducts(CustomDatasetRead):
Expand All @@ -83,21 +83,24 @@ class OsisafDriftersNextsim(CustomDatasetRead):
pattern = re.compile(r'OSISAF_Drifters_.*.nc')
projection = pyproj.Proj("+proj=stere +lat_0=90 +lat_ts=70 +lon_0=-45 "
" +a=6378273 +b=6356889.44891 ")
grid_mapping_variable = 'absent'


class SmosIceThickness(CustomDatasetRead):
pattern = re.compile(r'SMOS_Icethickness_v3.2_north_\d{8}.nc')
projection = pyproj.Proj(3411)
grid_mapping_variable = 'absent'


class Topaz4Forecast(CustomDatasetRead):
pattern = re.compile(r'\d{8}_dm-metno-MODEL-topaz4-ARC-b\d{8}-fv02.0.nc')
projection = pyproj.Proj("+proj=stere +lat_0=90 +lon_0=-45 +k=1 +x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs")
grid_mapping_variable = 'stereographic'


class NetcdfArcMFC(GeoDatasetWrite):
""" wrapper for netCDF4.Dataset with info about ArcMFC products """
grid_mapping_name = 'stereographic'
grid_mapping_variable = 'stereographic'
projection = pyproj.Proj(
'+proj=stere +a=6378273 +b=6378273.0 '
' +lon_0=-45 +lat_0=90 +lat_ts=90')
70 changes: 66 additions & 4 deletions geodataset/geodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import pyproj
from pyproj.exceptions import CRSError
from pyresample.utils import load_cf_area
from scipy.interpolate import RegularGridInterpolator
from xarray.core.variable import MissingDimensionsError

from geodataset.utils import InvalidDatasetError
from geodataset.utils import InvalidDatasetError, fill_nan_gaps

class GeoDatasetBase(Dataset):
""" Abstract wrapper for netCDF4.Dataset for common input or ouput tasks """
Expand Down Expand Up @@ -103,7 +104,7 @@ def get_nearest_date(self, pivot):

class GeoDatasetWrite(GeoDatasetBase):
""" Wrapper for netCDF4.Dataset for common ouput tasks """
grid_mapping_name = 'Polar_Stereographic_Grid'
grid_mapping_variable = None
spatial_dim_names = ('x', 'y')
time_name = 'time'
lonlat_names = ('longitude', 'latitude')
Expand All @@ -121,7 +122,7 @@ def set_projection_variable(self):
Check netcdf files at:
http://cfconventions.org/compliance-checker.html
"""
pvar = self.createVariable(self.grid_mapping_name, 'i1')
pvar = self.createVariable(self.grid_mapping_variable, 'i1')
pvar.setncatts(self.get_grid_mapping_ncattrs())

def set_time_variables_dimensions(self, time_data, time_atts, time_bnds_data):
Expand Down Expand Up @@ -225,7 +226,7 @@ def set_variable(self, vname, data, dims, atts, dtype='f4'):
# needs to be of right data type
ncatts['missing_value'] = type_converter(atts['missing_value'])
dst_var = self.createVariable(vname, dtype, dims, **kw)
ncatts['grid_mapping'] = self.grid_mapping_name
ncatts['grid_mapping'] = self.grid_mapping_variable
dst_var.setncatts(ncatts)
dst_var[:] = data

Expand Down Expand Up @@ -437,3 +438,64 @@ def get_proj_info_kwargs(self):
ecc = g.es,
)
return kwargs

def get_var_for_nextsim(
self, var_name, nbo, distance=5, on_elements=True, fill_value=np.nan):
""" Interpolate netCDF data onto mesh from NextsimBin object
Parameters
----------
var_name : str
name of variable
nbo : NextsimBin
nextsim bin object with mesh_info attribute
distance : int
extrapolation distance (in pixels) to avoid land contamintation
on_elements : bool
perform interpolation on elements or nodes?
fill_value : bool
value for filling out of bound regions
Returns
-------
v_pro : 1D nupy.array
values from netCDF interpolated on nextsim mesh
"""
# get self coordinates
nc_lon, nc_lat = self.get_lonlat_arrays()
if len(nc_lon.shape) < 2 or len(nc_lat.shape) < 2:
raise ValueError('Can inteporlate only 2D data from netCDF file')
# get variable
nc_v = self.get_variable_array(var_name).filled(np.nan)

# get elements coordinates in neXtSIM projection
nb_x, nb_y = nbo.mesh_info.get_nodes_xy()
t = nbo.mesh_info.get_indices() - 1
if on_elements:
nb_x, nb_y = [i[t].mean(axis=1) for i in [nb_x, nb_y]]

# transform nextsim coordinates to lon/lat
nb_x, nb_y = nbo.mesh_info.projection.pyproj(nb_x, nb_y, inverse=True)

# transform to common coordinate system if needed
if not self.is_lonlat_dim:
nc_x, nc_y = self.get_xy_dims_from_lonlat(nc_lon, nc_lat)
nb_x, nb_y = self.projection(nb_x, nb_y)
else:
nc_x, nc_y = nc_lon[0], nc_lat[:,0]

# fill nan gaps to avoid land contamination
nc_v = fill_nan_gaps(nc_v, distance)
# swap Y axis if needed
y_step = int(np.sign(np.mean(np.diff(nc_y))))
# make interpolator
rgi = RegularGridInterpolator((nc_y[::y_step], nc_x), nc_v[::y_step])
# interpolate only values within self bbox
gpi = ((nb_x > nc_x.min()) *
(nb_x < nc_x.max()) *
(nb_y > nc_y.min()) *
(nb_y < nc_y.max()))
v_pro = np.zeros(nb_x.shape) + fill_value
v_pro[gpi] = rgi((nb_y[gpi], nb_x[gpi]))
return v_pro
42 changes: 36 additions & 6 deletions geodataset/tests/test_geodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,12 @@ def test_datetimes(self, **kwargs):
get_grid_mapping_ncattrs=DEFAULT)
def test_set_projection_variable(self, **kwargs):
nc = GeoDatasetWrite()
nc.grid_mapping_name = 'Polar_Stereographic_Grid'
nc.grid_mapping_variable = 'psg'
nc.spatial_dim_names = ('x', 'y')
nc.time_name = 'time'
nc.lonlat_names = ('longitude', 'latitude')
nc.set_projection_variable()
kwargs['createVariable'].assert_called_once_with(
'Polar_Stereographic_Grid', 'i1'
)
kwargs['createVariable'].assert_called_once_with('psg', 'i1')
kwargs['get_grid_mapping_ncattrs'].assert_called_once()

@patch.multiple(GeoDatasetWrite,
Expand Down Expand Up @@ -183,7 +181,7 @@ def test_set_time_variables_dimensions(self, **kwargs):
def test_set_variable_1(self, f4, f8, **kwargs):
''' test f4 with _FillValue defined '''
nc = GeoDatasetWrite()
nc.grid_mapping_name = 'gmn'
nc.grid_mapping_variable = 'gmn'
atts = dict(a1='A1', a2='A2', _FillValue='fv')
f4.return_value = 'fv4'
nc.set_variable('vname', 'data', 'dims', atts, dtype='f4')
Expand All @@ -206,7 +204,7 @@ def test_set_variable_1(self, f4, f8, **kwargs):
def test_set_variable_2(self, f4, f8, **kwargs):
''' test f8 with missing_value defined '''
nc = GeoDatasetWrite()
nc.grid_mapping_name = 'gmn'
nc.grid_mapping_variable = 'gmn'
atts = dict(a1='A1', a2='A2', missing_value='fv')
f8.return_value = 'fv8'

Expand Down Expand Up @@ -387,5 +385,37 @@ def test_get_proj_info_kwargs(self, **kwargs):
{'proj': 'stere', 'lat_0': 90, 'lat_ts': 70, 'lon_0': -45,
'a': 6378273.0, 'ecc': 0.0066938828637783665})

@patch.multiple(GeoDatasetRead,
__init__=MagicMock(return_value=None),
__exit__=MagicMock(return_value=None),
projection=pyproj.Proj(3411),
get_lonlat_arrays=MagicMock(return_value=(
np.array([[0,1],[0,1]]),
np.array([[1,1],[0,0]]),
)),
get_variable_array=MagicMock(
return_value=np.ma.array([[1,2],[3,4]])),
is_lonlat_dim=False,
)
@patch('geodataset.geodataset.fill_nan_gaps')
def test_get_var_for_nextsim(self, mock_fng, **kwargs):
mock_fng.return_value = np.array([[1,2],[3,4]])

nbo = MagicMock()
nbo.mesh_info.get_nodes_xy.return_value = (
np.array([8569000.1, 8569000.2, 8569000.3]),
np.array([-8569000.1, -8569000.2, -8569000.3]))
nbo.mesh_info.get_indices.return_value = np.array([[1,2,3],])
nbo.mesh_info.projection.pyproj = pyproj.Proj(3411)

with GeoDatasetRead() as ds:
v_pro = ds.get_var_for_nextsim('var_name', nbo, 10)

self.assertAlmostEqual(v_pro[0], 1.00000402, 1)
ds.get_lonlat_arrays.assert_called_once()
ds.get_variable_array.assert_called_once_with('var_name')
mock_fng.assert_called_once()


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions geodataset/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def test_open_netcdf(self):
print(nc_file, ds.lonlat_names)
self.assertIsInstance(ds.lonlat_names[0], str)
self.assertIsInstance(ds.lonlat_names[1], str)
self.assertIsInstance(ds.variable_names, list)
self.assertIsInstance(ds.variable_names[0], str)

def test_get_lonlat_arrays(self):
for nc_file in self.nc_files:
Expand Down
19 changes: 19 additions & 0 deletions geodataset/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import unittest

import numpy as np

from geodataset.utils import fill_nan_gaps


class TestsUtils(unittest.TestCase):
def test_fill_nan_gaps(sefl):
a = np.array(
[[np.nan,np.nan,3],[np.nan,5,6],[7,8,9]], float)
b = fill_nan_gaps(a)
np.testing.assert_array_equal(b,
np.array([[5,5,3],[7,5,6],[7,8,9]], float)
)


if __name__ == "__main__":
unittest.main()
26 changes: 26 additions & 0 deletions geodataset/utils.py
Original file line number Diff line number Diff line change
@@ -1 +1,27 @@
import numpy as np
from scipy.ndimage.morphology import distance_transform_edt

class InvalidDatasetError(Exception): pass

def fill_nan_gaps(array, distance=5):
""" Fill gaps in input array
Parameters
----------
array : 2D numpy.array
Raster with data
distance : int
Minimum size of gap to fill
Returns
-------
array : 2D numpy.array
Raster with data with gaps filled
"""
dist, indi = distance_transform_edt(
np.isnan(array),
return_distances=True,
return_indices=True)
gpi = dist <= distance
r, c = indi[:, gpi]
array = np.array(array)
array[gpi] = array[r, c]
return array

0 comments on commit 6c9f982

Please sign in to comment.