from filetypes.base import *
import malcat
import binascii
import struct

MAX_ALIGNEMENT = 0x100

INODE_TYPES = [
    ("DIRENT",          0xe001),
    ("INODE",           0xe002),
    ("CLEAN_MARKER",    0x2003),
    ("PADDING",         0x2004),
    ("SUMMARY",         0x2006),
    ("XATTR",           0xe008),
    ("XREF",            0xe009),
]

COMPRESSION_METHODS = [
	    ("NONE",	            0),
            ("ZERO",                1),
            ("RTIME",               2),
            ("RUBINMIPS",           3),
            ("COPY",                4),
            ("DYNRUBIN",            5),
            ("ZLIB",                6),
            ("LZO",                 7),
            ("LZMA",                8),
        ]

class Jffs2Node(Struct):

    # optim
    field1 = UInt16(name="Type", values = INODE_TYPES)
    field2 = UInt32(name="Size", comment="Node size")
    field3 = UInt32(name="HeaderCrc", comment="header crc")

    def parse(self):
        to_checksum = self.look_ahead(8)
        magic = yield UInt16(name="Magic", comment="magic")
        if magic != 0x1985:
            raise FatalError("Invalid magic")
        inode_type = yield Jffs2Node.field1
        inode_size = yield Jffs2Node.field2
        crc = yield Jffs2Node.field3
        computed_crc = (binascii.crc32(to_checksum, -1) ^ -1) & 0xFFFFFFFF
        if computed_crc != crc:
            raise FatalError("Invalid checksum")

        if inode_type == 0xe001:
            yield DirectoryEntry()
        elif inode_type == 0xe002:
            yield Inode()
        elif len(self) < inode_size:
            yield Bytes(inode_size - len(self), name="Data")



class DirectoryEntry(Struct):

    # optim
    fields1 = [
        UInt32(name="Parent", comment="header parent inode"),
        UInt32(name="Version", comment=""),
        UInt32(name="Id", comment="inode id number of this directory entry"),
        Timestamp(name="Timestamp", comment="inode MCtime"),
    ]
    fields2 = [
        UInt8(name="DirectoryType", comment=""),
        Unused(2),
        UInt32(name="NodeCrc", comment="node crc"),
        UInt32(name="NameCrc", comment="name crc"),
    ]
        

    def parse(self):
        for f in DirectoryEntry.fields1:
            yield f
        ns = yield UInt8(name="NameSize", comment="")
        for f in DirectoryEntry.fields2:
            yield f
        yield StringUtf8(ns, name="Name")


class Inode(StaticStruct):

    @classmethod
    def parse(cls):
        yield UInt32(name="Id", comment="inode id")
        yield UInt32(name="Version", comment="")
        yield UInt32(name="Mode")
        yield UInt16(name="Uid")
        yield UInt16(name="Gid")
        yield UInt32(name="Size", comment="Inode size")
        yield Timestamp(name="AccessTime")
        yield Timestamp(name="ModifiedTime")
        yield Timestamp(name="CreatedTime")
        yield UInt32(name="WriteOffset", comment="where to begin to write")
        yield UInt32(name="CompressedSize", comment="compressed data size")
        yield UInt32(name="DecompressedSize", comment="size of node after decompression")
        yield UInt8(name="Compression", values=COMPRESSION_METHODS)
        yield UInt8(name="UserCompression", values=COMPRESSION_METHODS, comment="Compression requested by the user")
        yield UInt16(name="Flags")
        yield UInt32(name="DataCrc", comment="compressed data crc")
        yield UInt32(name="NodeCrc", comment="raw node crc")



def rtime_decompress(data):
    # https://stackoverflow.com/questions/11663394/rtime-compression-used-in-jffs2
    if len(data) % 2:
        raise ValueError("Invalid data length")
    res = bytearray()
    positions = [None] * 256
    bp = 0
    for i in range(0, len(data), 2):
        v = data[i]
        res.append(v)
        bo = positions[v]
        positions[v] = len(res) - 1
        for j in range(data[i+1]):
            res.append(res[bo + j])
    return res


class JFFS2Analyzer(FileTypeAnalyzer):
    category = malcat.FileType.FILESYSTEM
    name = "JFFS2"
    regexp = r"\x85\x19[\x01\x02]\xe0.{8}\x01\x00\x00\x00.{13}\x08|\x19\x85\xe0[\x01\x02].{8}\x00\x00\x00\x01.{13}\x08"

    @classmethod
    def locate(cls, curfile, offset_magic, parent_parser):
        if parent_parser is not None and parent_parser.name == "JFFS2":
            # no JFFS2 in JFFS2 
            return None
        return offset_magic, ""

    def open(self, vfile, password=None):
        inodes = self.filesystem.get(vfile.path)
        if not inodes:
            raise IndexError(f"file not found: {vfile.path}")
        res = bytearray(inodes[0]["Size"])
        for inode in inodes:
            data = self.read(inode.offset + inode.size, inode["CompressedSize"])
            cm = inode["Compression"]
            if cm in (0, 1, 4):
                pass
            elif cm == 2:
                data = rtime_decompress(data)
            elif cm == 6:
                import zlib
                data = zlib.decompress(data)
            elif cm == 7:
                try:
                    import lzo
                except ImportError:
                    raise ImportError("You need to install python-lzo library first. On Windows I suggest that you use some pre-compiled wheel")
                # TODO: with or without header ?
                #data = lzo.decompress(data, False, inode["DecompressedSize"])     
                data = lzo.decompress(data)     
            elif cm == 8:
                import lzma
                data = lzma.decompress(data)
            else:
                raise NotImplementedError("Unsupported compression algorithm {}".format(cm))
            if len(data) != inode["DecompressedSize"]:
                raise ValueError(f"Invalid decompressed size: {len(data)} vs {inode['DecompressedSize']} ")
            offset = inode["WriteOffset"]
            res[offset:offset + len(data)] = data
        return res

    def get_full_name(self, dir_id, dirs):
        names = []
        de = dirs.get(dir_id)
        while de is not None:
            names.append(de["Name"])
            dir_id = de["Parent"]
            de = dirs.get(dir_id)
        if dir_id == 1:
            names.append("")
        return "/".join(names[::-1])

    def parse(self, hint):
        self.filesystem = {}
        magic = self.read(0, 2)
        lsb = magic == b"\x85\x19"

        if lsb:
            node_class = Jffs2Node
        else:
            raise NotImplementedError("MSB JFFS2 not supported yet")
        last_dir_id = None
        count = 0

        dirs = {}
        nodes = {}

        while self.remaining() > 2:
            start_of_pos = self.tell()

            hdr = yield node_class(name="Node")

            if hdr["Type"] == 0xe001:
                de = hdr["DirectoryEntry"]
                last_dir_id = de["Id"]
                dirs[last_dir_id] = de

            elif hdr["Type"] == 0xe002:
                inode = hdr["Inode"]
                if hdr.size < hdr["Size"]:
                    data_size = hdr["Size"] - hdr.size
                    self.jump(self.tell() + data_size)
                mode = inode["Mode"]
                if (mode & 0o0170000) == 0x8000:
                    l = nodes.get(last_dir_id, None)
                    if l is None:
                        l = []
                        nodes[last_dir_id] = l
                    l.append(inode)

            count += 1
            if count > 1:
                self.confirm()

            # look for next header, skipping (variable) alignment
            nxt = self.read(size=min(MAX_ALIGNEMENT + 1, self.remaining()))
            index = nxt.find(magic)
            if index < 0:
                break
            self.jump(self.tell() + index)

           
        self.add_section("<FS>", 0, self.tell(), r=True, x=True)

        for dirid, inodes in nodes.items():
            if not inodes:
                continue
            name = self.get_full_name(dirid, dirs)
            self.filesystem[name] = inodes
            sz = inodes[0]["Size"]
            if sz:
                self.add_file(name, sz, "open")
