/*
 * Decompiled with CFR 0.152.
 */
package org.scribble.visit;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.stream.Collectors;
import org.scribble.ast.Module;
import org.scribble.ast.ProtocolDecl;
import org.scribble.ast.ScribNode;
import org.scribble.ast.global.GProtocolDecl;
import org.scribble.main.ScribbleException;
import org.scribble.model.global.GIOAction;
import org.scribble.model.local.EndpointFSM;
import org.scribble.model.local.EndpointGraph;
import org.scribble.model.local.EndpointState;
import org.scribble.model.local.IOAction;
import org.scribble.model.local.Send;
import org.scribble.model.wf.WFBuffers;
import org.scribble.model.wf.WFConfig;
import org.scribble.model.wf.WFState;
import org.scribble.model.wf.WFStateErrors;
import org.scribble.sesstype.name.GProtocolName;
import org.scribble.sesstype.name.Role;
import org.scribble.visit.Job;
import org.scribble.visit.ModuleContextVisitor;

public class GlobalModelChecker
extends ModuleContextVisitor {
    public GlobalModelChecker(Job job) {
        super(job);
    }

    @Override
    public void enter(ScribNode parent, ScribNode child) throws ScribbleException {
        super.enter(parent, child);
        child.del().enterCompatCheck(parent, child, this);
    }

    @Override
    public ScribNode leave(ScribNode parent, ScribNode child, ScribNode visited) throws ScribbleException {
        visited = visited.del().leaveCompatCheck(parent, child, this, visited);
        return super.leave(parent, child, visited);
    }

    @Override
    public ScribNode visit(ScribNode parent, ScribNode child) throws ScribbleException {
        if (child instanceof ProtocolDecl) {
            if (child instanceof GProtocolDecl) {
                GProtocolDecl gpd = (GProtocolDecl)child;
                return gpd.isAuxModifier() ? gpd : this.visitOverrideForGProtocolDecl((Module)parent, gpd);
            }
            return child;
        }
        return super.visit(parent, child);
    }

    private GProtocolDecl visitOverrideForGProtocolDecl(Module parent, GProtocolDecl child) throws ScribbleException {
        GProtocolDecl gpd = child;
        GProtocolName fullname = gpd.getFullMemberName(parent);
        Map<Role, EndpointFSM> egraphs = this.getEndpointFSMs(fullname, gpd);
        HashMap<Integer, WFState> seen = new HashMap<Integer, WFState>();
        WFState init = this.buildGlobalModel(fullname, gpd, egraphs, seen);
        this.getJobContext().addGlobalModel(fullname, init);
        this.checkGlobalModel(fullname, init, seen);
        return child;
    }

    private void checkGlobalModel(GProtocolName fullname, WFState init, Map<Integer, WFState> all) throws ScribbleException {
        Job job = this.getJob();
        String errorMsg = "";
        Map<Integer, Set<Integer>> reach = GlobalModelChecker.getReachability(job, all);
        int count = 0;
        for (WFState s : all.values()) {
            WFStateErrors errors;
            if (job.debug && ++count % 50 == 0) {
                job.debugPrintln("(" + fullname + ") Checking global states: " + count);
            }
            if (!(errors = s.getErrors()).isEmpty()) {
                List<GIOAction> trace = GlobalModelChecker.getTrace(all, init, s, reach);
                errorMsg = String.valueOf(errorMsg) + "\nSafety violation(s) at " + s.toString() + ":\n    Trace=" + trace;
            }
            if (!errors.stuck.isEmpty()) {
                errorMsg = String.valueOf(errorMsg) + "\n    Stuck messages: " + errors.stuck;
            }
            if (!errors.waitFor.isEmpty()) {
                errorMsg = String.valueOf(errorMsg) + "\n    Wait-for errors: " + errors.waitFor;
            }
            if (errors.orphans.isEmpty()) continue;
            errorMsg = String.valueOf(errorMsg) + "\n    Orphan messages: " + errors.orphans;
        }
        job.debugPrintln("(" + fullname + ") Checked all states: " + count);
        if (!job.noLiveness) {
            HashSet<Set<Integer>> termsets = new HashSet<Set<Integer>>();
            GlobalModelChecker.findTerminalSets(all, reach, termsets);
            for (Set set : termsets) {
                Map<Role, Set<Send>> msgLiveness;
                HashSet<Role> safety = new HashSet<Role>();
                HashSet<Role> roleLiveness = new HashSet<Role>();
                GlobalModelChecker.checkTerminalSet(all, init, set, safety, roleLiveness);
                if (!safety.isEmpty()) {
                    errorMsg = String.valueOf(errorMsg) + "\nSafety violation for " + safety + " in terminal set:\n    " + set.stream().map(i -> ((WFState)all.get(i)).toString()).collect(Collectors.joining(","));
                }
                if (!roleLiveness.isEmpty()) {
                    errorMsg = String.valueOf(errorMsg) + "\nRole progress violation for " + roleLiveness + " in terminal set:\n    " + set.stream().map(i -> ((WFState)all.get(i)).toString()).collect(Collectors.joining(","));
                }
                if ((msgLiveness = GlobalModelChecker.checkMessageLiveness(all, init, set)).isEmpty()) continue;
                errorMsg = String.valueOf(errorMsg) + "\nMessage liveness violation for " + msgLiveness + " in terminal set:\n    " + set.stream().map(i -> ((WFState)all.get(i)).toString()).collect(Collectors.joining(","));
            }
        }
        if (!errorMsg.equals("")) {
            throw new ScribbleException(errorMsg);
        }
    }

    private WFState buildGlobalModel(GProtocolName fullname, GProtocolDecl gpd, Map<Role, EndpointFSM> egraphs, Map<Integer, WFState> seen) throws ScribbleException {
        Job job = this.getJob();
        WFBuffers b0 = new WFBuffers(egraphs.keySet(), !gpd.modifiers.contains((Object)GProtocolDecl.Modifiers.EXPLICIT));
        WFConfig c0 = new WFConfig(egraphs, b0);
        WFState init = new WFState(c0);
        LinkedHashSet<WFState> todo = new LinkedHashSet<WFState>();
        todo.add(init);
        int count = 0;
        while (!todo.isEmpty()) {
            List<IOAction> acceptable_r;
            Iterator i = todo.iterator();
            WFState curr = (WFState)i.next();
            i.remove();
            seen.put(curr.id, curr);
            if (job.debug && ++count % 50 == 0) {
                job.debugPrintln("(" + fullname + ") Building global states: " + count);
            }
            Map<Role, List<IOAction>> takeable = curr.getTakeable();
            for (Role r : takeable.keySet()) {
                acceptable_r = takeable.get(r);
                EndpointFSM currfsm = curr.config.states.get(r);
                EndpointState.Kind k = currfsm.getStateKind();
                if (k == EndpointState.Kind.OUTPUT) {
                    for (IOAction a : acceptable_r) {
                        if (!acceptable_r.stream().anyMatch(x -> !a.equals(x) && iOAction.peer.equals(x.peer) && iOAction.mid.equals(x.mid) && !iOAction.payload.equals(x.payload))) continue;
                        throw new ScribbleException("Bad non-deterministic action payloads: " + acceptable_r);
                    }
                    continue;
                }
                if (k != EndpointState.Kind.UNARY_INPUT && k != EndpointState.Kind.POLY_INPUT && k != EndpointState.Kind.ACCEPT) continue;
                for (IOAction a : acceptable_r) {
                    if (!currfsm.getAllTakeable().stream().anyMatch(x -> !a.equals(x) && iOAction.peer.equals(x.peer) && iOAction.mid.equals(x.mid) && !iOAction.payload.equals(x.payload))) continue;
                    throw new ScribbleException("Bad non-deterministic action payloads: " + currfsm.getAllTakeable());
                }
            }
            for (Role r : takeable.keySet()) {
                acceptable_r = takeable.get(r);
                for (IOAction a : acceptable_r) {
                    GIOAction g;
                    List<IOAction> as;
                    if (a.isSend() || a.isReceive() || a.isDisconnect()) {
                        this.getNextStates(todo, seen, curr, a.toGlobal(r), curr.take(r, a));
                        continue;
                    }
                    if (a.isAccept() || a.isConnect()) {
                        as = takeable.get(a.peer);
                        IOAction d = a.toDual(r);
                        if (as == null || !as.contains(d)) continue;
                        as.remove(d);
                        g = a.isConnect() ? a.toGlobal(r) : d.toGlobal(a.peer);
                        this.getNextStates(todo, seen, curr, g, curr.sync(r, a, a.peer, d));
                        continue;
                    }
                    if (a.isWrapClient() || a.isWrapServer()) {
                        as = takeable.get(a.peer);
                        IOAction w = a.toDual(r);
                        if (as == null || !as.contains(w)) continue;
                        as.remove(w);
                        g = a.isConnect() ? a.toGlobal(r) : w.toGlobal(a.peer);
                        this.getNextStates(todo, seen, curr, g, curr.sync(r, a, a.peer, w));
                        continue;
                    }
                    throw new RuntimeException("Shouldn't get in here: " + a);
                }
            }
        }
        job.debugPrintln("(" + fullname + ") Building global model..\n" + init.toDot() + "\n(" + fullname + ") Built global model (" + count + " states)");
        return init;
    }

    private Map<Role, EndpointFSM> getEndpointFSMs(GProtocolName fullname, GProtocolDecl gpd) throws ScribbleException {
        Job job = this.getJob();
        HashMap<Role, EndpointFSM> egraphs = new HashMap<Role, EndpointFSM>();
        for (Role self : gpd.header.roledecls.getRoles()) {
            EndpointGraph graph = job.getContext().getEndpointGraph(fullname, self);
            job.debugPrintln("(" + fullname + ") EFSM for " + self + ":\n" + graph);
            if (!job.fair && !job.noLiveness) {
                graph = job.getContext().getUnfairEndpointGraph(fullname, self);
                job.debugPrintln("(" + fullname + ") Non-fair EFSM for " + self + ":\n" + graph.init.toDot());
            }
            egraphs.put(self, graph.toFsm());
        }
        return egraphs;
    }

    private void getNextStates(LinkedHashSet<WFState> todo, Map<Integer, WFState> seen, WFState curr, GIOAction a, List<WFConfig> nexts) {
        for (WFConfig next : nexts) {
            WFState news = new WFState(next);
            WFState succ = null;
            for (WFState tmp : seen.values()) {
                if (!tmp.equals(news)) continue;
                succ = tmp;
            }
            if (succ == null) {
                for (WFState tmp : todo) {
                    if (!tmp.equals(news)) continue;
                    succ = tmp;
                }
            }
            if (succ == null) {
                succ = news;
                todo.add(succ);
            }
            curr.addEdge(a, succ);
        }
    }

    private static void checkTerminalSet(Map<Integer, WFState> all, WFState init, Set<Integer> termset, Set<Role> safety, Set<Role> liveness) throws ScribbleException {
        Iterator<Integer> i = termset.iterator();
        WFState s = all.get(i.next());
        HashMap ss = new HashMap();
        s.config.states.keySet().forEach(r -> {
            WFState wFState2 = ss.put((Role)r, s);
        });
        while (i.hasNext()) {
            WFState next = all.get(i.next());
            Map<Role, EndpointFSM> tmp = next.config.states;
            block1: for (Role r2 : tmp.keySet()) {
                if (ss.get(r2) == null) continue;
                for (GIOAction a : next.getActions()) {
                    if (!a.containsRole(r2)) continue;
                    ss.put(r2, null);
                    continue block1;
                }
            }
        }
        for (Role r3 : ss.keySet()) {
            EndpointFSM tmp;
            WFState foo = (WFState)ss.get(r3);
            if (foo == null || (tmp = foo.config.states.get(r3)) == null || foo.config.canSafelyTerminate(r3) || !s.config.buffs.get(r3).values().stream().allMatch(v -> v == null)) continue;
            liveness.add(r3);
        }
    }

    private static Map<Role, Set<Send>> checkMessageLiveness(Map<Integer, WFState> all, WFState init, Set<Integer> termset) throws ScribbleException {
        Set<Role> roles = all.get((Object)termset.iterator().next()).config.states.keySet();
        Iterator<Integer> i = termset.iterator();
        Map<Role, Map<Role, Send>> b0 = all.get((Object)i.next()).config.buffs.getBuffers();
        while (i.hasNext()) {
            WFState s = all.get(i.next());
            WFBuffers b = s.config.buffs;
            for (Role r1 : roles) {
                for (Role r2 : roles) {
                    Send tmp;
                    Send s0 = b0.get(r1).get(r2);
                    if (s0 == null || (tmp = b.get(r1).get(r2)) != null) continue;
                    b0.get(r1).put(r2, null);
                }
            }
        }
        HashMap<Role, Set<Send>> res = new HashMap<Role, Set<Send>>();
        for (Role r1 : roles) {
            for (Role r2 : roles) {
                Send m = b0.get(r1).get(r2);
                if (m == null) continue;
                HashSet<Send> tmp = (HashSet<Send>)res.get(r2);
                if (tmp == null) {
                    tmp = new HashSet<Send>();
                    res.put(r2, tmp);
                }
                tmp.add(m);
            }
        }
        return res;
    }

    private static void findTerminalSets(Map<Integer, WFState> all, Map<Integer, Set<Integer>> reach, Set<Set<Integer>> termsets) {
        HashSet<Set<Integer>> checked = new HashSet<Set<Integer>>();
        for (Integer i : reach.keySet()) {
            WFState s = all.get(i);
            Set<Integer> rs = reach.get(s.id);
            if (checked.contains(rs) || !rs.contains(s.id)) continue;
            checked.add(rs);
            if (!GlobalModelChecker.isTerminalSetMember(all, reach, s)) continue;
            termsets.add(rs);
        }
    }

    private static boolean isTerminalSetMember(Map<Integer, WFState> all, Map<Integer, Set<Integer>> reach, WFState s) {
        Set<Integer> rs = reach.get(s.id);
        HashSet<Integer> tmp = new HashSet<Integer>(rs);
        tmp.remove(s.id);
        for (Integer r : tmp) {
            if (reach.containsKey(r) && reach.get(r).equals(rs)) continue;
            return false;
        }
        return true;
    }

    private static List<GIOAction> getTrace(Map<Integer, WFState> all, WFState start, WFState end2, Map<Integer, Set<Integer>> reach) {
        TreeMap<Integer, Set<Integer>> candidates = new TreeMap<Integer, Set<Integer>>();
        HashSet<Integer> dis0 = new HashSet<Integer>();
        dis0.add(start.id);
        candidates.put(0, dis0);
        HashSet<Integer> seen = new HashSet<Integer>();
        seen.add(start.id);
        return GlobalModelChecker.getTraceAux(new LinkedList<GIOAction>(), all, seen, candidates, end2, reach);
    }

    private static List<GIOAction> getTraceAux(List<GIOAction> trace, Map<Integer, WFState> all, Set<Integer> seen, SortedMap<Integer, Set<Integer>> candidates, WFState end2, Map<Integer, Set<Integer>> reach) {
        Integer dis = candidates.keySet().iterator().next();
        Set cs = (Set)candidates.get(dis);
        Iterator it = cs.iterator();
        Integer currid = (Integer)it.next();
        it.remove();
        if (cs.isEmpty()) {
            candidates.remove(dis);
        }
        WFState curr = all.get(currid);
        Iterator<GIOAction> as = curr.getActions().iterator();
        Iterator<WFState> ss = curr.getSuccessors().iterator();
        while (as.hasNext()) {
            GIOAction a = as.next();
            WFState s = ss.next();
            if (s.id == end2.id) {
                trace.add(a);
                return trace;
            }
            if (seen.contains(s.id) || !reach.containsKey(s.id) || !reach.get(s.id).contains(end2.id)) continue;
            seen.add(s.id);
            HashSet<Integer> tmp1 = (HashSet<Integer>)candidates.get(dis + 1);
            if (tmp1 == null) {
                tmp1 = new HashSet<Integer>();
                candidates.put(dis + 1, tmp1);
            }
            tmp1.add(s.id);
            LinkedList<GIOAction> tmp2 = new LinkedList<GIOAction>(trace);
            tmp2.add(a);
            List<GIOAction> res = GlobalModelChecker.getTraceAux(tmp2, all, seen, candidates, end2, reach);
            if (res == null) continue;
            return res;
        }
        return null;
    }

    private static Map<Integer, Set<Integer>> getReachability(Job job, Map<Integer, WFState> all) {
        HashMap<Integer, Integer> all1 = new HashMap<Integer, Integer>();
        HashMap<Integer, Integer> all2 = new HashMap<Integer, Integer>();
        int i = 0;
        for (WFState s : all.values()) {
            all1.put(s.id, i);
            all2.put(i, s.id);
            ++i;
        }
        return GlobalModelChecker.getReachabilityAux(job, all, all1, all2);
    }

    private static Map<Integer, Set<Integer>> getReachabilityAux(Job job, Map<Integer, WFState> all, Map<Integer, Integer> all1, Map<Integer, Integer> all2) {
        int size = all1.keySet().size();
        boolean[][] reach = new boolean[size][size];
        for (Integer s1id : all1.keySet()) {
            for (WFState s2 : all.get(s1id).getSuccessors()) {
                reach[all1.get((Object)s1id).intValue()][all1.get((Object)Integer.valueOf((int)s2.id)).intValue()] = true;
            }
        }
        boolean again = true;
        while (again) {
            again = false;
            int i = 0;
            while (i < size) {
                int j = 0;
                while (j < size) {
                    if (reach[i][j]) {
                        int k = 0;
                        while (k < size) {
                            if (reach[j][k] && !reach[i][k]) {
                                reach[i][k] = true;
                                again = true;
                            }
                            ++k;
                        }
                    }
                    ++j;
                }
                ++i;
            }
        }
        HashMap<Integer, Set<Integer>> res = new HashMap<Integer, Set<Integer>>();
        int i = 0;
        while (i < size) {
            HashSet<Integer> tmp = (HashSet<Integer>)res.get(all2.get(i));
            int j = 0;
            while (j < size) {
                if (reach[i][j]) {
                    if (tmp == null) {
                        tmp = new HashSet<Integer>();
                        res.put(all2.get(i), tmp);
                    }
                    tmp.add(all2.get(j));
                }
                ++j;
            }
            ++i;
        }
        return res;
    }
}

