from transforms.base import *
import struct
import math
import sys


class Deflate(Transform):
    """
    Deflate -> pack a stream
    window_size: base-two logarithm of window size, must be >= 8 and <= 15, defaults to 15
    """
    category = "compress"
    name = "deflate"
    icon = "wxART_FILETYPE_ARCHIVE"

    def run(self, data:bytes, window_size:int=15):
        import zlib
        if window_size < 8 or window_size > 15:
            raise ValueError("Invalid window_size value")
        o = zlib.compressobj(wbits=-window_size)
        return o.compress(data) + o.flush()

class Inflate(Transform):
    """
    Inflate -> unpack an deflated stream
    window_size: base-two logarithm of window size, must be >= 8 and <= 15, defaults to 15
    """
    category = "compress"
    name = "inflate"
    icon = "wxART_FILE_COMPRESSED"

    def run(self, data:bytes, window_size:int=15):
        import zlib
        if window_size < 8 or window_size > 15:
            raise ValueError("Invalid window_size value")
        return zlib.decompress(data, wbits=-window_size)


class ZlibCompress(Transform):
    """
    Compress a stream using zlib
    window_size: base-two logarithm of window size, must be >= 8 and <= 15, defaults to 15
    """
    category = "compress"
    name = "zlib compress"
    icon = "wxART_FILETYPE_ARCHIVE"

    def run(self, data:bytes, window_size:int=15):
        import zlib
        if window_size < 8 or window_size > 15:
            raise ValueError("Invalid window_size value")
        o = zlib.compressobj(wbits=window_size)
        return o.compress(data) + o.flush()


class ZlibDecompress(Transform):
    """
    Uncompress a zlib stream
    """
    category = "compress"
    name = "zlib decompress"
    icon = "wxART_FILE_COMPRESSED"

    def run(self, data:bytes):
        import zlib
        return zlib.decompress(data)    


class ZstdCompress(Transform):
    """
    Compress a stream using zstd
    """
    category = "compress"
    name = "zstd compress"
    icon = "wxART_FILETYPE_ARCHIVE"

    def run(self, data:bytes):
        if sys.version_info >= (3, 14):
            from compression import zstd
        else:
            from backports import zstd
        return zstd.compress(data)


class ZstdDecompress(Transform):
    """
    Uncompress a zstd stream
    """
    category = "compress"
    name = "zstd decompress"
    icon = "wxART_FILE_COMPRESSED"

    def run(self, data:bytes):
        if sys.version_info >= (3, 14):
            from compression import zstd
        else:
            from backports import zstd
        return zstd.decompress(data)   


class GzipCompress(Transform):
    """
    Compress a stream using gzip
    """
    category = "compress"
    name = "gzip compress"
    icon = "wxART_FILETYPE_ARCHIVE"

    def run(self, data:bytes):
        import zlib
        o = zlib.compressobj(wbits=31)
        return o.compress(data) + o.flush()

class GzipDecompress(Transform):
    """
    Uncompress a gzip archive
    """
    category = "compress"
    name = "gzip decompress"
    icon = "wxART_FILE_COMPRESSED"

    def run(self, data:bytes):
        import zlib
        return zlib.decompress(data, wbits=31)    


class LzmaCompress(Transform):
    """
    Compress a stream using LZMA (with or without xz header)
    """
    category = "compress"
    name = "lzma compress"
    icon = "wxART_FILETYPE_ARCHIVE"

    def run(self, data:bytes, format:["XZ", "ALONE"]="ALONE", compression:int=6):
        import lzma
        if format == "ALONE":
            format = lzma.FORMAT_ALONE
        else:
            format = lzma.FORMAT_XZ
        return lzma.compress(data, format=format, preset=compression)

class LzmaDecompress(Transform):
    """
    uncompress a (raw) lzma stream
    """
    category = "compress"
    name = "lzma decompress"
    icon = "wxART_FILE_COMPRESSED"

    def run(self, data:bytes, raw_mode:bool=False, lzma2:bool=False, max_length:int=-1):
        import lzma
        if raw_mode:
            if lzma2:
                props = lzma._decode_filter_properties(lzma.FILTER_LZMA2, data[0:1])
                data = data[1:]
            else:
                props = lzma._decode_filter_properties(lzma.FILTER_LZMA1, data[0:5])
                data = data[5:]
            dec = lzma.LZMADecompressor(lzma.FORMAT_RAW, filters=[props])
        else:
            dec = lzma.LZMADecompressor()
        try:
            return dec.decompress(data, max_length=max_length)
        except lzma.LZMAError:
            if not raw_mode:
                data = data[:5] + b"\xff\xff\xff\xff\xff\xff\xff\xff" + data[5+8:]
                dec = lzma.LZMADecompressor()
                # try to force uncompressed_size to -1 (https://github.com/python/cpython/issues/92018#issuecomment-1113879639)
                return dec.decompress(data, max_length=max_length)
            raise


class Bzip2Compress(Transform):
    """
    Compress a stream using bzip2
    level: compression level between 1 and 9
    """
    category = "compress"
    name = "bz2 compress"
    icon = "wxART_FILETYPE_ARCHIVE"

    def run(self, data:bytes, level:int=9):
        import bz2
        if level < 1 or level > 9:
            raise ValueError("Invalid level value")
        return bz2.compress(data, level)


class Bzip2Decompress(Transform):
    """
    Decompress a stream using bzip2
    """
    category = "compress"
    name = "bz2 decompress"
    icon = "wxART_FILE_COMPRESSED"

    def run(self, data:bytes):
        import bz2
        return bz2.decompress(data)


class LzoCompress(Transform):
    """
    Compress using LZO
    level: compression level between 1 (default) and 9
    header: include metadata header for decompression in the output (default: True)
    """
    category = "compress"
    name = "lzo compress"
    icon = "wxART_FILETYPE_ARCHIVE"

    def run(self, data:bytes, level:int=1, header:bool=True):
        try:
            import lzo
        except ImportError:
            raise ImportError("You need to install python-lzo library first. On Windows I suggest that you use some pre-compiled wheel")
        if level < 1 or level > 9:
            raise ValueError("Invalid level value")
        return lzo.compress(data, level, header)

class LzoDecompress(Transform):
    """
    Decompress using LZO
    header: header is included in input (default: True)
    buflen: if header is False, a buffer length in bytes must be given that will fit the output
    """
    category = "compress"
    name = "lzo decompress"
    icon = "wxART_FILE_COMPRESSED"

    def run(self, data:bytes, header:bool=True, buflen:int=0):
        try:
            import lzo
        except ImportError:
            raise ImportError("You need to install python-lzo library first. On WIndows I suggest that you use some pre-compiled wheel")
        return lzo.decompress(data, header, buflen)            

class Lznt1Compress(Transform):
    """
    Compress using LZNT1 (RtlCompressBuffer)
    header: include metadata header for decompression in the output (default: True)
    """
    category = "compress"
    name = "lznt1 compress"
    icon = "wxART_FILETYPE_ARCHIVE"

    def run(self, data:bytes, chunk_size:int=1000):
        from .libs import lznt1
        return lznt1.compress(data, chunk_size)

class Lznt1Decompress(Transform):
    """
    Decompress using LZNT1 (RtlDecompressBuffer)
    length_check: proofcheck chunks length
    """
    category = "compress"
    name = "lznt1 decompress"
    icon = "wxART_FILE_COMPRESSED"

    def run(self, data:bytes, length_check:bool=True):
        from .libs import lznt1
        return lznt1.decompress(data, length_check)         

class LzxDecompress(Transform):
    """
    Decompress using LZX (used for instance in CAB files).
    Note that you need to give the output size
    """
    category = "compress"
    name = "lzx decompress"
    icon = "wxART_FILE_COMPRESSED"

    def run(self, data:bytes, window_size:bool=True, output_size:int=0):
        dec = malcat.LzxDecompressor(window_size)
        if not output_size:
            raise ValueError("You need to giver the size of the output buffer beforehand")
        return dec.decompress(data, output_size)             


class LzfseDecompress(Transform):
    """
    Decompress using LZFSE (Apple inflate-like compression).
    """
    category = "compress"
    name = "lzfse decompress"
    icon = "wxART_FILE_COMPRESSED"

    def run(self, data:bytes):
        dec = malcat.LzfseDecompressor()
        output_size, = struct.unpack("<I", data[4:8])
        if not output_size:
            raise ValueError("You need to giver the size of the output buffer beforehand")
        return dec.decompress(data, output_size)   


class APLibDecompress(Transform):
    """
    Decompress using APLIB, with or without AP32 header
    """
    category = "compress"
    name = "aplib decompress"
    icon = "wxART_FILE_COMPRESSED"

    def run(self, data:bytes):
        dec = malcat.APLibDecompressor()
        if data.startswith(b'AP32') and len(data) >= 24:
            header_size, packed_size, packed_crc, orig_size, orig_crc = struct.unpack_from('=IIIII', data, 4)
            data = data[header_size : header_size + packed_size]
        return dec.decompress(data)       


class XpressLz77Decompress(Transform):
    """
    Decompress using using the Xpress algorithm with LZ77 compression (used in RtlDecompressBuffer)
    """
    category = "compress"
    name = "xpress (lz77) decompress"
    icon = "wxART_FILE_COMPRESSED"

    def run(self, data:bytes):
        dec = malcat.XpressDecompressor()
        return dec.decompress(data)   


class OfficeRleUnpack(Transform):
    """
    RLE decoding as used in office documents
    """
    category = "compress"
    name = "Office RLE"
    icon = "wxART_FILE_COMPRESSED"

    def run(self, data:bytes):
        result = bytearray()
        if len(data) <= 1 or data[0] != 1:
            raise ValueError("Invalid data")
        ptr = 1
        while ptr < len(data):
            block_ptr = ptr
            header, = struct.unpack("<H", data[ptr:ptr+2])
            block_size = min((header & 0xFFF) + 3, len(data) - ptr)
            sig = (header >> 12) & 0x07
            if sig != 3 and False:
                raise ValueError("Invalid header signature at {}: 0x{:x}".format(ptr, sig))
            is_compressed = (header >> 15) != 0
            ptr += 2
            if not is_compressed:
                # uncompressed chunk
                result.extend(data[ptr:ptr+4096])
                ptr += 4096
            else:
                # comrpessed chunk
                dptr = len(result)
                while ptr < block_ptr + block_size and ptr < len(data):
                    flagbyte = data[ptr]
                    ptr += 1
                    for i in range(8):
                        if ptr >= block_ptr + block_size:
                            break
                        is_copy = (flagbyte & 1) == 1
                        flagbyte = flagbyte >> 1
                        if not is_copy:
                            # literal token
                            result.append(data[ptr])
                            ptr += 1
                        else:
                            # copy token
                            token, = struct.unpack("<H", data[ptr:ptr+2])
                            ptr += 2
                            bit_count = max(4, int(math.ceil(math.log(len(result) - dptr, 2))))
                            length_mask = 0xFFFF >> bit_count
                            offset_mask = ~length_mask
                            length = (token & length_mask) + 3
                            temp1 = token & offset_mask
                            temp2 = 16 - bit_count
                            offset = (temp1 >> temp2) + 1
                            start_point = len(result) - offset
                            for i in range(start_point, start_point + length):
                                result.append(result[i])
        return bytes(result)


class Nrv2BDecompress(Transform):
    """
    Decompress data using the NRV2B algorithm, used by Zeus
    """
    category = "compress"
    name = "nrv2b decompress"
    icon = "wxART_FILE_COMPRESSED"


    def getbit(self, pos, data, fourBytes, count):
        #Get the bit at position count. If count == 0, reinitialize count and move to next decompression.
        if count == 0:
            count = 31
            fourBytes = struct.unpack('<L', data[pos:pos+4])[0]
            pos += 4
        else:
            count -= 1
        bit = ((fourBytes >> count ) & 1)
        return (bit, pos, fourBytes, count)

    def run(self, data:bytes):
        recordDataDecoded = b''
        sPos = 0
        dPos = 0
        lastMOff = 1
        shift = 0
        fourBytes = 0
        #Main Loop
        while True:
            if sPos >= len(data):
                return recordDataDecoded
            (gb, sPos, fourBytes, shift) = self.getbit(sPos, data, fourBytes, shift)
            while(gb != 0):
                recordDataDecoded += bytes([data[sPos]])
                sPos += 1
                if sPos > len(data):
                    return recordDataDecoded
                dPos += 1
                (gb, sPos, fourBytes, shift) = self.getbit(sPos, data, fourBytes, shift)
            #mOff calculation
            if sPos >= len(data):
                return recordDataDecoded
            (gb, sPos, fourBytes, shift) = self.getbit(sPos, data, fourBytes, shift)
            mOff = 2+gb
            if sPos >= len(data):
                return recordDataDecoded
            (gb, sPos, fourBytes, shift) = self.getbit(sPos, data, fourBytes, shift)
            while(gb == 0):
                if sPos >= len(data):
                        return recordDataDecoded
                (gb, sPos, fourBytes, shift) = self.getbit(sPos, data, fourBytes, shift)
                mOff = 2*mOff + gb
                if sPos >= len(data):
                    return recordDataDecoded
                (gb, sPos, fourBytes, shift) = self.getbit(sPos, data, fourBytes, shift)
            if mOff == 2:
                mOff = lastMOff
            else:
                mOff = (mOff - 3) * 256 + data[sPos]
                sPos += 1
                if sPos > len(data):
                    return recordDataDecoded
                if int(mOff) == -1:
                    break
                else:
                    mOff += 1
                    lastMOff = mOff
            #mLen calculation
            if sPos >= len(data):
                return recordDataDecoded
            (mLen, sPos, fourBytes, shift) = self.getbit(sPos, data, fourBytes, shift)
            if sPos >= len(data):
                return recordDataDecoded
            (gb, sPos, fourBytes, shift) = self.getbit(sPos, data, fourBytes, shift)
            mLen = mLen*2 + gb
            if mLen == 0:
                mLen += 1
                if sPos >= len(data):
                    return recordDataDecoded
                (gb, sPos, fourBytes, shift) = self.getbit(sPos, data, fourBytes, shift)
                mLen = 2*mLen + gb
                if sPos >= len(data):
                    return recordDataDecoded
                (gb, sPos, fourBytes, shift) = self.getbit(sPos, data, fourBytes, shift)
                while (gb == 0):
                    if sPos >= len(data):
                        return recordDataDecoded
                    (gb, sPos, fourBytes, shift) = self.getbit(sPos, data, fourBytes, shift)
                    mLen = 2*mLen + gb
                    (gb, sPos, fourBytes, shift) = self.getbit(sPos, data, fourBytes, shift)
                mLen += 2
            if mOff > 0xd00:
                mLen += 1
            mPos = dPos - mOff
            if mPos < 0:
                return recordDataDecoded
            if mPos > dPos:
                return recordDataDecoded
            #Copy uncompressed data
            recordDataDecoded += bytes([recordDataDecoded[mPos]])
            mPos += 1
            dPos += 1
            while mLen > 0:
                mLen -= 1
                recordDataDecoded += bytes([recordDataDecoded[mPos]])
                dPos += 1
                mPos += 1
        return recordDataDecoded


class Zlb1Decompress(Transform):
    """
    Decompress data using ZLB1 (header 1zlb)
    """
    category = "compress"
    name = "zlb1 decompress"
    icon = "wxART_FILE_COMPRESSED"

    def __init__(self):
        self.state = None
        self.ptr = 0
        self.last_cmd_ptr = 0
        self.counter = 0
        self.data = None
        self.result = bytearray()
        self.result_ptr = 0

    def __next(self):
        if not self.state:
            self.state = list(map(int, "{:016b}".format(struct.unpack("<H", self.data[self.ptr:self.ptr+2])[0])))
            self.last_cmd_ptr = self.ptr
            self.ptr += 2
        return self.state.pop(0)

    def __next_varint(self):
        r = 1
        cont = True
        while cont:
            r = r*2 + self.__next()
            cont = self.__next()
        return r

    def __output(self, data):
        self.result[self.result_ptr] = data
        self.result_ptr += 1

    def run(self, data:bytes):
        header, _, size_unp, dicsize, __ = struct.unpack("<4sIQII", data[:24])
        if header != b"1zlb":
            raise ValueError("Not a valid zlb1 archive")
        self.data = data[24:]
        if not data:
            return
        self.result = bytearray(size_unp)
        while self.result_ptr < size_unp:
            if self.result_ptr % dicsize == 0:
                self.__output(self.data[self.ptr])
                self.ptr += 1
                self.state = None
            if not self.__next():
                self.__output(self.data[self.ptr])
                self.ptr += 1
            else:
                cmd = self.data[self.last_cmd_ptr:self.last_cmd_ptr+17].hex()
                length = self.__next_varint() + 2
                offset_high = self.__next_varint() - 2
                offset_low = self.data[self.ptr]
                self.ptr += 1
                start = self.result_ptr - (offset_high * 256 + offset_low + 1)
                for i in range(length):
                    self.__output(self.result[i + start])
        return self.result  
