from filetypes.base import *
import malcat
import struct

# https://en.wikipedia.org/wiki/Design_of_the_FAT_file_system

def align(val, what, down=False):
    if val % what:
        if down:
            val -= val % what
        else:
            val += what - (val % what)
    return val



class BootSector(Struct):

    def parse(self):
        yield Bytes(3, name="InitialJump")
        yield String(8, name="OemName")
        yield ExtendedBiosParameter()
        yield Bytes(510 - len(self), name="BootloaderCode")
        yield UInt16(name="Signature")



class ExtendedBiosParameter(Struct):

    def parse(self):
        yield UInt16(name="BytesPerSector")
        yield UInt8(name="SectorsPerCluster")
        yield UInt16(name="ReservedSectors")
        yield UInt8(name="NumberOfFAT")
        yield UInt16(name="MaximumNumberOfRootDirectoryEntries")
        yield UInt16(name="TotalLogicalSectors")
        yield UInt8(name="MediaType", values=[
            ("FixedDisk", 0xf8),
        ])
        yield UInt16(name="SectorsPerFat")
        yield UInt16(name="SectorsPerTrack")
        yield UInt16(name="NumberOfHeads")
        yield UInt32(name="HiddenSectors")
        yield UInt32(name="TotalLogicalSectors")
        yield UInt8(name="DriveNumber")
        yield Unused(1)
        yield UInt8(name="ExtendedBootSignature")
        yield Unused(1)
        yield UInt64(name="NumberOfSectors")
        yield UInt64(name="MFTCluster")
        yield UInt64(name="MirrorMFTCluster")
        yield UInt8(name="MFTEntrySize", comment="Values 0 to 128 represent sizes of 0 to 128 sectors. Values 128 to 255 represent sizes of 2^(256-n) bytes; or 2^(-n) if considered as a signed byte.")
        yield Unused(3)
        yield UInt8(name="IndexEntrySize", comment="Values 0 to 128 represent sizes of 0 to 128 sectors. Values 128 to 255 represent sizes of 2^(256-n) bytes; or 2^(-n) if considered as a signed byte.")
        yield Unused(3)
        yield UInt64(name="VolumeSerialNumber")
        yield Unused(4)
        




class Inode(StaticStruct):

    @classmethod
    def parse(self):               
        pass

class MFTEntry(Struct):

    def parse(self):
        sig = yield String(4, zero_terminated=False, name="Signature")
        if sig != "FILE" and sig != "BAAD":
            raise FatalError("Invalid MFT signature")
        fixup_offset = yield Offset16(name="FixupValuesOffset", base=self.offset)
        fixup_count = yield UInt16(name="FixupCount")
        yield UInt64(name="JournalSequenceNumber")
        yield UInt16(name="SequenceIndex")
        yield UInt16(name="ReferenceCount")
        attr_offset = yield Offset16(name="AttributesOffset", base=self.offset)
        yield BitsField(
            Bit(name="MFT_RECORD_IN_USE"),
            Bit(name="MFT_RECORD_IS_DIRECTORY"),
            Bit(name="MFT_RECORD_IN_EXTEND"),
            Bit(name="MFT_RECORD_IS_VIEW_INDEX"),
            NullBits(12),
            name="Flags")
        yield UInt32(name="UsedEntrySize")
        tsz = yield UInt32(name="TotalEntrySize")
        yield UInt64(name="BaseRecordFileReference")
        if fixup_count:
            if len(self) < fixup_offset:
                yield Unused(fixup_offset - len(self))
            yield UInt16(name="FixupPlaceHolderValue")
            yield Array(fixup_count, UInt16(), name="FixupOriginalValues")
        if len(self) < attr_offset:
            yield Unused(attr_offset - len(self))
        yield MFTAttributes(name="Attributes")
        if len(self) < tsz:
            yield Unused(tsz - len(self))
        

class MFTAttributes(Struct):

    def parse(self):
        while True:
            name, at = ATTRIBUTE_TYPES.get(struct.unpack("<I", self.look_ahead(4))[0], ("???", None))
            if name == "???":
                break
            yield at(name=name) 
     

SIZE_TO_UINT = {
    1: UInt8,
    2: UInt16,
    3: UInt24,
    4: UInt32,
    8: UInt64,
}

SIZE_TO_INT = {
    1: Int8,
    2: Int16,
    4: Int32,
    8: Int64,
}


class DataRuns(Struct):

    def parse(self):
        while self.look_ahead(1) != b"\x00":
            yield DataRun()


class DataRun(Struct):

    def parse(self):
        hdr = yield UInt8(name="CountValueSize")
        count_size = hdr & 0xf
        value_size = hdr // 16
        if count_size:
            if count_size not in SIZE_TO_UINT:
                raise FatalError(f"Invalid count size {count_size} at #{self.offset:x}")
            yield SIZE_TO_UINT[count_size](name="Count")
        if value_size:
            if value_size not in SIZE_TO_UINT:
                raise FatalError(f"Invalid value size {value_size} at #{self.offset:x}")
            yield SIZE_TO_UINT[value_size](name="Offset")       


class Inode(StaticStruct):

    @classmethod
    def parse(cls):
        yield UInt48(name="MFTIndex")
        yield UInt16(name="SequenceNumber")

class AttributeRecordHeader(Struct):
    
    def parse(self):
        typ = yield UInt32(name="Type", values=[(v[0], k) for k, v in ATTRIBUTE_TYPES.items()])
        yield UInt32(name="RecordLength")
        res = yield UInt8(name="Flag", values=[
            ("RESIDENT", 0),
            ("NON_RESIDENT", 1),
        ])
        nsz = yield UInt8(name="NameSize")
        no = yield Offset16(name="NameOffset", base=self.offset)
        yield BitsField(
            Bit(name="ATTRIBUTE_FLAG_COMPRESSED"),
            NullBits(13),
            Bit(name="ATTRIBUTE_FLAG_ENCRYPTED"),
            Bit(name="ATTRIBUTE_FLAG_SPARSE"),
            name="Flags")
        yield UInt16(name="Identifier")
        if res == 0:
            data_sz = yield UInt32(name="DataSize")
            data_offset = yield Offset16(name="DataOffset", base=self.offset)
            yield BitsField(
                Bit(name="Indexed"),
                NullBits(7),
            name="IndexationFlags")
            yield Unused(1)
            if nsz:
                if len(self) < no:
                    yield Unused(no - len(self))
                yield StringUtf16le(nsz, zero_terminated=False, name="Name")
        else:
            yield UInt64(name="DataFirstClusterNumber")
            yield UInt64(name="DataLastClusterNumber")
            data_runs = yield Offset16(name="DataRunsOffset", base=self.offset)
            comp_sz = yield UInt16(name="CompressionUnitSize")
            yield Unused(4)
            yield UInt64(name="AllocatedDataSize")
            yield UInt64(name="DataSize")
            yield UInt64(name="ValidDataSize")
            if comp_sz > 0:
                yield UInt64(name="TotalAllocatedSize")
            if nsz:
                if len(self) < no:
                    yield Unused(no - len(self))
                yield StringUtf16le(nsz, zero_terminated=False, name="Name")
            if len(self) < data_runs:
                yield Unused(data_runs - len(self))
            yield DataRuns(name="DataRuns")


class GenericAttributeRecord(AttributeRecordHeader):

    def parse(self):
        hdr = yield AttributeRecordHeader(name="Header")
        if len(self) < hdr["RecordLength"]:
            yield Unused(hdr["RecordLength"] - len(self))


class StandardInformationAttribute(AttributeRecordHeader):

    def parse(self):
        hdr = yield AttributeRecordHeader(name="Header")
        rl = hdr["RecordLength"]
        yield Filetime(name="CreationTime")
        yield Filetime(name="LastWrittenTime", comment="last time the data was written")
        yield Filetime(name="LastModificationTime", comment="last time the MFT entry was modified")
        yield Filetime(name="LastAccessTime", comment="last time the file was accesssed")
        yield BitsField(
            Bit(name="ReadOnly"),
            Bit(name="Hidden"),
            Bit(name="System"),
            NullBits(1),
            Bit(name="Directory"),
            Bit(name="Archive"),
            Bit(name="Device"),
            Bit(name="Normal"),
            Bit(name="Temporary"),
            Bit(name="SparseFile"),
            Bit(name="ReparsePoint"),
            Bit(name="Compressed"),
            Bit(name="Offline"),
            Bit(name="NotContentIndexed"),
            Bit(name="Encrypted"),
            NullBits(1),
            Bit(name="Virtual"),
            NullBits(11),
            Bit(name="DirectoryIndexed"),
            Bit(name="IndexView"),
        name="FileAttributes")
        yield UInt32(name="MaximumVersions")
        yield UInt32(name="VersionNumber")
        yield UInt32(name="ClassIdentifer")
        if len(self) < rl:
            yield UInt32(name="OwnerIdentifer")
        if len(self) < rl:
            yield UInt32(name="SecurityDescriptorIdentifer")
        if len(self) < rl:
            yield UInt64(name="QuotaChanged")
        if len(self) < rl:
            yield UInt64(name="UpdateSequenceNumber")
        if len(self) < hdr["RecordLength"]:
            yield Unused(hdr["RecordLength"] - len(self))


class FileNameAttribute(AttributeRecordHeader):

    def parse(self):
        hdr = yield AttributeRecordHeader(name="Header")
        yield Inode(name="ParentInode")
        yield Filetime(name="CreationTime")
        yield Filetime(name="LastWrittenTime", comment="last time the data was written")
        yield Filetime(name="LastModificationTime", comment="last time the MFT entry was modified")
        yield Filetime(name="LastAccessTime", comment="last time the file was accesssed")
        yield UInt64(name="AllocatedFileSize")
        yield UInt64(name="FileSize")
        yield BitsField(
            Bit(name="ReadOnly"),
            Bit(name="Hidden"),
            Bit(name="System"),
            NullBits(1),
            Bit(name="Directory"),
            Bit(name="Archive"),
            Bit(name="Device"),
            Bit(name="Normal"),
            Bit(name="Temporary"),
            Bit(name="SparseFile"),
            Bit(name="ReparsePoint"),
            Bit(name="Compressed"),
            Bit(name="Offline"),
            Bit(name="NotContentIndexed"),
            Bit(name="Encrypted"),
            NullBits(1),
            Bit(name="Virtual"),
            NullBits(11),
            Bit(name="DirectoryIndexed"),
            Bit(name="IndexView"),
        name="FileAttributes")
        yield UInt32(name="ExtendedData")
        ns = yield UInt8(name="NameStringSize")
        yield UInt8(name="Namespace", values=[
            ("POSIX", 0),
            ("WINDOWS", 1),
            ("DOS", 2),
            ("DOS_AND_WINDOWS", 3),
        ])
        yield StringUtf16le(ns, zero_terminated=False, name="Name")
        if len(self) < hdr["RecordLength"]:
            yield Unused(hdr["RecordLength"] - len(self))


ATTRIBUTE_TYPES = {
            0x00000010: ("$STANDARD_INFORMATION", StandardInformationAttribute),
            0x00000020: ("$ATTRIBUTE_LIST", GenericAttributeRecord),
            0x00000030: ("$FILE_NAME", FileNameAttribute),
            0x00000040: ("$OBJECT_ID", GenericAttributeRecord),
            0x00000050: ("$SECURITY_DESCRIPTOR", GenericAttributeRecord),
            0x00000060: ("$VOLUME_NAME", GenericAttributeRecord),
            0x00000070: ("$VOLUME_INFORMATION", GenericAttributeRecord),
            0x00000080: ("$DATA", GenericAttributeRecord),
            0x00000090: ("$INDEX_ROOT", GenericAttributeRecord),
            0x000000a0: ("$INDEX_ALLOCATION", GenericAttributeRecord),
            0x000000b0: ("$BITMAP", GenericAttributeRecord),
            0x000000c0: ("$REPARSE_POINT", GenericAttributeRecord),
            0x000000d0: ("$EA_INFORMATION", GenericAttributeRecord),
            0x000000e0: ("$EA", GenericAttributeRecord),
    }        

class NTFSAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.FILESYSTEM
    name = "NTFS"
    regexp = r"\xEB[\x52-\x60]\x90NTFS    \x00[\x01\x02\x04\x08\x10][\x01\x02\x04\x08\x10\x20\x40\x80]\x00\x00\x00\x00\x00..\xf8\x00\x00\x3f\x00\xff\x00.{4}\x00{4}\x80\x00\x80\x00.{470}\x55\xaa"


    def __init__(self):
        FileTypeAnalyzer.__init__(self)
        self.filesystem = {}
        self.mft = []

    def iterate_data_runs(mfte):
        if not "DataRuns" in mfte:
            return
        cur = None
        for datarun in mfte["DataRuns"]:
            if not "Count" in datarun:
                break
            count = datarun["Count"]
            if offset in datarun:
                offset = datarun["Offset"]

    def open_file(self, vfile, password=None):
        cluster, filesize = self.filesystem.get(vfile.path, (None, 0))
        if not filesize:
            raise ValueError("Cannot locate file")
        res = bytearray()
        for offset, sz in self.iter_file_intervals(cluster, filesize):
            res.extend(self.read(offset, sz))
        return res

    def parse(self, hint):
        boot_sector = yield BootSector()
        ebp = boot_sector["ExtendedBiosParameter"]
        self.sector_size = ebp["BytesPerSector"]
        self.set_eof(ebp["NumberOfSectors"] * self.sector_size)
        self.cluster_size = ebp["SectorsPerCluster"] * self.sector_size
        mfte = ebp["MFTEntrySize"]
        if mfte <= 128:
            self.mft_entry_size = mfte
        else:
            self.mft_entry_size = 1 << (256 - mfte)
        mft_offset = ebp["MFTCluster"] * self.cluster_size
        
        # compute mft effective size
        mft0data = self.read(mft_offset, self.mft_entry_size)
        attributes_delta, = struct.unpack("<H", mft0data[0x14:0x16])
        if self.mft_entry_size <= attributes_delta:
            raise FatalError("Invalid attribute delta size")
        attributes = self.read(mft_offset + attributes_delta, self.mft_entry_size - attributes_delta)
        mft_effective_size = 0
        aoff = 0
        while aoff < len(attributes):
            atype, asize = struct.unpack_from("<II", attributes, offset=aoff)
            if atype == 0x30:
                # found $FILENAME
                mft_effective_size, = struct.unpack_from("<Q", attributes, offset = aoff + 0x48)
                break
            aoff += asize
        if not mft_effective_size:
            raise FatalError("Cannot compute size of MFT")
        self.confirm()
        
        num_mft_entry = mft_effective_size // self.mft_entry_size
        self.jump(mft_offset)
        self.mft = yield VariableArray(num_mft_entry, MFTEntry, name="MFT")
        mft_offset = ebp["MirrorMFTCluster"] * self.cluster_size
        self.jump(mft_offset)
        yield VariableArray(4, MFTEntry, name="MFTMirror")
        
