from transforms.base import Transform, TransformError

class Replace(Transform):
    """
    Replace occurence of a pattern by another
    if regexp is set to True, old_pattern is treated as a python regular expression and capture groups references like \\1 can be used in new_pattern
    """
    category = ""
    name = "replace"
    icon = "wxART_TRANSFORM"

    def run(self, data:bytes, old_pattern:bytes=b"old", new_pattern:bytes=b"new", regexp:bool=False):
        if regexp:
            import re
            return re.sub(old_pattern, new_pattern, data)
        else:
            return data.replace(old_pattern, new_pattern)


class FillRepeat(Transform):
    """
    Fill buffer using a repeating pattern
    """
    category = ""
    name = "fill (repeat pattern)"
    icon = "wxART_EDIT"


    def run(self, data:bytes, pattern:bytes=b"\x00"):
        osize = len(data)
        data = pattern * (1 + len(data) // len(pattern))
        return data[:osize]


class FillGenerator(Transform):
    """
    Fill buffer using a generated pattern
    """
    category = ""
    name = "fill (python generator)"
    icon = "wxART_EDIT"

    TEMPLATE = """def fill_pattern_generator():
    # generate an (infinite) sequence
    i = 0
    while True:
        yield i & 0xff
        i += 1
    """

    def run(self, data:bytes, pattern_generator:str=TEMPLATE):
        osize = len(data)
        import itertools
        dic = dict(globals())
        exec(pattern_generator, dic)
        gen = itertools.cycle(dic["fill_pattern_generator"]())
        res = bytearray(osize)
        for i in range(osize):
            res[i] = next(gen)
        return res


class Reverse(Transform):
    """
    Reverse bytes, aka: [A B C D E F G H I] -> [I H G F E D C B A]

    You can also use a sliding window, e.g. with a window of 3: 
    [A B C D E F G H I] -> [C B A F E D I H G]
    """
    category = ""
    name = "reverse"
    icon = "wxART_NOTE_FIXUP"

    def run(self, data:bytes, sliding_window:int=0):
        if sliding_window == 0:
            return data[::-1]
        else:
            res = bytearray(len(data))
            for i in range(0, len(data), sliding_window):
                sliding_window = min(sliding_window, len(data) - i)
                res[i:i+sliding_window] = data[i:i+sliding_window][::-1]
            return res


class Skip(Transform):
    """
    skip some bytes, aka: [A B C D E F G H I] with step=3 and start=4 -> [A B C D F G I]

    step: skip every nth byte
    start: length of prefix to keep, i.e. the "keeping" will only starts after this many bytes have been seen
    """
    category = ""
    name = "skip"
    icon = "wxART_UNDO"

    def run(self, data:bytes, step:int=2, start:int=0):
        if step <= 1:
            raise ValueError("Invalid step value")
        if start < 0 or start >= len(data):
            raise ValueError("Invalid start value")
        res = bytearray()
        res.extend(data[:start])
        for i in range(start, len(data), step):
            for j in range(1, step):
                if i + j < len(data):
                    res.append(data[i + j])
        return res


class Keep(Transform):
    """
    keep some bytes, aka: [A B C D E F G H I] with step=3 and start=4 -> [A B C D E H]

    step: keep every nth byte
    start: length of prefix to keep, i.e. the "keeping" will only starts after this many bytes have been seen
    """
    category = ""
    name = "keep"
    icon = "wxART_REDO"

    def run(self, data:bytes, step:int=2, start:int=0):
        if step <= 0:
            raise ValueError("Invalid step value")
        if start < 0 or start >= len(data):
            raise ValueError("Invalid start value")
        return data[:start] + data[start::step]    
    



class Cut(Transform):
    """
    Cut a data buffer, i.e. remove bytes at the beggining and/or end of the buffer

    cut_before: remove all bytes before this offset, 0 = cut nothing
    cut_after: remove all byte after this offset, 0 = cut nothing
    cut_last: remove the last X bytes, 0 = cut nothing
    """
    category = ""
    name = "cut"
    icon = "wxART_CUT"

    def run(self, data:bytes, cut_before:int=0, cut_after:int=0, cut_last:int=0):
        if cut_after and cut_last:
            raise ValueError("Please user either cut_after or cut_last")
        if cut_after:
            return data[cut_before:cut_after]
        elif cut_last:
            return data[cut_before:-cut_last]
        else:
            return data[cut_before:]    


class ExtractDelimitedTokens(Transform):
    """
    Select (i.e. keep) only the tokens (aka bytes sequences) located between two delimiters

    start_delimiter: the starting delimiter
    end_delimiter: the end delimiter
    keep_delimiters: include the delimiters in the output
    multiline: wether a token can span over multiple lines or not
    """
    category = ""
    name = "tokenize"
    icon = "wxART_CUT"

    def run(self, data:bytes, start_delimiter:bytes=b"\"", end_delimiter:bytes=b"\"", 
            keep_delimiters:bool=False, multiline:bool=False):
        import re
        flags = 0
        if multiline:
            flags = flags | re.S
        r = re.compile(f"{re.escape(start_delimiter.decode('ascii'))}(.*?){re.escape(end_delimiter.decode('ascii'))}".encode("ascii"), flags)
        res = bytearray()
        for m in r.finditer(data):
            if keep_delimiters:
                res.extend(m.group(0))
            else:
                res.extend(m.group(1))
        return res



class Resize(Transform):
    """
    Resize the buffer. If the new size is bigger than the old one, padding_value will be use to fill the void (repeated as necessary)
    """
    category = ""
    name = "resize"
    icon = "wxART_SELECT"

    def run(self, data:bytes, new_size:int=1, padding_value:bytes=b"\x00"):
        res = data[:new_size]        
        diff = new_size - len(res)
        if diff > 0:
            padding_value = padding_value * (1 + diff // len(padding_value))
            res += padding_value[:diff]
        return res


class Flip2d(Transform):
    """
    Flip a 2d array, swapping its dimension (column <-> rows)
    """
    category = ""
    name = "flip2d"
    icon = "wxART_SWITCH"

    def run(self, data:bytes, num_columns:int=16, cell_size_in_bytes:int=1):
        line_width = num_columns * cell_size_in_bytes
        if len(data) % line_width != 0:
            raise ValueError(f"Data length is not a multiple of {num_columns} * {cell_size_in_bytes}")
        new_line_width = len(data) // line_width
        res = bytearray(len(data))
        i = 0
        for x in range(num_columns):
            for y in range(len(data) // line_width):
                src = y*line_width + x * cell_size_in_bytes
                res[i:i+cell_size_in_bytes] = data[src:src+cell_size_in_bytes]
                i += cell_size_in_bytes
        return res    



class Insert(Transform):
    """
    Insert arbitrary data into the input buffer

    where: the *to_insert* buffer will be inserted in *data* BEFORE this offset. Set to -1 to specify end of file (append)
    """
    category = ""
    name = "insert"
    icon = "wxART_PLUS"


    def run(self, data:bytes, to_insert:bytes=b"DATA TO INSERT", where:int = -1, repeat:int=1):
        if repeat < 1:
            raise ValueError("Repeat must be greater than 0")
        elif repeat > 1:
            to_insert = to_insert * repeat
        if where < 0 or where >= len(data):
            return data + to_insert
        else:
            return data[:where] + to_insert + data[where:]


class Custom(Transform):
    """
    Transform the input buffer using custom python code
    """
    category = ""
    name = "user code (python)"
    icon = "wxART_EDIT"

    TEMPLATE = """def operation(input:bytes):
    # input: the data, as a bytes buffer
    # return value: a bytes or bytearray buffer

    return input
"""

    def run(self, data:bytes, python_code:str=TEMPLATE):
        dic = dict(globals())
        exec(python_code, dic)
        return dic["operation"](data)

    
