package astra.learn.library.DTModel;

import astra.core.Agent;
import astra.formula.Comparison;
import astra.formula.Formula;
import astra.learn.LearningProcessException;
import astra.learn.LearningProtectedPredicates;
import astra.learn.library.DataModel;
import astra.term.ListTerm;
import astra.term.Primitive;
import astra.term.Term;
import astra.term.Variable;
import astra.type.Type;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:astra/learn/library/DTModel/DTDataModel.class */
public class DTDataModel extends DataModel {
    Object trainingData;
    String label;
    TreeNode root;
    Agent agent;
    List<TreeNode> leafNodes;
    boolean debug = false;
    int depth = -1;
    int IDCOUNT = 0;
    private HashMap<Integer, List<Formula>> featureSplitFormulas = new HashMap<>();

    public void setAgent(Agent agent) {
        this.agent = agent;
    }

    public void addInput(String str, Term[] termArr) {
        if (this.debug) {
            System.out.println("[DTDataModel] Add inoput for term " + str);
        }
        boolean z = -1;
        switch (str.hashCode()) {
            case 152849609:
                if (str.equals(LearningProtectedPredicates.KNOWLEDGE_INPUT)) {
                    z = true;
                    break;
                }
                break;
            case 155218931:
                if (str.equals(LearningProtectedPredicates.KNOWLEDGE_LABEL)) {
                    z = 2;
                    break;
                }
                break;
            case 163112711:
                if (str.equals(LearningProtectedPredicates.KNOWLEDGE_TRAIN)) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                this.trainingData = parseTrainingData(termArr);
                return;
            case true:
                return;
            case true:
                if (this.debug) {
                    System.out.println("[DTDataModel] Add inoput for term " + termArr);
                }
                this.label = parseOutputLabel(termArr);
                return;
            default:
                throw new LearningProcessException("Could not set ID3 configuration for " + str);
        }
    }

    private Object parseTrainingData(Term[] termArr) {
        return null;
    }

    @Override // astra.learn.library.DataModel
    public String parseOutputLabel(Term[] termArr) {
        if (termArr.length == 1) {
            Term term = termArr[0];
            if (term.type().equals(Type.STRING)) {
                return Type.stringValue(term);
            }
        }
        throw new LearningProcessException("Given input for the knowledge_label belief in RL process is incorrect: it should have one term, of type string");
    }

    public TreeNode getRoot() {
        return this.root;
    }

    public List<TreeNode> getLeafNodes() {
        return getLeafNodesRecursive(new ArrayList(), getRoot());
    }

    public List<TreeNode> getLeafNodesRecursive(List<TreeNode> list, TreeNode treeNode) {
        if (treeNode.leaf()) {
            list.add(treeNode);
        } else {
            Iterator<TreeNode> it = treeNode.children().iterator();
            while (it.hasNext()) {
                getLeafNodesRecursive(list, it.next());
            }
        }
        return list;
    }

    public void buildTree(ListTerm listTerm, ListTerm listTerm2, TreeNode treeNode, int i) {
        this.featureSplitFormulas = new HashMap<>();
        if (this.debug) {
            System.out.println("[ID3DecisionTree Model] Building tree for size " + listTerm2.size());
        }
        if (listTerm.isEmpty()) {
            if (this.debug) {
                System.out.println("[ID3DecisionTree Model] X is empty, return");
                return;
            }
            return;
        }
        if (this.root == null && treeNode == null) {
            if (this.debug) {
                System.out.println("[ID3DecisionTree Model] Adding root node");
            }
            this.root = new TreeNode(this.IDCOUNT, listTerm);
            this.IDCOUNT++;
            treeNode = this.root;
        } else {
            this.depth = treeNode.depth();
            Term term = listTerm2.get(0);
            if (entropy(listTerm, listTerm2) == 0.0d) {
                if (this.debug) {
                    System.out.println("[ID3DecisionTree Model] Entropy of 0 for node with label " + listTerm2.get(0));
                }
                treeNode.setLabel(term);
                treeNode.setEntropy(0.0d);
                return;
            }
            if (this.depth == i) {
                double entropy = entropy(listTerm, listTerm2);
                Term mostFrequentLabel = mostFrequentLabel(listTerm2);
                treeNode.setLabel(mostFrequentLabel);
                treeNode.setEntropy(entropy);
                if (this.debug) {
                    System.out.println("[ID3DecisionTree Model] Reached max depth, for node with label " + mostFrequentLabel.toString() + ", entropy " + entropy);
                    return;
                }
                return;
            }
        }
        int size = ((ListTerm) listTerm.get(0)).size();
        double d = 0.0d;
        int i2 = -1;
        for (int i3 = 0; i3 < size; i3++) {
            double informationGain = informationGain(listTerm, listTerm2, i3);
            if (informationGain >= d) {
                d = informationGain;
                i2 = i3;
            }
        }
        for (Formula formula : this.featureSplitFormulas.get(Integer.valueOf(i2))) {
            ListTerm listTerm3 = new ListTerm();
            ListTerm listTerm4 = new ListTerm();
            for (int i4 = 0; i4 < listTerm.size(); i4++) {
                Comparison comparison = (Comparison) formula;
                if (this.agent.query(new Comparison(comparison.operator(), (Primitive) ((ListTerm) listTerm.get(i4)).get(i2), comparison.right()), new HashMap()) != null) {
                    listTerm3.add(listTerm.get(i4));
                    listTerm4.add(listTerm2.get(i4));
                }
            }
            TreeNode treeNode2 = new TreeNode(this.IDCOUNT, listTerm3);
            this.IDCOUNT++;
            treeNode2.setFormula(formula, i2);
            treeNode.addChild(treeNode2);
            buildTree(listTerm3, listTerm4, treeNode2, i);
        }
    }

    public double entropy(ListTerm listTerm, ListTerm listTerm2) {
        int size = listTerm2.size();
        if (size == 0) {
            return 0.0d;
        }
        HashMap hashMap = new HashMap();
        Iterator<Term> it = listTerm2.iterator();
        while (it.hasNext()) {
            Primitive primitive = (Primitive) it.next();
            Integer num = (Integer) hashMap.get(primitive.value());
            hashMap.put(primitive.value(), num == null ? 1 : Integer.valueOf(num.intValue() + 1));
        }
        double d = 0.0d;
        Iterator it2 = hashMap.keySet().iterator();
        while (it2.hasNext()) {
            double intValue = ((Integer) hashMap.get(it2.next())).intValue() / size;
            if (intValue != 0.0d) {
                d += -(intValue * log2(intValue));
            }
        }
        return d;
    }

    public Term mostFrequentLabel(ListTerm listTerm) {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        Iterator<Term> it = listTerm.iterator();
        while (it.hasNext()) {
            Term next = it.next();
            Primitive primitive = (Primitive) next;
            Integer num = (Integer) hashMap.get(primitive.value());
            hashMap.put(primitive.value(), num == null ? 1 : Integer.valueOf(num.intValue() + 1));
            hashMap2.put(primitive.value(), next);
        }
        Term term = null;
        int i = 0;
        for (Map.Entry entry : hashMap.entrySet()) {
            if (((Integer) entry.getValue()).intValue() > i) {
                i = ((Integer) entry.getValue()).intValue();
                term = (Term) hashMap2.get(entry.getKey());
            }
        }
        return term;
    }

    public static double log2(double d) {
        return Math.log(d) / Math.log(2.0d);
    }

    public double informationGain(ListTerm listTerm, ListTerm listTerm2, int i) {
        double entropyAfterSplitCategory;
        Primitive primitive = (Primitive) ((ListTerm) listTerm.get(0)).get(i);
        if (primitive.type().equals(Type.STRING) || primitive.type().equals(Type.CHAR)) {
            entropyAfterSplitCategory = entropyAfterSplitCategory(listTerm, listTerm2, i);
        } else if (primitive.type().equals(Type.DOUBLE) || primitive.type().equals(Type.INTEGER) || primitive.type().equals(Type.LONG)) {
            entropyAfterSplitCategory = entropyAfterBestSplitContinousDouble(listTerm, listTerm2, i);
        } else if (primitive.type().equals(Type.FLOAT)) {
            entropyAfterSplitCategory = entropyAfterBestSplitContinousFloat(listTerm, listTerm2, i);
        } else {
            if (!primitive.type().equals(Type.BOOLEAN)) {
                throw new LearningProcessException("Expected the X dataset for the DT to be a list of a list of primitives (i.e. char/strings, boolean or numerical), like [ [1,'hi'], [2,'low'] ]");
            }
            entropyAfterSplitCategory = entropyAfterSplitCategory(listTerm, listTerm2, i);
        }
        return entropy(listTerm, listTerm2) - entropyAfterSplitCategory;
    }

    public double entropyAfterSplitCategory(ListTerm listTerm, ListTerm listTerm2, int i) {
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < listTerm.size(); i2++) {
            Primitive primitive = (Primitive) ((ListTerm) listTerm.get(i2)).get(i);
            List list = (List) hashMap.get(primitive.value());
            if (list == null) {
                list = new ArrayList();
            }
            list.add(Integer.valueOf(i2));
            hashMap.put(primitive.value(), list);
        }
        createEqualsFormulas(hashMap.keySet(), i);
        float size = listTerm2.size();
        double d = 0.0d;
        Iterator it = hashMap.keySet().iterator();
        while (it.hasNext()) {
            List<Integer> list2 = (List) hashMap.get(it.next());
            d += Math.abs(r0.size() / size) * entropy(buildSubList(listTerm, list2), buildSubList(listTerm2, list2));
        }
        return d;
    }

    private ListTerm buildSubList(ListTerm listTerm, List<Integer> list) {
        int size = list.size();
        Term[] termArr = new Term[size];
        for (int i = 0; i < size; i++) {
            termArr[i] = listTerm.get(list.get(i).intValue());
        }
        return new ListTerm(termArr);
    }

    private void createEqualsFormulas(Set<Object> set, int i) {
        List<Formula> list = this.featureSplitFormulas.get(Integer.valueOf(i));
        if (list == null) {
            list = new ArrayList();
        }
        Iterator<Object> it = set.iterator();
        while (it.hasNext()) {
            list.add(new Comparison(Comparison.EQUAL, new Variable(Type.BOOLEAN, "X"), Primitive.newPrimitive(it.next())));
        }
        this.featureSplitFormulas.put(Integer.valueOf(i), list);
    }

    public double entropyAfterBestSplitContinousDouble(ListTerm listTerm, ListTerm listTerm2, int i) {
        double d = -1.0d;
        double d2 = -1.0d;
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < listTerm.size(); i2++) {
            double doubleValue = ((Double) ((Primitive) ((ListTerm) listTerm.get(i2)).get(i)).value()).doubleValue();
            if (i2 == 0) {
                d2 = doubleValue;
                d = doubleValue;
            } else {
                if (doubleValue < d2) {
                    d2 = doubleValue;
                }
                if (doubleValue > d) {
                    d = doubleValue;
                }
            }
        }
        double d3 = d2 + ((d - d2) / 2.0d);
        createContinuousFormulasDouble(d3, i);
        for (int i3 = 0; i3 < listTerm.size(); i3++) {
            int i4 = ((Double) ((Primitive) ((ListTerm) listTerm.get(i3)).get(i)).value()).doubleValue() >= d3 ? 1 : 0;
            List list = (List) hashMap.get(Integer.valueOf(i4));
            if (list == null) {
                list = new ArrayList();
            }
            list.add(Integer.valueOf(i3));
            hashMap.put(Integer.valueOf(i4), list);
        }
        float size = listTerm2.size();
        double d4 = 0.0d;
        Iterator it = hashMap.keySet().iterator();
        while (it.hasNext()) {
            List<Integer> list2 = (List) hashMap.get(it.next());
            d4 += Math.abs(r0.size() / size) * entropy(buildSubList(listTerm, list2), buildSubList(listTerm2, list2));
        }
        return d4;
    }

    public double entropyAfterBestSplitContinousFloat(ListTerm listTerm, ListTerm listTerm2, int i) {
        float f = -1.0f;
        float f2 = -1.0f;
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < listTerm.size(); i2++) {
            float floatValue = ((Float) ((Primitive) ((ListTerm) listTerm.get(i2)).get(i)).value()).floatValue();
            if (i2 == 0) {
                f2 = floatValue;
                f = floatValue;
            } else {
                if (floatValue < f2) {
                    f2 = floatValue;
                }
                if (floatValue > f) {
                    f = floatValue;
                }
            }
        }
        float f3 = f2 + ((f - f2) / 2.0f);
        createContinuousFormulasFloat(f3, i);
        for (int i3 = 0; i3 < listTerm.size(); i3++) {
            int i4 = ((Float) ((Primitive) ((ListTerm) listTerm.get(i3)).get(i)).value()).floatValue() >= f3 ? 1 : 0;
            List list = (List) hashMap.get(Integer.valueOf(i4));
            if (list == null) {
                list = new ArrayList();
            }
            list.add(Integer.valueOf(i3));
            hashMap.put(Integer.valueOf(i4), list);
        }
        float size = listTerm2.size();
        double d = 0.0d;
        Iterator it = hashMap.keySet().iterator();
        while (it.hasNext()) {
            List<Integer> list2 = (List) hashMap.get(it.next());
            d += Math.abs(r0.size() / size) * entropy(buildSubList(listTerm, list2), buildSubList(listTerm2, list2));
        }
        return d;
    }

    private void createContinuousFormulasDouble(double d, int i) {
        List<Formula> list = this.featureSplitFormulas.get(Integer.valueOf(i));
        if (list == null) {
            list = new ArrayList();
        }
        Comparison comparison = new Comparison(Comparison.GREATER_THAN_OR_EQUAL, new Variable(Type.DOUBLE, "X"), Primitive.newPrimitive(Double.valueOf(d)));
        Comparison comparison2 = new Comparison(Comparison.LESS_THAN, new Variable(Type.DOUBLE, "X"), Primitive.newPrimitive(Double.valueOf(d)));
        list.add(comparison);
        list.add(comparison2);
        this.featureSplitFormulas.put(Integer.valueOf(i), list);
    }

    private void createContinuousFormulasFloat(float f, int i) {
        List<Formula> list = this.featureSplitFormulas.get(Integer.valueOf(i));
        if (list == null) {
            list = new ArrayList();
        }
        Comparison comparison = new Comparison(Comparison.GREATER_THAN_OR_EQUAL, new Variable(Type.FLOAT, "X"), Primitive.newPrimitive(Float.valueOf(f)));
        Comparison comparison2 = new Comparison(Comparison.LESS_THAN, new Variable(Type.FLOAT, "X"), Primitive.newPrimitive(Float.valueOf(f)));
        list.add(comparison);
        list.add(comparison2);
        this.featureSplitFormulas.put(Integer.valueOf(i), list);
    }

    @Override // astra.learn.library.DataModel
    public boolean mergeModel(DataModel dataModel) {
        throw new UnsupportedOperationException("Unimplemented method 'mergeModel'");
    }

    @Override // astra.learn.library.DataModel
    public String print() {
        String str = "";
        boolean z = true;
        for (TreeNode treeNode : flatten()) {
            if (z) {
                z = false;
            } else {
                str = str + "\n";
            }
            str = str + printNode(treeNode);
        }
        return str;
    }

    private List<TreeNode> flatten() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.root);
        recursiveAdd(arrayList, this.root);
        return arrayList;
    }

    private List<TreeNode> recursiveAdd(List<TreeNode> list, TreeNode treeNode) {
        for (TreeNode treeNode2 : treeNode.children()) {
            if (!list.contains(treeNode2)) {
                list.add(treeNode2);
            }
        }
        Iterator<TreeNode> it = treeNode.children().iterator();
        while (it.hasNext()) {
            recursiveAdd(list, it.next());
        }
        return list;
    }

    private String printNode(TreeNode treeNode) {
        return ("\n*************************** \n" + treeNode.toString()) + "\n*************************** ";
    }
}
