de.bund.bfr.knime.pmm.common/src/de/bund/bfr/knime/pmm/common/ModelCombiner.java
/*******************************************************************************
* Copyright (c) 2015 Federal Institute for Risk Assessment (BfR), Germany
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
* Contributors:
* Department Biological Safety - BfR
*******************************************************************************/
package de.bund.bfr.knime.pmm.common;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import de.bund.bfr.knime.pmm.common.generictablemodel.KnimeSchema;
import de.bund.bfr.knime.pmm.common.generictablemodel.KnimeTuple;
import de.bund.bfr.knime.pmm.common.math.MathUtilities;
import de.bund.bfr.knime.pmm.common.pmmtablemodel.Model1Schema;
import de.bund.bfr.knime.pmm.common.pmmtablemodel.Model2Schema;
import de.bund.bfr.knime.pmm.common.pmmtablemodel.SchemaFactory;
import de.bund.bfr.knime.pmm.common.pmmtablemodel.TimeSeriesSchema;
import de.bund.bfr.knime.pmm.common.units.Categories;
import de.bund.bfr.knime.pmm.common.units.Category;
import de.bund.bfr.knime.pmm.common.units.ConvertException;
public class ModelCombiner {
private Map<KnimeTuple, List<KnimeTuple>> tupleCombinations;
private Map<KnimeTuple, Map<KnimeTuple, Map<String, String>>> parameterRenaming;
public ModelCombiner(List<KnimeTuple> tuples, boolean containsData, Map<String, String> initParams,
Map<String, String> lagParams) {
if (initParams == null) {
initParams = new LinkedHashMap<>();
}
if (lagParams == null) {
lagParams = new LinkedHashMap<>();
}
tupleCombinations = getTuplesToCombine(tuples, containsData);
parameterRenaming = new LinkedHashMap<>();
for (KnimeTuple newTuple : tupleCombinations.keySet()) {
List<KnimeTuple> usedTuples = tupleCombinations.get(newTuple);
Map<KnimeTuple, Map<String, String>> rename = new LinkedHashMap<>();
for (KnimeTuple tuple : usedTuples) {
rename.put(tuple, new LinkedHashMap<String, String>());
String modelID = ((CatalogModelXml) tuple.getPmmXml(Model1Schema.ATT_MODELCATALOG).get(0)).id + "";
String depVarSec = ((DepXml) tuple.getPmmXml(Model2Schema.ATT_DEPENDENT).get(0)).name;
if (depVarSec.equals(initParams.get(modelID)) || depVarSec.equals(lagParams.get(modelID))) {
continue;
}
String formulaSec = ((CatalogModelXml) tuple.getPmmXml(Model2Schema.ATT_MODELCATALOG).get(0))
.formula;
PmmXmlDoc indepVarsSec = tuple.getPmmXml(Model2Schema.ATT_INDEPENDENT);
PmmXmlDoc paramsSec = tuple.getPmmXml(Model2Schema.ATT_PARAMETER);
for (PmmXmlElementConvertable el : paramsSec.getElementSet()) {
ParamXml element = (ParamXml) el;
int index = 1;
String paramName = element.name;
String newParamName = paramName;
while (CellIO.getNameList(newTuple.getPmmXml(Model1Schema.ATT_PARAMETER)).contains(newParamName)) {
index++;
newParamName = paramName + index;
}
rename.get(tuple).put(paramName, newParamName);
if (index > 1) {
formulaSec = MathUtilities.replaceVariable(formulaSec, paramName, newParamName);
element.name = newParamName;
}
element.correlations.clear();
}
String replacement = "(" + formulaSec.replace(depVarSec + "=", "") + ")";
String formula = ((CatalogModelXml) newTuple.getPmmXml(Model1Schema.ATT_MODELCATALOG).get(0))
.formula;
PmmXmlDoc newParams = newTuple.getPmmXml(Model1Schema.ATT_PARAMETER);
PmmXmlDoc newIndepVars = newTuple.getPmmXml(Model1Schema.ATT_INDEPENDENT);
newParams.getElementSet().remove(CellIO.getNameList(newParams).indexOf(depVarSec));
newParams.getElementSet().addAll(paramsSec.getElementSet());
for (PmmXmlElementConvertable el : indepVarsSec.getElementSet()) {
IndepXml element = (IndepXml) el;
if (!CellIO.getNameList(newIndepVars).contains(element.name)) {
newIndepVars.getElementSet().add(element);
} else {
IndepXml original = null;
for (PmmXmlElementConvertable el2 : newIndepVars.getElementSet()) {
if (((IndepXml) el2).name.equals(element.name)) {
original = (IndepXml) el2;
break;
}
}
Double min = element.min;
Double max = element.max;
if (original.unit != null && !original.unit.equals(element.unit)) {
Category cat = Categories.getCategoryByUnit(original.unit);
try {
String conversion = "(" + cat.getConversionString(element.name, original.unit,
element.unit) + ")";
replacement = MathUtilities.replaceVariable(replacement, element.name, conversion);
min = cat.convert(min, element.unit, original.unit);
max = cat.convert(max, element.unit, original.unit);
} catch (ConvertException e) {
e.printStackTrace();
}
}
if (min != null) {
if (original.min != null) {
original.min = Math.min(original.min, min);
} else {
original.min = min;
}
}
if (max != null) {
if (original.max != null) {
original.max = Math.max(original.max, max);
} else {
original.max = max;
}
}
}
}
PmmXmlDoc modelXml = tuple.getPmmXml(Model1Schema.ATT_MODELCATALOG);
((CatalogModelXml) modelXml.get(0))
.formula = MathUtilities.replaceVariable(formula, depVarSec, replacement);
newTuple.setValue(Model1Schema.ATT_MODELCATALOG, modelXml);
newTuple.setValue(Model1Schema.ATT_INDEPENDENT, newIndepVars);
newTuple.setValue(Model1Schema.ATT_PARAMETER, newParams);
}
int newID = ((CatalogModelXml) newTuple.getPmmXml(Model1Schema.ATT_MODELCATALOG).get(0)).id;
for (KnimeTuple tuple : usedTuples) {
newID += ((CatalogModelXml) tuple.getPmmXml(Model2Schema.ATT_MODELCATALOG).get(0)).id;
}
newID = MathUtilities.generateID(newID);
PmmXmlDoc modelXml = newTuple.getPmmXml(Model1Schema.ATT_MODELCATALOG);
PmmXmlDoc estModelXml = newTuple.getPmmXml(Model1Schema.ATT_ESTMODEL);
((CatalogModelXml) modelXml.get(0)).id = newID;
((EstModelXml) estModelXml.get(0)).id = usedTuples.get(0).getInt(Model2Schema.ATT_GLOBAL_MODEL_ID);
((EstModelXml) estModelXml.get(0)).sse = null;
((EstModelXml) estModelXml.get(0)).rms = null;
((EstModelXml) estModelXml.get(0)).r2 = null;
((EstModelXml) estModelXml.get(0)).aic = null;
newTuple.setValue(Model1Schema.ATT_MODELCATALOG, modelXml);
newTuple.setValue(Model1Schema.ATT_ESTMODEL, estModelXml);
newTuple.setValue(Model1Schema.ATT_DBUUID, null);
newTuple.setValue(Model1Schema.ATT_DATABASEWRITABLE, Model1Schema.NOTWRITABLE);
parameterRenaming.put(newTuple, rename);
}
if (containsData) {
updateMetaData(tuples, tupleCombinations);
}
updateParamValues(tuples, tupleCombinations);
updatePrimaryIndepRanges(tuples, tupleCombinations);
}
public Map<KnimeTuple, List<KnimeTuple>> getTupleCombinations() {
return tupleCombinations;
}
public Map<KnimeTuple, Map<KnimeTuple, Map<String, String>>> getParameterRenaming() {
return parameterRenaming;
}
private static Map<KnimeTuple, List<KnimeTuple>> getTuplesToCombine(List<KnimeTuple> tuples, boolean containsData) {
KnimeSchema outSchema = null;
if (containsData) {
outSchema = SchemaFactory.createM1DataSchema();
} else {
outSchema = SchemaFactory.createM1Schema();
}
Map<String, KnimeTuple> newTuples = new LinkedHashMap<>();
Map<String, List<KnimeTuple>> usedTupleLists = new LinkedHashMap<>();
Map<String, Set<String>> replacements = new LinkedHashMap<>();
for (KnimeTuple tuple : tuples) {
String id = null;
try {
id = String.valueOf(tuple.getInt(Model2Schema.ATT_GLOBAL_MODEL_ID));
} catch (Exception e) {
continue;
}
if (containsData) {
id += "(" + tuple.getInt(TimeSeriesSchema.ATT_CONDID) + ")";
}
if (!newTuples.containsKey(id)) {
KnimeTuple newTuple = new KnimeTuple(outSchema);
for (int i = 0; i < outSchema.size(); i++) {
String attr = outSchema.getName(i);
newTuple.setCell(attr, tuple.getCell(attr));
}
PmmXmlDoc params = newTuple.getPmmXml(Model1Schema.ATT_PARAMETER);
for (PmmXmlElementConvertable el : params.getElementSet()) {
ParamXml element = (ParamXml) el;
element.value = null;
}
newTuple.setValue(Model1Schema.ATT_PARAMETER, params);
newTuples.put(id, newTuple);
usedTupleLists.put(id, new ArrayList<KnimeTuple>());
replacements.put(id, new LinkedHashSet<String>());
}
String depVarSec = ((DepXml) tuple.getPmmXml(Model2Schema.ATT_DEPENDENT).get(0)).name;
if (replacements.get(id).add(depVarSec)) {
usedTupleLists.get(id).add(tuple);
}
}
Map<KnimeTuple, List<KnimeTuple>> toCombine = new LinkedHashMap<>();
for (String id : newTuples.keySet()) {
toCombine.put(newTuples.get(id), usedTupleLists.get(id));
}
return toCombine;
}
private static void updateMetaData(List<KnimeTuple> tuples, Map<KnimeTuple, List<KnimeTuple>> tupleCombinations) {
Map<Integer, Set<String>> organisms = new LinkedHashMap<>();
Map<Integer, Set<String>> matrices = new LinkedHashMap<>();
Map<Integer, Set<String>> organismDetails = new LinkedHashMap<>();
Map<Integer, Set<String>> matrixDetails = new LinkedHashMap<>();
Map<Integer, Set<String>> comments = new LinkedHashMap<>();
for (KnimeTuple tuple : tuples) {
int id = -1;
try {
id = tuple.getInt(Model2Schema.ATT_GLOBAL_MODEL_ID);
} catch (Exception e) {
continue;
}
if (!organisms.containsKey(id)) {
organisms.put(id, new LinkedHashSet<String>());
matrices.put(id, new LinkedHashSet<String>());
organismDetails.put(id, new LinkedHashSet<String>());
matrixDetails.put(id, new LinkedHashSet<String>());
comments.put(id, new LinkedHashSet<String>());
}
String organism = ((AgentXml) tuple.getPmmXml(TimeSeriesSchema.ATT_AGENT).get(0)).name;
String matrix = ((MatrixXml) tuple.getPmmXml(TimeSeriesSchema.ATT_MATRIX).get(0)).name;
String organismDetail = ((AgentXml) tuple.getPmmXml(TimeSeriesSchema.ATT_AGENT).get(0)).detail;
String matrixDetail = ((MatrixXml) tuple.getPmmXml(TimeSeriesSchema.ATT_MATRIX).get(0)).detail;
String comment = ((MdInfoXml) tuple.getPmmXml(TimeSeriesSchema.ATT_MDINFO).get(0)).comment;
if (organism != null) {
organisms.get(id).add(organism);
}
if (matrix != null) {
matrices.get(id).add(matrix);
}
if (organismDetail != null) {
organismDetails.get(id).add(organismDetail);
}
if (matrixDetail != null) {
matrixDetails.get(id).add(matrixDetail);
}
if (comment != null) {
comments.get(id).add(comment);
}
}
for (KnimeTuple tuple : tupleCombinations.keySet()) {
int id = tupleCombinations.get(tuple).get(0).getInt(Model2Schema.ATT_GLOBAL_MODEL_ID);
String organism = "";
String matrix = "";
String organismDetail = "";
String matrixDetail = "";
String comment = "";
for (String o : organisms.get(id)) {
organism += "," + o;
}
if (!organism.isEmpty()) {
organism = organism.substring(1);
}
for (String m : matrices.get(id)) {
matrix += "," + m;
}
if (!matrix.isEmpty()) {
matrix = matrix.substring(1);
}
for (String o : organismDetails.get(id)) {
organismDetail += "," + o;
}
if (!organismDetail.isEmpty()) {
organismDetail = organismDetail.substring(1);
}
for (String m : matrixDetails.get(id)) {
matrixDetail += "," + m;
}
if (!matrixDetail.isEmpty()) {
matrixDetail = matrixDetail.substring(1);
}
for (String c : comments.get(id)) {
comment += "," + c;
}
if (!comment.isEmpty()) {
comment = comment.substring(1);
}
AgentXml organismXml = new AgentXml();
MatrixXml matrixXml = new MatrixXml();
MdInfoXml infoXml = new MdInfoXml(null, null, comment, null, null);
organismXml.name = organism;
organismXml.detail = organismDetail;
matrixXml.name = matrix;
matrixXml.detail = matrixDetail;
tuple.setValue(TimeSeriesSchema.ATT_AGENT, new PmmXmlDoc(organismXml));
tuple.setValue(TimeSeriesSchema.ATT_MATRIX, new PmmXmlDoc(matrixXml));
tuple.setValue(TimeSeriesSchema.ATT_MDINFO, new PmmXmlDoc(infoXml));
}
}
private static void updateParamValues(List<KnimeTuple> tuples,
Map<KnimeTuple, List<KnimeTuple>> tupleCombinations) {
Map<Integer, Map<String, Double>> paramSums = new LinkedHashMap<>();
Map<Integer, Map<String, Integer>> paramCounts = new LinkedHashMap<>();
Map<Integer, Map<String, Double>> paramValues = new LinkedHashMap<>();
Map<Integer, Map<String, Double>> errorValues = new LinkedHashMap<>();
Map<Integer, Map<String, Double>> tValues = new LinkedHashMap<>();
Map<Integer, Map<String, Double>> pValues = new LinkedHashMap<>();
for (KnimeTuple tuple : tuples) {
int id = -1;
try {
id = tuple.getInt(Model2Schema.ATT_GLOBAL_MODEL_ID);
} catch (Exception e) {
continue;
}
if (!paramSums.containsKey(id)) {
Map<String, Double> sums = new LinkedHashMap<>();
Map<String, Integer> counts = new LinkedHashMap<>();
for (PmmXmlElementConvertable el : tuple.getPmmXml(Model1Schema.ATT_PARAMETER).getElementSet()) {
ParamXml param = (ParamXml) el;
sums.put(param.name, 0.0);
counts.put(param.name, 0);
}
paramSums.put(id, sums);
paramCounts.put(id, counts);
paramValues.put(id, new LinkedHashMap<>());
errorValues.put(id, new LinkedHashMap<>());
tValues.put(id, new LinkedHashMap<>());
pValues.put(id, new LinkedHashMap<>());
}
Map<String, Double> sums = paramSums.get(id);
Map<String, Integer> counts = paramCounts.get(id);
Map<String, Double> errors = errorValues.get(id);
Map<String, Double> ts = tValues.get(id);
Map<String, Double> ps = pValues.get(id);
for (PmmXmlElementConvertable el : tuple.getPmmXml(Model1Schema.ATT_PARAMETER).getElementSet()) {
ParamXml param = (ParamXml) el;
String name = param.name;
if (param.value != null) {
sums.put(name, sums.get(name) + param.value);
counts.put(name, counts.get(name) + 1);
}
if (!errors.containsKey(name)) {
errors.put(name, param.error);
} else if (!Objects.equals(errors.get(name), param.error)) {
errors.put(name, null);
}
if (!ts.containsKey(name)) {
ts.put(name, param.t);
} else if (!Objects.equals(ts.get(name), param.t)) {
ts.put(name, null);
}
if (!ps.containsKey(name)) {
ps.put(name, param.P);
} else if (!Objects.equals(ps.get(name), param.P)) {
ps.put(name, null);
}
}
}
for (Integer id : paramSums.keySet()) {
for (String param : paramSums.get(id).keySet()) {
int count = paramCounts.get(id).get(param);
if (count != 0) {
paramValues.get(id).put(param, paramSums.get(id).get(param) / count);
}
}
}
for (KnimeTuple tuple : tupleCombinations.keySet()) {
int id = tupleCombinations.get(tuple).get(0).getInt(Model2Schema.ATT_GLOBAL_MODEL_ID);
PmmXmlDoc paramXml = tuple.getPmmXml(Model1Schema.ATT_PARAMETER);
for (PmmXmlElementConvertable el : paramXml.getElementSet()) {
ParamXml param = (ParamXml) el;
if (param.value == null && paramValues.get(id).get(param.name) != null) {
param.value = paramValues.get(id).get(param.name);
param.correlations.clear();
param.error = errorValues.get(id).get(param.name);
param.t = tValues.get(id).get(param.name);
param.P = pValues.get(id).get(param.name);
}
}
tuple.setValue(Model1Schema.ATT_PARAMETER, paramXml);
}
}
private static void updatePrimaryIndepRanges(List<KnimeTuple> tuples,
Map<KnimeTuple, List<KnimeTuple>> tupleCombinations) {
Map<Integer, Map<String, Double>> indepMin = new LinkedHashMap<>();
Map<Integer, Map<String, Double>> indepMax = new LinkedHashMap<>();
for (KnimeTuple tuple : tuples) {
int id = -1;
try {
id = tuple.getInt(Model2Schema.ATT_GLOBAL_MODEL_ID);
} catch (Exception e) {
continue;
}
if (!indepMin.containsKey(id)) {
Map<String, Double> min = new LinkedHashMap<>();
Map<String, Double> max = new LinkedHashMap<>();
for (PmmXmlElementConvertable el : tuple.getPmmXml(Model1Schema.ATT_INDEPENDENT).getElementSet()) {
IndepXml indep = (IndepXml) el;
min.put(indep.name, Double.POSITIVE_INFINITY);
max.put(indep.name, Double.NEGATIVE_INFINITY);
}
indepMin.put(id, min);
indepMax.put(id, max);
}
Map<String, Double> min = indepMin.get(id);
Map<String, Double> max = indepMax.get(id);
for (PmmXmlElementConvertable el : tuple.getPmmXml(Model1Schema.ATT_INDEPENDENT).getElementSet()) {
IndepXml indep = (IndepXml) el;
if (indep.min != null) {
min.put(indep.name, Math.min(min.get(indep.name), indep.min));
}
if (indep.max != null) {
max.put(indep.name, Math.max(max.get(indep.name), indep.max));
}
}
}
for (KnimeTuple tuple : tupleCombinations.keySet()) {
int id = tupleCombinations.get(tuple).get(0).getInt(Model2Schema.ATT_GLOBAL_MODEL_ID);
Map<String, Double> mins = indepMin.get(id);
Map<String, Double> maxs = indepMax.get(id);
PmmXmlDoc indepXml = tuple.getPmmXml(Model1Schema.ATT_INDEPENDENT);
for (PmmXmlElementConvertable el : indepXml.getElementSet()) {
IndepXml indep = (IndepXml) el;
if (mins.containsKey(indep.name)) {
Double min = mins.get(indep.name);
Double max = maxs.get(indep.name);
indep.min = !min.isInfinite() ? min : null;
indep.max = !max.isInfinite() ? max : null;
}
}
tuple.setValue(Model1Schema.ATT_INDEPENDENT, indepXml);
}
}
}