from transforms.base import *
from transforms.text import STANDARD_CODECS
import base64
import re
import codecs
import struct


class Base64Decode(Transform):
    """
    base64 decoding
    """
    category = "encoding"
    name = "base64 decode"
    icon = "wxART_VIEW_STR"

    def run(self, data:bytes, automatically_add_padding:bool=True, ignore_invalid_characters:bool=True):
        if ignore_invalid_characters:
            data = re.sub(b"[^a-zA-Z0-9+/]", b"", data)
            automatically_add_padding = True
        if automatically_add_padding and len(data) % 4 != 0:
            padding_needed = 4 - (len(data) & 3)
            if padding_needed == 3:
                data +=  b"A=="
            else:
                data += b"=" * padding_needed
        return base64.b64decode(data)


class Base64Encode(Transform):
    """
    base64 encoding
    """
    category = "encoding"
    name = "base64 encode"
    icon = "wxART_VIEW_STR"

    def run(self, data:bytes):
        return base64.b64encode(data)

class Base85Decode(Transform):
    """
    base85 decoding
    """
    category = "encoding"
    name = "base85 decode"
    icon = "wxART_VIEW_STR"

    def run(self, data:bytes):
        return base64.b85decode(data)


class Base85Encode(Transform):
    """
    base85 encoding
    """
    category = "encoding"
    name = "base85 encode"
    icon = "wxART_VIEW_STR"

    def run(self, data:bytes):
        return base64.b85encode(data)    


class HexDecode(Transform):
    """
    hex decoding
    """
    category = "encoding"
    name = "hex decode"
    icon = "wxART_FIND_HEX"

    def run(self, data:bytes, ignore_nonhexa_characters:bool=False):
        if ignore_nonhexa_characters:
            data = re.sub(b"[^a-fA-F0-9]", b"", data)
        if len(data) % 2:
            data += b"0"
        return bytes.fromhex(data.decode("ascii"))


class HexEncode(Transform):
    """
    hex encoding
    """
    category = "encoding"
    name = "hex encode"
    icon = "wxART_FIND_HEX"

    def run(self, data:bytes, uppercase:bool=False):
        data = data.hex()
        if uppercase:
            data = data.upper()
        return data.encode("ascii")


class ListDecode(Transform):
    """
    convert a list (with separators) of numbers to an array of bytes

    "1, 32, 128" -> { 01 20 80 }

    width: with of each number, i.e. "byte" will encode each number on a byte 
    """
    category = "encoding"
    name = "list decode"
    icon = "wxART_NOTE_HEADER"

    def run(self, data:bytes, base:("binary", "octal", "decimal", "hexadecimal")="decimal", width:("byte", "word", "dword", "qword") = "byte", little_endian:bool = True, ignore_errors:bool=False):
        res = []
        interval, decode = {
            "binary": ("01", lambda x: int(x, base=2)),
            "octal": ("0-7", lambda x: int(x, base=8)),
            "decimal": ("0-9", lambda x: int(x, base=10)),
            "hexadecimal": ("0-9a-fA-F", lambda x: int(x, base=16)),
        }[base]
        nformat, nmax = {
            "byte": ("B", 0xff),
            "word": ("H", 0xffff),
            "dword": ("I", 0xffffffff),
            "qword": ("Q", 0xffffffffffffffff),
        }[width]
        if little_endian:
            formatter = f"<{nformat}"
        else:
            formatter = f">{nformat}"
        try:
            for m in re.finditer("[{}]+".format(interval).encode("ascii"), data):
                number = decode(m.group(0))
                if number < 0 or number > nmax:
                    raise ValueError("Number not in range: {}".format(m.group(0)))
                res.append(struct.pack(formatter, number))
        except BaseException:
            if not ignore_errors:
                raise
        return b"".join(res)


class ListEncode(Transform):
    """
    convert a byte array to a list of numbers (one number for each byte) in the chosen base and with the chosen separator
    
    { 01 20 80 } -> "1, 32, 128"
    """
    category = "encoding"
    name = "list encode"
    icon = "wxART_NOTE_HEADER"

    def run(self, data:bytes, base:("binary", "octal", "decimal", "hexadecimal")="decimal", width:("byte", "word", "dword", "qword") = "byte", little_endian:bool = True, separator:bytes=b",", ignore_errors:bool=False):
        res = []
        encode = {
            "binary": lambda x: "{:b}".format(x).encode("ascii"),
            "octal": lambda x: "{:o}".format(x).encode("ascii"),
            "decimal": lambda x: "{:d}".format(x).encode("ascii"),
            "hexadecimal": lambda x: "{:x}".format(x).encode("ascii"),
        }[base]
        nformat = {
            "byte": "B",
            "word": "H",
            "dword": "I",
            "qword": "Q",
        }[width]
        if little_endian:
            formatter = f"<{nformat}"
        else:
            formatter = f">{nformat}"
        try:
            for number in struct.iter_unpack(formatter, data):
                number = number[0]
                res.append(encode(number))
        except BaseException:
            if not ignore_errors:
                raise
        return separator.join(res)


class TextTranscode(Transform):
    """
    change text encoding
    """
    category = "encoding"
    name = "change text encoding"
    icon = "wxART_REFERENCE_SYMBOL"

    def run(self, data:bytes, source:STANDARD_CODECS="utf-16le", dest:STANDARD_CODECS="utf-8", errors:("strict", "replace", "ignore", "xmlcharrefreplace", "backslashreplace")="strict"):
        return data.decode(source, errors=errors).encode(dest, errors=errors)
