from filetypes.base import *
import malcat

BUFFER_INCREMENT = 1024*16

class AloneHeader(Struct):

    def parse(self):
        yield UInt8(name="Properties", comment="LZMA properties")
        yield UInt32(name="DictionnarySize", comment="")
        yield UInt64(name="UncompressedSize", comment="")

class XzHeader(Struct):

    def parse(self):
        yield String(6, zero_terminated=True, name="Magic")
        yield UInt16(name="Flags", comment="")
        yield UInt32(name="Crc32", comment="")        

class LzmaAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.ARCHIVE
    name = "LZMA"
    regexp = r"(?:\x5d\x00..[\x00-\x08](?<!\x00\x00)(?:....(?<!\x00\x00\x00\x00).\x00\x00\x00|\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF))|\xFD7zXZ\x00"

    def unpack(self, vfile, password=None):
        import lzma
        stream = self.at("Lzma stream")
        return lzma.decompress(self.read(0, stream.offset + stream.size))

    def parse(self, hint):
        import lzma
        remaining = self.remaining()
        ptr = self.tell()
        if (self.read(1, 4) == b"7zXZ"):
            hdr = yield XzHeader(category=Type.HEADER)
        else:
            hdr = yield AloneHeader(category=Type.HEADER)
        # bruteforce deflate stream length
        obj = lzma.LZMADecompressor()
        decompressed_size = 0
        while remaining and not obj.eof:
            todo = min(remaining, BUFFER_INCREMENT)
            try:
                decompressed_size += len(obj.decompress(self.read(ptr, todo)))
            except BaseException as e:
                raise FatalError("Not a valid Lzma stream ({})".format(e))
            ptr += todo
            remaining -= todo
        if not obj.eof:
            raise FatalError("Truncated Lzma stream")
        yield Bytes(ptr - (hdr.size + len(obj.unused_data)), name="Lzma stream", category=Type.DATA)
        self.add_file("<packed content>", decompressed_size, "unpack")


