import requests
import struct
import asyncio
import sys
import itertools
import re
from malcat import KesakodeMatch, Detection, Function, FoundString
from typing import Dict, List, Union


DetectionMapping = {
    "clean": 0,
    "unknown": 1,
    "lib": 2,
    "suspicious": 30,
    "malicious": 100,
}


class KesakodeExternalResult:
    """
    A class used by external Kesakode providers to give back their result
    """

    def __init__(self):
        self.verdict = {}       
        self.intelligence = {}  
        self.matches = []
        self.quota_left = 0     
        self.quota_total = 0    

    def set_quota(self, quota_left:int, quota_total:int):
        """
        Tells the user how many more queries he can do this month
        """
        if quota_left > quota_total:
            raise ValueError("Quota left mut be less than quota total")
        self.quota_left = quota_left
        self.quota_total = quota_total

    def set_verdict(self, verdict:Dict[str,float]):
        """
        a dict str->float that gives a score (0-100) for each malware family. if you leave it  empty, Malcat will compute scores for you
        """
        self.verdict = verdict

    def set_intelligence(self, veridct:Dict[str,str]):
        """
        a dict str-> str that give some textual information for each discovered family
        """
        self.verdict = verdict

    def set_object_matches(self, obj:Union[Function,FoundString], matches:List[KesakodeMatch.Hit]):
        """
        Adds one or several Kesakode identification hits for a given function or string
        """
        if isinstance(obj, Function):
            km = KesakodeMatch(obj.start, obj.size, KesakodeMatch.Type.FUNCTION)
        elif isinstance(obj, FoundString):
            km = KesakodeMatch(obj.start, obj.size, KesakodeMatch.Type.STRING)
        else:
            raise ValueError("You can only add matches for identified functions/strings")
        if isinstance(matches, KesakodeMatch.Hit):
            matches = [matches]
        for hit in matches:
            if not isinstance(hit, KesakodeMatch.Hit):
                raise ValueError("Not a valid hit type: {}".format(type(hit)))
            km.add_hit(hit)
        self.matches.append(km)

    def __setitem__(self, o, hits):
        """
        Shortcut for set_object_matches
        """
        self.set_object_matches(o, hits)


def lookup_hashes(endpoint, service, hashes, license, options={}, ssl_verify=True):
    to_lookup = []
    for h in hashes:
        if type(h) == int:
            to_lookup.append(struct.pack("<Q", h).hex())
        elif type(h) == bytes:
            to_lookup.append(h.hex())
        else:
            raise ValueError(f"Invalid hash format: {h}")

    session = requests.Session()
    session.headers.update({'accept': 'application/json'})
    session.headers.update({'ApiKey': license})
    session.headers.update({'user-agent': "malcat"})
    data = {"hashes": ",".join(to_lookup), "retry": 1, "timeout": 30}
    data.update(options)
    r = session.post(f"{endpoint}/{service}", data=data, verify=ssl_verify)
    if not r.ok:
        try:
            detail = r.json()["detail"]
        except: 
            detail = ""
        if detail:
            raise ValueError(detail)
    r.raise_for_status()
    data = r.json()
    status = data.get("status", "missing")
    if status != "ok":
        raise ValueError(f"Invalid job status {status}: {data.get('error', '')}")
    res = []
    for d in data.get("result", []):
        h =  bytes.fromhex(d["hash"])
        if len(h) == 8:
            h, = struct.unpack("<Q", h)
        lvl = DetectionMapping.get(d["level"], 1)
        res.append((h, lvl, d["score"], d["name"], d.get("paths", []), d.get("symbols", [])))
    return res



def get_stats(endpoint, license, ssl_verify=True):
    session = requests.Session()
    session.headers.update({'accept': 'application/json'})
    session.headers.update({'ApiKey': license})
    session.headers.update({'user-agent': "malcat"})
    r = session.get(f"{endpoint}/status", verify=ssl_verify)
    r.raise_for_status()
    data = r.json()
    stats = {}
    for worker in data.get("workers", []):
        v = worker.get("version", {})
        if not "Version" in v:
            continue
        if "Version" not in stats or stats["Version"] < v["Version"]:
            stats = v
    return list(stats.items())


def get_quota(endpoint, license, ssl_verify=True):
    session = requests.Session()
    session.headers.update({'accept': 'application/json'})
    session.headers.update({'ApiKey': license})
    session.headers.update({'user-agent': "malcat"})
    r = session.get(f"{endpoint}/quota", verify=ssl_verify)
    if not r.ok:
        try:
            detail = r.json()["detail"]
        except: 
            detail = ""
        if detail:
            raise ValueError(detail)
    r.raise_for_status()
    data = r.json()
    return data.get("monthly_left", 0), data.get("monthly_total", 0)


def get_intel(endpoint, families, prefix, license, ssl_verify=True):
    session = requests.Session()
    session.headers.update({'accept': 'application/json'})
    session.headers.update({'ApiKey': license})
    session.headers.update({'user-agent': "malcat"})
    r = session.post(f"{endpoint}/families", data={"families": ",".join(families), "force_prefix": prefix.strip()}, verify=ssl_verify)
    r.raise_for_status()
    return r.json() or {}



def multi_lookup(endpoint, filetype, architecture, pic_hashes, string_hashes, constant_hash, license, ssl_verify=True):
    if sys.platform == "win32" and sys.version_info >= (3, 8, 0):
        asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
    if endpoint.endswith("/"):
        endpoint = endpoint[:-1]
    # kesakode lookup
    async def multi_poll(pichashes, stringhashes, constant_hash, license, ssl_verify=True):
        loop = asyncio.get_event_loop()
        tasks = [
            loop.run_in_executor(None, lookup_hashes, endpoint, "pic", pichashes, license, {"architecture": architecture}, ssl_verify),
            loop.run_in_executor(None, lookup_hashes, endpoint, "str", stringhashes, license, {}, ssl_verify),
            loop.run_in_executor(None, lookup_hashes, endpoint, "cst", [constant_hash], license, {}, ssl_verify),
            loop.run_in_executor(None, get_stats, endpoint, license, ssl_verify),
            loop.run_in_executor(None, get_quota, endpoint, license, ssl_verify),
        ]
        all_returns = await asyncio.gather(*tasks, return_exceptions=True)
        for r in all_returns:
            if isinstance(r, Exception):
                raise r
        return all_returns

    pic, str, cst, stats, quota = asyncio.run(multi_poll(pic_hashes, string_hashes, constant_hash, license, ssl_verify=ssl_verify))
    
    # gather intel
    if filetype == "PE":
        prefix = "win."
    elif filetype == "ELF":
        prefix = "elf."
    else:
        prefix = ""
    all_families = { x[3] for x in itertools.chain(pic, str, cst) if x[1] >= 100 }
    if all_families:
        intel = get_intel(endpoint, all_families, prefix, license, ssl_verify=ssl_verify)
    else:
        intel = {}
    return pic, str, cst, stats, intel, quota[0], quota[1]


def fuzzy_lookup(endpoint, filetype, architecture, fuzzy_hashes, threshold, license, ssl_verify=True):
    if sys.platform == "win32" and sys.version_info >= (3, 8, 0):
        asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
    if endpoint.endswith("/"):
        endpoint = endpoint[:-1]
    # kesakode lookup
    async def multi_poll(fuzzyhashes, threshold, license, ssl_verify=True):
        loop = asyncio.get_event_loop()
        tasks = [
                loop.run_in_executor(None, lookup_hashes, endpoint, "fuz", fuzzyhashes, license, {"architecture": architecture, "threshold": threshold, "timeout": 60}, ssl_verify),
        ]
        all_returns = await asyncio.gather(*tasks, return_exceptions=True)
        for r in all_returns:
            if isinstance(r, Exception):
                raise r
        return all_returns

    fuz, = asyncio.run(multi_poll(fuzzy_hashes, threshold, license, ssl_verify=ssl_verify))
    
    # gather intel
    if filetype == "PE":
        prefix = "win."
    elif filetype == "ELF":
        prefix = "elf."
    else:
        prefix = ""
    all_families = { x[3] for x in fuz if x[1] >= 100 }
    if all_families:
        intel = get_intel(endpoint, all_families, prefix, license, ssl_verify=ssl_verify)
    else:
        intel = {}
    return fuz, intel
    

def submit_sample(endpoint, file, license, is_malware=True, family="", comment=""):
    session = requests.Session()
    session.headers.update({'accept': 'application/json'})
    session.headers.update({'ApiKey': license})
    session.headers.update({'user-agent': "malcat"})
    r = session.post(f"{endpoint}/submit_sample", files={"file": file[:]}, data={"is_malware":is_malware, "malware_family": family, "comment": comment}, verify=False)
    if not r.ok:
        try:
            detail = r.json()["detail"]
        except: 
            detail = ""
        if detail:
            raise ValueError(detail)
    r.raise_for_status()

