import malcat
from malcat import PythonHelper
import sys

from .types import *

class ParsingError(ValueError) : pass
class FatalError(ParsingError): pass            # fatal error, raised by file parser: result should be discarded
class OutOfBoundError(ParsingError): pass       # tried to read out of file bound: keep result iff oob_soft mode is set
class SuperposingError(ParsingError): pass      # tried to map the same file space to two different structures, keep result but abort parsing
class EarlyExitError(ParsingError): pass        # for lazy parsing
class InvalidPassword(ParsingError): pass       # for files extraction

   

###################################################################        

class FileTypeAnalyzer:
    name = None
    regexp = ""
    category = malcat.FileType.UNKNOWN
    priority = 0


    def __init__(self):
        self._end_of_file = None
        self.__imagebase = 0
        self._parsed = None

    def parse(self, hint):
        """
        Main parsing method
        """
        raise NotImplementedError

    def add_file(self, vpath, size=0, unpack_method_name="", type="", hint=""):
        """
        Add a new file to the virtual file system 
        """
        self._parsed.add_file(malcat.VirtualFile(vpath, size, unpack_method_name, type, hint))

    @property
    def files(self):
        return tuple(self._parsed.files)

    def add_symbol(self, memory_address, name, type=malcat.FileSymbol.EXPORT):
        """
        Add a new symbol
        """
        self._parsed.add_symbol(malcat.FileSymbol(memory_address, type, name))

    @property
    def symbols(self):
        """
        Return an iterator over the list of defined symbols (type malcat.FileSymbol)
        """
        for va, symlist in self._parsed.symbols.items():
            for sym in symlist:
                yield sym

    def add_fixup(self, memory_address, value=0xffffffffffffffff, size=4):
        """
        Add a new fixup
        """
        self._parsed.add_fixup(malcat.FileFixup(memory_address, size, value))

    @property
    def fixups(self):
        """
        Return an iterator over the list of fixups (type malcat.FileFixup)
        """
        return self._parsed.fixups.values()

    def add_section(self, 
            name, offset, size,
            va=None, vsize=None,
            r=True, w=False, x=False, discardable=False):
        """
        Define a new section in the file
        """
        if va is None:
            va = offset
        if vsize is None:
            vsize = size
        self._parsed.add_section(malcat.Section(offset, size, va, vsize, name, r, w, x, discardable))

    @property
    def sections(self):
        return tuple(self._parsed.sections)

    def add_metadata(self, key, value, category=""):
        """
        Define a new metadata entry, to be displayed in the summary view
        """
        self._parsed.add_metadata(category, key, value)

    def set_imagebase(self, memory_address):
        """
        Define the address in memory where the file is usually loaded (for programs)
        """
        self.__imagebase = memory_address
        self._parsed.set_imagebase(memory_address)

    @property
    def is_msb(self):
        return self._parsed.is_msb 

    @property
    def is_lsb(self):
        return not self._parsed.is_msb 

    def set_endianness(self, is_msb=False):
        """
        Define the address in memory where the file is usually loaded (for programs)
        """
        self._parsed.set_endianness(is_msb)

    @property
    def imagebase(self):
        return self.__imagebase # small optim

    def set_architecture(self, architecture):
        """
        Set the main CPU architecture of the file (for programs)
        """
        if architecture is None:
            architecture = malcat.Architecture.NONE
        self._parsed.set_architecture(architecture)

    @property
    def architecture(self):
        return self._parsed.architecture

    def set_eof(self, eof):
        """
        call this method to set the end of the file
        """
        self._end_of_file = min(eof, self._stop)
        self._remaining -= self._stop - self._end_of_file
        self._stop = self._end_of_file
        if self._lazy and self._confirmed:
            raise EarlyExitError("")        

    def confirm(self):
        """
        call this method to confirm that the file you are parsing is what you are expecting
        After his call, exceptions raised by the parser will be changed to warnings (parsing 
        will still be interrupted, but result is kept) 
        """
        self._confirmed = True
        if self._end_of_file is not None and self._lazy:
            raise EarlyExitError("")        

    def run(self, file, hint="", confirmed=False, lazy=False):
        """
        Do the complete parsing and return the cpp structure
        To be called from malcat only
        """
        try:
            for _ in self._run(file, hint, confirmed, lazy):
                pass #print("{} -> {:x}".format(element.name, self._offset))
        except EarlyExitError as e:
            pass
        except FatalError as e:
            if not self._confirmed:
                raise
            else:
                print(e)
        except OutOfBoundError as e:
            if not self._confirmed:
                raise
            else:
                print(e)
        except SuperposingError as e:
            print(e)    # abort parsing
            if not self._confirmed:
                raise
        return self._parsed

    def subparse(self, parser, size=None, hint="", confirmed=False):
        """
        parse current position using a different analyzer, for a maximum of <size> bytes read
        Returns all the structures parsed by this analyzer
        To be called from another analyser, wanting to parse a subfile for instance
        """
        if size is None:
            size = self.remaining()
        res = list(parser._run(malcat.File.subfile(self._file, self._offset, size, ""), hint, confirmed))    
        for el, fa in res:
            el._is_parsed = True
            yield el, fa

    def jump(self, foffs):
        """
        Jump to another offset so that the next yield XXX put the structure at this address.
        Only works from within the Analyzer parse() method
        """
        if foffs is None:
            raise ValueError("Attempting to jump to None address")
        if foffs > self._stop:
            raise OutOfBoundError("Jumping to invalid offset #{:x}".format(foffs))
        self._offset = foffs
        self._remaining = self._stop - foffs

    def tell(self):
        return self._offset

    def size(self):
        return self._stop

    def remaining(self):
        return self._remaining

    def eof(self):
        return self._remaining <= 0

    def is_confirmed(self):
        return self._confirmed

    def search(self, regexp, start=0, size=None):
        """
        returns (absolute offset of match, size of match).
        If no match, size is 0
        """
        if size is None or start + size > self._stop:
            size = self._stop - start
        if size < 0:
            raise ValueError("Invalid search range")
        off, sz = self._file.search(regexp, start, size)
        if sz == 0:
            off = None
        return off, sz

    def read(self, where=None, size=1):
        if where is None:
            where = self.tell()
        if where + size > self._stop:
            raise OutOfBoundError("Reading beyond end of buffer at {:x}-{:x}".format(where, where + size))
        return self._file.read(where, size)

    def look_ahead(self,  size=1):
        return self.read(self.tell(), size)

    def read_cstring_ascii(self, where=None, max_bytes=512):
        if where is None:
            where = self.tell()
        if where >= self._stop:
            raise OutOfBoundError("Reading beyond end of buffer at {:x}".format(where))
        return self._file.read_cstring_ascii(where, max_bytes)

    def read_cstring_utf8(self, where=None, max_bytes=512):
        if where is None:
            where = self.tell()
        if where >= self._stop:
            raise OutOfBoundError("Reading beyond end of buffer at {:x}".format(where))
        return self._file.read_cstring_utf8(where, max_bytes)

    def read_cstring_utf16le(self, where=None, max_bytes=512):
        if where is None:
            where = self.tell()
        if where >= self._stop:
            raise OutOfBoundError("Reading beyond end of buffer at {:x}".format(where))
        try:
            return self._file.read_cstring_utf16le(where, max_bytes)
        except UnicodeDecodeError:
            return "invalid ut16-le"

    def read_cstring_utf16be(self, where=None, max_bytes=512):
        if where is None:
            where = self.tell()
        if where >= self._stop:
            raise OutOfBoundError("Reading beyond end of buffer at {:x}".format(where))
        try:
            return self._file.read_cstring_utf16be(where, max_bytes)
        except UnicodeDecodeError:
            return "invalid ut16-be"        

    def compute_size_prefixed_elements_array_size(self, number_of_element=0, size_field_width=4, little_endian=True):
        return PythonHelper.compute_size_prefixed_elements_array_size(self._file, self.tell(), number_of_element, size_field_width, little_endian)

    def __iter__(self):
        for i in range(self._parsed.number_of_structures):
            yield self._parsed[i]
    
    def at(self, key):
        return self._parsed.at(key)

    def __getitem__(self, key):
        return self._parsed[key]

    def __contains__(self, key):
        return key in self._parsed

    def _run(self, file, hint, confirmed=False, lazy=False):
        self._parsed = malcat.FileType(file, self.__class__.name, self.__class__.category)
        self._offset = 0
        self._file = file
        self._stop = self._file.size
        self._remaining = self._stop
        self._confirmed = confirmed
        self._lazy = lazy
        self.set_imagebase(0)
        iterator = self._iterate(self.parse(hint))
        try:
            element = iterator.send(None)
            struc_index = 0
            while True:
                struc = malcat.Structure(element.category, element.name, element.native)
                try:
                    res = self._parsed.add_structure(self._offset, struc)
                except BaseException as e:
                    # forward field errors to structure creation location
                    element = iterator.throw(OutOfBoundError("structure {} has invalid offset {:x}-{:x}".format(element.name, self._offset, self._offset + element.native.size)))
                    continue
                if not res:
                    element = iterator.throw(SuperposingError, "structure {} at offset {:x}-{:x} would overwrite an existing structure".format(element.name, self._offset, self._offset + element.native.size))
                    continue
                if element.parent:
                    struc.set_parent_hint(element.parent)
                fa = self._parsed.at(struc_index)
                yield element, fa
                element = iterator.send(fa.value)
                struc_index += 1
        except StopIteration:
            pass    
       

    def _iterate(self, iterator, virtual=False):
        """
        Ugly coroutine stuff to make every field be accessible immediately
        """
        #for element in iterator:
        element = None
        try:
            while True:
                try:
                    if element is None:
                        element = iterator.send(None)

                    if isinstance(element, Struct):
                        if not hasattr(element, "native"):
                            raise ValueError(f"element {element} is not a valid field")
                        element.offset = self._offset
                        fa = yield element
                        #element.analyzer = self
                        element.parser = self
                        if not hasattr(element, "_is_parsed"):
                            if isinstance(element, StaticStruct):
                                subiterator = self._iterate(element.__class__.static_iterate(), True)
                            else:
                                subiterator = self._iterate(element.parse(), virtual)
                            i = 0
                            prev = None
                            while True:
                                try:
                                    sub = subiterator.send(prev)
                                except StopIteration:
                                    break
                                except:
                                    sub = subiterator.throw(*sys.exc_info())
                                element.native.add(sub.name, sub.native)
                                if not virtual:
                                    prev = fa[i]
                                i += 1
                        else:
                            self._offset += element.native.size
                    elif isinstance(element, Array):
                        old_loc = self._offset
                        if not hasattr(element, "_is_parsed"):
                            def dummyiter(l):
                                yield l
                            subiterator = self._iterate(dummyiter(element.subtype), virtual=True)
                            while True:
                                try:
                                    sub = subiterator.send(None)
                                except StopIteration:
                                    break
                                except:
                                    sub = subiterator.throw(*sys.exc_info())
                                native_subtype = sub.native
                            size_element = self._offset - old_loc
                            self._offset = old_loc
                            element.native = malcat.ArrayField(native_subtype, element.count, element.comment)
                        if not hasattr(element, "native"):
                            raise ValueError(f"element {element} is not a valid field")
                        if self._offset + element.native.size > self._stop:
                            element = iterator.throw(OutOfBoundError, "Parsing beyond EOB for {} at #{:x} on 0x{:x} bytes (0x{:x} remaining)".format(element, self._offset, element.native.size, self._remaining))
                            continue
                        fa = yield element
                        self._offset += element.native.size
                        self._remaining = self._stop - self._offset
                    elif isinstance(element, DynamicArray):
                        old_loc = self._offset
                        if not hasattr(element, "_is_parsed"):
                            def dummyiter(l):
                                yield l
                            subiterator = self._iterate(dummyiter(element.subtype), virtual=True)
                            while True:
                                try:
                                    sub = subiterator.send(None)
                                except StopIteration:
                                    break
                                except:
                                    sub = subiterator.throw(*sys.exc_info())
                                native_subtype = sub.native
                            size_element = self._offset - old_loc
                            self._offset = old_loc
                            element.native = malcat.ArrayField(native_subtype, 1, element.comment)
                        if not hasattr(element, "native"):
                            raise ValueError(f"element {element} is not a valid field")
                        if self._offset + element.native.size > self._stop:
                            element = iterator.throw(OutOfBoundError, "Parsing beyond EOB for {} at #{:x} on 0x{:x} bytes (0x{:x} remaining)".format(element, self._offset, element.native.size, self._remaining))
                            continue
                        fa = yield element
                        count = 1
                        error_thrown = False
                        while True:
                            last = fa[count - 1]
                            if element.fn_terminator(last, count):
                                break
                            if old_loc + size_element * (count + 1) > self._stop:
                                element = iterator.throw(OutOfBoundError, "Reached EOB before termination condition at 0x{:x} ({})".format(old_loc + size_element * (count + 1), native_subtype)) 
                                error_thrown = True
                                break
                            count += 1
                            element.native.resize(count)
                        if error_thrown:
                            continue
                        self._offset += element.native.size
                        self._remaining = self._stop - self._offset
                    else:
                        if isinstance(element, DelayedType):
                            try:
                                element.build(self)
                            except:
                                type, value, _ = sys.exc_info()
                                element = iterator.throw(type, value)
                                continue
                            if element.native is None:
                                element = iterator.send(None)
                                continue
                        if not hasattr(element, "native"):
                            raise ValueError(f"element {element} is not a valid field")
                        if self._offset + element.native.size > self._stop:
                            element = iterator.throw(OutOfBoundError, "Parsing beyond EOB for {} at #{:x} on 0x{:x} bytes (0x{:x} remaining)".format(element, self._offset, element.native.size, self._remaining))
                            continue
                        fa = yield element
                        self._offset += element.native.size
                        self._remaining = self._stop - self._offset
                    element = iterator.send(fa)
                except StopIteration:
                    break
                except:
                    # forward field errors to structure creation location
                    element = iterator.throw(*sys.exc_info())
                    continue
        except StopIteration:
            pass

class DummyAnalyzer(FileTypeAnalyzer):
    name = "DummyAnalyzer"
    regexp = ""
    category = malcat.FileType.UNKNOWN

    def parse(self, hint):
        if False:
            yield None
