from filetypes.base import *
import malcat
import struct

def align(val, what, down=False):
    if val % what:
        if down:
            val -= val % what
        else:
            val += what - (val % what)
    return val


class VolumeDescriptor(Struct):

    def parse(self):
        yield UInt8(name="Type", values=[
            ("BOOT", 0),
            ("PRIMARY", 1),
            ("SECONDARY", 2),
            ("VOLUME", 3),
            ("TERMINATOR", 255),
        ])
        magic = yield String(5, zero_terminated=False, name="Magic")
        if magic != "CD001" and False:
            raise FatalError("Invalid magic")
        ver = yield UInt8(name="Version", comment="always 1")
        if ver != 1:
            raise FatalError("Invalid version")
        yield Bytes(2041, name="Data")

class PrimaryVolumeDescriptor(Struct):

    def parse(self):
        yield UInt8(name="Type", values=[
            ("BOOT", 0),
            ("PRIMARY", 1),
            ("SECONDARY", 2),
            ("VOLUME", 3),
            ("TERMINATOR", 255),
        ])
        magic = yield String(5, zero_terminated=False, name="Magic")
        ver = yield UInt8(name="Version", comment="always 1")
        if ver != 1:
            raise FatalError("Invalid version")
        yield Unused(1)
        yield String(32, name="SystemIdentifier", comment="name of the system that can act upon sectors 0x00-0x0F for the volume")
        yield String(32, name="VolumeIdentifier", comment="identification of this volume")
        yield Unused(8)
        yield UInt32LsbMsb(name="VolumeSpaceSize", comment="number of Logical Blocks in which the volume is recorded")
        yield Unused(32)
        yield UInt16LsbMsb(name="VolumeSetSize", comment="size of the set in this logical volume (number of disks)")
        yield UInt16LsbMsb(name="VolumeSequenceNumber", comment="number of this disk in the Volume Set")
        yield UInt16LsbMsb(name="LogicalBlockSize", comment="size in bytes of a logical block. NB: This means that a logical block on a CD could be something other than 2 KiB!")
        yield UInt32LsbMsb(name="PathTableSize", comment="size in bytes of the path table")
        yield UInt32(name="PathTableLsb", comment="LBA location of the path table. The path table pointed to contains only little-endian values")
        yield UInt32(name="OptionalPathTableLsb", comment="LBA location of the optional path table. The path table pointed to contains only little-endian values. Zero means that no optional path table exists.")
        yield UInt32BE(name="PathTableMsb", comment="LBA location of the path table. The path table pointed to contains only big-endian values")
        yield UInt32BE(name="OptionalPathTableMsb", comment="LBA location of the optional path table. The path table pointed to contains only big-endian values. Zero means that no optional path table exists.")
        yield Directory(name="RootDirectory")
        yield String(128, name="VolumeSetIdentifier", comment="volume set of which this volume is a member")
        yield String(128, name="PublisherIdentifier", comment="volume publisher. For extended publisher information, the first byte should be 0x5F, followed by the filename of a file in the root directory. If not specified, all bytes should be 0x20")
        yield String(128, name="DataPreparerIdentifier", comment="identifier of the person(s) who prepared the data for this volume. For extended preparation information, the first byte should be 0x5F, followed by the filename of a file in the root directory. If not specified, all bytes should be 0x20")
        yield String(128, name="ApplicationIdentifier", comment="identifies how the data are recorded on this volume. For extended information, the first byte should be 0x5F, followed by the filename of a file in the root directory. If not specified, all bytes should be 0x20")
        yield String(37, name="CopyrightFileIdentifier", comment="name of a file in the root directory that contains copyright information for this volume set. If not specified, all bytes should be 0x20")
        yield String(37, name="AbstractFileIdentifier", comment="name of a file in the root directory that contains abstract information for this volume set. If not specified, all bytes should be 0x20")
        yield String(37, name="BibliographicFileIdentifier", comment="name of a file in the root directory that contains bibliographic information for this volume set. If not specified, all bytes should be 0x20")
        yield DecDateTime(name="VolumeCreationTime", comment="date and time of when the volume was created")
        yield DecDateTime(name="VolumeModificationTime", comment="date and time of when the volume was modified")
        yield DecDateTime(name="VolumeExpirationTime", comment="date and time after which this volume is considered to be obsolete. If not specified, then the volume is never considered to be obsolete")
        yield DecDateTime(name="VolumeEffectiveTime", comment="date and time after which the volume may be used. If not specified, the volume may be used immediately")
        yield UInt8(name="Version", comment="directory records and path table version (always 0x01)")
        if len(self) < 0x800:
            yield Unused(0x800 - len(self))


class BootVolumeDescriptor(Struct):

    def parse(self):
        yield UInt8(name="Type", values=[
            ("BOOT", 0),
            ("PRIMARY", 1),
            ("SECONDARY", 2),
            ("VOLUME", 3),
            ("TERMINATOR", 255),
        ])
        magic = yield String(5, zero_terminated=False, name="Magic")
        ver = yield UInt8(name="Version", comment="always 1")
        if ver != 1:
            raise FatalError("Invalid version")
        yield String(32, name="BootSystemIdentifier", comment="ID of the system which can act on and boot the system from the boot record")
        yield String(32, name="BootIdentifier", comment="Identification of the boot system defined in the rest of this descriptor")
        yield UInt32(name="BootExtentLocation", comment="")
        yield UInt32(name="BootExtentSize", comment="")
        yield Va64(name="LoadAddress", comment="")
        yield Va64(name="StartAddress", comment="")
        yield ExtendedTimestamp(name="CreationTime")
        yield BitsField(
            Bit(name="Erase", comment="any such Boot Descriptor (including this Boot Descriptor) shall be ignored."),
            NullBits(15),
            name="Flags")
        yield Unused(32)
        if len(self) < 0x800:
            yield Bytes(0x800 - len(self), name="BootSystem")


class DecDateTime(Struct):

    def parse(self):
        yield String(4, name="Year")
        yield String(2, name="Month")
        yield String(2, name="Day")
        yield String(2, name="Hours")
        yield String(2, name="Minutes")
        yield String(2, name="Seconds")
        yield String(2, name="HundredsSeconds")
        yield Int8(name="GmtOffset", comment="time zone offset from GMT in 15 minute intervals, starting at interval -48 (west) and running up to interval 52 (east)")


class BinDateTime(Struct):
    def parse(self):
        yield UInt8(name="Year", comment="years since 1900")
        yield UInt8(name="Month")
        yield UInt8(name="Day")
        yield UInt8(name="Hours")
        yield UInt8(name="Minutes")
        yield UInt8(name="Seconds")
        yield Int8(name="GmtOffset", comment="time zone offset from GMT in 15 minute intervals, starting at interval -48 (west) and running up to interval 52 (east)")

class ExtendedTimestamp(Struct):
    def parse(self):
        yield UInt16(name="TypeAndTimezone", comment="years since 1900")
        yield UInt16(name="Year")
        yield UInt8(name="Month")
        yield UInt8(name="Day")
        yield UInt8(name="Hours")
        yield UInt8(name="Minutes")
        yield UInt8(name="Seconds")
        yield UInt8(name="Centiseconds")
        yield UInt8(name="HundredsMicroseconds")
        yield UInt8(name="Microseconds")


class UInt32LsbMsb(Struct):

    def parse(self):
        yield UInt32(name="Lsb", comment="LSB representation of value")
        yield UInt32BE(name="Msb", comment="MSB representation of value")


class UInt16LsbMsb(Struct):

    def parse(self):
        yield UInt16(name="Lsb", comment="LSB representation of value")
        yield UInt16BE(name="Msb", comment="MSB representation of value")        


class Directory(Struct):

    def parse(self):
        sz = yield UInt8(name="Size", comment="length of Directory Record")
        es = yield UInt8(name="ExtendedSize", comment="Extended Attribute Record length")
        yield UInt32LsbMsb(name="ExtentLocation", comment="location of extent (LBA) in both-endian format.")
        yield UInt32LsbMsb(name="ExtentSize", comment="data length (size of extent) in both-endian format")
        yield BinDateTime(name="DateTime", comment="Recording date and time")
        yield BitsField(
            Bit(name="Hidden"),
            Bit(name="IsDirectory"),
            Bit(name="AssociatedFile"),
            Bit(name="ExtendedFormat"),
            Bit(name="ExtendedOwners"),
            NullBits(2),
            Bit(name="NotFinal"),
            name="Flags")
        yield UInt8(name="FileUnitSize", comment="file unit size for files recorded in interleaved mode, zero otherwise")
        yield UInt8(name="GapSize", comment="interleave gap size for files recorded in interleaved mode, zero otherwise")
        yield UInt16LsbMsb(name="VolumeSequenceNumber", comment="volume sequence number - the volume that this extent is recorded on, in 16 bit both-endian format")
        nsz = yield UInt8(name="FileIdentifierSize")
        yield String(nsz, zero_terminated=False, name="FileIdentifier")
        yield Align(2)
        if len(self) < sz:
            yield Unused(sz - len(self))

















DESCRIPTORS = {
    (0, b"CD001"): BootVolumeDescriptor,
    (1, b"CD001"): PrimaryVolumeDescriptor,
}


class ISOAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.FILESYSTEM
    name = "ISO"
    regexp = r"(?<=\x01)CD001\x01"

    @classmethod
    def locate(cls, curfile, offset_magic, parent_parser):
        if offset_magic < 0x8001:
            return None
        return offset_magic - 0x8001, ""

    def __init__(self):
        FileTypeAnalyzer.__init__(self)
        self.filesystem = {}

    def open(self, vfile, password=None):
        entry = self.filesystem[vfile.path]
        data_sector = entry["ExtentLocation"]["Lsb"] + entry["ExtendedSize"]
        data_size = entry["ExtentSize"]["Lsb"]
        data = self.read(data_sector * self.sector_size, data_size)
        return bytearray(data)
 

    def parse(self, hint):
        self.add_section("SystemArea", 0, 0x8000, discardable=True)
        self.jump(0x8000)
        while True:
            typ, tag = struct.unpack("<B5s", self.read(self.tell(), 6))
            class_ = DESCRIPTORS.get((typ, tag), VolumeDescriptor)
            vd = yield class_()
            if vd["Magic"] != "CD001":
                raise FatalError("Invalid magic")
            if vd["Type"] == 0xFF:
                break
        self.add_section("ISO", 0x8000, 0x8000, discardable=True)

        hdr = self["PrimaryVolumeDescriptor"]
        self.sector_size = hdr["LogicalBlockSize"]["Lsb"]
        self.set_eof(hdr["VolumeSpaceSize"]["Lsb"] * self.sector_size)
        self.confirm()

        # parse file system
        todo = [([], hdr["RootDirectory"]["ExtentLocation"]["Lsb"] + hdr["RootDirectory"]["ExtendedSize"],  hdr["RootDirectory"]["ExtentSize"]["Lsb"])]
        seen = set()

        self.confirm()
        while todo:
            parent_path, dir_sector, dir_size = todo.pop()
            seen.add(dir_sector)
            dir_offset = dir_sector * self.sector_size
            self.jump(dir_offset)
            while self.tell() < dir_offset + dir_size:
                nxt_sector = align(self.tell() + 1, self.sector_size)
                if self.read(self.tell(), 1)[0] == 0 or self.tell() + 34 > nxt_sector:
                    self.jump(nxt_sector)
                    continue
                entry = yield Directory()
                name = entry["FileIdentifier"]
                if name in ("\x00", "\x01"):
                    continue
                if ";" in name:
                    name = name.split(";")[0]
                path = parent_path + [name]
                data_sector = entry["ExtentLocation"]["Lsb"] + entry["ExtendedSize"]
                data_size = entry["ExtentSize"]["Lsb"]
                if entry["Flags"]["IsDirectory"] and data_sector not in seen:
                    todo.append((path, data_sector, data_size))
                else:
                    self.filesystem["/".join(path)] = entry
                    self.add_file("/".join(path), data_size, "open")
                    self.add_section(name, data_sector*self.sector_size, data_size)


                
