from filetypes.base import *
from filetypes.TIFF import TiffAnalyzer, TIFF_TAGS_FOR_META
import malcat 
import struct
from filetypes.BMP import RGB



RDF_NS = '{http://www.w3.org/1999/02/22-rdf-syntax-ns#}'
XML_NS = '{http://www.w3.org/XML/1998/namespace}'
NS_MAP = {
    'http://www.w3.org/1999/02/22-rdf-syntax-ns#': 'rdf',
    'http://purl.org/dc/elements/1.1/': 'dc',
    'http://ns.adobe.com/xap/1.0/': 'xap',
    'http://ns.adobe.com/pdf/1.3/': 'pdf',
    'http://ns.adobe.com/xap/1.0/mm/': 'xapmm',
    'http://ns.adobe.com/pdfx/1.3/': 'pdfx',
    'http://prismstandard.org/namespaces/basic/2.0/': 'prism',
    'http://crossref.org/crossmark/1.0/': 'crossmark',
    'http://ns.adobe.com/xap/1.0/rights/': 'rights',
    'http://www.w3.org/XML/1998/namespace': 'xml'
}


class XmpParser(object):
    def __init__(self, xmp):
        from xml.etree import ElementTree as ET
        self.tree = ET.XML(xmp)
        self.rdftree = self.tree.find(RDF_NS+'RDF')

    @property
    def meta(self):
        from collections import defaultdict
        """ A dictionary of all the parsed metadata. """
        meta = defaultdict(dict)
        for desc in self.rdftree.findall(RDF_NS+'Description'):
            for el in desc:
                ns, tag = self._parse_tag(el)
                value = self._parse_value(el)
                meta[ns][tag] = value
        return dict(meta)

    def _parse_tag(self, el):
        """ Extract the namespace and tag from an element. """
        ns = None
        tag = el.tag
        if tag[0] == "{":
            ns, tag = tag[1:].split('}', 1)
            if ns in NS_MAP:
                ns = NS_MAP[ns]
        return ns, tag

    def _parse_value(self, el):
        """ Extract the metadata value from an element. """
        if el.find(RDF_NS+'Bag') is not None:
            value = []
            for li in el.findall(RDF_NS+'Bag/'+RDF_NS+'li'):
                value.append(li.text)
        elif el.find(RDF_NS+'Seq') is not None:
            value = []
            for li in el.findall(RDF_NS+'Seq/'+RDF_NS+'li'):
                value.append(li.text)
        elif el.find(RDF_NS+'Alt') is not None:
            value = {}
            for li in el.findall(RDF_NS+'Alt/'+RDF_NS+'li'):
                value[li.get(XML_NS+'lang')] = li.text
        else:
            value = el.text
        return value



class App0Header(Struct):

    def parse(self):
        yield UInt16BE(name="Tag", comment="segment type")
        yield UInt16BE(name="Size", comment="segment size")
        yield String(5, zero_terminated=False, name="Name", comment="should be 'JFIF' or 'JFXX'")
        

class JFIFHeader(Struct):

    def parse(self):
        yield UInt8(name="MajorVersion")
        yield UInt8(name="MinorVersion")
        yield UInt8(name="Units", comment="Units for the X and Y densities", values=[
            ("PIXEL_ASPECT_RATIO", 0),
            ("DOTS_PER_INCH", 1),
            ("DOTS_PER_CM", 2),
            ])
        yield UInt16BE(name="Xdensity", comment="horizontal pixel density")
        yield UInt16BE(name="Ydensity", comment="vertical pixel density")
        x = yield UInt8(name="Xthumbnail", comment="thumbnail horizontal pixel count")
        y = yield UInt8(name="Ythumbnail", comment="thumbnail vertical pixel count")
        if x and y:
            yield Array(x*y, RGB(), name="Thumbnail", comment="thumbnail pixels")


class JFXXHeader(Struct):

    def __init__(self, maxsz, *args, **kwargs):
        Struct.__init__(self, *args, **kwargs)
        self.maxsz = maxsz

    def parse(self):
        code = yield UInt8(name="ExtensionCode", comment="identifies the extension", values=[
            ("THUMBNAIL_JPEG", 0x10),
            ("THUMBNAIL_8", 0x11),
            ("THUMBNAIL_RGB", 0x13),
            ])
        if code == 0x10:
            # todo: subparse
            yield Bytes(self.maxsz - len(self), name="Thumbnail")
        elif code == 0x11:
            x = yield UInt8(name="Xthumbnail", comment="thumbnail horizontal pixel count")
            y = yield UInt8(name="Ythumbnail", comment="thumbnail vertical pixel count")
            yield Array(256, RGB(), name="Palette", comment="thumbnail palette")
            yield Bytes(x*y, name="Thumbnail", comment="thumbnail pixels")
        elif code == 0x13:
            x = yield UInt8(name="Xthumbnail", comment="thumbnail horizontal pixel count")
            y = yield UInt8(name="Ythumbnail", comment="thumbnail vertical pixel count")
            yield Array(x*y, RGB(), name="Thumbnail", comment="thumbnail pixels")           



class App1Header(Struct):

    def parse(self):
        yield UInt16BE(name="Tag", comment="segment type")
        yield UInt16BE(name="Size", comment="segment size")
        n = self.look_ahead(4)
        if n == b"Exif":
            name = yield String(6, zero_terminated=False, name="Name", comment="should be 'EXIF\\0\\0'")
            if name != "Exif\0\0":
                raise FatalError("invalid app1 name")
        elif n == b"http" or n == b"XMP\0":
            name = yield String(29, zero_terminated=False, name="Name", comment="should be 'http://ns.adobe.com/xap/1.0/\\0'")
            if name[4:] != "://ns.adobe.com/xap/1.0/\0":
                raise FatalError("invalid app1 name")
     

class ComponentInfo(Struct):

    def parse(self):
        yield UInt8(name="Id", values=[
            ("Y", 1),
            ("Cb", 2),
            ("Cr", 3),
            ("I", 4),
            ("Q", 5),
            ])
        yield UInt8(name="SamplingFactor", comment="bit 0-3 vertical., 4-7 horizontal")
        yield UInt8(name="QTNumber", comment="quantization table number")

class StartOfFrame(Struct):

    def parse(self):
        yield UInt16BE(name="Tag", comment="segment type")
        start = len(self)
        sz = yield UInt16BE(name="Size", comment="segment size")
        yield UInt8(name="DataPrecision", comment="bit per sample (should be 9)")
        h = yield UInt16BE(name="Height", comment="image height")
        w = yield UInt16BE(name="Width", comment="image width")
        bpp = yield UInt8(name="Components", comment="bpp", values=[
            ("Greyscale", 1),
            ("YcbCr", 3),
            ("CMYK", 4),
            ])
        yield Array(bpp, ComponentInfo(), name="ComponentInfos")
        done = len(self)
        if done < sz:
            yield Bytes(sz-done, name="Extra")

class Marker(Struct):

    def parse(self):
        yield UInt16BE(name="Tag", comment="segment type")


class CommentSegment(Struct):

    def parse(self):
        yield UInt16BE(name="Tag", comment="segment type")
        sz = yield UInt16BE(name="Size", comment="segment size")
        if sz > 2:
            s = yield String(sz - 2, name="Comment")        

class UnknownSegment(Struct):

    def parse(self):
        yield UInt16BE(name="Tag", comment="segment type")
        sz = yield UInt16BE(name="Size", comment="segment size")
        if sz > 2:
            yield Bytes(sz - 2, name="Data")
    

# from https://github.com/bennoleslie/pexif/blob/master/pexif.py
JPEG_MARKERS = {
    0xc0: ("SOF0", StartOfFrame, Type.HEADER),
    0xc1: ("SOF1", StartOfFrame, Type.HEADER),
    0xc2: ("SOF2", StartOfFrame, Type.HEADER),
    0xc4: ("DHT", UnknownSegment, Type.FIXUP),
    0xc9: ("SOF9", UnknownSegment, Type.FIXUP),
    0xcc: ("DAC", UnknownSegment, Type.FIXUP),
    0xd0: ("RST0", Marker, Type.HEADER),
    0xd1: ("RST1", Marker, Type.HEADER),
    0xd2: ("RST2", Marker, Type.HEADER),
    0xd3: ("RST3", Marker, Type.HEADER),
    0xd4: ("RST4", Marker, Type.HEADER),
    0xd5: ("RST5", Marker, Type.HEADER),
    0xd6: ("RST6", Marker, Type.HEADER),
    0xd7: ("RST7", Marker, Type.HEADER),
    0xd8: ("SOI", Marker, Type.HEADER),
    0xd9: ("EOI", Marker, Type.HEADER),
    0xda: ("SOS", Marker, Type.HEADER),
    0xdb: ("DQT", UnknownSegment, Type.FIXUP),
    0xdd: ("DRI", UnknownSegment, Type.FIXUP),
    0xe0: ("APP0", App0Header, Type.HEADER),
    0xe1: ("APP1", App1Header, Type.HEADER),
    0xe2: ("APP2", UnknownSegment, Type.HEADER),
    0xe3: ("APP3", UnknownSegment, Type.HEADER),
    0xe4: ("APP4", UnknownSegment, Type.HEADER),
    0xe5: ("APP5", UnknownSegment, Type.DATA),
    0xe6: ("APP6", UnknownSegment, Type.DATA),
    0xe7: ("APP7", UnknownSegment, Type.DATA),
    0xe8: ("APP8", UnknownSegment, Type.DATA),
    0xe9: ("APP9", UnknownSegment, Type.DATA),
    0xea: ("APP10", UnknownSegment, Type.DATA),
    0xeb: ("APP11", UnknownSegment, Type.DATA),
    0xec: ("APP12", UnknownSegment, Type.DATA),
    0xed: ("APP13", UnknownSegment, Type.HEADER),
    0xee: ("APP14", UnknownSegment, Type.DATA),
    0xef: ("APP15", UnknownSegment, Type.DATA),
    0xfe: ("COM", CommentSegment, Type.META),
}



class JPEGAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.IMAGE
    name = "JPEG"
    regexp = r"\xff\xD8\xff[\xc0-\xcc\xda-\xef\xfe]"


    def __init__(self):
        FileTypeAnalyzer.__init__(self)


    def parse(self, hint):
        yield Marker(name="StartOfImage", category=Type.HEADER)
        markers_seen = set()
        try:
            while self.remaining() > 0:
                segment_start, size = self.search(r"\xff[\x01-\xff]", self.tell())
                if size == 0:
                    raise FatalError("No more marker found")
                if segment_start > self.tell():
                    yield Bytes(segment_start - self.tell(), name="Overlay", category=Type.ANOMALY)
                marker, tag = struct.unpack("<BB", self.read(self.tell(), 2))
                if marker != 0xFF:
                    raise FatalError("Invalid segment marker at #{:x}".format(self.tell()))
                if tag not in JPEG_MARKERS:
                    raise FatalError("Invalid segment tag at #{:x}".format(self.tell()))
                name, class_, cat = JPEG_MARKERS[tag]
                markers_seen.add(name)

                segment_header = yield class_(name=name, category=cat)
                if class_ == CommentSegment and "Comment" in segment_header:
                    self.add_metadata("Comment", segment_header["Comment"])
                start = self.tell()
                if "Size" in segment_header:
                    segment_size = segment_header["Size"] + 2
                elif tag != 0xd9:
                    regexp = r"\xff[\x01-\xfe]"
                    if tag == 0xda:
                        regexp = r"\xff\xd9"
                    where, size = self.search(regexp, start)
                    if size == 0:
                        raise FatalError("Cannot find next segment marker")
                    segment_size = where - segment_start
                elif self.remaining() < 2 or self.read(self.tell(), 2) != b"\xff\xda":
                    break
                else:
                    segment_size = 2
                self.add_section(name, segment_start, segment_size)
                if tag == 0xDA and segment_size:
                    data_size = segment_size - (self.tell() - segment_start)
                    if data_size:
                        image_data = True
                        yield Bytes(data_size, name="ImageData", category=Type.DATA)
                # parse App0 segment content
                if class_ == App0Header:
                    if segment_header["Name"] == "JFIF\0":
                        yield JFIFHeader(category=Type.HEADER, parent=segment_header)
                    elif segment_header["Name"] == "JFXX\0":
                        yield JFXXHeader(segment_header["Size"] - 2, category=Type.HEADER, parent=segment_header)
                    else:
                        raise FatalError("invalid app0 name")     
                # parse App1 segment content
                elif class_ == App1Header:
                    if "Name" in segment_header and segment_header["Name"][4:] == "://ns.adobe.com/xap/1.0/\0":
                        xap_size = segment_header.offset + 2 + segment_header["Size"] - self.tell()
                        meta = {}
                        xap = yield StringUtf8(xap_size, name="Xap", parent=segment_header, category=Type.META)
                        dic = XmpParser(xap).meta
                        for k, v in dic.items():
                            for k2, v2 in v.items():
                                if k2 in TIFF_TAGS_FOR_META:
                                    self.add_metadata(TIFF_TAGS_FOR_META[k2],  str(v2), category="XAP")
                    elif "Name" in segment_header and segment_header["Name"] == "Exif\0\0":
                        base_offset = self.tell()
                        tiff = TiffAnalyzer()
                        for thestruct, fa in self.subparse(tiff, hint="JPEG"):
                            thestruct.parent = segment_header
                            self.jump(base_offset + fa.offset)
                            el = yield thestruct
                        for k, v in tiff.infos.items():
                            self.add_metadata(k, str(v), category="EXIF")
                self.jump(segment_start + segment_size)
        except FatalError:
            # Allow partially invalid JPEG
            if not self.is_confirmed():
                raise
            where, size = self.search(r"\xff\xd9", self.tell())
            if not size:
                raise
            yield Bytes(where - self.tell(), name="UnknownData", category=Type.ANOMALY)
            yield Marker(name="EndOfImage", category=Type.HEADER)
        if not "SOS" in markers_seen or (not "SOF0" in markers_seen and not "SOF1" in markers_seen and not "SOF2" in markers_seen) or not "DHT" in markers_seen:
            raise FatalError("Missing important marker")



