from filetypes.base import *
from filetypes.types import *
import struct


INDEX_TO_DESC = {
        1:  ("BeaconType", {"values": [
            ("HTTP", 0), 
            ("Hybrid HTTP DNS", 1), 
            ("SMB", 2), 
            ("TCP", 4), 
            ("HTTPS", 8), 
            ("BIND TCP", 16), 
        ]}),
        2:  ("Port", {}),
        3:  ("SleepTime", {}),
        4:  ("MaxGetSize", {}),
        5:  ("Jitter", {}),
        6:  ("MaxDNS", {}),
        7:  ("PublicKey", {}),
        7:  ("PublicKey_MD5", {}),
        8:  ("C2Server", {"is_string": True}),
        9:  ("UserAgent", {"is_string": True}),
        10: ("HttpPostUri", {"is_string": True}),
        11: ("Malleable_C2_Instructions", {"is_malleable": True}),
        12: ("HttpGet_Metadata", {"is_headers": True}),
        13: ("HttpPost_Metadata", {"is_headers": True}),
        14: ("SpawnTo", {"is_string": True}),
        15: ("PipeName", {"is_string": True}),
        19: ("DNS_Idle", {}),
        20: ("DNS_Sleep", {}),
        21: ("SSH_Host", {"is_string": True}),
        22: ("SSH_Port", {}),
        23: ("SSH_Username", {}),
        24: ("SSH_Password_Plaintext", {}),
        25: ("SSH_Password_Pubkey", {}),
        54: ("SSH_Banner", {}),
        26: ("HttpGet_Verb", {}),
        27: ("HttpPost_Verb", {}),
        28: ("HttpPostChunk", {}),
        29: ("Spawnto_x86", {"is_string": True}),
        30: ("Spawnto_x64", {"is_string": True}),
        31: ("CryptoScheme", {}),
        32: ("Proxy_Config", {}),
        33: ("Proxy_User", {}),
        34: ("Proxy_Password", {}),
        35: ("Proxy_Behavior", {"values": [
            ("Use proxy server (manual)", 1), 
            ("Use IE settings", 2), 
            ("Use proxy server (credentials)", 4), 
            ]}),
        36: ("Watermark_Hash", {}),
        37: ("Watermark", {}),
        38: ("bStageCleanup", {}),
        39: ("bCFGCaution", {}),
        40: ("KillDate", {}),
        41: ("TextSectionEnd", {}),
        42: ("ObfuscateSectionsInfo", {}),
        43: ("bProcInject_StartRWX", {}),
        44: ("bProcInject_UseRWX", {}),
        45: ("bProcInject_MinAllocSize", {}),
        46: ("ProcInject_PrependAppend_x86", {}),
        47: ("ProcInject_PrependAppend_x64", {}),
        51: ("ProcInject_Execute", {"values": [
            ("CreateThread", 1), 
            ("SetThreadContext", 2), 
            ("CreateRemoteThread", 3), 
            ("RtlCreateUserThread", 4),
            ("NtQueueApcThread", 5),
            ("NtQueueApcThread-s", 8)
            ], 
            "is_blob": True}),
        52: ("ProcInject_AllocationMethod",  {"values": [
            ("VirtualAllocEx", 0), 
            ("NtMapViewOfSection", 3), 
            ]}),
        53: ("ProcInject_Stub", {"is_blob": True}),
        50: ("bUsesCookies", {}),
        54: ("HostHeader", {}),
        55: ("HostUnknown", {}),
        57: ("smbFrameHeader", {}),
        58: ("tcpFrameHeader", {}),
        59: ("headersToRemove", {}),
        60: ("DNS_Beaconing", {}),
        61: ("DNS_get_TypeA", {}),
        62: ("DNS_get_TypeAAAA", {}),
        63: ("DNS_get_TypeTXT", {}),
        64: ("DNS_put_metadata", {}),
        65: ("DNS_put_output", {}),
        66: ("DNS_resolver", {}),
        67: ("DNS_strategy", {"values": 
                [(e.upper(), i) for i, e in enumerate(["round-robin", "random", "failover", "failover-5x", "failover-50x", "failover-100x", "failover-1m", "failover-5m", "failover-15m", "failover-30m", "failover-1h", "failover-3h", "failover-6h", "failover-12h", "failover-1d", "rotate-1m", "rotate-5m", "rotate-15m", "rotate-30m", "rotate-1h", "rotate-3h", "rotate-6h", "rotate-12h", "rotate-1d" ]) ]
            }),
        68: ("DNS_strategy_rotate_seconds", {}),
        69: ("DNS_strategy_fail_x", {}),
        70: ("DNS_strategy_fail_seconds", {}),
        71: ("Retry_Max_Attempts", {}),
        72: ("Retry_Increase_Attempts", {}),
        73: ("Retry_Duration", {}),
        74: ("Masked_Watermark", {}),
        75: ("Unknown1", {}),
        76: ("Unknown2", {}),
}

MANDATORY_VALUES = { "C2Server", "Port", "Malleable_C2_Instructions" }

MALLEABLE_OPCODES = [
    ("REMOVE_SUFFIX", 1),
    ("REMOVE_PREFIX", 2),
    ("BASE64_DECODE", 4),
    ("NETBIOS_DECODE_a", 8),
    ("NETBIOS_DECODE_A", 11),
    ("BASE64_URLDECODE_SAFE", 13),
    ("XOR_MASK_RAND_KEY", 15),
]

TSTEPS = {0: "stop", 1: "append", 2: "prepend", 3: "base64", 4: "print", 5: "parameter", 6: "header", 7: "build", 8: "netbios", 9: "const_parameter", 10: "const_header", 11: "netbiosu", 12: "uri_append", 13: "base64url", 14: "strrep", 15: "mask", 16: "const_host_header"}

class ConfigValue(Struct):

    def __init__(self, *args, values=[], is_headers=False, is_blob=False, is_malleable=False, is_string=False, **kwargs):
        Struct.__init__(self,  *args, **kwargs)
        self.__values = values
        self.__is_headers = is_headers
        self.__is_blob = is_blob
        self.__is_malleable = is_malleable
        self.__is_string = is_string


    def parse(self):
        yield UInt16BE(name="Index")
        tp = yield UInt16BE(name="Type", values=[
           ("NONE", 0),
           ("SHORT", 1),
           ("INT", 2),
           ("DATA", 3),
        ])
        sz = yield UInt16BE(name="Size")
        start_value = len(self)
        if tp == 1:
            yield UInt16BE(name="Value", values=self.__values)
        elif tp == 2:
            yield UInt32BE(name="Value", values=self.__values)
        elif tp == 3:
            if self.__is_headers:
                yield HeaderConfigValue(sz, name="Value")
            elif self.__is_blob and self.__values:
                start = len(self)
                values_allowed = {v for k, v in self.__values}
                while len(self) - start < sz:
                    e = yield UInt8(name="Value", values=self.__values)
                    if not e in values_allowed:
                        break
                if len(self) - start < sz:
                    yield Unused(sz - (len(self) - start))
            elif self.__is_malleable:
                yield MalleableConfigValue(sz, name="Value")
            elif self.__is_string:
                yield CString(name="Value", max_size=sz)
                if len(self) - start_value < sz:
                    yield Unused(sz - (len(self) - start_value))
            else:
                yield Bytes(sz, name="Value")
        else:
            yield Bytes(sz, name="Value")


class HeaderConfigValue(Struct):

    def __init__(self, size, *args, **kwargs):
        Struct.__init__(self,  *args, **kwargs)
        self.__size = size

    def parse(self):
        while len(self) < self.__size:
            op = yield HeaderOpcode()
            if not op["Tstep"]:
                break
        if len(self) < self.__size:
            yield Unused(self.__size - len(self))


class HeaderOpcode(Struct):

    def parse(self):
        tstep = yield UInt32BE(name="Tstep", values=[(v, k) for k, v in TSTEPS.items()])
        if tstep == 7:
            yield UInt32BE(name="Name", values=[
                ("Metdata (GET) / SessionID (POST)", 0),
                ("Output", 1),
            ])
        elif tstep in (1, 2, 5, 6, 10, 16, 9):
            sz = yield UInt32BE(name="Size")
            yield String(sz, name="String")


class MalleableConfigValue(Struct):

    def __init__(self, size, *args, **kwargs):
        Struct.__init__(self,  *args, **kwargs)
        self.__size = size

    def parse(self):
        while len(self) < self.__size:
            op = yield MalleableOpcode(name="Operation")
            if not op["Opcode"]:
                break
        if len(self) < self.__size:
            yield Unused(self.__size - len(self))


class MalleableOpcode(Struct):

    def parse(self):
        op = yield UInt32BE(name="Opcode", values=MALLEABLE_OPCODES)
        if op in (1,2):
            yield UInt32BE(name="BytesCount")



class CobaltHeader(Struct):
    def parse(self):
        while self.remaining():
            index, type = struct.unpack(">HH", self.look_ahead(4))
            name, options = INDEX_TO_DESC.get(index, ("", {}))
            if type > 3 or not name:
                raise EOFError("Unrecognized index/type pair: {}+{}".format(index, type))
            yield ConfigValue(name=name, **options)




class CobaltStrikeConfigAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.DATABASE
    name = "CobaltStrikeConfig"
    regexp = r"(?<=\x00\x01\x00)\x01\x00\x02\x00[\x00-\x10]\x00\x02\x00\x01\x00\x02..\x00[\x03-\x10]"

    @classmethod
    def locate(cls, curfile, offset_magic, parent_parser):
        return offset_magic - 3, ""

    def parse(self, hint):
        try:
            yield CobaltHeader(category=Type.HEADER)
        except EOFError as e:
            pass
        if not "CobaltHeader" in self:
            raise FatalError("No header")
        ch = self["CobaltHeader"]
        values = {v.name for v in ch}
        if not MANDATORY_VALUES.issubset(values):
            raise FatalError("Values missing: {}".format(", ".join(MANDATORY_VALUES - values)))
