2421. Number of Good Paths

    There is a tree (i.e. a connected, undirected graph with no cycles) consisting of n nodes numbered from 0 to n - 1 and exactly n - 1 edges.

    You are given a 0-indexed integer array vals of length n where vals[i] denotes the value of the ith node. You are also given a 2D integer array edges where edges[i] = [ai, bi] denotes that there exists an undirected edge connecting nodes ai and bi.

    A good path is a simple path that satisfies the following conditions:

    Return the number of distinct good paths.

    Note that a path and its reverse are counted as the same path. For example, 0 -> 1 is considered to be the same as 1 -> 0. A single node is also considered as a valid path.

      Example 1:

    Input: vals = [1,3,2,1,3], edges = [[0,1],[0,2],[2,3],[2,4]]
    Output: 6
    Explanation: There are 5 good paths consisting of a single node.
    There is 1 additional good path: 1 -> 0 -> 2 -> 4.
    (The reverse path 4 -> 2 -> 0 -> 1 is treated as the same as 1 -> 0 -> 2 -> 4.)
    Note that 0 -> 2 -> 3 is not a good path because vals[2] > vals[0].

    Example 2:

    Input: vals = [1,1,2,2,3], edges = [[0,1],[1,2],[2,3],[2,4]]
    Output: 7
    Explanation: There are 5 good paths consisting of a single node.
    There are 2 additional good paths: 0 -> 1 and 2 -> 3.

    Example 3:

    Input: vals = [1], edges = []
    Output: 1
    Explanation: The tree consists of only one node, so there is one good path.


    Solution (Java)

    class Solution {
        public int numberOfGoodPaths(int[] vals, int[][] edges) {
                // Prepare
            Node[] nodes = new Node[vals.length];
                // Ceate nodes
            for (int i = 0; i < vals.length; i++) nodes[i] = new Node(i, vals[i]);
                // Assign edges but src has big val
            for (int[] edge : edges) {
                Node n1 = nodes[edge[0]];
                Node n2 = nodes[edge[1]];
                if (n1.val > n2.val || (n1.val == n2.val && n1.pos > n2.pos)) n1.edges.add(n2);
                else  n2.edges.add(n1);
            int result = 0;
                // Sort by node value
            Arrays.sort(nodes, (n1, n2) -> n1.val == n2.val ? n1.pos-n2.pos : n1.val-n2.val);
                // groups
            Map<Integer, Integer> map = new HashMap<>();
            int val = nodes[0].val;
            int prev = 0;
            for (int i = 0; i < nodes.length; i++) {
                      // Calculate groups when node[i] has bigger val
                if (nodes[i].val != val) {
                    result += count(nodes, map, prev, i);
                    val = nodes[i].val;
                    prev = i;
                      // Grow the tree and merge when nodes are connected.
                for (Node next : nodes[i].edges) {
                    Node srcParent = nodes[i].parent();
                    Node tgtParent = next.parent();
                    if (srcParent.pos < tgtParent.pos) tgtParent.parent = srcParent;
                    else if (srcParent.pos > tgtParent.pos) srcParent.parent = tgtParent;
            result += count(nodes, map, prev, nodes.length);
            return result;
        private int count(Node[] nodes, Map<Integer, Integer> groups, int l, int r) {
            int count = 0;
            for (int i = l; i < r; i++) groups.compute(nodes[i].parent().pos, (k, v) -> v == null ? 1 : v+1);
            for (int num : groups.values()) count += num*(num+1)/2;
            return count;
        class Node {
                // node position
            int pos;
                // node value
            int val;
                // parent of node; Union-Find
            Node parent;
            // Node with big val to Node with small val
            List<Node> edges;
            public Node(int pos, int val) {
                this.pos = pos;
                this.val = val;
                this.edges = new ArrayList<>();
                this.parent = this;
                // return parent of this node
            public Node parent() {
                      // If this node is not parent, find its parent and update this.parent;
                if (this.parent.pos != this.pos) this.parent = this.parent.parent();
                return this.parent;


