#
# Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
# Use of this file is governed by the BSD 3-clause license that
# can be found in the LICENSE.txt file in the project root.
#

from io import StringIO
from antlr4.Token import Token

from antlr4.CommonTokenStream import CommonTokenStream


class TokenStreamRewriter(object):
    DEFAULT_PROGRAM_NAME = "default"
    PROGRAM_INIT_SIZE = 100
    MIN_TOKEN_INDEX = 0

    def __init__(self, tokens):
        """
        :type  tokens: antlr4.BufferedTokenStream.BufferedTokenStream
        :param tokens:
        :return:
        """
        super(TokenStreamRewriter, self).__init__()
        self.tokens = tokens
        self.programs = {self.DEFAULT_PROGRAM_NAME: []}
        self.lastRewriteTokenIndexes = {}

    def getTokenStream(self):
        return self.tokens

    def rollback(self, instruction_index, program_name):
        ins = self.programs.get(program_name, None)
        if ins:
            self.programs[program_name] = ins[self.MIN_TOKEN_INDEX: instruction_index]

    def deleteProgram(self, program_name=DEFAULT_PROGRAM_NAME):
        self.rollback(self.MIN_TOKEN_INDEX, program_name)

    def insertAfterToken(self, token, text, program_name=DEFAULT_PROGRAM_NAME):
        self.insertAfter(token.tokenIndex, text, program_name)

    def insertAfter(self, index, text, program_name=DEFAULT_PROGRAM_NAME):
        op = self.InsertAfterOp(self.tokens, index + 1, text)
        rewrites = self.getProgram(program_name)
        op.instructionIndex = len(rewrites)
        rewrites.append(op)

    def insertBeforeIndex(self, index, text):
        self.insertBefore(self.DEFAULT_PROGRAM_NAME, index, text)

    def insertBeforeToken(self, token, text, program_name=DEFAULT_PROGRAM_NAME):
        self.insertBefore(program_name, token.tokenIndex, text)

    def insertBefore(self, program_name, index, text):
        op = self.InsertBeforeOp(self.tokens, index, text)
        rewrites = self.getProgram(program_name)
        op.instructionIndex = len(rewrites)
        rewrites.append(op)

    def replaceIndex(self, index, text):
        self.replace(self.DEFAULT_PROGRAM_NAME, index, index, text)

    def replaceRange(self, from_idx, to_idx, text):
        self.replace(self.DEFAULT_PROGRAM_NAME, from_idx, to_idx, text)

    def replaceSingleToken(self, token, text):
        self.replace(self.DEFAULT_PROGRAM_NAME, token.tokenIndex, token.tokenIndex, text)

    def replaceRangeTokens(self, from_token, to_token, text, program_name=DEFAULT_PROGRAM_NAME):
        self.replace(program_name, from_token.tokenIndex, to_token.tokenIndex, text)

    def replace(self, program_name, from_idx, to_idx, text):
        if any((from_idx > to_idx, from_idx < 0, to_idx < 0, to_idx >= len(self.tokens.tokens))):
            raise ValueError(
                'replace: range invalid: {}..{}(size={})'.format(from_idx, to_idx, len(self.tokens.tokens)))
        op = self.ReplaceOp(from_idx, to_idx, self.tokens, text)
        rewrites = self.getProgram(program_name)
        op.instructionIndex = len(rewrites)
        rewrites.append(op)

    def deleteToken(self, token):
        self.delete(self.DEFAULT_PROGRAM_NAME, token, token)

    def deleteIndex(self, index):
        self.delete(self.DEFAULT_PROGRAM_NAME, index, index)

    def delete(self, program_name, from_idx, to_idx):
        if isinstance(from_idx, Token):
            self.replace(program_name, from_idx.tokenIndex, to_idx.tokenIndex, "")
        else:
            self.replace(program_name, from_idx, to_idx, "")

    def lastRewriteTokenIndex(self, program_name=DEFAULT_PROGRAM_NAME):
        return self.lastRewriteTokenIndexes.get(program_name, -1)

    def setLastRewriteTokenIndex(self, program_name, i):
        self.lastRewriteTokenIndexes[program_name] = i

    def getProgram(self, program_name):
        return self.programs.setdefault(program_name, [])
        
    def getDefaultText(self):
        return self.getText(self.DEFAULT_PROGRAM_NAME, 0, len(self.tokens.tokens) - 1)

    def getText(self, program_name, start:int, stop:int):
        """
        :return: the text in tokens[start, stop](closed interval)
        """
        rewrites = self.programs.get(program_name)

        # ensure start/end are in range
        if stop > len(self.tokens.tokens) - 1:
            stop = len(self.tokens.tokens) - 1
        if start < 0:
            start = 0

        # if no instructions to execute
        if not rewrites: return self.tokens.getText(start, stop)
        buf = StringIO()
        indexToOp = self._reduceToSingleOperationPerIndex(rewrites)
        i = start
        while all((i <= stop, i < len(self.tokens.tokens))):
            op = indexToOp.pop(i, None)
            token = self.tokens.get(i)
            if op is None:
                if token.type != Token.EOF: buf.write(token.text)
                i += 1
            else:
                i = op.execute(buf)

        if stop == len(self.tokens.tokens)-1:
            for op in indexToOp.values():
                if op.index >= len(self.tokens.tokens)-1: buf.write(op.text)

        return buf.getvalue()

    def _reduceToSingleOperationPerIndex(self, rewrites):
        # Walk replaces
        for i, rop in enumerate(rewrites):
            if any((rop is None, not isinstance(rop, TokenStreamRewriter.ReplaceOp))):
                continue
            # Wipe prior inserts within range
            inserts = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.InsertBeforeOp)]
            for iop in inserts:
                if iop.index == rop.index:
                    rewrites[iop.instructionIndex] = None
                    rop.text = '{}{}'.format(iop.text, rop.text)
                elif all((iop.index > rop.index, iop.index <= rop.last_index)):
                    rewrites[iop.instructionIndex] = None

            # Drop any prior replaces contained within
            prevReplaces = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.ReplaceOp)]
            for prevRop in prevReplaces:
                if all((prevRop.index >= rop.index, prevRop.last_index <= rop.last_index)):
                    rewrites[prevRop.instructionIndex] = None
                    continue
                isDisjoint = any((prevRop.last_index<rop.index, prevRop.index>rop.last_index))
                if all((prevRop.text is None, rop.text is None, not isDisjoint)):
                    rewrites[prevRop.instructionIndex] = None
                    rop.index = min(prevRop.index, rop.index)
                    rop.last_index = min(prevRop.last_index, rop.last_index)
                    print('New rop {}'.format(rop))
                elif (not(isDisjoint)):
                    raise ValueError("replace op boundaries of {} overlap with previous {}".format(rop, prevRop))

        # Walk inserts
        for i, iop in enumerate(rewrites):
            if any((iop is None, not isinstance(iop, TokenStreamRewriter.InsertBeforeOp))):
                continue
            prevInserts = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.InsertBeforeOp)]
            for prev_index, prevIop in enumerate(prevInserts):
                if prevIop.index == iop.index and type(prevIop) is TokenStreamRewriter.InsertBeforeOp:
                    iop.text += prevIop.text
                    rewrites[prev_index] = None
                elif prevIop.index == iop.index and type(prevIop) is TokenStreamRewriter.InsertAfterOp:
                    iop.text = prevIop.text + iop.text
                    rewrites[prev_index] = None
            # look for replaces where iop.index is in range; error
            prevReplaces = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.ReplaceOp)]
            for rop in prevReplaces:
                if iop.index == rop.index:
                    rop.text = iop.text + rop.text
                    rewrites[i] = None
                    continue
                if all((iop.index >= rop.index, iop.index <= rop.last_index)):
                    raise ValueError("insert op {} within boundaries of previous {}".format(iop, rop))

        reduced = {}
        for i, op in enumerate(rewrites):
            if op is None: continue
            if reduced.get(op.index): raise ValueError('should be only one op per index')
            reduced[op.index] = op

        return reduced

    class RewriteOperation(object):

        def __init__(self, tokens, index, text=""):
            """
            :type tokens: CommonTokenStream
            :param tokens:
            :param index:
            :param text:
            :return:
            """
            self.tokens = tokens
            self.index = index
            self.text = text
            self.instructionIndex = 0

        def execute(self, buf):
            """
            :type buf: StringIO.StringIO
            :param buf:
            :return:
            """
            return self.index

        def __str__(self):
            return '<{}@{}:"{}">'.format(self.__class__.__name__, self.tokens.get(self.index), self.text)

    class InsertBeforeOp(RewriteOperation):

        def __init__(self, tokens, index, text=""):
            super(TokenStreamRewriter.InsertBeforeOp, self).__init__(tokens, index, text)

        def execute(self, buf):
            buf.write(self.text)
            if self.tokens.get(self.index).type != Token.EOF:
                buf.write(self.tokens.get(self.index).text)
            return self.index + 1

    class InsertAfterOp(InsertBeforeOp):
        pass
            
    class ReplaceOp(RewriteOperation):

        def __init__(self, from_idx, to_idx, tokens, text):
            super(TokenStreamRewriter.ReplaceOp, self).__init__(tokens, from_idx, text)
            self.last_index = to_idx

        def execute(self, buf):
            if self.text:
                buf.write(self.text)
            return self.last_index + 1
            
        def __str__(self):
            if self.text:
                return '<ReplaceOp@{}..{}:"{}">'.format(self.tokens.get(self.index), self.tokens.get(self.last_index),
                                                        self.text)