Skip to content

Commit

Permalink
refactor(shapefile_utils): pathlib compatibility, other improvements
Browse files Browse the repository at this point in the history
* Allow write_gridlines_shapefile to write .prj file
* Add test_write_gridlines_shapefile
* Add remaining file missed from modflowpy#1563 for Mt3dms
  • Loading branch information
mwtoews committed Oct 13, 2022
1 parent b8d471c commit 797678b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 43 deletions.
39 changes: 34 additions & 5 deletions autotest/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import pytest
from autotest.conftest import (
SHAPEFILE_EXTENSIONS,
excludes_platform,
get_example_data_path,
has_pkg,
requires_exe,
requires_pkg,
requires_spatial_reference, excludes_platform,
requires_spatial_reference,
)
from flaky import flaky

Expand Down Expand Up @@ -197,13 +198,38 @@ def test_export_output(tmpdir, example_data_path):
nc.nc.close()


@requires_pkg("shapefile")
def test_write_gridlines_shapefile(tmpdir):
import shapefile

from flopy.discretization import StructuredGrid
from flopy.export.shapefile_utils import write_gridlines_shapefile

sg = StructuredGrid(
delr=np.ones(10) * 1.1,
# cell spacing along model rows
delc=np.ones(10) * 1.1,
# cell spacing along model columns
epsg=26715,
)
outshp = tmpdir / "gridlines.shp"
write_gridlines_shapefile(outshp, sg)

for suffix in [".dbf", ".prj", ".shp", ".shx"]:
assert outshp.with_suffix(suffix).exists()

with shapefile.Reader(str(outshp)) as sf:
assert sf.shapeType == shapefile.POLYLINE
assert len(sf) == 22


@flaky
@requires_pkg("shapefile", "shapely")
def test_write_grid_shapefile(tmpdir):
from shapefile import Reader

from flopy.discretization import StructuredGrid
from flopy.export.shapefile_utils import shp2recarray, write_grid_shapefile
from flopy.export.shapefile_utils import write_grid_shapefile

sg = StructuredGrid(
delr=np.ones(10) * 1.1,
Expand All @@ -212,14 +238,17 @@ def test_write_grid_shapefile(tmpdir):
# cell spacing along model columns
epsg=26715,
)
outshp = os.path.join(tmpdir, "junk.shp")
outshp = tmpdir / "junk.shp"
write_grid_shapefile(outshp, sg, array_dict={})

for suffix in [".dbf", ".prj", ".shp", ".shx"]:
assert outshp.with_suffix(suffix).exists()

# test that vertices aren't getting altered by writing shapefile
# check that pyshp reads integers
# this only check that row/column were recorded as "N"
# not how they will be cast by python or numpy
sfobj = Reader(outshp)
sfobj = Reader(str(outshp))
for f in sfobj.fields:
if f[0] == "row" or f[0] == "column":
assert f[1] == "N"
Expand All @@ -240,7 +269,7 @@ def test_write_grid_shapefile(tmpdir):
meta = src.meta
assert "int" in meta["schema"]["properties"]["row"]
assert "int" in meta["schema"]["properties"]["column"]
except:
except ImportError:
pass


Expand Down
71 changes: 37 additions & 34 deletions flopy/export/shapefile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
"""
import copy
import json
import os
import shutil
import sys
import warnings
from os.path import expandvars
from pathlib import Path

import numpy as np
Expand All @@ -26,7 +26,7 @@ def write_gridlines_shapefile(filename, mg):
Parameters
----------
filename : string
filename : Path or str
name of the shapefile to write
mg : model grid
Expand All @@ -36,7 +36,7 @@ def write_gridlines_shapefile(filename, mg):
"""
shapefile = import_optional_dependency("shapefile")
wr = shapefile.Writer(filename, shapeType=shapefile.POLYLINE)
wr = shapefile.Writer(str(filename), shapeType=shapefile.POLYLINE)
wr.field("number", "N", 18, 0)
if mg.__class__.__name__ == "SpatialReference":
grid_lines = mg.get_grid_lines()
Expand All @@ -52,6 +52,7 @@ def write_gridlines_shapefile(filename, mg):
wr.record(i)

wr.close()
write_prj(filename, mg=mg)
return


Expand All @@ -68,7 +69,7 @@ def write_grid_shapefile(
Parameters
----------
path : str
path : Path or str
shapefile file path
mg : flopy.discretization.Grid object
flopy model grid
Expand All @@ -78,7 +79,7 @@ def write_grid_shapefile(
value to fill nans
epsg : str, int
epsg code
prj : str
prj : Path or str
projection file name path
Returns
Expand All @@ -87,7 +88,7 @@ def write_grid_shapefile(
"""
shapefile = import_optional_dependency("shapefile")
w = shapefile.Writer(path, shapeType=shapefile.POLYGON)
w = shapefile.Writer(str(path), shapeType=shapefile.POLYGON)
w.autoBalance = 1

if mg.__class__.__name__ == "SpatialReference":
Expand Down Expand Up @@ -214,7 +215,7 @@ def model_attributes_to_shapefile(
Parameters
----------
path : string
path : Path or str
path to write the shapefile to
ml : flopy.mbase
model instance
Expand All @@ -230,7 +231,7 @@ def model_attributes_to_shapefile(
of the modelgrid attached to the modflow model object
epsg : int
epsg projection information
prj : str
prj : Path or str
user supplied prj file
Returns
Expand Down Expand Up @@ -388,7 +389,7 @@ def shape_attr_name(name, length=6, keep_layer=False):
Parameters
----------
name : string
name : str
data array name
length : int
maximum length of string to return. Value passed to function is
Expand All @@ -400,7 +401,7 @@ def shape_attr_name(name, length=6, keep_layer=False):
Returns
-------
String
str
Examples
--------
Expand Down Expand Up @@ -442,7 +443,8 @@ def enforce_10ch_limit(names):
Returns
-------
names : list of unique strings of len <= 10.
list
list of unique strings of len <= 10.
"""
names = [n[:5] + n[-4:] + "_" if len(n) > 10 else n for n in names]
dups = {x: names.count(x) for x in names}
Expand Down Expand Up @@ -488,19 +490,19 @@ def shp2recarray(shpname):
Parameters
----------
shpname : str
shpname : Path or str
ESRI Shapefile.
Returns
-------
recarray : np.recarray
np.recarray
"""
from ..utils.geospatial_utils import GeoSpatialCollection

sf = import_optional_dependency("shapefile")

sfobj = sf.Reader(shpname)
sfobj = sf.Reader(str(shpname))
dtype = [
(str(f[0]), get_pyshp_field_dtypes(f[1])) for f in sfobj.fields[1:]
]
Expand Down Expand Up @@ -540,11 +542,11 @@ def recarray2shp(
list of shapefile.Shape objects, or geojson geometry collection
The number of geometries in geoms must equal the number of records in
recarray.
shpname : str
shpname : Path or str, default "recarray.shp"
Path for the output shapefile
epsg : int
EPSG code. See https://www.epsg-registry.org/ or spatialreference.org
prj : str
prj : Path or str
Existing projection file to be used with new shapefile.
Notes
Expand Down Expand Up @@ -578,7 +580,7 @@ def recarray2shp(

# set up for pyshp 2
shapefile = import_optional_dependency("shapefile")
w = shapefile.Writer(shpname, shapeType=geomtype)
w = shapefile.Writer(str(shpname), shapeType=geomtype)
w.autoBalance = 1

# set up the attribute fields
Expand Down Expand Up @@ -615,7 +617,7 @@ def recarray2shp(

def write_prj(shpname, mg=None, epsg=None, prj=None, wkt_string=None):
# projection file name
prjname = shpname.replace(".shp", ".prj")
prjname = Path(shpname).with_suffix(".prj")

# figure which CRS option to use
# prioritize args over grid reference
Expand All @@ -626,10 +628,10 @@ def write_prj(shpname, mg=None, epsg=None, prj=None, wkt_string=None):
prjtxt = CRS.getprj(epsg)
# copy a supplied prj file
elif prj is not None:
if os.path.exists(prjname):
print(".prj file {} already exists ".format(prjname))
if prjname.exists():
print(f".prj file {prjname} already exists")
else:
shutil.copy(prj, prjname)
shutil.copy(str(prj), str(prjname))

elif mg is not None:
if mg.epsg is not None:
Expand All @@ -643,8 +645,7 @@ def write_prj(shpname, mg=None, epsg=None, prj=None, wkt_string=None):
"(writing .prj files from proj4 strings not supported)"
)
if prjtxt is not None:
with open(prjname, "w") as output:
output.write(prjtxt)
prjname.write_text(prjtxt)


class CRS:
Expand Down Expand Up @@ -868,9 +869,10 @@ def getprj(epsg, addlocalreference=True, text="esriwkt"):
addlocalreference : boolean
adds the projection file text associated with epsg to a local
database, epsgref.json, located in the user's data directory.
Returns
-------
prj : str
str
text for a projection (*.prj) file.
"""
Expand All @@ -896,9 +898,10 @@ def get_spatialreference(epsg, text="esriwkt"):
epsg code for coordinate system
text : str
string added to url
Returns
-------
url : str
str
"""
from ..utils.flopy_io import get_url_text
Expand Down Expand Up @@ -938,9 +941,10 @@ def getproj4(epsg):
----------
epsg : int
epsg code for coordinate system
Returns
-------
prj : str
str
text for a projection (*.prj) file.
"""
return CRS.get_spatialreference(epsg, text="proj4")
Expand All @@ -957,22 +961,21 @@ class EpsgReference:

def __init__(self):
if sys.platform.startswith("win"):
flopy_appdata = Path(os.path.expandvars(r"%LOCALAPPDATA%\flopy"))
flopy_appdata = Path(expandvars(r"%LOCALAPPDATA%\flopy"))
else:
flopy_appdata = Path.home() / ".local" / "share" / "flopy"
if not flopy_appdata.exists():
flopy_appdata.mkdir(parents=True, exist_ok=True)
dbname = "epsgref.json"
self.location = str(flopy_appdata / dbname)
self.location = flopy_appdata / dbname

def to_dict(self):
"""
returns dict with EPSG code integer key, and WKT CRS text
"""
data = {}
if os.path.exists(self.location):
with open(self.location, "r") as f:
loaded_data = json.load(f)
if self.location.exists():
loaded_data = json.loads(self.location.read_text())
# convert JSON key from str to EPSG integer
for key, value in loaded_data.items():
try:
Expand All @@ -982,15 +985,15 @@ def to_dict(self):
return data

def _write(self, data):
with open(self.location, "w") as f:
with self.location.open("w") as f:
json.dump(data, f, indent=0)
f.write("\n")

def reset(self, verbose=True):
if os.path.exists(self.location):
if self.location.exists():
if verbose:
print(f"Resetting {self.location}")
os.remove(self.location)
self.location.unlink()
elif verbose:
print(f"{self.location} does not exist, no reset required")

Expand Down
8 changes: 4 additions & 4 deletions flopy/mt3d/mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Mt3dms(BaseModel):
(False, default).
version : str, default "mt3dms"
Mt3d version. Choose one of: "mt3dms" (default) or "mt3d-usgs".
exe_name : str, default "mt3dms.exe"
exe_name : str, default "mt3dms"
The name of the executable to use.
structured : bool, default True
Specify if model grid is structured (default) or unstructured.
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(
ftlfilename="mt3d_link.ftl",
ftlfree=False,
version="mt3dms",
exe_name="mt3dms.exe",
exe_name="mt3dms",
structured=True,
listunit=16,
ftlunit=10,
Expand Down Expand Up @@ -444,7 +444,7 @@ def load(
cls,
f,
version="mt3dms",
exe_name="mt3dms.exe",
exe_name="mt3dms",
verbose=False,
model_ws=".",
load_only=None,
Expand All @@ -460,7 +460,7 @@ def load(
Path to MT3D name file to load.
version : str, default "mt3dms"
Mt3d version. Choose one of: "mt3dms" (default) or "mt3d-usgs".
exe_name : str, default "mt3dms.exe"
exe_name : str, default "mt3dms"
The name of the executable to use.
verbose : bool, default False
Print information on the load process if True.
Expand Down

0 comments on commit 797678b

Please sign in to comment.