Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(performance): Implemented performance improvements including improvements suggested by @briochh (#1092) #1097

Merged
merged 2 commits into from
Apr 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions flopy/mf6/coordinates/modelgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,13 +594,14 @@ def get_model_dim_names(self):
return ["node"]

def get_num_spatial_coordinates(self):
if self.grid_type() == DiscretizationType.DIS:
grid_type = self.grid_type()
if grid_type == DiscretizationType.DIS:
return 3
elif self.grid_type() == DiscretizationType.DISV:
elif grid_type == DiscretizationType.DISV:
return 2
elif (
self.grid_type() == DiscretizationType.DISU
or self.grid_type() == DiscretizationType.DISL
grid_type == DiscretizationType.DISU
or grid_type == DiscretizationType.DISL
):
return 1

Expand Down
50 changes: 49 additions & 1 deletion flopy/mf6/data/mfdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ...datbase import DataInterface, DataType
from .mfdatastorage import DataStructureType
from .mfdatautil import to_string
from ...mbase import ModelInterface


class MFTransient:
Expand Down Expand Up @@ -256,6 +257,10 @@ def __init__(
self._structure_init()
# tie this to the simulation dictionary
sim_data.mfdata[self._path] = self
# set up model grid caching
self._cache_next_grid = False
self._grid_cached = False
self._cached_model_grid = None

def __repr__(self):
return repr(self._get_storage_obj())
Expand Down Expand Up @@ -300,6 +305,20 @@ def plottable(self):
"must define plottable in child " "class to use this base class"
)

@property
def _cache_model_grid(self):
return self._cache_next_grid

@_cache_model_grid.setter
def _cache_model_grid(self, cache_model_grid):
if cache_model_grid:
self._cache_next_grid = True
self._grid_cached = False
else:
self._cache_next_grid = False
self._grid_cached = False
self._cached_model_grid = None

def _resync(self):
model = self.model
if model is not None:
Expand Down Expand Up @@ -440,6 +459,31 @@ def is_valid(self):
# TODO: Implement for each data type
return self._valid

def _get_model_grid(self):
mg = None
if (
self._cache_next_grid
or not self._grid_cached
or self._cached_model_grid is None
):
# construct a new model grid
if isinstance(self._model_or_sim, ModelInterface) and hasattr(
self._model_or_sim, "modelgrid"
):
# get model grid info
mg = self._model_or_sim.modelgrid
else:
mg = None
if self._grid_cached and self._cached_model_grid is not None:
# get the model grid from cache
mg = self._cached_model_grid
elif self._cache_next_grid:
# cache the existing model grid
self._cached_model_grid = mg
self._grid_cached = mg is not None
self._cache_next_grid = False
return mg

def _structure_init(self, data_set=None):
if data_set is None:
# Initialize variables
Expand All @@ -466,7 +510,11 @@ def _get_constant_formatting_string(
sim_data = self._simulation_data
const_format = list(sim_data.constant_formatting)
const_format[1] = to_string(
const_val, data_type, self._simulation_data, self._data_dimensions
const_val,
data_type,
self._simulation_data,
self._data_dimensions,
verify_data=self._simulation_data.verify_data,
)
return "{}{}".format(sim_data.indent_string.join(const_format), suffix)

Expand Down
2 changes: 1 addition & 1 deletion flopy/mf6/data/mfdataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,7 +1265,7 @@ def plot(
"not available."
)

modelgrid = self.model.modelgrid
modelgrid = self._get_model_grid()
a = self.array
num_plottable_layers = modelgrid.get_number_plottable_layers(a)

Expand Down
30 changes: 23 additions & 7 deletions flopy/mf6/data/mfdatalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,10 @@ def _set_data(self, data, autofill=False, check_data=True):
data_check = None
else:
data_check = data
if data_check is None or (
isinstance(data_check, list) and len(data_check) == 0
if (
data_check is None
or not self._simulation_data.verify_data
or (isinstance(data_check, list) and len(data_check) == 0)
):
check_data = False
if iterable(data_check) and check_data:
Expand Down Expand Up @@ -428,7 +430,7 @@ def _set_data(self, data, autofill=False, check_data=True):
self._simulation_data.debug,
ex,
)
if check_data:
if check_data and self._simulation_data.verify_data:
# verify cellids
self._check_valid_cellids()

Expand All @@ -438,7 +440,7 @@ def _check_valid_cellids(self):
self._model_or_sim, "modelgrid"
):
# get model grid info
mg = self._model_or_sim.modelgrid
mg = self._get_model_grid()
if not mg.is_complete:
return
idomain = mg.idomain
Expand Down Expand Up @@ -530,8 +532,8 @@ def _check_line_size(self, data_line, min_line_size):
self._simulation_data.debug,
)

def set_data(self, data, autofill=False):
self._set_data(data, autofill)
def set_data(self, data, autofill=False, check_data=True):
self._set_data(data, autofill, check_data=check_data)

def append_data(self, data):
try:
Expand Down Expand Up @@ -811,6 +813,7 @@ def _get_file_entry_record(
data_item.is_cellid,
data_item.possible_cellid,
data_item,
self._simulation_data.verify_data,
)
)
index += 1
Expand Down Expand Up @@ -988,6 +991,7 @@ def _get_file_entry_record(
k_data_item.is_cellid,
k_data_item.possible_cellid,
k_data_item,
self._simulation_data.verify_data,
)
)
except Exception as ex:
Expand Down Expand Up @@ -1047,6 +1051,7 @@ def _get_file_entry_record(
self._data_dimensions,
False,
data_item=data_item,
verify_data=self._simulation_data.verify_data,
)
)
index += 1
Expand All @@ -1060,6 +1065,7 @@ def _get_file_entry_record(
data_item.is_cellid,
data_item.possible_cellid,
data_item,
self._simulation_data.verify_data,
)
)
except Exception as ex:
Expand Down Expand Up @@ -1455,6 +1461,7 @@ def store_as_external_file(
replace_existing_external=True,
check_data=True,
):
self._cache_model_grid = True
sim_time = self._data_dimensions.package_dim.model_dim[
0
].simulation_time
Expand All @@ -1478,6 +1485,7 @@ def store_as_external_file(
replace_existing_external,
check_data,
)
self._cache_model_grid = False

def get_data(self, key=None, apply_mult=False, **kwargs):
if self._data_storage is not None and len(self._data_storage) > 0:
Expand Down Expand Up @@ -1509,6 +1517,7 @@ def get_data(self, key=None, apply_mult=False, **kwargs):
return None

def set_data(self, data, key=None, autofill=False):
self._cache_model_grid = True
if isinstance(data, dict) or isinstance(data, OrderedDict):
if "filename" not in data:
# each item in the dictionary is a list for one stress period
Expand All @@ -1523,8 +1532,14 @@ def set_data(self, data, key=None, autofill=False):
self.empty_keys[key] = True
else:
self.empty_keys[key] = False
if "check" in list_item:
check = list_item["check"]
else:
check = True
self._set_data_prep(list_item, key)
super().set_data(list_item, autofill=autofill)
super().set_data(
list_item, autofill=autofill, check_data=check
)
for key in del_keys:
del data[key]
else:
Expand All @@ -1548,6 +1563,7 @@ def set_data(self, data, key=None, autofill=False):
else:
self._set_data_prep(data, key)
super().set_data(data, autofill)
self._cache_model_grid = False

def get_file_entry(
self, key=0, ext_file_action=ExtFileAction.copy_relative_paths
Expand Down
2 changes: 2 additions & 0 deletions flopy/mf6/data/mfdatascalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def get_file_entry(
self._simulation_data,
self._data_dimensions,
data_item=data_item,
verify_data=self._simulation_data.verify_data,
)
)
except Exception as ex:
Expand Down Expand Up @@ -529,6 +530,7 @@ def get_file_entry(
self._simulation_data,
self._data_dimensions,
data_item=data_item,
verify_data=self._simulation_data.verify_data,
)
except Exception as ex:
message = (
Expand Down
14 changes: 8 additions & 6 deletions flopy/mf6/data/mfdatautil.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def to_string(
is_cellid=False,
possible_cellid=False,
data_item=None,
verify_data=True,
):
if data_type == DatumType.double_precision:
if data_item is not None and data_item.support_negative_index:
Expand Down Expand Up @@ -158,10 +159,14 @@ def to_string(
elif is_cellid or (possible_cellid and isinstance(val, tuple)):
if DatumUtil.is_int(val):
return str(val + 1)
if len(val) > 0 and isinstance(val, str) and val.lower() == "none":
if len(val) == 4 and isinstance(val, str) and val.lower() == "none":
# handle case that cellid is 'none'
return val
if is_cellid and data_dim.get_model_dim(None).model_name is not None:
if (
verify_data
and is_cellid
and data_dim.get_model_dim(None).model_name is not None
):
model_grid = data_dim.get_model_grid()
cellid_size = model_grid.get_num_spatial_coordinates()
if len(val) != cellid_size:
Expand Down Expand Up @@ -199,10 +204,7 @@ def to_string(
return " ".join(string_val)
elif data_type == DatumType.integer:
if data_item is not None and data_item.numeric_index:
if isinstance(val, str):
return str(int(val) + 1)
else:
return str(int(val) + 1)
return str(int(val) + 1)
return str(int(val))
elif data_type == DatumType.string:
try:
Expand Down
1 change: 1 addition & 0 deletions flopy/mf6/data/mffileaccess.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def get_data_string(self, data, data_type, data_indent=""):
self._simulation_data,
self._data_dimensions,
is_cellid,
verify_data=self._simulation_data.verify_data,
)
except Exception as ex:
type_, value_, traceback_ = sys.exc_info()
Expand Down
21 changes: 11 additions & 10 deletions flopy/utils/datautil.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,34 +265,34 @@ def split_data_line(line, external_file=False, delimiter_conf_length=15):
# consistent delimiter has been found. continue using that
# delimiter without doing further checks
if PyListUtil.delimiter_used is None:
comment_split = line.strip().split("#", 1)
comment_split = line.split("#", 1)
clean_line = comment_split[0].strip().split()
else:
comment_split = line.strip().split("#", 1)
comment_split = line.split("#", 1)
clean_line = (
comment_split[0].strip().split(PyListUtil.delimiter_used)
)
if len(comment_split) > 1:
clean_line.append("#")
clean_line.append(comment_split[1])
clean_line.append(comment_split[1].strip())
else:
# compare against the default split option without comments split
comment_split = line.strip().split("#", 1)
comment_split = line.split("#", 1)
clean_line = comment_split[0].strip().split()
if len(comment_split) > 1:
clean_line.append("#")
clean_line.append(comment_split[1])
clean_line.append(comment_split[1].strip())
# try different delimiters and use the one the breaks the data
# apart the most
max_split_size = len(clean_line)
max_split_type = None
max_split_list = clean_line
for delimiter in PyListUtil.delimiter_list:
comment_split = line.strip().split("#")
comment_split = line.split("#")
alt_split = comment_split[0].strip().split(delimiter)
if len(comment_split) > 1:
alt_split.append("#")
alt_split.append(comment_split[1])
alt_split.append(comment_split[1].strip())
alt_split_len = len(alt_split)
if alt_split_len > max_split_size:
max_split_size = len(alt_split)
Expand All @@ -318,7 +318,8 @@ def split_data_line(line, external_file=False, delimiter_conf_length=15):
arr_fixed_line = []
index = 0
# loop through line to fix quotes and delimiters
while index < len(clean_line):
len_cl = len(clean_line)
while index < len_cl:
item = clean_line[index]
if item and item not in PyListUtil.delimiter_list:
if item and item[0] in PyListUtil.quote_list:
Expand All @@ -328,9 +329,9 @@ def split_data_line(line, external_file=False, delimiter_conf_length=15):
else:
arr_fixed_line.append(item[1:])
# loop until trailing quote found
while index < len(clean_line):
while index < len_cl:
index += 1
if index < len(clean_line):
if index < len_cl:
item = clean_line[index]
if item[-1] in PyListUtil.quote_list:
arr_fixed_line[-1] = "{} {}".format(
Expand Down