# # Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. # Use of this file is governed by the BSD 3-clause license that # can be found in the LICENSE.txt file in the project root. # # # A tree pattern matching mechanism for ANTLR {@link ParseTree}s. # #

Patterns are strings of source input text with special tags representing # token or rule references such as:

# #

{@code = ;}

# #

Given a pattern start rule such as {@code statement}, this object constructs # a {@link ParseTree} with placeholders for the {@code ID} and {@code expr} # subtree. Then the {@link #match} routines can compare an actual # {@link ParseTree} from a parse with this pattern. Tag {@code } matches # any {@code ID} token and tag {@code } references the result of the # {@code expr} rule (generally an instance of {@code ExprContext}.

# #

Pattern {@code x = 0;} is a similar pattern that matches the same pattern # except that it requires the identifier to be {@code x} and the expression to # be {@code 0}.

# #

The {@link #matches} routines return {@code true} or {@code false} based # upon a match for the tree rooted at the parameter sent in. The # {@link #match} routines return a {@link ParseTreeMatch} object that # contains the parse tree, the parse tree pattern, and a map from tag name to # matched nodes (more below). A subtree that fails to match, returns with # {@link ParseTreeMatch#mismatchedNode} set to the first tree node that did not # match.

# #

For efficiency, you can compile a tree pattern in string form to a # {@link ParseTreePattern} object.

# #

See {@code TestParseTreeMatcher} for lots of examples. # {@link ParseTreePattern} has two static helper methods: # {@link ParseTreePattern#findAll} and {@link ParseTreePattern#match} that # are easy to use but not super efficient because they create new # {@link ParseTreePatternMatcher} objects each time and have to compile the # pattern in string form before using it.

# #

The lexer and parser that you pass into the {@link ParseTreePatternMatcher} # constructor are used to parse the pattern in string form. The lexer converts # the {@code = ;} into a sequence of four tokens (assuming lexer # throws out whitespace or puts it on a hidden channel). Be aware that the # input stream is reset for the lexer (but not the parser; a # {@link ParserInterpreter} is created to parse the input.). Any user-defined # fields you have put into the lexer might get changed when this mechanism asks # it to scan the pattern string.

# #

Normally a parser does not accept token {@code } as a valid # {@code expr} but, from the parser passed in, we create a special version of # the underlying grammar representation (an {@link ATN}) that allows imaginary # tokens representing rules ({@code }) to match entire rules. We call # these bypass alternatives.

# #

Delimiters are {@code <} and {@code >}, with {@code \} as the escape string # by default, but you can set them to whatever you want using # {@link #setDelimiters}. You must escape both start and stop strings # {@code \<} and {@code \>}.

# from antlr4.CommonTokenStream import CommonTokenStream from antlr4.InputStream import InputStream from antlr4.ParserRuleContext import ParserRuleContext from antlr4.Lexer import Lexer from antlr4.ListTokenSource import ListTokenSource from antlr4.Token import Token from antlr4.error.ErrorStrategy import BailErrorStrategy from antlr4.error.Errors import RecognitionException, ParseCancellationException from antlr4.tree.Chunk import TagChunk, TextChunk from antlr4.tree.RuleTagToken import RuleTagToken from antlr4.tree.TokenTagToken import TokenTagToken from antlr4.tree.Tree import ParseTree, TerminalNode, RuleNode # need forward declaration Parser = None ParseTreePattern = None class CannotInvokeStartRule(Exception): def __init__(self, e:Exception): super().__init__(e) class StartRuleDoesNotConsumeFullPattern(Exception): pass class ParseTreePatternMatcher(object): # Constructs a {@link ParseTreePatternMatcher} or from a {@link Lexer} and # {@link Parser} object. The lexer input stream is altered for tokenizing # the tree patterns. The parser is used as a convenient mechanism to get # the grammar name, plus token, rule names. def __init__(self, lexer:Lexer, parser:Parser): self.lexer = lexer self.parser = parser self.start = "<" self.stop = ">" self.escape = "\\" # e.g., \< and \> must escape BOTH! # Set the delimiters used for marking rule and token tags within concrete # syntax used by the tree pattern parser. # # @param start The start delimiter. # @param stop The stop delimiter. # @param escapeLeft The escape sequence to use for escaping a start or stop delimiter. # # @exception IllegalArgumentException if {@code start} is {@code null} or empty. # @exception IllegalArgumentException if {@code stop} is {@code null} or empty. # def setDelimiters(self, start:str, stop:str, escapeLeft:str): if start is None or len(start)==0: raise Exception("start cannot be null or empty") if stop is None or len(stop)==0: raise Exception("stop cannot be null or empty") self.start = start self.stop = stop self.escape = escapeLeft # Does {@code pattern} matched as rule {@code patternRuleIndex} match {@code tree}?# def matchesRuleIndex(self, tree:ParseTree, pattern:str, patternRuleIndex:int): p = self.compileTreePattern(pattern, patternRuleIndex) return self.matches(tree, p) # Does {@code pattern} matched as rule patternRuleIndex match tree? Pass in a # compiled pattern instead of a string representation of a tree pattern. # def matchesPattern(self, tree:ParseTree, pattern:ParseTreePattern): mismatchedNode = self.matchImpl(tree, pattern.patternTree, dict()) return mismatchedNode is None # # Compare {@code pattern} matched as rule {@code patternRuleIndex} against # {@code tree} and return a {@link ParseTreeMatch} object that contains the # matched elements, or the node at which the match failed. # def matchRuleIndex(self, tree:ParseTree, pattern:str, patternRuleIndex:int): p = self.compileTreePattern(pattern, patternRuleIndex) return self.matchPattern(tree, p) # # Compare {@code pattern} matched against {@code tree} and return a # {@link ParseTreeMatch} object that contains the matched elements, or the # node at which the match failed. Pass in a compiled pattern instead of a # string representation of a tree pattern. # def matchPattern(self, tree:ParseTree, pattern:ParseTreePattern): labels = dict() mismatchedNode = self.matchImpl(tree, pattern.patternTree, labels) from antlr4.tree.ParseTreeMatch import ParseTreeMatch return ParseTreeMatch(tree, pattern, labels, mismatchedNode) # # For repeated use of a tree pattern, compile it to a # {@link ParseTreePattern} using this method. # def compileTreePattern(self, pattern:str, patternRuleIndex:int): tokenList = self.tokenize(pattern) tokenSrc = ListTokenSource(tokenList) tokens = CommonTokenStream(tokenSrc) from antlr4.ParserInterpreter import ParserInterpreter parserInterp = ParserInterpreter(self.parser.grammarFileName, self.parser.tokenNames, self.parser.ruleNames, self.parser.getATNWithBypassAlts(),tokens) tree = None try: parserInterp.setErrorHandler(BailErrorStrategy()) tree = parserInterp.parse(patternRuleIndex) except ParseCancellationException as e: raise e.cause except RecognitionException as e: raise e except Exception as e: raise CannotInvokeStartRule(e) # Make sure tree pattern compilation checks for a complete parse if tokens.LA(1)!=Token.EOF: raise StartRuleDoesNotConsumeFullPattern() from antlr4.tree.ParseTreePattern import ParseTreePattern return ParseTreePattern(self, pattern, patternRuleIndex, tree) # # Recursively walk {@code tree} against {@code patternTree}, filling # {@code match.}{@link ParseTreeMatch#labels labels}. # # @return the first node encountered in {@code tree} which does not match # a corresponding node in {@code patternTree}, or {@code null} if the match # was successful. The specific node returned depends on the matching # algorithm used by the implementation, and may be overridden. # def matchImpl(self, tree:ParseTree, patternTree:ParseTree, labels:dict): if tree is None: raise Exception("tree cannot be null") if patternTree is None: raise Exception("patternTree cannot be null") # x and , x and y, or x and x; or could be mismatched types if isinstance(tree, TerminalNode) and isinstance(patternTree, TerminalNode ): mismatchedNode = None # both are tokens and they have same type if tree.symbol.type == patternTree.symbol.type: if isinstance( patternTree.symbol, TokenTagToken ): # x and tokenTagToken = patternTree.symbol # track label->list-of-nodes for both token name and label (if any) self.map(labels, tokenTagToken.tokenName, tree) if tokenTagToken.label is not None: self.map(labels, tokenTagToken.label, tree) elif tree.getText()==patternTree.getText(): # x and x pass else: # x and y if mismatchedNode is None: mismatchedNode = tree else: if mismatchedNode is None: mismatchedNode = tree return mismatchedNode if isinstance(tree, ParserRuleContext) and isinstance(patternTree, ParserRuleContext): mismatchedNode = None # (expr ...) and ruleTagToken = self.getRuleTagToken(patternTree) if ruleTagToken is not None: m = None if tree.ruleContext.ruleIndex == patternTree.ruleContext.ruleIndex: # track label->list-of-nodes for both rule name and label (if any) self.map(labels, ruleTagToken.ruleName, tree) if ruleTagToken.label is not None: self.map(labels, ruleTagToken.label, tree) else: if mismatchedNode is None: mismatchedNode = tree return mismatchedNode # (expr ...) and (expr ...) if tree.getChildCount()!=patternTree.getChildCount(): if mismatchedNode is None: mismatchedNode = tree return mismatchedNode n = tree.getChildCount() for i in range(0, n): childMatch = self.matchImpl(tree.getChild(i), patternTree.getChild(i), labels) if childMatch is not None: return childMatch return mismatchedNode # if nodes aren't both tokens or both rule nodes, can't match return tree def map(self, labels, label, tree): v = labels.get(label, None) if v is None: v = list() labels[label] = v v.append(tree) # Is {@code t} {@code (expr )} subtree?# def getRuleTagToken(self, tree:ParseTree): if isinstance( tree, RuleNode ): if tree.getChildCount()==1 and isinstance(tree.getChild(0), TerminalNode ): c = tree.getChild(0) if isinstance( c.symbol, RuleTagToken ): return c.symbol return None def tokenize(self, pattern:str): # split pattern into chunks: sea (raw input) and islands (, ) chunks = self.split(pattern) # create token stream from text and tags tokens = list() for chunk in chunks: if isinstance( chunk, TagChunk ): # add special rule token or conjure up new token from name if chunk.tag[0].isupper(): ttype = self.parser.getTokenType(chunk.tag) if ttype==Token.INVALID_TYPE: raise Exception("Unknown token " + str(chunk.tag) + " in pattern: " + pattern) tokens.append(TokenTagToken(chunk.tag, ttype, chunk.label)) elif chunk.tag[0].islower(): ruleIndex = self.parser.getRuleIndex(chunk.tag) if ruleIndex==-1: raise Exception("Unknown rule " + str(chunk.tag) + " in pattern: " + pattern) ruleImaginaryTokenType = self.parser.getATNWithBypassAlts().ruleToTokenType[ruleIndex] tokens.append(RuleTagToken(chunk.tag, ruleImaginaryTokenType, chunk.label)) else: raise Exception("invalid tag: " + str(chunk.tag) + " in pattern: " + pattern) else: self.lexer.setInputStream(InputStream(chunk.text)) t = self.lexer.nextToken() while t.type!=Token.EOF: tokens.append(t) t = self.lexer.nextToken() return tokens # Split {@code = ;} into 4 chunks for tokenizing by {@link #tokenize}.# def split(self, pattern:str): p = 0 n = len(pattern) chunks = list() # find all start and stop indexes first, then collect starts = list() stops = list() while p < n : if p == pattern.find(self.escape + self.start, p): p += len(self.escape) + len(self.start) elif p == pattern.find(self.escape + self.stop, p): p += len(self.escape) + len(self.stop) elif p == pattern.find(self.start, p): starts.append(p) p += len(self.start) elif p == pattern.find(self.stop, p): stops.append(p) p += len(self.stop) else: p += 1 nt = len(starts) if nt > len(stops): raise Exception("unterminated tag in pattern: " + pattern) if nt < len(stops): raise Exception("missing start tag in pattern: " + pattern) for i in range(0, nt): if starts[i] >= stops[i]: raise Exception("tag delimiters out of order in pattern: " + pattern) # collect into chunks now if nt==0: chunks.append(TextChunk(pattern)) if nt>0 and starts[0]>0: # copy text up to first tag into chunks text = pattern[0:starts[0]] chunks.add(TextChunk(text)) for i in range(0, nt): # copy inside of tag = pattern[starts[i] + len(self.start) : stops[i]] ruleOrToken = tag label = None colon = tag.find(':') if colon >= 0: label = tag[0:colon] ruleOrToken = tag[colon+1 : len(tag)] chunks.append(TagChunk(label, ruleOrToken)) if i+1 < len(starts): # copy from end of to start of next text = pattern[stops[i] + len(self.stop) : starts[i + 1]] chunks.append(TextChunk(text)) if nt > 0 : afterLastTag = stops[nt - 1] + len(self.stop) if afterLastTag < n : # copy text from end of last tag to end text = pattern[afterLastTag : n] chunks.append(TextChunk(text)) # strip out the escape sequences from text chunks but not tags for i in range(0, len(chunks)): c = chunks[i] if isinstance( c, TextChunk ): unescaped = c.text.replace(self.escape, "") if len(unescaped) < len(c.text): chunks[i] = TextChunk(unescaped) return chunks