from filetypes.base import *
import malcat
import struct
from filetypes.PE import Section

CV_SIGNATURE_C6  = 0  # Actual signature is >64K
CV_SIGNATURE_C7  = 1  # First explicit signature
CV_SIGNATURE_C11 = 2  # C11 (vc5.x) 32-bit types
CV_SIGNATURE_C13 = 4  # C13 (vc7.x) zero terminated names

class COFF(Struct):
    def parse(self):
        yield UInt16(name="Machine", values=[
            ("IMAGE_FILE_MACHINE_UNKNOWN", 0x0),
            ("IMAGE_FILE_MACHINE_AM33", 0x1d3),
            ("IMAGE_FILE_MACHINE_AMD64", 0x8664),
            ("IMAGE_FILE_MACHINE_ARM", 0x1c0),
            ("IMAGE_FILE_MACHINE_ARM64", 0xaa64),
            ("IMAGE_FILE_MACHINE_ARMNT", 0x1c4),
            ("IMAGE_FILE_MACHINE_EBC", 0xebc),
            ("IMAGE_FILE_MACHINE_I386", 0x14c),
            ("IMAGE_FILE_MACHINE_IA64", 0x200),
            ("IMAGE_FILE_MACHINE_M32R", 0x9041),
            ("IMAGE_FILE_MACHINE_MIPS16", 0x266),
            ("IMAGE_FILE_MACHINE_MIPSFPU", 0x366),
            ("IMAGE_FILE_MACHINE_MIPSFPU16", 0x466),
            ("IMAGE_FILE_MACHINE_POWERPC", 0x1f0),
            ("IMAGE_FILE_MACHINE_POWERPCFP", 0x1f1),
            ("IMAGE_FILE_MACHINE_R4000", 0x166),
            ("IMAGE_FILE_MACHINE_RISCV32", 0x5032),
            ("IMAGE_FILE_MACHINE_RISCV64", 0x5064),
            ("IMAGE_FILE_MACHINE_RISCV128", 0x5128),
            ("IMAGE_FILE_MACHINE_SH3", 0x1a2),
            ("IMAGE_FILE_MACHINE_SH3DSP", 0x1a3),
            ("IMAGE_FILE_MACHINE_SH4", 0x1a6),
            ("IMAGE_FILE_MACHINE_SH5", 0x1a8),
            ("IMAGE_FILE_MACHINE_THUMB", 0x1c2),
            ("IMAGE_FILE_MACHINE_WCEMIPSV2", 0x169),
            ], comment="supported architecture")
        yield UInt16(name="NumberOfSections", comment="number of sections")
        yield Timestamp(name="TimeDateStamp", comment="file creation time")
        yield Offset32(name="PointerToSymbolTable", comment="pointer to symbol table")
        yield UInt32(name="NumberOfSymbols", comment="number of symbols")
        yield UInt16(name="SizeOfOptionalHeader", comment="size of optional header")
        yield UInt16(name="Characteristics", comment="file characteristics")


class SymbolTable(Struct):

    def __init__(self, nument, **args):
        Struct.__init__(self, **args)
        self.nument = nument

    def parse(self):
        sz = self.nument * 18
        string_table = self.offset + sz
        while len(self) < sz:
            se = yield SymbolEntry(string_table, name="Symbol")
            numaux = se["NumAux"]
            for i in range(numaux):
                if len(self) < sz:
                    yield Bytes(numaux*18, name="Symbol.AuxData", comment="auxilliary data")


class StringTable(Struct):

    def parse(self):
        sz = yield UInt32(name="Size", comment="string table size")
        if sz > 4:
            yield Bytes(sz - 4, name="Strings", comment="strings")




class SymbolEntry(Struct):
    def __init__(self, string_table, **kwargs):
        Struct.__init__(self, **kwargs)
        self.string_table = string_table

    def parse(self):
        if self.look_ahead(4) == b"\x00\x00\x00\x00":
            yield Unused(4, name="Zero")
            yield Offset32(name="NameOffset", base=self.string_table, hint=String(0, zero_terminated=True), comment="offset of symbol name inside string table")
        else:
            yield String(8, name="Name", comment="symbol name")
        yield UInt32(name="Value", comment="value of symbol, type-dependant")
        yield Int16(name="SectionNumber", comment="section number this symbol belongs to. could also be 0 (pubic symbol), -1 (absolute symbol) or -2 (debug symbol")
        yield UInt16(name="Type", comment="symbol type")
        yield UInt8(name="StorageClass", comment="tells where and what the symbol represents")
        numaux = yield UInt8(name="NumAux", comment="how many equivalent SYMENTs are used for aux entries")



class CVSymbolStream13(Struct):
    def __init__(self, size, **kwargs):
        Struct.__init__(self, **kwargs)
        self.maxsize = size

    def parse(self):
        subtype = yield UInt32(name="SubSectionType", comment="type of subsection")
        if subtype != 0xF1:
            raise FatalError("Invalid subsection type: {:x}".format(subtype))
        todo = yield UInt32(name="SubsectionSize", comment="size of subsection")
        if todo != 0:
            sz = min(self.maxsize - len(self), todo)
        else:
            sz = self.maxsize - len(self)
        yield Bytes(todo, name="Symbols", comment="symbols data")
        #while analyzer.tell() + 4 <= end:
        #    size, type = struct.unpack("<HH", analyzer.read(analyzer.tell(), 4))
        #    struc = RECORD_TYPE_TO_STRUCT.get(type, CVSymbolRecord)
        #    yield struc(name="Symbol")
        if todo % 4:
            yield Unused(4 - (todo % 4), name="Padding")


class CVUnknownStream13(Struct):
    def __init__(self, size, **kwargs):
        Struct.__init__(self, **kwargs)
        self.maxsize = size

    def parse(self):
        subtype = yield UInt32(name="SubSectionType", comment="type of subsection")
        todo = yield UInt32(name="SubsectionSize", comment="size of subsection")
        if todo != 0:
            sz = min(self.maxsize - len(self), todo)
        else:
            sz = self.maxsize - len(self)
        yield Bytes(sz, name="StreamData")
        if todo % 4:
            yield Unused(4 - (todo % 4), name="Padding")


class CoffRelocation(Struct):

    def __init__(self, reloc_base, *args, **kwargs):
        Struct.__init__(self, *args, **kwargs)
        self.reloc_base = reloc_base

    def parse(self):
        yield Offset32(name="VirtualAddress", comment="address of the relocation", base=self.reloc_base)
        yield UInt32(name="SymbolTableIndex", comment="zero-based index into the symbol table. this symbol gives the address that is to be used for the relocation")
        yield UInt16(name="Type", comment="indicates the kind of relocation that should be performed")


    

        


class OBJAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.PROGRAM
    name = "OBJ"
    regexp = r"(?:\x64\x86|\x64\xaa|\x4c\x01|\x13\x0c).[\x00-\x20].......\x00...\x00\x00\x00\x00\x00\..{39}\."

    def __init__(self):
        FileTypeAnalyzer.__init__(self)

   
    def parse(self, hint):
        coff = yield COFF(category=Type.HEADER)
        if coff["Machine"] == 0x8664:
            self.set_architecture(malcat.Architecture.X64)
        elif coff["Machine"] == 0xaa64:
            self.set_architecture(malcat.Architecture.AARCH64)
        else:
            self.set_architecture(malcat.Architecture.X86)
        section_start = self.tell()

        machine = coff["Machine"]
        numsyms = coff["NumberOfSymbols"]
        syms = coff["PointerToSymbolTable"]
        string_table = syms + numsyms * 18
        self.jump(string_table)
        if self.remaining() < 4:
            raise FatalError("No string table")
        yield StringTable(category=Type.DATA)
        self.set_eof(self.tell())

        self.jump(section_start)
        sections = yield Array(coff["NumberOfSections"], Section(), name="Sections", category=Type.HEADER)
        if sections.count < 1 or sections.count > 5000:
            raise FatalError("Invalid number of sections") 
        maximum_physical_address = self.tell()
        has_section = False
        for s in sections:
            sname = s["Name"]
            si = sname.find("\x00")
            if si >= 0:
                sname = sname[:si]
            psize = (s["SizeOfRawData"] + 10*s["NumberOfRelocations"] + 6*s["NumberOfLinenumbers"]) & 0xFFFFFFFF
            eos = s["PointerToRawData"] + psize
            if eos > self.size():
                raise FatalError("Truncated section")
            #elif s["VirtualSize"] == 0 and s["SizeOfRawData"] == 0:
            #    raise FatalError("Empty section")
            elif not s["Characteristics"]["MemRead"] and not s["Characteristics"]["MemWrite"] and not s["Characteristics"]["MemExecute"] and not s["Characteristics"]["MemDiscardable"] and not s["Characteristics"]["ScnLnkInfo"] and s["SizeOfRawData"]:
                raise FatalError("invalid characteristics for {}".format(sname))
            if psize:
                maximum_physical_address = max(maximum_physical_address, eos)
                self.add_section(sname, s["PointerToRawData"], psize,
                    r = s["Characteristics"]["MemRead"],
                    w = s["Characteristics"]["MemWrite"],
                    x = s["Characteristics"]["MemExecute"],
                    discardable = s["Characteristics"]["MemDiscardable"]
                )
                has_section = True
        if not has_section:
            raise FatalError("No section")
        self.confirm()

        # symbols
        symtable = None
        if syms and numsyms:
            self.jump(syms)
            symtable = yield SymbolTable(numsyms, category=Type.DEBUG)

        # symbols
        if symtable is not None:
            for s in symtable:
                if type(s.value) == bytes:
                    continue
                if not "Value" in s or s["Type"] != 0x20 or s["SectionNumber"] <= 0:
                    continue
                if "Name" in s:
                    name = s["Name"].replace("\x00", "")
                elif string_table is not None:
                    name = self.read_cstring_ascii(string_table + s["NameOffset"], 512)
                else:
                    raise FatalError("Cannot get name for symbol: no string table")
                section = sections[s["SectionNumber"] - 1]
                off = section["PointerToRawData"] + s["Value"]
                if self.read(off, 2) != b"\x03\x30":
                    self.add_symbol(off, name, malcat.FileSymbol.EXPORT)

        for s in sections:
            if s["PointerToRelocations"] and s["NumberOfRelocations"]:
                self.jump(s["PointerToRelocations"])
                reloc_base = s["PointerToRawData"] + s["VirtualAddress"]
                relocs = yield Array(s["NumberOfRelocations"], CoffRelocation(reloc_base), name="{}.Relocations".format(s["Name"].replace("\0","")), category=Type.FIXUP)
                for reloc in relocs:
                    rtype = reloc["Type"]
                    if symtable is not None and (
                            rtype in (0x4, 0x6, 0x14)):
                        rindex = reloc["SymbolTableIndex"]
                        if rindex >= symtable.count:
                            raise FatalError("{} > {} for {:x}".format(rindex, symtable.count, reloc.offset))
                        symentry = symtable[rindex]
                        if type(symentry) == bytes:
                            print(f"Cannot read symbol {rindex:x} at {symtable.at(rindex).offset:x} for reloc #{reloc.offset:x} ({reloc_base:x})")
                            continue
                        if "Name" in symentry:
                            name = symentry["Name"].replace("\x00", "")
                        elif string_table is not None:
                            name = self.read_cstring_ascii(string_table + symentry["NameOffset"], 512)
                        else:
                            raise FatalError("Cannot get name for symbol: no string table")
                        self.add_symbol(reloc_base + reloc["VirtualAddress"], name, malcat.FileSymbol.IMPORT)

            if s["Name"] == ".debug$S" and s["SizeOfRawData"] > 12:
                self.jump(s["PointerToRawData"])
                end = self.tell() + s["SizeOfRawData"]
                sig = yield UInt32(name="Signature", comment="cvres version", category=Type.HEADER)
                if sig == CV_SIGNATURE_C13:
                    # msvc 13
                    while self.tell() + 4 < end:
                        subtype, = struct.unpack("<I", self.read(self.tell(), 4))
                        if subtype == 0xF1:
                           yield CVSymbolStream13(end - self.tell(), name="SymbolStream", category=Type.DEBUG)
                        else:
                            yield CVUnknownStream13(end - self.tell(), name="Stream_{:2X}".format(subtype), category=Type.DEBUG)
