Skip to content

Commit

Permalink
fix(performance): Implemented performance improvements including impr…
Browse files Browse the repository at this point in the history
…ovements suggested by @briochh (#1092) (#1097)

* fix(performance): Implemented performance improvements including improvements suggested by @briochh (#1092)

* fix(performance): Implemented additional performance improvements
  • Loading branch information
spaulins-usgs authored Apr 12, 2021
1 parent 801a417 commit 4104cf5
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 29 deletions.
9 changes: 5 additions & 4 deletions flopy/mf6/coordinates/modelgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,13 +594,14 @@ def get_model_dim_names(self):
return ["node"]

def get_num_spatial_coordinates(self):
if self.grid_type() == DiscretizationType.DIS:
grid_type = self.grid_type()
if grid_type == DiscretizationType.DIS:
return 3
elif self.grid_type() == DiscretizationType.DISV:
elif grid_type == DiscretizationType.DISV:
return 2
elif (
self.grid_type() == DiscretizationType.DISU
or self.grid_type() == DiscretizationType.DISL
grid_type == DiscretizationType.DISU
or grid_type == DiscretizationType.DISL
):
return 1

Expand Down
50 changes: 49 additions & 1 deletion flopy/mf6/data/mfdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ...datbase import DataInterface, DataType
from .mfdatastorage import DataStructureType
from .mfdatautil import to_string
from ...mbase import ModelInterface


class MFTransient:
Expand Down Expand Up @@ -256,6 +257,10 @@ def __init__(
self._structure_init()
# tie this to the simulation dictionary
sim_data.mfdata[self._path] = self
# set up model grid caching
self._cache_next_grid = False
self._grid_cached = False
self._cached_model_grid = None

def __repr__(self):
return repr(self._get_storage_obj())
Expand Down Expand Up @@ -300,6 +305,20 @@ def plottable(self):
"must define plottable in child " "class to use this base class"
)

@property
def _cache_model_grid(self):
return self._cache_next_grid

@_cache_model_grid.setter
def _cache_model_grid(self, cache_model_grid):
if cache_model_grid:
self._cache_next_grid = True
self._grid_cached = False
else:
self._cache_next_grid = False
self._grid_cached = False
self._cached_model_grid = None

def _resync(self):
model = self.model
if model is not None:
Expand Down Expand Up @@ -440,6 +459,31 @@ def is_valid(self):
# TODO: Implement for each data type
return self._valid

def _get_model_grid(self):
mg = None
if (
self._cache_next_grid
or not self._grid_cached
or self._cached_model_grid is None
):
# construct a new model grid
if isinstance(self._model_or_sim, ModelInterface) and hasattr(
self._model_or_sim, "modelgrid"
):
# get model grid info
mg = self._model_or_sim.modelgrid
else:
mg = None
if self._grid_cached and self._cached_model_grid is not None:
# get the model grid from cache
mg = self._cached_model_grid
elif self._cache_next_grid:
# cache the existing model grid
self._cached_model_grid = mg
self._grid_cached = mg is not None
self._cache_next_grid = False
return mg

def _structure_init(self, data_set=None):
if data_set is None:
# Initialize variables
Expand All @@ -466,7 +510,11 @@ def _get_constant_formatting_string(
sim_data = self._simulation_data
const_format = list(sim_data.constant_formatting)
const_format[1] = to_string(
const_val, data_type, self._simulation_data, self._data_dimensions
const_val,
data_type,
self._simulation_data,
self._data_dimensions,
verify_data=self._simulation_data.verify_data,
)
return "{}{}".format(sim_data.indent_string.join(const_format), suffix)

Expand Down
2 changes: 1 addition & 1 deletion flopy/mf6/data/mfdataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,7 +1265,7 @@ def plot(
"not available."
)

modelgrid = self.model.modelgrid
modelgrid = self._get_model_grid()
a = self.array
num_plottable_layers = modelgrid.get_number_plottable_layers(a)

Expand Down
30 changes: 23 additions & 7 deletions flopy/mf6/data/mfdatalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,10 @@ def _set_data(self, data, autofill=False, check_data=True):
data_check = None
else:
data_check = data
if data_check is None or (
isinstance(data_check, list) and len(data_check) == 0
if (
data_check is None
or not self._simulation_data.verify_data
or (isinstance(data_check, list) and len(data_check) == 0)
):
check_data = False
if iterable(data_check) and check_data:
Expand Down Expand Up @@ -428,7 +430,7 @@ def _set_data(self, data, autofill=False, check_data=True):
self._simulation_data.debug,
ex,
)
if check_data:
if check_data and self._simulation_data.verify_data:
# verify cellids
self._check_valid_cellids()

Expand All @@ -438,7 +440,7 @@ def _check_valid_cellids(self):
self._model_or_sim, "modelgrid"
):
# get model grid info
mg = self._model_or_sim.modelgrid
mg = self._get_model_grid()
if not mg.is_complete:
return
idomain = mg.idomain
Expand Down Expand Up @@ -530,8 +532,8 @@ def _check_line_size(self, data_line, min_line_size):
self._simulation_data.debug,
)

def set_data(self, data, autofill=False):
self._set_data(data, autofill)
def set_data(self, data, autofill=False, check_data=True):
self._set_data(data, autofill, check_data=check_data)

def append_data(self, data):
try:
Expand Down Expand Up @@ -811,6 +813,7 @@ def _get_file_entry_record(
data_item.is_cellid,
data_item.possible_cellid,
data_item,
self._simulation_data.verify_data,
)
)
index += 1
Expand Down Expand Up @@ -988,6 +991,7 @@ def _get_file_entry_record(
k_data_item.is_cellid,
k_data_item.possible_cellid,
k_data_item,
self._simulation_data.verify_data,
)
)
except Exception as ex:
Expand Down Expand Up @@ -1047,6 +1051,7 @@ def _get_file_entry_record(
self._data_dimensions,
False,
data_item=data_item,
verify_data=self._simulation_data.verify_data,
)
)
index += 1
Expand All @@ -1060,6 +1065,7 @@ def _get_file_entry_record(
data_item.is_cellid,
data_item.possible_cellid,
data_item,
self._simulation_data.verify_data,
)
)
except Exception as ex:
Expand Down Expand Up @@ -1455,6 +1461,7 @@ def store_as_external_file(
replace_existing_external=True,
check_data=True,
):
self._cache_model_grid = True
sim_time = self._data_dimensions.package_dim.model_dim[
0
].simulation_time
Expand All @@ -1478,6 +1485,7 @@ def store_as_external_file(
replace_existing_external,
check_data,
)
self._cache_model_grid = False

def get_data(self, key=None, apply_mult=False, **kwargs):
if self._data_storage is not None and len(self._data_storage) > 0:
Expand Down Expand Up @@ -1509,6 +1517,7 @@ def get_data(self, key=None, apply_mult=False, **kwargs):
return None

def set_data(self, data, key=None, autofill=False):
self._cache_model_grid = True
if isinstance(data, dict) or isinstance(data, OrderedDict):
if "filename" not in data:
# each item in the dictionary is a list for one stress period
Expand All @@ -1523,8 +1532,14 @@ def set_data(self, data, key=None, autofill=False):
self.empty_keys[key] = True
else:
self.empty_keys[key] = False
if "check" in list_item:
check = list_item["check"]
else:
check = True
self._set_data_prep(list_item, key)
super().set_data(list_item, autofill=autofill)
super().set_data(
list_item, autofill=autofill, check_data=check
)
for key in del_keys:
del data[key]
else:
Expand All @@ -1548,6 +1563,7 @@ def set_data(self, data, key=None, autofill=False):
else:
self._set_data_prep(data, key)
super().set_data(data, autofill)
self._cache_model_grid = False

def get_file_entry(
self, key=0, ext_file_action=ExtFileAction.copy_relative_paths
Expand Down
2 changes: 2 additions & 0 deletions flopy/mf6/data/mfdatascalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def get_file_entry(
self._simulation_data,
self._data_dimensions,
data_item=data_item,
verify_data=self._simulation_data.verify_data,
)
)
except Exception as ex:
Expand Down Expand Up @@ -529,6 +530,7 @@ def get_file_entry(
self._simulation_data,
self._data_dimensions,
data_item=data_item,
verify_data=self._simulation_data.verify_data,
)
except Exception as ex:
message = (
Expand Down
14 changes: 8 additions & 6 deletions flopy/mf6/data/mfdatautil.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def to_string(
is_cellid=False,
possible_cellid=False,
data_item=None,
verify_data=True,
):
if data_type == DatumType.double_precision:
if data_item is not None and data_item.support_negative_index:
Expand Down Expand Up @@ -158,10 +159,14 @@ def to_string(
elif is_cellid or (possible_cellid and isinstance(val, tuple)):
if DatumUtil.is_int(val):
return str(val + 1)
if len(val) > 0 and isinstance(val, str) and val.lower() == "none":
if len(val) == 4 and isinstance(val, str) and val.lower() == "none":
# handle case that cellid is 'none'
return val
if is_cellid and data_dim.get_model_dim(None).model_name is not None:
if (
verify_data
and is_cellid
and data_dim.get_model_dim(None).model_name is not None
):
model_grid = data_dim.get_model_grid()
cellid_size = model_grid.get_num_spatial_coordinates()
if len(val) != cellid_size:
Expand Down Expand Up @@ -199,10 +204,7 @@ def to_string(
return " ".join(string_val)
elif data_type == DatumType.integer:
if data_item is not None and data_item.numeric_index:
if isinstance(val, str):
return str(int(val) + 1)
else:
return str(int(val) + 1)
return str(int(val) + 1)
return str(int(val))
elif data_type == DatumType.string:
try:
Expand Down
1 change: 1 addition & 0 deletions flopy/mf6/data/mffileaccess.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def get_data_string(self, data, data_type, data_indent=""):
self._simulation_data,
self._data_dimensions,
is_cellid,
verify_data=self._simulation_data.verify_data,
)
except Exception as ex:
type_, value_, traceback_ = sys.exc_info()
Expand Down
21 changes: 11 additions & 10 deletions flopy/utils/datautil.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,34 +265,34 @@ def split_data_line(line, external_file=False, delimiter_conf_length=15):
# consistent delimiter has been found. continue using that
# delimiter without doing further checks
if PyListUtil.delimiter_used is None:
comment_split = line.strip().split("#", 1)
comment_split = line.split("#", 1)
clean_line = comment_split[0].strip().split()
else:
comment_split = line.strip().split("#", 1)
comment_split = line.split("#", 1)
clean_line = (
comment_split[0].strip().split(PyListUtil.delimiter_used)
)
if len(comment_split) > 1:
clean_line.append("#")
clean_line.append(comment_split[1])
clean_line.append(comment_split[1].strip())
else:
# compare against the default split option without comments split
comment_split = line.strip().split("#", 1)
comment_split = line.split("#", 1)
clean_line = comment_split[0].strip().split()
if len(comment_split) > 1:
clean_line.append("#")
clean_line.append(comment_split[1])
clean_line.append(comment_split[1].strip())
# try different delimiters and use the one the breaks the data
# apart the most
max_split_size = len(clean_line)
max_split_type = None
max_split_list = clean_line
for delimiter in PyListUtil.delimiter_list:
comment_split = line.strip().split("#")
comment_split = line.split("#")
alt_split = comment_split[0].strip().split(delimiter)
if len(comment_split) > 1:
alt_split.append("#")
alt_split.append(comment_split[1])
alt_split.append(comment_split[1].strip())
alt_split_len = len(alt_split)
if alt_split_len > max_split_size:
max_split_size = len(alt_split)
Expand All @@ -318,7 +318,8 @@ def split_data_line(line, external_file=False, delimiter_conf_length=15):
arr_fixed_line = []
index = 0
# loop through line to fix quotes and delimiters
while index < len(clean_line):
len_cl = len(clean_line)
while index < len_cl:
item = clean_line[index]
if item and item not in PyListUtil.delimiter_list:
if item and item[0] in PyListUtil.quote_list:
Expand All @@ -328,9 +329,9 @@ def split_data_line(line, external_file=False, delimiter_conf_length=15):
else:
arr_fixed_line.append(item[1:])
# loop until trailing quote found
while index < len(clean_line):
while index < len_cl:
index += 1
if index < len(clean_line):
if index < len_cl:
item = clean_line[index]
if item[-1] in PyListUtil.quote_list:
arr_fixed_line[-1] = "{} {}".format(
Expand Down

0 comments on commit 4104cf5

Please sign in to comment.