from filetypes.base import *
import malcat
import struct



def align(val, what, down=False):
    if val % what:
        if down:
            val -= val % what
        else:
            val += what - (val % what)
    return val


class FileHeader(StaticStruct):

    @classmethod
    def parse(cls):
        yield String(16, zero_terminated=False, name="FileName")
        yield String(12, zero_terminated=False, name="ModificationDate")
        yield String(6, zero_terminated=False, name="OwnerId")
        yield String(6, zero_terminated=False, name="GroupId")
        yield String(8, zero_terminated=False, name="FileMode")
        yield String(10, zero_terminated=False, name="FileSize")
        yield String(2, zero_terminated=False, name="Ending")


class SymbolTable(Struct):

    def parse(self):
        noe = yield UInt32BE(name="NumberOfEntries")
        yield Array(noe, Offset32BE(), name="SymbolPositions")
        sname = CString(name="SymbolName")
        for i in range(noe):
            yield sname


class SymbolTableMicrosoft(Struct):

    def parse(self):
        noe = yield UInt32(name="NumberOfOffsets")
        yield Array(noe, Offset32(), name="SymbolOffsets")
        noe = yield UInt32(name="NumberOfSymbols")
        yield Array(noe, UInt16(), name="SymbolPointers")
        sname = CString(name="SymbolName")
        for i in range(noe):
            yield sname


class ArAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.ARCHIVE
    name = "AR"
    regexp = r"!<arch>\x0a[^\x00]{16}[0-9 -]{12}.{30}\x60\x0a"

    def __init__(self):
        FileTypeAnalyzer.__init__(self)
        self.filesystem = {}

    def open(self, vfile, password=None):
        offset, size = self.filesystem[vfile.path]
        return bytearray(self.read(offset, size))


    def parse(self, hint):
        st_seen = False
        extended_names = ""
        fname_seen = {}

        yield String(8, zero_terminated=False, name="Magic")
        while self.remaining() >= 60:
            if self.read(self.tell() + 58, 2) != b"\x60\x0a":
                break
            start = self.tell()
            fh = yield FileHeader()
            size = min(int(fh["FileSize"]), self.remaining())
            name = fh["FileName"].strip()
            mode = int(fh["FileMode"].strip() or "0")

            # special files
            if name == "/":
                if not st_seen:
                    st = yield SymbolTable(category=Type.FIXUP)
                    st_seen = True
                    symbols = st["SymbolPositions"]
                    for i, symname in enumerate(st[2:]):
                        # add symbol
                        n = symname.value
                        t = malcat.FileSymbol.EXPORT
                        if n.startswith("\x7f"):
                            n = n[1:]
                            t = malcat.FileSymbol.DATA
                        self.add_symbol(symbols[i] + 60, n, t)
                else:
                    st = yield SymbolTableMicrosoft(category=Type.DATA)
            elif name == "//":
                extended_names = self.read(self.tell(), size)
            else:
                if name.startswith("/") and extended_names:
                    position = int(name[1:])
                    if position < len(extended_names):
                        d = extended_names[position:]
                        end = d.find(b"/")
                        if end != -1:
                            name = d[:end].decode("utf8", errors="replace")
                elif name.endswith("/"):
                    name = name[:-1]

                count_seen = fname_seen.get(name, 0)
                if count_seen == 0:
                    fname_seen[name] = 1
                else:
                    fname_seen[name] += 1
                    name = f"{name}#{count_seen:d}"

                self.filesystem[name] = (self.tell(), size)
                self.add_file(name, size, "open")
                self.add_section(name, start, len(fh) + size, r=True)
            self.jump(start + len(fh) + size)
            if (self.tell() % 2 ) == 1:
                self.jump(self.tell() + 1)
            self.confirm()

