import struct
import re
from filetypes.base import *
from filetypes.PE import Resource, PEAnalyzer
from filetypes.PYC import MAGIC_TO_VERSION
import json
import sys

NUITKA_PE_ONEFILE_RSRC_NAME = "RCDATA/27/unk"
NUITKA_PE_BLOB_RSRC_NAME = "RCDATA/3/unk"

def read_bytes(data):
    fname = []
    m = re.match(rb"^(.*?)\x00", data, re.DOTALL)
    if not m:
        return None, 0
    return m.group(1), len(m.group(1)) + 1

def read_string_utf16(data):
    fname = []
    m = re.match(rb"^((?:..)*?)\x00\x00", data, re.DOTALL)
    if not m:
        return None, 0
    fname_bytes = m.group(1)
    try:
        return fname_bytes.decode("utf-16le"), len(fname_bytes) + 2
    except UnicodeDecodeError:
        return None, 0


def read_string_utf8(data):
    fname = []
    m = re.match(rb"^(.*?)\x00", data, re.DOTALL)
    if not m:
        return None, 0
    fname_bytes = m.group(1)
    try:
        return fname_bytes.decode("utf-8"), len(fname_bytes) + 1
    except UnicodeDecodeError:
        return None, 0

###################################################################################    


def extract_nuitka_onefile_stream(pe):
    rsrc = pe.resources.get(NUITKA_PE_ONEFILE_RSRC_NAME)
    if rsrc is None or rsrc.size < 16:
        return None
    off = pe.rva2off(rsrc.rva)
    if off and pe.read(off, 2) == b"KA":
        nuitka_data = memoryview(pe.read(off, rsrc.size))
    else:
        return None

    payload_size, = struct.unpack("<Q", nuitka_data[-8:])
    compressed = nuitka_data[2:3] == b"Y"
    nuitka_data = nuitka_data[3:-8]
    if payload_size - 3 != len(nuitka_data):
        raise FatalError("Invalid nuitka payload size")
    if compressed:
        try:
            if sys.version_info >= (3, 14):
                from compression import zstd
            else:
                from backports import zstd
        except ImportError:
            raise FatalError("Archive is compressed, you need to install backports.zstd python library via pip")
        dec = zstd.ZstdDecompressor()
        nuitka_data = dec.decompress(nuitka_data)
    return nuitka_data


def extract_nuitka_blob_stream(pe):
    rsrc = pe.resources.get(NUITKA_PE_BLOB_RSRC_NAME)
    if rsrc is None or rsrc.size < 18:
        return None
    off = pe.rva2off(rsrc.rva)
    if off and pe.read(off+8, 10) == b".bytecode\x00":
        rsrc_size, = struct.unpack("<I", pe.read(off+4, 4))
        return memoryview(pe.read(off + 8, rsrc_size))
    else:
        return None


def nuitka_extract_onefile_entry(pe, vfile, password=None):
    if not vfile.path in pe.nuitka_files:
        raise KeyError("Unknown file")
    offset, size = pe.nuitka_files[vfile.path]
    nuitka_data = extract_nuitka_onefile_stream(pe)
    if not nuitka_data:
        raise ValueError("Cannot get nuitka stream")
    return nuitka_data[offset: offset+size]

###################################################################################


def nuitka_extract_blob_entry(pe, vfile, password=None):
    if not vfile.path in pe.nuitka_files:
        raise KeyError("Unknown file")
    offset, size = pe.nuitka_files[vfile.path]
    nuitka_data = extract_nuitka_blob_stream(pe)
    if not nuitka_data:
        raise ValueError("Cannot get nuitka stream")
    data = nuitka_data[offset: offset+size]
    if len(data) < 2:
        raise ValueError("stream too short")
    num_objects, = struct.unpack("<H", data[:2])
    data = data[2:]
    original_blob_data = data
    try:
        if vfile.path.endswith("/.bytecode"):
            try: 
                pyc_files = nuitka_extract_all_modules(data, num_objects, pe.nuitka_python_version, True)
            except:
                pyc_files = nuitka_extract_all_modules(data, num_objects, pe.nuitka_python_version, False)
            
            return b"Reconstructed modules:\n" + b"".join(pyc_files)  # let malcat do its magic
        else:
            objects = []
            try:
                try: 
                    objects = nuitka_deserialise_all(data, num_objects, True)
                except:
                    objects = nuitka_deserialise_all(data, num_objects, False)
            except BaseException as e:
                import traceback
                return json.dumps(objects, indent=4).encode("utf-8") + b"\x00" + traceback.format_exc().encode("utf8") + b"\x00" + bytes(original_blob_data)
            return b"Reconstructed object pool:\n" + json.dumps(objects, indent=4).encode("utf-8")

    except BaseException as e:
        return f"Could not deserialise blob: {e}  Original data:".encode("utf-8") + bytes(original_blob_data)


def nuitka_parse_varint(mv):
    res = 0
    factor = 1
    while mv:
        b = int(mv[0])
        mv = mv[1:]
        res += (b & 0x7f) * factor
        if (b & 0x80) == 0:
            break
        factor = factor << 7
    else:
        raise ValueError("Truncated stream")
    return res, mv

def nuitka_extract_all_modules(data, num_objects, nuitka_python_version, use_integer_varint = True):
    MAGIC_TO_VERSION_INV = {v: k for k, v in MAGIC_TO_VERSION.items()}
    magic = MAGIC_TO_VERSION_INV.get(nuitka_python_version, MAGIC_TO_VERSION_INV[(3, 8)])
    pyc_files = []
    for i in range(num_objects):
        if not data or data[0:1]  != b"X":
            raise FatalError(f"Invalid module header for object #{i}: {bytes(data[:8])}")
        data = data[1:]
        if use_integer_varint:
            size, data = nuitka_parse_varint(data)
        else:
            size, = struct.unpack("<I", data[:4])
            data = data[4:]
        if size == 0:
            break
        header = struct.pack("<I", magic)
        header += b"\x00\x00\x00\x00"
        if nuitka_python_version >= (3, 7):
            header += b"\x00\x00\x00\x00"
        if nuitka_python_version >= (3, 3):
            header += struct.pack("<I", size)
        pyc_files.append(header + bytes(data[:size]))
        data = data[size:]
    return pyc_files

def nuitka_deserialise_all(data, num_objects, use_integer_varint = True):
    objects = []
    for i in range(num_objects):
        obj, new_data = nuitka_deserialise(data, use_integer_varint)
        if len(new_data) >= len(data):
            raise ValueError("Deserialisation error")
        data = new_data
        objects.append(obj)
    return objects


def nuitka_deserialise(mv, use_integer_varint = True):
    res = None
    typ = bytes(mv[0:1])
    mv = mv[1:]
    if typ in (b"a", b"u", b"E", b"O"):
        res, sz = read_string_utf8(mv)
        mv = mv[sz:]
    elif typ == b"c":
        res, sz = read_bytes(mv)
        res = f"bytes({repr(res)})"
        mv = mv[sz:]
    elif typ == b"d":
        res = repr(bytes(mv[:1]))
        mv = mv[1:]
    elif typ == b"w":
        res = bytes(mv[:1]).decode("utf-8")
        mv = mv[1:]
    elif typ in (b"l", b"q"):
        if use_integer_varint:
            res, mv = nuitka_parse_varint(mv)
        else:
            res = struct.unpack("<I", mv[:4])
            mv = mv[4:]
        if typ == b"q":
            res = -res
    elif typ == b"f":
        res, = struct.unpack("<d", mv[:8])
        mv = mv[8:]
    elif typ in (b"T", b"L"):
        sz, mv = nuitka_parse_varint(mv)
        res = []
        for i in range(sz):
            o, mv = nuitka_deserialise(mv, use_integer_varint)
            res.append(o)
        if typ == b"T":
            res = tuple(res)
    elif typ == b"P" or typ == b"S":
        sz, mv = nuitka_parse_varint(mv)
        res = []
        for i in range(sz):
            o, mv = nuitka_deserialise(mv, use_integer_varint)
            res.append(o)
    elif typ == b"D":
        sz, mv = nuitka_parse_varint(mv)
        res = {}
        keys = []
        values = []
        for i in range(sz):
            k, mv = nuitka_deserialise(mv, use_integer_varint)
            keys.append(k)
        for i in range(sz):
            v, mv = nuitka_deserialise(mv, use_integer_varint)
            values.append(v)
        for k, v in zip(keys, values):
            res[k] = v
    elif typ == b"Z":
        subtype = mv[0]
        mv = mv[1:]
        if subtype <= 1:
            res = 0.0
        elif subtype <= 3:
            res = float("nan")
        elif subtype == 4:
            res = 0.0
        else:
            raise ValueError("Invalid float constant")
    elif typ in (b";", b":"):
        start, mv = nuitka_deserialise(mv, use_integer_varint)
        stop, mv = nuitka_deserialise(mv, use_integer_varint)
        step, mv = nuitka_deserialise(mv, use_integer_varint)
        res = f"[{start}:{stop}:{step}]"
    elif typ == b"A":
        alias1, mv = nuitka_deserialise(mv, use_integer_varint)
        alias2, mv = nuitka_deserialise(mv, use_integer_varint)
        res = ("<alias>", alias1, alias2)
    elif typ == b"H":
        union, mv = nuitka_deserialise(mv, use_integer_varint)
        res = ("<union>", union)
    elif typ == b"M":
        index = mv[0]
        mv = mv[1:]
        res = [
            "NoneType",
            "EllipsisType",
            "ExceptionType",
            "FunctionType",
            "GeneratorType",
            "CFunctionType",
            "CodeType",
            "ModuleType",
            "FileType",
            "ClassType",
            "UnionType",
            "MethodType",
        ][index]
    elif typ == b"n":
        res = None
    elif typ == b"p":   # prev?
        res = None
    elif typ == b"t":
        res = True
    elif typ == b"F":
        res = "False"
    elif typ == b"s":
        res = ""
    else:
        raise ValueError(f"Unsupported object type {typ}: {bytes(mv[:16])}")
    return res, mv

def nuitka_get_python_version(pe):
    for dll_name, _ in pe.raw_imports:
        dll_name = dll_name.lower().split(".")[0]
        m = re.match("python([2-5])(\d+)", dll_name)
        if m:
            major, minor = list(map(int, m.groups()))
            return (major, minor)
    return (3, 8)
    


###################################################################################

def parse_nuitka(pe):

    pe.nuitka_files = {}
    pe.nuitka_python_version = (3, 8)
    nuitka_onefile_stream = extract_nuitka_onefile_stream(pe)
    if nuitka_onefile_stream:
        from types import MethodType
        pe.nuitka_extract_onefile_entry = MethodType(nuitka_extract_onefile_entry, pe)
        crc32_presence_checked = False
        uses_crc32 = False
        original_size = len(nuitka_onefile_stream)
        while len(nuitka_onefile_stream) > 8:
            fname, fname_byte_size = read_string_utf16(nuitka_onefile_stream)
            if not fname_byte_size:
                raise FatalError("Could not read file name")
            nuitka_onefile_stream = nuitka_onefile_stream[fname_byte_size:]
            if len(nuitka_onefile_stream) < 8:
                raise FatalError("Truncated stream")
            file_size, = struct.unpack("<Q", nuitka_onefile_stream[:8])
            nuitka_onefile_stream = nuitka_onefile_stream[8:]
            if not crc32_presence_checked and file_size + 4 <= len(nuitka_onefile_stream):
                import zlib
                crc32_candidate, = struct.unpack("<I", nuitka_onefile_stream[:4])
                crc32 = zlib.crc32(nuitka_onefile_stream[4:4+file_size])
                uses_crc32 = crc32_candidate == crc32
                crc32_presence_checked = True
            if uses_crc32:
                nuitka_onefile_stream = nuitka_onefile_stream[4:]
            if file_size > len(nuitka_onefile_stream):
                print(f"File size too large: {file_size:x}")
                break
            # check for crc32 for custom installs
            file_offset = (original_size - len(nuitka_onefile_stream) )
            nuitka_onefile_stream = nuitka_onefile_stream[file_size:]
            fname = "/".join(["#NUITKA_ONEFILE", fname])
            pe.nuitka_files[fname] = (file_offset, file_size)
            pe.add_file(fname, file_size, "nuitka_extract_onefile_entry")
        return

    nuitka_blob_stream = extract_nuitka_blob_stream(pe)
    if nuitka_blob_stream:
        pe.nuitka_python_version = nuitka_get_python_version(pe)
        from types import MethodType
        pe.nuitka_extract_blob_entry = MethodType(nuitka_extract_blob_entry, pe)
        original_size = len(nuitka_blob_stream)
        while len(nuitka_blob_stream) > 8:
            fname, fname_byte_size = read_string_utf8(nuitka_blob_stream)
            if not fname_byte_size:
                raise FatalError("Could not read file name")
            if not fname:
                fname = "<constants>"
            nuitka_blob_stream = nuitka_blob_stream[fname_byte_size:]
            if len(nuitka_blob_stream) < 4:
                raise FatalError("Truncated stream")
            file_size, = struct.unpack("<I", nuitka_blob_stream[:4])
            nuitka_blob_stream = nuitka_blob_stream[4:]
            if file_size > len(nuitka_blob_stream):
                raise FatalError(f"File size too large: {file_size}")
            file_offset = (original_size - len(nuitka_blob_stream) )
            nuitka_blob_stream = nuitka_blob_stream[file_size:]
            fname = "/".join(["#NUITKA_BLOBS", fname])
            pe.nuitka_files[fname] = (file_offset, file_size)
            pe.add_file(fname, file_size, "nuitka_extract_blob_entry")




        
