package ai.search.tree.montecarlo;

import java.util.Arrays;
import java.util.Random;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:ai/search/tree/montecarlo/TreeNode.class */
public final class TreeNode<S> {
    private static final double EPSILON = 1.0E-6d;
    private static final Random random = new Random();
    private final MonteCarloTreeSearch<S> search;
    private final TreeNode<S> parent;
    private final S state;
    private TreeNode<S>[] children;
    private int childcount;
    private double visitcount;
    private double totalvalue;

    public TreeNode(MonteCarloTreeSearch<S> monteCarloTreeSearch, TreeNode<S> treeNode, S s) {
        this.parent = treeNode;
        this.search = monteCarloTreeSearch;
        this.state = s;
    }

    public S getState() {
        return this.state;
    }

    public double getVisitCount() {
        return this.visitcount;
    }

    public double getTotalValue() {
        return this.totalvalue;
    }

    public TreeNode<S> getBestChild() {
        if (this.childcount == 0) {
            return null;
        }
        if (this.childcount == 1) {
            return this.children[0];
        }
        Arrays.sort(this.children, 0, this.childcount, new TreeNodeComparator());
        TreeNode<S> treeNode = this.children[0];
        double d = this.visitcount / this.childcount;
        if (d * 2.0d > treeNode.visitcount) {
            throw new AmbigousResultException(String.format("average %s, best node %s", Double.valueOf(d), Double.valueOf(treeNode.visitcount)));
        }
        return treeNode;
    }

    public void search() {
        TreeNode<S> expand = select(this).expand();
        propagateBack(expand, expand.rollOut());
    }

    private static <S> TreeNode<S> select(TreeNode<S> treeNode) {
        TreeNode<S> treeNode2;
        do {
            treeNode2 = treeNode;
            treeNode = treeNode.selectChild();
        } while (treeNode != null);
        return treeNode2;
    }

    private synchronized TreeNode<S> selectChild() {
        if (!hasChildren() || hasMoreStates()) {
            return null;
        }
        TreeNode<S> treeNode = null;
        double d = Double.MIN_VALUE;
        for (TreeNode<S> treeNode2 : this.children) {
            double sqrt = (treeNode2.totalvalue / (treeNode2.visitcount + EPSILON)) + Math.sqrt((2.0d * Math.log(this.visitcount + 1.0d)) / (treeNode2.visitcount + EPSILON)) + (random.nextDouble() * EPSILON);
            if (sqrt > d) {
                treeNode = treeNode2;
                d = sqrt;
            }
        }
        return treeNode;
    }

    private boolean hasMoreStates() {
        return this.children != null && this.childcount < this.children.length;
    }

    private boolean hasChildren() {
        return this.childcount > 0;
    }

    private synchronized TreeNode<S> expand() {
        if (this.children == null) {
            S[] findSuccessors = this.search.findSuccessors(this.state);
            shuffle(findSuccessors);
            this.children = new TreeNode[findSuccessors.length];
            for (int i = 0; i < findSuccessors.length; i++) {
                this.children[i] = new TreeNode<>(this.search, this, findSuccessors[i]);
            }
        }
        if (this.childcount >= this.children.length) {
            return this;
        }
        TreeNode<S>[] treeNodeArr = this.children;
        int i2 = this.childcount;
        this.childcount = i2 + 1;
        return treeNodeArr[i2];
    }

    private static <S> void shuffle(S[] sArr) {
        for (int length = sArr.length; length > 1; length--) {
            int nextInt = random.nextInt(length);
            S s = sArr[length - 1];
            sArr[length - 1] = sArr[nextInt];
            sArr[nextInt] = s;
        }
    }

    private double rollOut() {
        return this.search.getValue(this.state);
    }

    private static <S> void propagateBack(TreeNode<S> treeNode, double d) {
        while (treeNode != null) {
            synchronized (treeNode) {
                ((TreeNode) treeNode).visitcount += 1.0d;
                ((TreeNode) treeNode).totalvalue += d;
            }
            treeNode = ((TreeNode) treeNode).parent;
        }
    }
}
