package astra.learn.library.RLModel;

import astra.formula.Formula;
import astra.learn.library.QLearning;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:astra/learn/library/RLModel/RLDataModelSARSA.class */
public class RLDataModelSARSA extends RLDataModel {
    public RLDataModelSARSA() {
        this.Q = new HashMap<>();
        this.actions = new ArrayList();
    }

    @Override // astra.learn.library.RLModel.RLDataModel
    public Map<String, Double> update(List<Formula> list, List<Formula> list2, List<Formula> list3, double d, double d2, Map<String, Double> map, boolean z) {
        if (this.previous_state == null && this.previous_action == null) {
            this.previous_action = this.action;
            this.previous_state = this.state;
            return map;
        }
        if (this.trace) {
            System.out.println("S, previous state " + this.previous_state.toString());
        }
        if (this.trace) {
            System.out.println("A, previous action " + this.previous_action.toString());
        }
        StateActionValue stateActionValue = this.Q.get(this.previous_state);
        if (stateActionValue == null) {
            stateActionValue = new StateActionValue();
        }
        Double valueForAction = stateActionValue.getValueForAction(this.previous_action);
        if (valueForAction == null) {
            valueForAction = Double.valueOf(0.0d);
        }
        if (this.trace) {
            System.out.println("Q(S,A) is " + valueForAction);
        }
        StateActionValue stateActionValue2 = this.Q.get(this.state);
        if (stateActionValue2 == null) {
            stateActionValue2 = new StateActionValue();
            this.Q.put(this.state, stateActionValue2);
        }
        if (this.trace) {
            System.out.println("S', current state " + this.state.toString());
        }
        if (this.trace) {
            System.out.println("A', next action " + this.action.toString());
        }
        Double valueForAction2 = stateActionValue2.getValueForAction(this.action);
        if (valueForAction2 == null) {
            valueForAction2 = Double.valueOf(0.0d);
        }
        if (this.trace) {
            System.out.println("Q(S',A') " + valueForAction2);
        }
        double doubleValue = d2 * ((this.reward + (d * valueForAction2.doubleValue())) - valueForAction.doubleValue());
        double doubleValue2 = valueForAction.doubleValue() + doubleValue;
        for (String str : map.keySet()) {
            if (str.equals(QLearning.TOTAL_REWARD)) {
                map.put(str, Double.valueOf(map.get(str).doubleValue() + this.reward));
            }
            if (str.equals(QLearning.TOTAL_DISCOUNTED_REWARD)) {
                map.put(str, Double.valueOf(map.get(str).doubleValue() + doubleValue));
            }
        }
        if (this.trace) {
            System.out.println("After update, new value is " + doubleValue2);
        }
        if (!z) {
            if (this.trace) {
                System.out.println("Updated Q(S, A) for state(" + this.previous_state.toString() + "), action(" + this.action.toString() + ")");
            }
            stateActionValue.put(this.previous_action, Double.valueOf(doubleValue2));
            this.Q.put(this.previous_state, stateActionValue);
        } else if (this.trace) {
            System.out.println("Evaluation mode: Did not update Q(S, A) for state(" + this.previous_state.toString() + "), action(" + this.action.toString() + ")");
        }
        for (RLDataObject rLDataObject : this.Q.keySet()) {
            if (this.trace) {
                System.out.println("For Q(S, ...), S= " + rLDataObject.toString() + ", Q(S,A) is ");
            }
            if (this.trace) {
                System.out.println(this.Q.get(rLDataObject).toString());
            }
        }
        this.previous_action = this.action;
        this.previous_state = this.state;
        return map;
    }

    @Override // astra.learn.library.RLModel.RLDataModel, astra.learn.library.DataModel
    public String print() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("RLDataModelSARSA ");
        for (RLDataObject rLDataObject : this.Q.keySet()) {
            stringBuffer.append("\nKey: ");
            stringBuffer.append(rLDataObject.toString());
            StateActionValue stateActionValue = this.Q.get(rLDataObject);
            stringBuffer.append("\n");
            stringBuffer.append(stateActionValue.toString());
        }
        stringBuffer.append("\nEnd Model");
        return stringBuffer.toString();
    }
}
