from filetypes.base import *
from filetypes.PYC import MAGIC_TO_VERSION
import malcat
import os
import struct
from filetypes.PYC import PythonAnalyzer, TYPE_LIST




class Py2exeHeader(Struct):

    def parse(self):
        yield UInt32(name="Magic")
        yield UInt32(name="Major")
        yield UInt32(name="Minor")
        yield UInt32(name="Size")
        yield CString(name="ArchiveName")



class Py2exeAnalyzer(PythonAnalyzer):
    category = malcat.FileType.PROGRAM
    name = "PY2EXE"
    regexp = r"\x12\x34\x56\x78"


    def __init__(self):
        PythonAnalyzer.__init__(self)

    def open(self, vfile, password=None):
        entry = self.filesystem[vfile.path]
        data = entry.bytes
        MAGIC_TO_VERSION_INV = {v: k for k, v in MAGIC_TO_VERSION.items()}
        magic = MAGIC_TO_VERSION_INV.get(self.version, MAGIC_TO_VERSION_INV[(3, 6)])
        header = struct.pack("<I", magic)
        header += b"\x00\x00\x00\x00"
        if self.version >= (3, 7):
            header += b"\x00\x00\x00\x00"
        if self.version >= (3, 3):
            header += struct.pack("<I", len(data))
        data = header + data
        # todo: how to resolve refs ?
        raise NotImplementedError
        return bytearray(data)


    def parse(self, hint):
        self.set_architecture(malcat.Architecture.PY27)
        hdr = yield Py2exeHeader()

        # guess version (not very precise but eh)
        off, sz = self.search(rb"\x40\x00\x00\x00[\x73\xf3]", self.tell() + 5)
        if not sz:
            raise FatalError("Could not locate flags position")
        distance = off - (self.tell() + 5)
        if distance == 0x15:
            self.version = (3, 8)   # 3.8 <= version
        elif distance == 0x11:
            self.version = (3, 6)   # 3.0 <= version < 3.8
        elif distance == 0x0d:
            self.version = (2, 7)   # version < 3.0
        else:
            raise FatalError("Invalid first flags position")
        self.update_arch(self.version)

        # parse code object
        for obj in self.parse_object(name="Modules", comment="modules defined in this file"):
            modules = yield obj
            if modules["Type"] & 0x7f != TYPE_LIST:
                raise FatalError("Root object is not a list")
            for i in range(modules["Size"]):
                module = modules.at(2 + i)
                name = self.read_object(module["Filename"])
                self.add_section(os.path.basename(name), module.offset, module.size)

            
