from filetypes.base import *
import malcat
import struct

# https://github.com/libyal/libvhdi/blob/main/documentation/Virtual%20Hard%20Disk%20%28VHD%29%20image%20format.asciidoc

def align(val, what, down=False):
    if val % what:
        if down:
            val -= val % what
        else:
            val += what - (val % what)
    return val


SECTOR_SIZE = 512

class DiskGeometry(Struct):

    def parse(self):
        yield UInt16BE(name="NumberOfCylinders")
        yield UInt8(name="NumberOfHeads")
        yield UInt8(name="NumberOfSectorsPerTrack")

class Footer(Struct):

    def parse(self):
        yield String(8, name="Magic")
        yield BitsField(
            NullBits(30),
            Bit(name="Reserved", comment="Always 1"),
            Bit(name="Temporary", comment="Indicates that this disk is a candidate for deletion on shutdown"),
            name="Features")
        yield UInt16BE(name="VersionMajor")
        yield UInt16BE(name="VersionMinor")
        yield Offset64BE(name="NextMetadata")
        yield Timestamp2000(name="ModificationTime")
        yield String(4, name="CreatorApplication")
        yield UInt16BE(name="CreatorVersionMajor")
        yield UInt16BE(name="CreatorVersionMinor")
        yield String(4, name="CreatorOperatingSystem")
        yield UInt64BE(name="OriginalSize")
        yield UInt64BE(name="CurrentSize")
        yield DiskGeometry(name="DiskGeometry")
        yield UInt32BE(name="DiskType", values=[
            ("Fixed", 2),
            ("Dynamic", 3),
            ("Differential", 4),
        ])
        yield UInt32BE(name="Checksum")
        yield GUID(name="Identifier")
        yield UInt8(name="Saved", comment="Flag to indicate the image is in saved state")
        yield Unused(427)


class DynamicDiskHeader(Struct):

    def parse(self):
        yield String(8, name="Magic")
        yield Offset64BE(name="NextMetadata")
        yield Offset64BE(name="BlockTable")
        yield UInt16BE(name="VersionMajor")
        yield UInt16BE(name="VersionMinor")
        yield UInt32BE(name="NumberOfBlocks", comment="Maximum number of block allocation table entries")
        yield UInt32BE(name="BlockSize", comment="")
        yield UInt32BE(name="Checksum")
        yield GUID(name="ParentIdentifier")
        yield Timestamp2000(name="ParentModificationTime")
        yield Unused(4)
        yield StringUtf16be(256, name="ParentName")
        yield Array(24, Offset64BE(), name="ParentLocatorArray")
        yield Unused(256)


class Block:

    def __init__(self, offset, size, bitmap):
        self.offset = offset
        self.size = size
        self.bitmap = bitmap


class VHDAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.FILESYSTEM
    name = "VHD"
    regexp = r"conectix....\x00\x01\x00\x00.{44}\x00\x00\x00[\x03\x04].{448}cxsparse"


    #def end_of_file(self):
    #    return self["PrimaryVolumeDescriptor"]["VolumeSpaceSize"]["Lsb"] * self.sector_size

    def __init__(self):
        FileTypeAnalyzer.__init__(self)
        self.filesystem = {}
        self.bat = []

    def open_dynamic(self, vfile, password=None):
        if self.max_sector > 1000000000 // 512:
            raise ValueError("File too big")
        sectors_per_block = self.block_size // SECTOR_SIZE
        res = bytearray(self.max_sector * SECTOR_SIZE)
        for i in range(self.max_sector):
            blockid = i // sectors_per_block
            if blockid >= len(self.bat):
                raise ValueError("Invalid BAT")
            block = self.bat[blockid]
            if block is not None:
                offset = (i % sectors_per_block) * SECTOR_SIZE
                res[i*SECTOR_SIZE:(i+1)*SECTOR_SIZE] = self.read(block.offset + offset, SECTOR_SIZE)
        return res
 

    def parse(self, hint):
        hdr = yield Footer(name="CopyOfFooter")
        dyn = yield DynamicDiskHeader()
        num_blocks = dyn["NumberOfBlocks"]
        self.block_size = dyn["BlockSize"]
        last_address = None
        self.max_sector = hdr["CurrentSize"] // SECTOR_SIZE
        
        blocks = yield Array(hdr["CurrentSize"] // self.block_size, UInt32BE(), name="BlockAllocationTable", category=Type.FIXUP)
        start = self.tell()
        for i, b in enumerate(blocks):
            b = b.value
            if b == 0xffffffff:
                self.bat.append(None)
            else:
                offset = b * SECTOR_SIZE
                bitmap_size = align(self.block_size // (SECTOR_SIZE * 8), SECTOR_SIZE)
                if offset > self.size():
                    continue
                self.jump(offset)
                bitmap = yield Bytes(bitmap_size, name="SectorBitmap".format(i), category=Type.FIXUP)
                self.bat.append(Block(offset+bitmap_size, self.block_size, bitmap))
                self.add_section("block#{}".format(i), offset, bitmap_size + self.block_size)
                last_address = offset + bitmap_size + self.block_size

        if last_address is not None:
            # sometimes there is padding before footer ...
            footer_start, sz = self.search(r"conectix....\x00\x01\x00\x00.{44}\x00\x00\x00[\x03\x04]", last_address)
            if sz and footer_start - last_address <= 0x1000:
                self.jump(footer_start)
                yield Footer()

        # compute maximal file size
        for b in self.bat[::-1]:
            if b is None:
                self.max_sector -= self.block_size // SECTOR_SIZE
            else:
                break
        self.add_file("used_space", self.max_sector * SECTOR_SIZE, "open_dynamic")
        
                

