from transforms.base import *
import struct



FORMAT_INFOS = {
    "byte": ("B", 8),
    "word": ("H", 16),
    "dword": ("I", 32),
    "qword": ("Q", 64),
}


class GenericArithmetic:

    def run(self, operation, data, key, width, endianness, key_change=lambda x: x):
        letter, bits = FORMAT_INFOS[width]
        modulo = pow(2, bits)
        if endianness == "lsb":
            struct_format = "<" + letter
        else:
            struct_format = ">" + letter
        step = struct.calcsize(struct_format)
        if key < 0:
            key = ((~abs(key)) + 1) % modulo
        try:
            struct.pack(struct_format, key)
        except:
            raise ValueError("Key cannot be encoded on a {}".format(width))
        if type(data) != bytes:
            data = struct.pack(struct_format, data)
        if len(data) % step != 0:
            raise ValueError("Data length is not a multiple of {}".format(step))
        res = bytearray(len(data))
        for i in range(0, len(data), step):
            value, = struct.unpack_from(struct_format, data, i)
            value = operation(value, key)
            value = value % modulo
            struct.pack_into(struct_format, res, i, value)
            key = key_change(key)
        return res


class Add(Transform, GenericArithmetic):
    """
    Addition modulo
    """
    category = "arithmetic"
    name = "add"
    icon = "wxART_CALC"

    def run(self, data:bytes, value:int=1, width:["byte", "word", "dword", "qword"]="byte", endianness:["lsb", "msb"]="lsb"):
        return GenericArithmetic.run(self, int.__add__, data, value, width, endianness)


class Sub(Transform, GenericArithmetic):
    """
    Substraction modulo
    """
    category = "arithmetic"
    name = "sub"
    icon = "wxART_CALC"

    def run(self, data:bytes, value:int=1, width:["byte", "word", "dword", "qword"]="byte", endianness:["lsb", "msb"]="lsb"):
        return GenericArithmetic.run(self, int.__sub__, data, value, width, endianness)


class Mul(Transform, GenericArithmetic):
    """
    multiplication modulo
    """
    category = "arithmetic"
    name = "mul"
    icon = "wxART_CALC"

    def run(self, data:bytes, value:int=1, width:["byte", "word", "dword", "qword"]="byte", endianness:["lsb", "msb"]="lsb"):
        return GenericArithmetic.run(self, int.__mul__, data, value, width, endianness)

class Xor(Transform, GenericArithmetic):
    """
    Xor modulo
    """
    category = "arithmetic"
    name = "xor"
    icon = "wxART_CALC"

    def run(self, data:bytes, value:int=1, width:["byte", "word", "dword", "qword"]="byte", endianness:["lsb", "msb"]="lsb"):
        return GenericArithmetic.run(self, lambda x,y: x ^ y, data, value, width, endianness)


class Div(Transform, GenericArithmetic):
    """
    division modulo
    """
    category = "arithmetic"
    name = "div"
    icon = "wxART_CALC"

    def run(self, data:bytes, value:int=1, width:["byte", "word", "dword", "qword"]="byte", endianness:["lsb", "msb"]="lsb"):
        return GenericArithmetic.run(self, int.__floordiv__, data, value, width, endianness)

class Neg(Transform, GenericArithmetic):
    """
    Negate
    """
    category = "arithmetic"
    name = "neg"
    icon = "wxART_CALC"

    def run(self, data:bytes, width:["byte", "word", "dword", "qword"]="byte", endianness:["lsb", "msb"]="lsb"):
        return GenericArithmetic.run(self, int.__mul__, data, -1, width, endianness)    


class ArithmeticCustom(Transform, GenericArithmetic):
    """
    custom operator defined using user python code
    works on bytes, word, dword or qword in lsb or msb mode
    
    """
    category = "arithmetic"
    name = "user operation (python)"
    icon = "wxART_EDIT"

    TEMPLATE = """def operation(value:int, index:int):
    # value: the numerical value of the current byte/word/dword/qword
    # index: the 0-based index of the current byte/word/dword/qword in the buffer 
    # insert your code here --v
    return value
"""

    def run(self, data:bytes, width:["byte", "word", "dword", "qword"]="byte", endianness:["lsb", "msb"]="lsb", python_code:str=TEMPLATE):
        dic = dict(globals())
        exec(python_code, dic)
        return GenericArithmetic.run(self, dic["operation"], data, 0, width, endianness, key_change=lambda x: x+1)

