-
Notifications
You must be signed in to change notification settings - Fork 135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Alfoa/go on ensemble model #504
Changes from 9 commits
bd79d3c
65cb763
e352308
364de43
ce1e97b
0085807
cac6a38
171678f
807b10e
39199b4
bb74d5f
e435188
15df01a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -190,7 +190,7 @@ def initialize(self,runInfo,inputs,initDict=None): | |
self.cvInstance = self.retrieveObjectFromAssemblerDict('CV', self.cvInstance) | ||
self.cvInstance.initialize(runInfo, inputs, initDict) | ||
self.targetEvaluationInstance = self.retrieveObjectFromAssemblerDict('TargetEvaluation', self.targetEvaluationInstance) | ||
if not self.targetEvaluationInstance.isItEmpty(): | ||
if len(self.targetEvaluationInstance): | ||
self.raiseAWarning("The provided TargetEvaluation data object is not empty, the existing data will also be used to train the ROMs!") | ||
self.existTrainSize = len(self.targetEvaluationInstance) | ||
self.tempTargetEvaluation = copy.deepcopy(self.targetEvaluationInstance) | ||
|
@@ -204,8 +204,8 @@ def initialize(self,runInfo,inputs,initDict=None): | |
romInfo['Instance'] = self.retrieveObjectFromAssemblerDict('ROM', romName) | ||
if romInfo['Instance'] is None: | ||
self.raiseAnError(IOError, 'ROM XML block needs to be inputted!') | ||
modelInputs = self.targetEvaluationInstance.getParaKeys("inputs") | ||
modelOutputs = self.targetEvaluationInstance.getParaKeys("outputs") | ||
modelInputs = self.targetEvaluationInstance.getVars("input") | ||
modelOutputs = self.targetEvaluationInstance.getVars("output") | ||
modelName = self.modelInstance.name | ||
totalRomOutputs = [] | ||
for romInfo in self.romsDictionary.values(): | ||
|
@@ -365,13 +365,17 @@ def isRomConverged(self, outputDict): | |
@ Out, converged, bool, True if the rom is converged | ||
""" | ||
converged = True | ||
for targetName, metricInfo in outputDict.items(): | ||
if len(metricInfo.keys()) > 1: | ||
# very temporary solution | ||
exploredTargets = [] | ||
for cvKey, metricValues in outputDict.items(): | ||
#for targetName, metricInfo in outputDict.items(): | ||
# very temporary solution | ||
info = self.cvInstance.interface._returnCharacteristicsOfCvGivenOutputName(cvKey) | ||
if info['targetName'] in exploredTargets: | ||
self.raiseAnError(IOError, "Multiple metrics are used in cross validation '", self.cvInstance.name, "'. Currently, this can not be processed by the HybridModel '", self.name, "'!") | ||
for metricName, metricValues in metricInfo.items(): | ||
name = self.cvInstance.interface.metricsDict.keys()[0] | ||
metricType = metricName[len(name)+1:] | ||
converged = self.checkErrors(metricType, metricValues) | ||
exploredTargets.append(info['targetName']) | ||
name = self.cvInstance.interface.metricsDict.keys()[0] | ||
converged = self.checkErrors(info['metricType'], metricValues) | ||
return converged | ||
|
||
def checkErrors(self, metricType, metricResults): | ||
|
@@ -381,7 +385,7 @@ def checkErrors(self, metricType, metricResults): | |
@ In, metricResults, list or dict | ||
@ Out, converged, bool, True if the metric outputs are less than the tolerance | ||
""" | ||
if type(metricResults) == list: | ||
if type(metricResults) == list or isinstance(metricResults,np.ndarray): | ||
errorList = np.atleast_1d(metricResults) | ||
elif type(metricResults) == dict: | ||
errorList = np.atleast_1d(metricResults.values()) | ||
|
@@ -602,7 +606,7 @@ def evaluateSample(self, myInput, samplerType, kwargs): | |
@ In, samplerType, string, is the type of sampler that is calling to generate a new input | ||
@ In, kwargs, dict, is a dictionary that contains the information coming from the sampler, | ||
a mandatory key is the sampledVars'that contains a dictionary {'name variable':value} | ||
@ Out, returnValue, dict, This holds the output information of the evaluated sample. | ||
@ Out, rlz, dict, This holds the output information of the evaluated sample. | ||
""" | ||
self.raiseADebug("Evaluate Sample") | ||
kwargsKeys = kwargs.keys() | ||
|
@@ -611,8 +615,13 @@ def evaluateSample(self, myInput, samplerType, kwargs): | |
jobHandler = kwargs['jobHandler'] | ||
newInput = self.createNewInput(myInput, samplerType, **kwargsToKeep) | ||
## Unpack the specifics for this class, namely just the jobHandler | ||
returnValue = (newInput,self._externalRun(newInput,jobHandler)) | ||
return returnValue | ||
result = self._externalRun(newInput,jobHandler) | ||
# assure rlz has all metadata | ||
rlz = dict((var,np.atleast_1d(kwargsToKeep[var])) for var in kwargsToKeep.keys()) | ||
# update rlz with input space from inRun and output space from result | ||
rlz.update(dict((var,np.atleast_1d(kwargsToKeep['SampledVars'][var] if var in kwargs['SampledVars'] else result[var])) for var in set(result.keys()+kwargsToKeep['SampledVars'].keys()))) | ||
|
||
return rlz | ||
|
||
def _externalRun(self,inRun, jobHandler): | ||
""" | ||
|
@@ -653,7 +662,7 @@ def _externalRun(self,inRun, jobHandler): | |
if isinstance(evaluation, Runners.Error): | ||
self.raiseAnError(RuntimeError, "The job identified by "+finishedRun.identifier+" failed!") | ||
# collect output in temporary data object | ||
tempExportDict = self.createExportDictionaryFromFinishedJob(finishedRun, False) | ||
tempExportDict = evaluation #self.createExportDictionaryFromFinishedJob(finishedRun, False) | ||
exportDict = self.__mergeDict(exportDict, tempExportDict) | ||
if jobHandler.areTheseJobsFinished(uniqueHandler=uniqueHandler): | ||
self.raiseADebug("Jobs with uniqueHandler ", uniqueHandler, "are collected!") | ||
|
@@ -680,7 +689,7 @@ def _externalRun(self,inRun, jobHandler): | |
if isinstance(evaluation, Runners.Error): | ||
self.raiseAnError(RuntimeError, "The model "+self.modelInstance.name+" identified by "+finishedRun[0].identifier+" failed!") | ||
# collect output in temporary data object | ||
exportDict = self.modelInstance.createExportDictionaryFromFinishedJob(finishedRun[0], True) | ||
exportDict = evaluation #self.modelInstance.createExportDictionaryFromFinishedJob(finishedRun[0], True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the commented line? |
||
self.raiseADebug("Create exportDict") | ||
# used in the collectOutput | ||
exportDict['useROM'] = useROM | ||
|
@@ -696,17 +705,19 @@ def collectOutput(self,finishedJob,output): | |
evaluation = finishedJob.getEvaluation() | ||
if isinstance(evaluation, Runners.Error): | ||
self.raiseAnError(RuntimeError,"Job " + finishedJob.identifier +" failed!") | ||
exportDict = evaluation[1] | ||
useROM = exportDict['useROM'] | ||
#exportDict = evaluation[1] | ||
#exportDict = evaluation[1] | ||
useROM = evaluation['useROM'] | ||
try: | ||
jobIndex = self.tempOutputs['uncollectedJobIds'].index(finishedJob.identifier) | ||
self.tempOutputs['uncollectedJobIds'].pop(jobIndex) | ||
except ValueError: | ||
jobIndex = None | ||
if jobIndex is not None and not useROM: | ||
self.collectOutputFromDict(exportDict, self.tempTargetEvaluation) | ||
self.tempTargetEvaluation.addRealization(evaluation) | ||
#self.collectOutputFromDict(exportDict, self.tempTargetEvaluation) | ||
self.raiseADebug("ROM is invalid, collect ouptuts of Model with job identifier: {}".format(finishedJob.identifier)) | ||
Dummy.collectOutput(self, finishedJob, output, options = {'exportDict':exportDict}) | ||
Dummy.collectOutput(self, finishedJob, output )#, options = {'exportDict':exportDict}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as before. |
||
|
||
def __mergeDict(self,exportDict, tempExportDict): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -148,7 +148,7 @@ | |
</PointSet> | ||
<PointSet inputTs="2" name="Pointset_from_database_for_rom_trainer"> | ||
<Input>DeltaTimeScramToAux,DG1recoveryTime</Input> | ||
<Output>CladTempThreshold</Output> | ||
<Output>CladTempThreshold,time</Output> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does the variable time need to be added? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed...but the test does not work still... |
||
</PointSet> | ||
<PointSet historyName="1" inputTs="2" name="data_for_sampling_empty_at_begin"> | ||
<Input>DeltaTimeScramToAux,DG1recoveryTime</Input> | ||
|
@@ -158,17 +158,17 @@ | |
<Input>DeltaTimeScramToAux,DG1recoveryTime</Input> | ||
<Output>OutputPlaceHolder</Output> | ||
</PointSet> | ||
<PointSet inputTs="2" name="outputMontecarloRom"> | ||
<PointSet name="outputMontecarloRom"> | ||
<Input>DeltaTimeScramToAux,DG1recoveryTime</Input> | ||
<Output>CladTempThreshold</Output> | ||
</PointSet> | ||
<HistorySet name="outputMontecarloRomHS"> | ||
<Input>DeltaTimeScramToAux,DG1recoveryTime</Input> | ||
<Output>CladTempThreshold</Output> | ||
<Output>CladTempThreshold,time</Output> | ||
</HistorySet> | ||
<PointSet inputTs="2" name="outputMontecarloRomND"> | ||
<Input>DeltaTimeScramToAux,DG1recoveryTime</Input> | ||
<Output>CladTempThreshold</Output> | ||
<Output>CladTempThreshold,time</Output> | ||
</PointSet> | ||
</DataObjects> | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,41 @@ | ||
rightTemperature,leftTemperature,k | ||
1261.31072723,769.597955763,0.0322050521833 | ||
1174.50459336,1525.24138517,0.0322050521833 | ||
1571.16098317,1648.6958241,0.0322050521833 | ||
1269.60940817,1125.61331285,0.0322050521833 | ||
997.315857489,500.262433477,0.0322050521833 | ||
511.750896278,1179.76339298,0.0322050521833 | ||
1395.79177603,1118.11303082,0.0322050521833 | ||
1335.84571491,1425.07522938,0.0322050521833 | ||
1010.43601057,1138.74257958,0.0322050521833 | ||
1350.63522599,1347.57975741,0.0322050521833 | ||
1054.06268615,726.586376928,0.0322050521833 | ||
1620.82760677,1646.20702526,0.0322050521833 | ||
1667.87949651,1187.79552828,0.0322050521833 | ||
1603.45520049,783.539288371,0.0322050521833 | ||
1508.78469344,1122.47902309,0.0322050521833 | ||
1130.525271,1019.85670256,0.0322050521833 | ||
1002.19029516,669.596191442,0.0322050521833 | ||
1032.38404806,1446.77175976,0.0322050521833 | ||
1310.60222658,1144.03216956,0.0322050521833 | ||
1199.90569556,1367.12919121,0.0322050521833 | ||
601.558282064,1015.01618226,0.0322050521833 | ||
1593.72978217,1632.60342161,0.0322050521833 | ||
861.623179438,1310.71645683,0.0322050521833 | ||
1161.65727038,664.054127355,0.0322050521833 | ||
908.710236011,1200.45852817,0.0322050521833 | ||
914.092090077,1361.69643245,0.0322050521833 | ||
1519.01430306,900.562822167,0.0322050521833 | ||
1667.23133288,566.834234276,0.0322050521833 | ||
1006.46963169,696.796910045,0.0322050521833 | ||
1051.59998521,1364.35290008,0.0322050521833 | ||
919.53782598,872.210895822,0.0322050521833 | ||
599.860508018,1672.43886627,0.0322050521833 | ||
1415.49732641,1572.65856552,0.0322050521833 | ||
1378.3528113,873.482695961,0.0298833181459 | ||
1692.56615657,1156.06354164,0.0301285039345 | ||
1362.96807808,520.29037057,0.0322050521833 | ||
766.296345337,1333.69837553,0.0320933195934 | ||
1496.13095052,1258.43668635,0.0295905227099 | ||
1042.90632357,1526.60560026,0.0322050521833 | ||
1676.54607258,989.114559183,0.0316390586337 | ||
leftTemperature,rightTemperature,k | ||
769.597955763,1261.31072723,0.0322050521833 | ||
1525.24138517,1174.50459336,0.0322050521833 | ||
1648.6958241,1571.16098317,0.0322050521833 | ||
1125.61331285,1269.60940817,0.0322050521833 | ||
500.262433477,997.315857489,0.0322050521833 | ||
1179.76339298,511.750896278,0.0322050521833 | ||
1118.11303082,1395.79177603,0.0322050521833 | ||
1425.07522938,1335.84571491,0.0322050521833 | ||
1138.74257958,1010.43601057,0.0322050521833 | ||
1347.57975741,1350.63522599,0.0322050521833 | ||
726.586376928,1054.06268615,0.0322050521833 | ||
1646.20702526,1620.82760677,0.0322050521833 | ||
1187.79552828,1667.87949651,0.0322050521833 | ||
783.539288371,1603.45520049,0.0322050521833 | ||
1122.47902309,1508.78469344,0.0322050521833 | ||
1019.85670256,1130.525271,0.0322050521833 | ||
669.596191442,1002.19029516,0.0322050521833 | ||
1446.77175976,1032.38404806,0.0322050521833 | ||
1144.03216956,1310.60222658,0.0322050521833 | ||
1367.12919121,1199.90569556,0.0322050521833 | ||
1015.01618226,601.558282064,0.0322050521833 | ||
1632.60342161,1593.72978217,0.0322050521833 | ||
1310.71645683,861.623179438,0.0322050521833 | ||
664.054127355,1161.65727038,0.0322050521833 | ||
1200.45852817,908.710236011,0.0322050521833 | ||
1361.69643245,914.092090077,0.0322050521833 | ||
900.562822167,1519.01430306,0.0322050521833 | ||
566.834234276,1667.23133288,0.0322050521833 | ||
696.796910045,1006.46963169,0.0322050521833 | ||
1364.35290008,1051.59998521,0.0322050521833 | ||
872.210895822,919.53782598,0.0322050521833 | ||
1672.43886627,599.860508018,0.0322050521833 | ||
1572.65856552,1415.49732641,0.0322050521833 | ||
873.482695961,1378.3528113,0.0298833181459 | ||
1156.06354164,1692.56615657,0.0301285039345 | ||
520.29037057,1362.96807808,0.0322050521833 | ||
1333.69837553,766.296345337,0.0320933195934 | ||
1258.43668635,1496.13095052,0.0295905227099 | ||
1526.60560026,1042.90632357,0.0322050521833 | ||
989.114559183,1676.54607258,0.0316390586337 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the comment?