# Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. # Use of this file is governed by the BSD 3-clause license that # can be found in the LICENSE.txt file in the project root. #/ from uuid import UUID from io import StringIO from typing import Callable from antlr4.Token import Token from antlr4.atn.ATN import ATN from antlr4.atn.ATNType import ATNType from antlr4.atn.ATNState import * from antlr4.atn.Transition import * from antlr4.atn.LexerAction import * from antlr4.atn.ATNDeserializationOptions import ATNDeserializationOptions # This is the earliest supported serialized UUID. BASE_SERIALIZED_UUID = UUID("AADB8D7E-AEEF-4415-AD2B-8204D6CF042E") # This UUID indicates the serialized ATN contains two sets of # IntervalSets, where the second set's values are encoded as # 32-bit integers to support the full Unicode SMP range up to U+10FFFF. ADDED_UNICODE_SMP = UUID("59627784-3BE5-417A-B9EB-8131A7286089") # This list contains all of the currently supported UUIDs, ordered by when # the feature first appeared in this branch. SUPPORTED_UUIDS = [ BASE_SERIALIZED_UUID, ADDED_UNICODE_SMP ] SERIALIZED_VERSION = 3 # This is the current serialized UUID. SERIALIZED_UUID = ADDED_UNICODE_SMP class ATNDeserializer (object): def __init__(self, options : ATNDeserializationOptions = None): if options is None: options = ATNDeserializationOptions.defaultOptions self.deserializationOptions = options # Determines if a particular serialized representation of an ATN supports # a particular feature, identified by the {@link UUID} used for serializing # the ATN at the time the feature was first introduced. # # @param feature The {@link UUID} marking the first time the feature was # supported in the serialized ATN. # @param actualUuid The {@link UUID} of the actual serialized ATN which is # currently being deserialized. # @return {@code true} if the {@code actualUuid} value represents a # serialized ATN at or after the feature identified by {@code feature} was # introduced; otherwise, {@code false}. def isFeatureSupported(self, feature : UUID , actualUuid : UUID ): idx1 = SUPPORTED_UUIDS.index(feature) if idx1<0: return False idx2 = SUPPORTED_UUIDS.index(actualUuid) return idx2 >= idx1 def deserialize(self, data : str): self.reset(data) self.checkVersion() self.checkUUID() atn = self.readATN() self.readStates(atn) self.readRules(atn) self.readModes(atn) sets = [] # First, read all sets with 16-bit Unicode code points <= U+FFFF. self.readSets(atn, sets, self.readInt) # Next, if the ATN was serialized with the Unicode SMP feature, # deserialize sets with 32-bit arguments <= U+10FFFF. if self.isFeatureSupported(ADDED_UNICODE_SMP, self.uuid): self.readSets(atn, sets, self.readInt32) self.readEdges(atn, sets) self.readDecisions(atn) self.readLexerActions(atn) self.markPrecedenceDecisions(atn) self.verifyATN(atn) if self.deserializationOptions.generateRuleBypassTransitions \ and atn.grammarType == ATNType.PARSER: self.generateRuleBypassTransitions(atn) # re-verify after modification self.verifyATN(atn) return atn def reset(self, data:str): def adjust(c): v = ord(c) return v-2 if v>1 else v + 65533 temp = [ adjust(c) for c in data ] # don't adjust the first value since that's the version number temp[0] = ord(data[0]) self.data = temp self.pos = 0 def checkVersion(self): version = self.readInt() if version != SERIALIZED_VERSION: raise Exception("Could not deserialize ATN with version " + str(version) + " (expected " + str(SERIALIZED_VERSION) + ").") def checkUUID(self): uuid = self.readUUID() if not uuid in SUPPORTED_UUIDS: raise Exception("Could not deserialize ATN with UUID: " + str(uuid) + \ " (expected " + str(SERIALIZED_UUID) + " or a legacy UUID).", uuid, SERIALIZED_UUID) self.uuid = uuid def readATN(self): idx = self.readInt() grammarType = ATNType.fromOrdinal(idx) maxTokenType = self.readInt() return ATN(grammarType, maxTokenType) def readStates(self, atn:ATN): loopBackStateNumbers = [] endStateNumbers = [] nstates = self.readInt() for i in range(0, nstates): stype = self.readInt() # ignore bad type of states if stype==ATNState.INVALID_TYPE: atn.addState(None) continue ruleIndex = self.readInt() if ruleIndex == 0xFFFF: ruleIndex = -1 s = self.stateFactory(stype, ruleIndex) if stype == ATNState.LOOP_END: # special case loopBackStateNumber = self.readInt() loopBackStateNumbers.append((s, loopBackStateNumber)) elif isinstance(s, BlockStartState): endStateNumber = self.readInt() endStateNumbers.append((s, endStateNumber)) atn.addState(s) # delay the assignment of loop back and end states until we know all the state instances have been initialized for pair in loopBackStateNumbers: pair[0].loopBackState = atn.states[pair[1]] for pair in endStateNumbers: pair[0].endState = atn.states[pair[1]] numNonGreedyStates = self.readInt() for i in range(0, numNonGreedyStates): stateNumber = self.readInt() atn.states[stateNumber].nonGreedy = True numPrecedenceStates = self.readInt() for i in range(0, numPrecedenceStates): stateNumber = self.readInt() atn.states[stateNumber].isPrecedenceRule = True def readRules(self, atn:ATN): nrules = self.readInt() if atn.grammarType == ATNType.LEXER: atn.ruleToTokenType = [0] * nrules atn.ruleToStartState = [0] * nrules for i in range(0, nrules): s = self.readInt() startState = atn.states[s] atn.ruleToStartState[i] = startState if atn.grammarType == ATNType.LEXER: tokenType = self.readInt() if tokenType == 0xFFFF: tokenType = Token.EOF atn.ruleToTokenType[i] = tokenType atn.ruleToStopState = [0] * nrules for state in atn.states: if not isinstance(state, RuleStopState): continue atn.ruleToStopState[state.ruleIndex] = state atn.ruleToStartState[state.ruleIndex].stopState = state def readModes(self, atn:ATN): nmodes = self.readInt() for i in range(0, nmodes): s = self.readInt() atn.modeToStartState.append(atn.states[s]) def readSets(self, atn:ATN, sets:list, readUnicode:Callable[[], int]): m = self.readInt() for i in range(0, m): iset = IntervalSet() sets.append(iset) n = self.readInt() containsEof = self.readInt() if containsEof!=0: iset.addOne(-1) for j in range(0, n): i1 = readUnicode() i2 = readUnicode() iset.addRange(range(i1, i2 + 1)) # range upper limit is exclusive def readEdges(self, atn:ATN, sets:list): nedges = self.readInt() for i in range(0, nedges): src = self.readInt() trg = self.readInt() ttype = self.readInt() arg1 = self.readInt() arg2 = self.readInt() arg3 = self.readInt() trans = self.edgeFactory(atn, ttype, src, trg, arg1, arg2, arg3, sets) srcState = atn.states[src] srcState.addTransition(trans) # edges for rule stop states can be derived, so they aren't serialized for state in atn.states: for i in range(0, len(state.transitions)): t = state.transitions[i] if not isinstance(t, RuleTransition): continue outermostPrecedenceReturn = -1 if atn.ruleToStartState[t.target.ruleIndex].isPrecedenceRule: if t.precedence == 0: outermostPrecedenceReturn = t.target.ruleIndex trans = EpsilonTransition(t.followState, outermostPrecedenceReturn) atn.ruleToStopState[t.target.ruleIndex].addTransition(trans) for state in atn.states: if isinstance(state, BlockStartState): # we need to know the end state to set its start state if state.endState is None: raise Exception("IllegalState") # block end states can only be associated to a single block start state if state.endState.startState is not None: raise Exception("IllegalState") state.endState.startState = state if isinstance(state, PlusLoopbackState): for i in range(0, len(state.transitions)): target = state.transitions[i].target if isinstance(target, PlusBlockStartState): target.loopBackState = state elif isinstance(state, StarLoopbackState): for i in range(0, len(state.transitions)): target = state.transitions[i].target if isinstance(target, StarLoopEntryState): target.loopBackState = state def readDecisions(self, atn:ATN): ndecisions = self.readInt() for i in range(0, ndecisions): s = self.readInt() decState = atn.states[s] atn.decisionToState.append(decState) decState.decision = i def readLexerActions(self, atn:ATN): if atn.grammarType == ATNType.LEXER: count = self.readInt() atn.lexerActions = [ None ] * count for i in range(0, count): actionType = self.readInt() data1 = self.readInt() if data1 == 0xFFFF: data1 = -1 data2 = self.readInt() if data2 == 0xFFFF: data2 = -1 lexerAction = self.lexerActionFactory(actionType, data1, data2) atn.lexerActions[i] = lexerAction def generateRuleBypassTransitions(self, atn:ATN): count = len(atn.ruleToStartState) atn.ruleToTokenType = [ 0 ] * count for i in range(0, count): atn.ruleToTokenType[i] = atn.maxTokenType + i + 1 for i in range(0, count): self.generateRuleBypassTransition(atn, i) def generateRuleBypassTransition(self, atn:ATN, idx:int): bypassStart = BasicBlockStartState() bypassStart.ruleIndex = idx atn.addState(bypassStart) bypassStop = BlockEndState() bypassStop.ruleIndex = idx atn.addState(bypassStop) bypassStart.endState = bypassStop atn.defineDecisionState(bypassStart) bypassStop.startState = bypassStart excludeTransition = None if atn.ruleToStartState[idx].isPrecedenceRule: # wrap from the beginning of the rule to the StarLoopEntryState endState = None for state in atn.states: if self.stateIsEndStateFor(state, idx): endState = state excludeTransition = state.loopBackState.transitions[0] break if excludeTransition is None: raise Exception("Couldn't identify final state of the precedence rule prefix section.") else: endState = atn.ruleToStopState[idx] # all non-excluded transitions that currently target end state need to target blockEnd instead for state in atn.states: for transition in state.transitions: if transition == excludeTransition: continue if transition.target == endState: transition.target = bypassStop # all transitions leaving the rule start state need to leave blockStart instead ruleToStartState = atn.ruleToStartState[idx] count = len(ruleToStartState.transitions) while count > 0: bypassStart.addTransition(ruleToStartState.transitions[count-1]) del ruleToStartState.transitions[-1] # link the new states atn.ruleToStartState[idx].addTransition(EpsilonTransition(bypassStart)) bypassStop.addTransition(EpsilonTransition(endState)) matchState = BasicState() atn.addState(matchState) matchState.addTransition(AtomTransition(bypassStop, atn.ruleToTokenType[idx])) bypassStart.addTransition(EpsilonTransition(matchState)) def stateIsEndStateFor(self, state:ATNState, idx:int): if state.ruleIndex != idx: return None if not isinstance(state, StarLoopEntryState): return None maybeLoopEndState = state.transitions[len(state.transitions) - 1].target if not isinstance(maybeLoopEndState, LoopEndState): return None if maybeLoopEndState.epsilonOnlyTransitions and \ isinstance(maybeLoopEndState.transitions[0].target, RuleStopState): return state else: return None # # Analyze the {@link StarLoopEntryState} states in the specified ATN to set # the {@link StarLoopEntryState#isPrecedenceDecision} field to the # correct value. # # @param atn The ATN. # def markPrecedenceDecisions(self, atn:ATN): for state in atn.states: if not isinstance(state, StarLoopEntryState): continue # We analyze the ATN to determine if this ATN decision state is the # decision for the closure block that determines whether a # precedence rule should continue or complete. # if atn.ruleToStartState[state.ruleIndex].isPrecedenceRule: maybeLoopEndState = state.transitions[len(state.transitions) - 1].target if isinstance(maybeLoopEndState, LoopEndState): if maybeLoopEndState.epsilonOnlyTransitions and \ isinstance(maybeLoopEndState.transitions[0].target, RuleStopState): state.isPrecedenceDecision = True def verifyATN(self, atn:ATN): if not self.deserializationOptions.verifyATN: return # verify assumptions for state in atn.states: if state is None: continue self.checkCondition(state.epsilonOnlyTransitions or len(state.transitions) <= 1) if isinstance(state, PlusBlockStartState): self.checkCondition(state.loopBackState is not None) if isinstance(state, StarLoopEntryState): self.checkCondition(state.loopBackState is not None) self.checkCondition(len(state.transitions) == 2) if isinstance(state.transitions[0].target, StarBlockStartState): self.checkCondition(isinstance(state.transitions[1].target, LoopEndState)) self.checkCondition(not state.nonGreedy) elif isinstance(state.transitions[0].target, LoopEndState): self.checkCondition(isinstance(state.transitions[1].target, StarBlockStartState)) self.checkCondition(state.nonGreedy) else: raise Exception("IllegalState") if isinstance(state, StarLoopbackState): self.checkCondition(len(state.transitions) == 1) self.checkCondition(isinstance(state.transitions[0].target, StarLoopEntryState)) if isinstance(state, LoopEndState): self.checkCondition(state.loopBackState is not None) if isinstance(state, RuleStartState): self.checkCondition(state.stopState is not None) if isinstance(state, BlockStartState): self.checkCondition(state.endState is not None) if isinstance(state, BlockEndState): self.checkCondition(state.startState is not None) if isinstance(state, DecisionState): self.checkCondition(len(state.transitions) <= 1 or state.decision >= 0) else: self.checkCondition(len(state.transitions) <= 1 or isinstance(state, RuleStopState)) def checkCondition(self, condition:bool, message=None): if not condition: if message is None: message = "IllegalState" raise Exception(message) def readInt(self): i = self.data[self.pos] self.pos += 1 return i def readInt32(self): low = self.readInt() high = self.readInt() return low | (high << 16) def readLong(self): low = self.readInt32() high = self.readInt32() return (low & 0x00000000FFFFFFFF) | (high << 32) def readUUID(self): low = self.readLong() high = self.readLong() allBits = (low & 0xFFFFFFFFFFFFFFFF) | (high << 64) return UUID(int=allBits) edgeFactories = [ lambda args : None, lambda atn, src, trg, arg1, arg2, arg3, sets, target : EpsilonTransition(target), lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ RangeTransition(target, Token.EOF, arg2) if arg3 != 0 else RangeTransition(target, arg1, arg2), lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ RuleTransition(atn.states[arg1], arg2, arg3, target), lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ PredicateTransition(target, arg1, arg2, arg3 != 0), lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ AtomTransition(target, Token.EOF) if arg3 != 0 else AtomTransition(target, arg1), lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ ActionTransition(target, arg1, arg2, arg3 != 0), lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ SetTransition(target, sets[arg1]), lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ NotSetTransition(target, sets[arg1]), lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ WildcardTransition(target), lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ PrecedencePredicateTransition(target, arg1) ] def edgeFactory(self, atn:ATN, type:int, src:int, trg:int, arg1:int, arg2:int, arg3:int, sets:list): target = atn.states[trg] if type > len(self.edgeFactories) or self.edgeFactories[type] is None: raise Exception("The specified transition type: " + str(type) + " is not valid.") else: return self.edgeFactories[type](atn, src, trg, arg1, arg2, arg3, sets, target) stateFactories = [ lambda : None, lambda : BasicState(), lambda : RuleStartState(), lambda : BasicBlockStartState(), lambda : PlusBlockStartState(), lambda : StarBlockStartState(), lambda : TokensStartState(), lambda : RuleStopState(), lambda : BlockEndState(), lambda : StarLoopbackState(), lambda : StarLoopEntryState(), lambda : PlusLoopbackState(), lambda : LoopEndState() ] def stateFactory(self, type:int, ruleIndex:int): if type> len(self.stateFactories) or self.stateFactories[type] is None: raise Exception("The specified state type " + str(type) + " is not valid.") else: s = self.stateFactories[type]() if s is not None: s.ruleIndex = ruleIndex return s CHANNEL = 0 #The type of a {@link LexerChannelAction} action. CUSTOM = 1 #The type of a {@link LexerCustomAction} action. MODE = 2 #The type of a {@link LexerModeAction} action. MORE = 3 #The type of a {@link LexerMoreAction} action. POP_MODE = 4 #The type of a {@link LexerPopModeAction} action. PUSH_MODE = 5 #The type of a {@link LexerPushModeAction} action. SKIP = 6 #The type of a {@link LexerSkipAction} action. TYPE = 7 #The type of a {@link LexerTypeAction} action. actionFactories = [ lambda data1, data2: LexerChannelAction(data1), lambda data1, data2: LexerCustomAction(data1, data2), lambda data1, data2: LexerModeAction(data1), lambda data1, data2: LexerMoreAction.INSTANCE, lambda data1, data2: LexerPopModeAction.INSTANCE, lambda data1, data2: LexerPushModeAction(data1), lambda data1, data2: LexerSkipAction.INSTANCE, lambda data1, data2: LexerTypeAction(data1) ] def lexerActionFactory(self, type:int, data1:int, data2:int): if type > len(self.actionFactories) or self.actionFactories[type] is None: raise Exception("The specified lexer action type " + str(type) + " is not valid.") else: return self.actionFactories[type](data1, data2)