-
Notifications
You must be signed in to change notification settings - Fork 135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
convert CrossValidation PP to use new data objects. #470
Changes from 5 commits
dbe2a10
784825b
05d51a9
e482b68
f36b593
df9f87a
2f7ac97
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,7 +75,7 @@ class cls. | |
("n_iter",InputData.IntegerType), | ||
("test_size",InputData.StringType), | ||
("train_size",InputData.StringType), | ||
("scores",InputData.StringListType)]: | ||
("scores",InputData.StringType)]: | ||
dataType = InputData.parameterInputFactory(name, contentType=inputType) | ||
sciKitLearnInput.addSub(dataType) | ||
|
||
|
@@ -94,14 +94,18 @@ def __init__(self, messageHandler): | |
self.dynamic = False # is it time-dependent? | ||
self.metricsDict = {} # dictionary of metrics that are going to be assembled | ||
self.pivotParameter = None | ||
self.cvScores = [] | ||
self.cvScore = 'average' | ||
# assembler objects to be requested | ||
self.addAssemblerObject('Metric', 'n', True) | ||
# The list of cross validation engine that require the parameter 'n' | ||
# This will be removed if we updated the scikit-learn to version 0.20 | ||
# We will rely on the code to decide the value for the parameter 'n' | ||
self.CVList = ['KFold', 'LeaveOneOut', 'LeavePOut', 'ShuffleSplit'] | ||
self.validMetrics = ['mean_absolute_error', 'explained_variance_score', 'r2_score', 'mean_squared_error', 'median_absolute_error'] | ||
#self.validMetrics = ['mean_absolute_error', 'explained_variance_score', 'r2_score', 'mean_squared_error', 'median_absolute_error'] | ||
# 'median_absolute_error' is removed, the reasons for that are: | ||
# 1. this metric can not accept multiple ouptuts | ||
# 2. we seldom use this metric. | ||
self.validMetrics = ['mean_absolute_error', 'explained_variance_score', 'r2_score', 'mean_squared_error'] | ||
self.invalidRom = ['GaussPolynomialRom', 'HDMRRom'] | ||
|
||
def initialize(self, runInfo, inputs, initDict=None) : | ||
|
@@ -139,18 +143,16 @@ def _handleInput(self, paramInput): | |
@ In, paramInput, ParameterInput, the already parsed input. | ||
@ Out, None | ||
""" | ||
|
||
self.initializationOptionDict = {} | ||
scoreList = ['maximum', 'average', 'median'] | ||
cvNode = paramInput.findFirst('SciKitLearn') | ||
for child in cvNode.subparts: | ||
if child.getName() == 'scores': | ||
for elem in child.value: | ||
score = elem.strip().lower() | ||
if score in scoreList: | ||
self.cvScores.append(score) | ||
else: | ||
self.raiseAnError(IOError, "Unexpected input '", score, "' for XML node 'scores'! Valid inputs include: ", ",".join(scoreList)) | ||
score = child.value.strip().lower() | ||
if score in scoreList: | ||
self.cvScore = score | ||
else: | ||
self.raiseAnError(IOError, "Unexpected input '", score, "' for XML node 'scores'! Valid inputs include: ", ",".join(scoreList)) | ||
break | ||
for child in paramInput.subparts: | ||
if child.getName() == 'SciKitLearn': | ||
|
@@ -185,6 +187,7 @@ def inputToInternal(self, currentInp, full = False): | |
understandable by this pp. | ||
@ In, currentInp, list or DataObject, data object or a list of data objects | ||
@ In, full, bool, optional, True to retrieve the whole input or False to get the last element of the input | ||
TODO, full should be removed | ||
@ Out, newInputs, tuple, (dictionary of input and output data, instance of estimator) | ||
""" | ||
if type(currentInp) != list: | ||
|
@@ -219,42 +222,28 @@ def inputToInternal(self, currentInp, full = False): | |
if type(currentInput) != dict: | ||
dictKeys = list(cvEstimator.initializationOptionDict['Features'].split(',')) + list(cvEstimator.initializationOptionDict['Target'].split(',')) | ||
newInput = dict.fromkeys(dictKeys, None) | ||
if not currentInput.isItEmpty(): | ||
if not len(currentInput) == 0: | ||
dataSet = currentInput.asDataset() | ||
if inputType == 'PointSet': | ||
for elem in currentInput.getParaKeys('inputs'): | ||
if elem in newInput.keys(): | ||
newInput[elem] = copy.copy(np.array(currentInput.getParam('input', elem))[0 if full else -1:]) | ||
for elem in currentInput.getParaKeys('outputs'): | ||
for elem in currentInput.getVars('input') + currentInput.getVars('output'): | ||
if elem in newInput.keys(): | ||
newInput[elem] = copy.copy(np.array(currentInput.getParam('output', elem))[0 if full else -1:]) | ||
newInput[elem] = copy.copy(dataSet[elem].values) | ||
elif inputType == 'HistorySet': | ||
if full: | ||
for hist in range(len(currentInput)): | ||
realization = currentInput.getRealization(hist) | ||
for elem in currentInput.getParaKeys('inputs'): | ||
if elem in newInput.keys(): | ||
if newInput[elem] is None: | ||
newInput[elem] = c1darray(shape = (1,)) | ||
newInput[elem].append(realization['inputs'][elem]) | ||
for elem in currentInput.getParaKeys('outputs'): | ||
if elem in newInput.keys(): | ||
if newInput[elem] is None: | ||
newInput[elem] = [] | ||
newInput[elem].append(realization['outputs'][elem]) | ||
else: | ||
realization = currentInput.getRealization(len(currentInput) - 1) | ||
for elem in currentInput.getParaKeys('inputs'): | ||
sizeIndex = 0 | ||
for hist in range(len(currentInput)): | ||
for elem in currentInput.indexes + currentInput.getVars('outputs'): | ||
if elem in newInput.keys(): | ||
newInput[elem] = [realization['inputs'][elem]] | ||
for elem in currentInput.getParaKeys('outputs'): | ||
if newInput[elem] is None: | ||
newInput[elem] = [] | ||
newInput[elem].append(dataSet.isel(RAVEN_sample_ID=hist)[elem].values) | ||
sizeIndex = len(newInput[elem][-1]) | ||
for elem in currentInput.getVars('input'): | ||
if elem in newInput.keys(): | ||
newInput[elem] = [realization['outputs'][elem]] | ||
if newInput[elem] is None: | ||
newInput[elem] = [] | ||
newInput[elem].append(np.full((sizeIndex,), dataSet.isel(RAVEN_sample_ID=hist)[elem].values)) | ||
else: | ||
self.raiseAnError(IOError, "The input type '", inputType, "' can not be accepted") | ||
#Now if an OutputPlaceHolder is used it is removed, this happens when the input data is not representing is internally manufactured | ||
if 'OutputPlaceHolder' in currentInput.getParaKeys('outputs'): | ||
# this remove the counter from the inputs to be placed among the outputs | ||
newInput.pop('OutputPlaceHolder') | ||
else: | ||
#here we do not make a copy since we assume that the dictionary is for just for the model usage and any changes are not impacting outside | ||
newInput = currentInput | ||
|
@@ -306,45 +295,40 @@ def run(self, inputIn): | |
break | ||
if cvEngine is None: | ||
self.raiseAnError(IOError, "No cross validation engine is provided!") | ||
|
||
outputDict = {} | ||
# construct matrix and pass matrix | ||
for trainIndex, testIndex in cvEngine.generateTrainTestIndices(): | ||
trainDict, testDict = self.__generateTrainTestInputs(inputDict, trainIndex, testIndex) | ||
## Train the rom | ||
cvEstimator.train(trainDict) | ||
## evaluate the rom | ||
outputEvaluation = cvEstimator.evaluate(testDict) | ||
## Compute the distance between ROM and given data using Metric system | ||
for targetName, targetValue in outputEvaluation.items(): | ||
if targetName not in outputDict.keys(): | ||
outputDict[targetName] = {} | ||
for metricInstance in self.metricsDict.values(): | ||
metricValue = metricInstance.distance(targetValue, testDict[targetName]) | ||
if hasattr(metricInstance, 'metricType'): | ||
if metricInstance.metricType not in self.validMetrics: | ||
self.raiseAnError(IOError, "The metric type: ", metricInstance.metricType, " can not be used, the accepted metric types are: ", str(self.validMetrics)) | ||
metricName = metricInstance.metricType | ||
else: | ||
metricName = metricInstance.type | ||
metricName = metricInstance.name + '_' + metricName | ||
if metricName not in outputDict[targetName].keys(): | ||
outputDict[targetName][metricName] = [] | ||
outputDict[targetName][metricName].append(metricValue[0]) | ||
cvOutputs = np.asarray(outputEvaluation.values()).T | ||
testOutputs = np.asarray([testDict[var] for var in outputEvaluation.keys()]).T | ||
for metricInstance in self.metricsDict.values(): | ||
metricValue = metricInstance.distance(cvOutputs, testOutputs) | ||
if hasattr(metricInstance, 'metricType'): | ||
if metricInstance.metricType not in self.validMetrics: | ||
self.raiseAnError(IOError, "The metric type: ", metricInstance.metricType, " can not be used, the accepted metric types are: ", str(self.validMetrics)) | ||
else: | ||
self.raiseAnError(IOError, "The metric: ", metricInstance.name, " can not be used, the accepted metric types are: ", str(self.validMetrics)) | ||
|
||
varName = 'cv' + '_' + metricInstance.name + '_' + cvEstimator.name | ||
if varName not in outputDict.keys(): | ||
outputDict[varName] = [] | ||
outputDict[varName].append(metricValue[0]) | ||
|
||
scoreDict = {} | ||
if not self.cvScores: | ||
return outputDict | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So you removed the possibility to get the scores for each fold? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you put the "update documenation" in the todolist for this work? (The shared sheet in google doc) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think this PP should return a single quantity to tell the user if the ROM is good or not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have updated the shared excel sheet. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree it should by default, but honestly I can use the cross validation (k-fold) also as metric to understand how homogeneously my regression is able to capture the underlying model. I would have changed the default (to be the single scalar) but not remove the possibility to get the stats from all the folds... |
||
else: | ||
for targetName, metricInfo in outputDict.items(): | ||
scoreDict[targetName] = {} | ||
for metricName, metricValues in metricInfo.items(): | ||
scoreDict[targetName][metricName] = {} | ||
for cvScore in self.cvScores: | ||
if cvScore == 'maximum': | ||
scoreDict[targetName][metricName][cvScore] = np.amax(np.atleast_1d(metricValues)) | ||
elif cvScore == 'median': | ||
scoreDict[targetName][metricName][cvScore] = np.median(np.atleast_1d(metricValues)) | ||
elif cvScore == 'average': | ||
scoreDict[targetName][metricName][cvScore] = np.mean(np.atleast_1d(metricValues)) | ||
return scoreDict | ||
for varName, metricValues in outputDict.items(): | ||
if self.cvScore.lower() == 'maximum': | ||
scoreDict[varName] = np.atleast_1d(np.amax(np.atleast_1d(metricValues))) | ||
elif self.cvScore.lower() == 'median': | ||
scoreDict[varName] = np.atleast_1d(np.median(np.atleast_1d(metricValues))) | ||
else: | ||
scoreDict[varName] = np.atleast_1d(np.mean(np.atleast_1d(metricValues))) | ||
|
||
return scoreDict | ||
|
||
def collectOutput(self,finishedJob, output): | ||
""" | ||
|
@@ -358,89 +342,4 @@ def collectOutput(self,finishedJob, output): | |
self.raiseAnError(RuntimeError, ' No available output to collect') | ||
outputDict = evaluation[1] | ||
|
||
if isinstance(output, Files.File): | ||
availExtens = ['xml', 'csv'] | ||
outputExtension = output.getExt().lower() | ||
if outputExtension not in availExtens: | ||
self.raiseAMessage('Cross Validation postprocessor did not recognize extension ".', str(outputExtension), '". The output will be dumped to a text file') | ||
output.setPath(self._workingDir) | ||
self.raiseADebug('Write Cross Validation prostprocessor output in file with name: ', output.getAbsFile()) | ||
output.open('w') | ||
if outputExtension == 'xml': | ||
self._writeXML(output, outputDict) | ||
else: | ||
separator = ' ' if outputExtension != 'csv' else ',' | ||
self._writeText(output, outputDict, separator) | ||
else: | ||
self.raiseAnError(IOError, 'Output type ', str(output.type), ' can not be used for postprocessor', self.name) | ||
|
||
def _writeXML(self,output,outputDictionary): | ||
""" | ||
Defines the method for writing the post-processor to a .csv file | ||
@ In, output, File object, file to write to | ||
@ In, outputDictionary, dict, dictionary stores cross validation scores | ||
@ Out, None | ||
""" | ||
if output.isOpen(): | ||
output.close() | ||
if self.dynamic: | ||
outputInstance = Files.returnInstance('DynamicXMLOutput', self) | ||
else: | ||
outputInstance = Files.returnInstance('StaticXMLOutput', self) | ||
outputInstance.initialize(output.getFilename(), self.messageHandler, path=output.getPath()) | ||
outputInstance.newTree('CrossValidationPostProcessor', pivotParam=self.pivotParameter) | ||
outputResults = [outputDictionary] if not self.dynamic else outputDictionary.values() | ||
for ts, outputDict in enumerate(outputResults): | ||
pivotVal = outputDictionary.keys()[ts] | ||
for nodeName, nodeValues in outputDict.items(): | ||
for metricName, metricValues in nodeValues.items(): | ||
if self.cvScores: | ||
outputInstance.addVector(nodeName, metricName, metricValues, pivotVal = pivotVal) | ||
else: | ||
cvRuns = ['cv-' + str(i) for i in range(len(metricValues))] | ||
valueDict = dict(zip(cvRuns, metricValues)) | ||
outputInstance.addVector(nodeName, metricName, valueDict, pivotVal = pivotVal) | ||
outputInstance.writeFile() | ||
|
||
def _writeText(self,output,outputDictionary, separator=' '): | ||
""" | ||
Defines the method for writing the post-processor to a .csv file | ||
@ In, output, File object, file to write to | ||
@ In, outputDictionary, dict, dictionary stores metrics' results of outputs | ||
@ In, separator, string, optional, separator string | ||
@ Out, None | ||
""" | ||
if self.dynamic: | ||
output.write('Dynamic Cross Validation', separator, 'Pivot Parameter', separator, self.pivotParameter, separator, os.linesep) | ||
self.raiseAnError(IOError, 'The method to dump the dynamic cross validation results into a csv file is not implemented yet') | ||
|
||
outputResults = [outputDictionary] if not self.dynamic else outputDictionary.values() | ||
for ts, outputDict in enumerate(outputResults): | ||
if self.dynamic: | ||
output.write('Pivot value', separator, str(outputDictionary.keys()[ts]), os.linesep) | ||
nodeNames, nodeValues = outputDict.keys(), outputDict.values() | ||
metricNames = nodeValues[0].keys() | ||
if not self.cvScores: | ||
output.write('CV-Run-Number') | ||
for nodeName in nodeNames: | ||
for metricName in metricNames: | ||
output.write(separator + nodeName + '-' + metricName) | ||
output.write(os.linesep) | ||
for cvRunNum in range(len(nodeValues[0].values()[0])): | ||
output.write(str(cvRunNum)) | ||
for valueDict in nodeValues: | ||
for metricName in metricNames: | ||
output.write(separator+str(valueDict[metricName][cvRunNum])) | ||
output.write(os.linesep) | ||
else: | ||
output.write('ScoreMetric') | ||
for nodeName in nodeNames: | ||
for metricName in metricNames: | ||
output.write(separator + nodeName + '-' + metricName) | ||
output.write(os.linesep) | ||
for score in self.cvScores: | ||
output.write(score) | ||
for valueDict in nodeValues: | ||
for metricName in metricNames: | ||
output.write(separator+str(valueDict[metricName][score])) | ||
output.write(os.linesep) | ||
output.addRealization(outputDict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a dataobject rework TODO, or a future TODO for devel (in which case we might need an issue for it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a data object rework TODO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay. What is the change need before it's acted on? Is this a "once someone has time and needs speed" changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we start to finalize the inputToInternal, and remove the collectOutput. The final clean up I think.