package astra.learn.library.RLModel;

import astra.formula.Formula;
import astra.formula.Predicate;
import astra.learn.LearningProcessException;
import astra.learn.library.DataModel;
import astra.learn.library.QLearning;
import astra.statement.BeliefUpdate;
import astra.term.Term;
import astra.type.Type;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;

/* loaded from: input_file:astra/learn/library/RLModel/RLDataModel.class */
public class RLDataModel extends DataModel {
    protected RLDataObject state;
    protected RLDataObject action;
    protected String next_action;
    protected RLDataObject previous_state;
    protected RLDataObject previous_action;
    protected double reward;
    public Random generator = new Random();
    protected boolean trace = false;
    protected HashMap<RLDataObject, StateActionValue> Q = new HashMap<>();
    protected List<RLDataObject> actions = new ArrayList();

    public RLDataObject parseBeliefTerms(Term[] termArr) {
        return new RLDataObject(termArr);
    }

    public void setOutputLabel(Term[] termArr) {
        if (termArr.length != 1) {
            throw new LearningProcessException("Given input for the knowledge_label belief in RL process is incorrect: it should have one term, of type string");
        }
        Term term = termArr[0];
        if (term.type().equals(Type.STRING)) {
            this.next_action = Type.stringValue(term);
        }
    }

    public double parseReward(Term[] termArr) {
        if (termArr.length == 1) {
            Term term = termArr[0];
            if (term.type().equals(Type.FLOAT) || term.type().equals(Type.INTEGER) || term.type().equals(Type.DOUBLE) || term.type().equals(Type.LONG)) {
                return Type.doubleValue(term);
            }
        }
        throw new LearningProcessException("Given input for the knowledge_reward belief in RL process is incorrect: it should have one term, of type float, int, double or long");
    }

    public BeliefUpdateScoreTuple nextActionRandom() {
        if (this.trace) {
            System.out.println("Exploring");
        }
        if (this.trace) {
            System.out.println("Actions size is " + this.actions.size());
        }
        if (this.actions.size() < 1) {
            return null;
        }
        this.action = this.actions.get((int) (Math.random() * this.actions.size()));
        if (this.trace) {
            System.out.println("Randomly selected action is " + this.action.toString());
        }
        StateActionValue stateActionValue = this.Q.get(this.state);
        Double d = null;
        if (stateActionValue != null) {
            d = stateActionValue.getValueForAction(this.action);
        }
        return createBeliefUpdateForAction(d);
    }

    private BeliefUpdateScoreTuple createBeliefUpdateForAction(Double d) {
        Predicate predicate = new Predicate(this.next_action, (Term[]) this.action.getTerms().clone());
        if (d == null) {
            d = new Double(0.0d);
        }
        return new BeliefUpdateScoreTuple(new BeliefUpdate('+', predicate), d.doubleValue());
    }

    public BeliefUpdateScoreTuple nextActionHighestValue() {
        if (this.trace) {
            System.out.println("Exploiting");
        }
        StateActionValue stateActionValue = this.Q.get(this.state);
        if (stateActionValue != null) {
            RLDataObject highestValuedMatchingAction = stateActionValue.getHighestValuedMatchingAction(this.actions);
            if (highestValuedMatchingAction == null) {
                return nextActionRandom();
            }
            if (highestValuedMatchingAction != null) {
                this.action = highestValuedMatchingAction;
                return createBeliefUpdateForAction(stateActionValue.getValueForAction(this.action));
            }
        }
        return nextActionRandom();
    }

    public void clearActions() {
        this.actions = new ArrayList();
    }

    public Map<String, Double> update(List<Formula> list, List<Formula> list2, List<Formula> list3, double d, double d2, Map<String, Double> map, boolean z) {
        parseActions(list2);
        parseState(list);
        parseReward(list3);
        if (this.previous_state == null && this.previous_action == null) {
            this.previous_action = this.action;
            this.previous_state = this.state;
            return map;
        }
        if (this.trace) {
            System.out.println("Updating for previous state " + this.previous_state.toString());
        }
        if (this.trace) {
            System.out.println("and action " + this.action.toString());
        }
        StateActionValue stateActionValue = this.Q.get(this.previous_state);
        if (stateActionValue == null) {
            stateActionValue = new StateActionValue();
        }
        Double valueForAction = stateActionValue.getValueForAction(this.action);
        if (valueForAction == null) {
            valueForAction = Double.valueOf(0.0d);
        }
        if (this.trace) {
            System.out.println("Current value is " + valueForAction);
        }
        StateActionValue stateActionValue2 = this.Q.get(this.state);
        if (stateActionValue2 == null) {
            stateActionValue2 = new StateActionValue();
            this.Q.put(this.state, stateActionValue2);
        }
        if (this.trace) {
            System.out.println("Now get nax value for state and any subsequent action " + this.state.toString());
        }
        Double highestValue = stateActionValue2.getHighestValue();
        if (highestValue == null) {
            highestValue = Double.valueOf(0.0d);
        }
        if (this.trace) {
            System.out.println("Max value from current state, " + this.state.toString() + " is " + highestValue);
        }
        double doubleValue = d2 * ((this.reward + (d * highestValue.doubleValue())) - valueForAction.doubleValue());
        double doubleValue2 = valueForAction.doubleValue() + doubleValue;
        for (String str : map.keySet()) {
            if (str.equals(QLearning.TOTAL_REWARD)) {
                map.put(str, Double.valueOf(map.get(str).doubleValue() + this.reward));
            }
            if (str.equals(QLearning.TOTAL_DISCOUNTED_REWARD)) {
                map.put(str, Double.valueOf(map.get(str).doubleValue() + doubleValue));
            }
        }
        if (this.trace) {
            System.out.println("After update, new value is " + doubleValue2);
        }
        if (!z) {
            if (this.trace) {
                System.out.println("Updated Q(S, A) for state(" + this.previous_state.toString() + "), action(" + this.action.toString() + ")");
            }
            stateActionValue.put(this.action, Double.valueOf(doubleValue2));
            this.Q.put(this.previous_state, stateActionValue);
        } else if (this.trace) {
            System.out.println("Evaluation mode: Did not update Q(S, A) for state(" + this.previous_state.toString() + "), action(" + this.action.toString() + ")");
        }
        for (RLDataObject rLDataObject : this.Q.keySet()) {
            if (this.trace) {
                System.out.println("For Q(S, ...), S= " + rLDataObject.toString() + ", Q(S,A) is ");
            }
            if (this.trace) {
                System.out.println(this.Q.get(rLDataObject).toString());
            }
        }
        this.previous_action = this.action;
        this.previous_state = this.state;
        return map;
    }

    public void parseActions(List<Formula> list) {
        this.actions = new ArrayList();
        Iterator<Formula> it = list.iterator();
        while (it.hasNext()) {
            this.actions.add(parseBeliefTerms(((Predicate) it.next()).terms()));
        }
    }

    public void parseState(List<Formula> list) {
        if (list.size() > 1) {
            throw new LearningProcessException("Q-Learning Only handles one input belief at the moment");
        }
        Iterator<Formula> it = list.iterator();
        while (it.hasNext()) {
            this.state = parseBeliefTerms(((Predicate) it.next()).terms());
        }
    }

    public void parseReward(List<Formula> list) {
        if (list.size() > 1) {
            throw new LearningProcessException("Q-Learning Only handles one reward belief at the moment");
        }
        Iterator<Formula> it = list.iterator();
        while (it.hasNext()) {
            this.reward = parseReward(((Predicate) it.next()).terms());
        }
    }

    @Override // astra.learn.library.DataModel
    public boolean mergeModel(DataModel dataModel) {
        throw new UnsupportedOperationException("Unimplemented method 'mergeModel'");
    }

    @Override // astra.learn.library.DataModel
    public String print() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("RLDataModel ");
        for (RLDataObject rLDataObject : this.Q.keySet()) {
            stringBuffer.append("\nKey: ");
            stringBuffer.append(rLDataObject.toString());
            StateActionValue stateActionValue = this.Q.get(rLDataObject);
            stringBuffer.append("\n");
            stringBuffer.append(stateActionValue.toString());
        }
        stringBuffer.append("\nEnd Model");
        return stringBuffer.toString();
    }
}
