from filetypes.base import *
import malcat
import struct

# https://en.wikipedia.org/wiki/Design_of_the_FAT_file_system

def align(val, what, down=False):
    if val % what:
        if down:
            val -= val % what
        else:
            val += what - (val % what)
    return val



class BootSector(Struct):

    def parse(self):
        yield Bytes(3, name="InitialJump")
        yield String(8, name="OemName")
        yield ExtendedBiosParameter()
        yield Bytes(510 - len(self), name="BootloaderCode")
        yield UInt16(name="Signature")



class ExtendedBiosParameter(Struct):

    def parse(self):
        yield UInt16(name="BytesPerSector")
        yield UInt8(name="SectorsPerCluster")
        yield UInt16(name="ReservedSectors")
        yield UInt8(name="NumberOfFAT")
        yield UInt16(name="MaximumNumberOfRootDirectoryEntries")
        yield UInt16(name="TotalLogicalSectorsOld")
        yield UInt8(name="MediaType", values=[
            ("FixedDisk", 0xf8),
        ])
        yield UInt16(name="SectorsPerFatOld")
        yield UInt16(name="SectorsPerTrack")
        yield UInt16(name="NumberOfHeads")
        yield UInt32(name="HiddenSectors")
        yield UInt32(name="TotalLogicalSectors")
        yield UInt32(name="SectorsPerFat")
        yield UInt16(name="DriveDescription")
        yield UInt8(name="VersionMinor")
        yield UInt8(name="VersionMajor")
        yield UInt32(name="RootDirectoryCluster")
        yield UInt16(name="FSInformationSector")
        yield UInt16(name="BootCopySector")
        yield Unused(12)
        yield UInt8(name="DriveNumber")
        yield Unused(1)
        yield String(1, name="ExtendedBootSignature")
        yield UInt32(name="VolumeId")
        yield String(11, name="VolumeLabel")
        yield String(8, name="FileSystemType")


class InformationSector(Struct):

    def parse(self):
        sig = yield String(4, name="Signature")
        if sig != "RRaA":
            raise FatalError("Invalid FS signature")
        yield Unused(480)
        sig = yield String(4, name="Signature")
        if sig != "rrAa":
            raise FatalError("Invalid FS signature")
        yield UInt32(name="NumberOfFreeClusters")
        yield UInt32(name="LastAllocatedCluster")
        yield Unused(12)
        yield UInt32(name="Signature")


class DirectoryEntry(StaticStruct):

    @classmethod
    def parse(cls):
        yield String(8, name="ShortFileName")
        yield String(3, name="FileExtension")
        yield BitsField(
            Bit(name="ReadOnly"),
            Bit(name="Hidden"),
            Bit(name="System"),
            Bit(name="VolumeLabel"),
            Bit(name="Directory"),
            Bit(name="Archive"),
            Bit(name="Device"),
            name="FileAttributes")
        yield UInt8(name="ExtraAttributes")
        yield UInt8(name="CreationTimeFine")
        yield DosDateTime(name="CreationTime")
        yield DosDate(name="LastAccess")
        yield UInt16(name="FirstClusterHigh")
        yield DosDateTime(name="ModificationTime")
        yield UInt16(name="FirstClusterLow")
        yield UInt32(name="FileSize")


class Fat32Analyzer(FileTypeAnalyzer):
    category = malcat.FileType.FILESYSTEM
    name = "FAT32"
    regexp = r"\xEB.\x90.{8}\x00\x02[\x01\x02\x04\x08\x10\x20\x40\x80]..\x02\x00\x00\x00\x00\xf8\x00\x00.{42}\x29.{15}FAT32   .{420}\x55\xaa"


    def __init__(self):
        FileTypeAnalyzer.__init__(self)
        self.filesystem = {}
        self.fat = []


    def iter_cluster_chain(self, cluster):
        seen = set()
        while cluster >= 2 and cluster <= 0xFFFFFEF and cluster not in seen:
            seen.add(cluster)
            yield cluster
            if cluster >= len(self.fat):
                raise FatalError("No FAT entry for cluster {:x}".format(cluster))
            cluster = self.fat[cluster]

    def iter_file_intervals(self, start_cluster, totsz=None):
        last = None
        cursz = 0
        for cluster in self.iter_cluster_chain(start_cluster):
            if totsz is not None and totsz <= 0:
                break
            offset = self.user_space + (cluster - 2) * self.cluster_size
            if last is not None and offset != last + cursz:
                if totsz is not None:
                    cursz = min(cursz, totsz)
                yield last, cursz
                if totsz is not None:
                    totsz -= cursz
                last = None
            if last is None:
                last = offset
                cursz = self.cluster_size
            elif offset == last + cursz:
                cursz += self.cluster_size
        if last is not None:
            if totsz is not None:
                cursz = min(cursz, totsz)
            if cursz > 0:
                yield last, cursz
            if totsz is not None:
                totsz -= cursz
        if totsz:
            raise FatalError("Truncated file for cluster {:x}: {} left".format(start_cluster, totsz))

    def parse_dir(self, cluster, parents=[], size = None, seen=None):
        if seen is None:
            seen = set()
        elif cluster in seen:
            return
        seen.add(cluster)
        for offset, sz in self.iter_file_intervals(cluster, size):
            yield from self.parse_dir_range(offset, sz, parents, seen=seen)
            
    def parse_dir_range(self, offset, sz, parents=[], seen=None):
        if offset > self.size():
            return
        self.jump(offset)
        entries = yield Array(sz // 32, DirectoryEntry(), name="Directory")
        long_name_bytes = b""
        for e in entries:
            is_dir = e["FileAttributes"]["Directory"]
            filesize = e["FileSize"]
            if filesize == 0 and not is_dir:
                continue
            if e["FileAttributes"]["VolumeLabel"] and e["FileAttributes"]["System"] and e["FileAttributes"]["Hidden"] and e["FileAttributes"]["ReadOnly"]:
                cur_bytes = self.read(e.offset + 1, 10) + self.read(e.offset + 14, 12) + self.read(e.offset + 28, 4)
                long_name_bytes = cur_bytes + long_name_bytes
            else:
                name = e["ShortFileName"].strip()
                ext = e["FileExtension"].strip()
                if ext:
                    name += "." + ext
                if name and ord(name[0]) == 0xE5:
                    continue
                if long_name_bytes:
                    name = long_name_bytes.decode("utf-16-le").strip()
                    i = name.find("\x00")
                    if i >= 0:
                        name = name[:i]
                    long_name_bytes = b""
                cluster = e["FirstClusterLow"] + (e["FirstClusterHigh"] << 16)
                if is_dir:
                    if filesize == 0:
                        filesize = None
                    yield from self.parse_dir(cluster, parents + [name], filesize, seen=seen)
                else:
                    fname = "/".join(parents + [name])
                    self.filesystem[fname] = (cluster, filesize)
                    self.add_file(fname, filesize, "open_file")

    def open_file(self, vfile, password=None):
        cluster, filesize = self.filesystem.get(vfile.path, (None, 0))
        if not filesize:
            raise ValueError("Cannot locate file")
        res = bytearray()
        for offset, sz in self.iter_file_intervals(cluster, filesize):
            res.extend(self.read(offset, sz))
        return res

    def parse(self, hint):
        boot_sector = yield BootSector()
        ebp = boot_sector["ExtendedBiosParameter"]
        self.sector_size = ebp["BytesPerSector"]
        self.cluster_size = ebp["SectorsPerCluster"] * self.sector_size
        self.reserved_space = ebp["ReservedSectors"] * self.sector_size
        self.add_section("Boot", 0, self.reserved_space, x=True)
        eof = boot_sector["ExtendedBiosParameter"]["TotalLogicalSectors"] * self.sector_size
        self.set_eof(eof)

        # boot copy
        boot_copy = ebp["BootCopySector"]
        if self.read(0, self.sector_size) != self.read(boot_copy * self.sector_size, self.sector_size):
            raise FatalError("Boot copy does not match")
        self.jump(boot_copy * self.sector_size)
        yield BootSector(name="BootSectorCopy")
        self.confirm()

        is_offset = ebp["FSInformationSector"] * self.sector_size
        self.jump(is_offset)
        info_sector = yield InformationSector()

        # read FATs
        self.jump(self.reserved_space)
        fat_size = ebp["SectorsPerFat"] * self.sector_size
        self.add_section("FAT", self.reserved_space, ebp["NumberOfFAT"] * fat_size, discardable=True)
        for i in range(ebp["NumberOfFAT"]):
            fat = yield Array(fat_size // 4, UInt32(), name="FAT#{}".format(i), category=Type.FIXUP)
            if not self.fat:
                self.fat = [x.value for x in fat]
        self.confirm()

        # read root dir
        if ebp["MaximumNumberOfRootDirectoryEntries"]:
            raise FatalError("Not a valid root dir")
        self.user_space = self.tell()
        self.add_section("Clusters", self.user_space, eof - self.user_space, r=True, w=True)
        yield from self.parse_dir(ebp["RootDirectoryCluster"])
