# malcat features extractor for CAPA
# only works for x86-x64 !
import sys
import types

import capa
import capa.features.extractors.helpers
import capa.features.extractors.strings
from capa.features.extractors.base_extractor import FeatureExtractor, BBHandle, InsnHandle, FunctionHandle
from capa.features.common import Feature, String, Characteristic, MAX_BYTES_FEATURE_SIZE, Bytes, Namespace, Class, FeatureAccess
from capa.features.address import Address, NO_ADDRESS, FileOffsetAddress, AbsoluteVirtualAddress, EffectiveAddress
from capa.features.file import Export, Import, Section, FunctionName
from capa.features.basicblock import BasicBlock
from capa.features.extractors.helpers import MIN_STACKSTRING_LEN
from capa.features.insn import API, MAX_STRUCTURE_SIZE, Number, Offset, Mnemonic, OperandNumber, OperandOffset, Property
from capa.features.common import OS, FORMAT_PE, FORMAT_ELF, FORMAT_DOTNET, Arch, Format, String, ARCH_I386, ARCH_AMD64, ARCH_ANY, OS_WINDOWS, OS_LINUX, OS_ANY
from capa.features.extractors.base_extractor import BBHandle, FunctionHandle

from malcat import Architecture, Symbol, BasicBlockEdge, Instruction, InstructionOperand, FoundString
from filetypes.PE_net import get_string, resolve_symbol_name


class MergedBasicBlock:

    def __init__(self, first, last):
        self.first = first
        self.last = last

    @property
    def address(self):
        return self.first.address

    @property
    def code(self):
        return self.first.code

    @property
    def is_loop(self):
        return self.first.is_loop

    @property
    def incoming(self):
        return self.first.incoming

    @property
    def outgoing(self):
        return self.last.outgoing

    @property
    def start(self):
        return self.first.start

    @property
    def end(self):
        return self.last.end


def malcat_to_dnlib_name(symbol_name):
    left, _, right = symbol_name.rpartition(".")
    if left:
        return f"{left}::{right}"
    return right


class MalcatFeatureExtractor(FeatureExtractor):
    def __init__(self, analysis):
        super(MalcatFeatureExtractor, self).__init__()
        self.analysis = analysis

        self.is_dotnet = self.analysis.architecture == Architecture.DOTNET
        # pre-compute these because we'll yield them at *every* scope.
        self.__global_features = list(self.__extract_global_features())
        if self.is_dotnet:
            if "MethodDefTable" in self.analysis.struct:
                self.method_def_table = self.analysis.struct["MethodDefTable"]
            else:
                self.method_def_table = None
            if "MemberRefTable" in self.analysis.struct:
                self.member_ref_table = self.analysis.struct["MemberRefTable"]
            else:
                self.member_ref_table = None
            if "TypeRefTable" in self.analysis.struct:
                self.type_ref_table = self.analysis.struct["TypeRefTable"]
            else:
                self.type_ref_table = None
            if "TypeDefTable" in self.analysis.struct:
                self.type_def_table = self.analysis.struct["TypeDefTable"]
            else:
                self.type_def_table = None
            if "ImplMapTable" in self.analysis.struct:
                self.impl_map_table = self.analysis.struct["ImplMapTable"]
            else:
                self.impl_map_table = None
            

    def get_base_address(self):
        return AbsoluteVirtualAddress(self.analysis.imagebase)

################################# GLOBAL FEATURES

    def extract_global_features(self):
        yield from self.__global_features
    
    def __extract_global_features(self):
        if self.is_dotnet:
            if "CLR.Header" not in self.analysis.struct or "PE" not in self.analysis.struct:
                yield Arch(ARCH_ANY), NO_ADDRESS
            is_64bits = self.analysis.struct.PE.Machine.enum == "IMAGE_FILE_MACHINE_AMD64"
            is_clr32 = self.analysis.struct["CLR.Header"]["Flags"]["32Bits"]
            if is_clr32 and not is_64bits:
                yield Arch(ARCH_I386), NO_ADDRESS
            elif not is_clr32 and is_64bits:
                yield Arch(ARCH_AMD64), NO_ADDRESS
            else:
                yield Arch(ARCH_ANY), NO_ADDRESS
        elif self.analysis.architecture == Architecture.X64:
            yield Arch(ARCH_AMD64), NO_ADDRESS
        elif self.analysis.architecture == Architecture.X86:
            yield Arch(ARCH_I386), NO_ADDRESS
        else:
            yield Arch(ARCH_ANY), NO_ADDRESS
        # I don't really get this feature tbh ...
        if self.analysis.type == "ELF":
            yield OS(OS_LINUX), NO_ADDRESS
        else:
            yield OS(OS_WINDOWS), NO_ADDRESS

################################# FILE FEATURES

    def extract_file_features(self):
        for extractor_fn in (self.extract_symbols, self.extract_sections, self.extract_strings, self.extract_embedded_pe, self.extract_format, self.extract_function_names):
            yield from extractor_fn()
        # .NET-specific
        if self.is_dotnet:
            yield from self.extract_dotnet_classes()
            if "CLR.Header" in self.analysis.struct:
                if not self.analysis.struct["CLR.Header"]["Flags"]["ILOnly"]:
                    yield Characteristic("mixed mode"), NO_ADDRESS

    def extract_symbols(self):
        for s in self.analysis.syms:
            if s.type == Symbol.Type.IMPORT:
                name = s.name
                if self.is_dotnet:
                    # stay compatible with dnlib nomenclature
                    if self.impl_map_table is not None and s.address in self.impl_map_table:
                        dll, _, symbol = name.rpartition(".")
                        # native import
                        for subname in capa.features.extractors.helpers.generate_symbols(dll, symbol):
                            yield Import(subname), EffectiveAddress(s.address)
                    else:
                        yield Import(malcat_to_dnlib_name(name)), EffectiveAddress(s.address)
                else:
                    yield Import(name), EffectiveAddress(s.address)
                    if "." in name and not ".ord" in name:
                        dll, _, symbol = name.rpartition(".")
                        for subname in capa.features.extractors.helpers.generate_symbols(dll, symbol):
                            yield Import(subname), EffectiveAddress(s.address)
            elif s.type == Symbol.Type.EXPORT:
                yield Export(s.name), EffectiveAddress(s.address)

    def extract_sections(self):
        for region in self.analysis.map:
            yield Section(region.name), AbsoluteVirtualAddress(region.virt)

    def extract_strings(self):
        for string in self.analysis.strings:
            yield String(string.text), EffectiveAddress(string.address)

    def extract_embedded_pe(self):
        for sf in self.analysis.carved:
            if sf.type == "PE":
                off = self.analysis.map.a2p(sf.address)
                if off is not None:
                    yield Characteristic("embedded pe"), FileOffsetAddress(off)

    def extract_format(self):
        if self.analysis.architecture == Architecture.DOTNET:
            yield Format(FORMAT_DOTNET), NO_ADDRESS
        else:
            file_format = {
                    "PE": FORMAT_PE,
                    "ELF": FORMAT_ELF,
            }.get(self.analysis.type, self.analysis.type.lower())
            yield Format(file_format), NO_ADDRESS
        
    def extract_function_names(self):
        for fn in self.analysis.fns:
            name = fn.fullname
            if not name.startswith("sub_"):
                if self.is_dotnet:
                    name = malcat_to_dnlib_name(name)
            yield FunctionName(name), EffectiveAddress(fn.address)

    def extract_dotnet_classes(self):
        if self.is_dotnet:
            namespaces = set()
            if "TypeDefTable" in self.analysis.struct:
                for td in self.analysis.struct["TypeDefTable"]:
                    s = get_string(td["TypeNamespace"], self.analysis.parser)
                    namespaces.add(s)
                    symname = resolve_symbol_name(td, self.analysis.parser)
                    yield Class(symname), EffectiveAddress(td.address)
            if "TypeRefTable" in self.analysis.struct:
                for tr in self.analysis.struct["TypeRefTable"]:
                    s = get_string(tr["TypeNamespace"], self.analysis.parser)
                    namespaces.add(s)
                    symname = resolve_symbol_name(tr, self.analysis.parser)
                    yield Class(symname), EffectiveAddress(td.address)
            namespaces.discard("")

            for namespace in namespaces:
                # namespace do not have an associated token, so we yield 0x0
                yield Namespace(namespace), NO_ADDRESS

            
################################# FUNCTION FEATURES

    def is_library_function(self, ea: int) -> bool:
        # TODO
        return False

    def get_function_name(self, ea: int) -> str:
        # TODO
        raise KeyError(ea)

    def get_functions(self):
        for fn in self.analysis.fns:
            yield FunctionHandle(address=EffectiveAddress(fn.address), inner=fn)

    def extract_function_features(self, fh):
        for extractor_fn in (self.extract_function_recursive, self.extract_function_loop, self.extract_function_calls_to):
            yield from extractor_fn(fh)

    def extract_function_loop(self, f):
        for bb in self.analysis.loops[f.inner.start:f.inner.end]:
            yield Characteristic("loop"), f.address
            break

    def extract_function_recursive(self, f):
        bb = self.analysis.cfg[f.inner.address]
        for edge in bb.incoming:
            if edge.address == f.inner.address:
                # this is not correct, we should check for the presence of a call to the function's body, anywhere inside the function
                # but ida's extractor works this way so let's be consistent ...
                yield Characteristic("recursive call"), f.address

    def extract_function_calls_to(self, f):
        bb = self.analysis.cfg[f.inner.address]
        for edge in bb.incoming:
            if edge.type == BasicBlockEdge.Type.CALL:
                yield Characteristic("calls to"), EffectiveAddress(edge.address)


################################# BASIC BLOCKS FEATURES

    def get_basic_blocks(self, f):
        first = None
        last = None
        for bb in self.analysis.cfg[f.inner.start:f.inner.end]:
            if first is None:
                first = bb
            last = bb
            if not bb.exotic:
                yield BBHandle(inner=MergedBasicBlock(first, last), address=EffectiveAddress(first.address))
                first = None

    def extract_basic_block_features(self, f, bb):
        for extractor_fn in (self.extract_bb_tigh_loop, self.extract_bb_stack_string):
            yield from extractor_fn(f, bb)

    def extract_bb_tigh_loop(self, f, bb):
        if bb.inner.is_loop:
            yield Characteristic("tight loop"), bb.address

    def extract_bb_stack_string(self, f, bb):
        for s in self.analysis.strings[bb.inner.start:bb.inner.end]:
            if s.type == FoundString.Type.DYNAMIC:
                yield Characteristic("stack string"), bb.address
                break


################################# INSTRUCTION FEATURES

    def get_instructions(self, f, bb):
        for insn in self.analysis.asm[bb.inner.start:bb.inner.end]:
            yield InsnHandle(address=EffectiveAddress(insn.address), inner=insn)


    def extract_insn_features(self, f, bb, insn):
        if self.is_dotnet:
            fns = (
                self.extract_insn_api_features, 
                self.extract_insn_dotnet_properties,
                self.extract_insn_number_features, 
                self.extract_insn_strings_features, 
                self.extract_insn_class_features, 
                self.extract_insn_unmanaged_call_features, 
            )
        else:
            fns = (
                self.extract_insn_number_features, 
                self.extract_insn_api_features, 
                self.extract_insn_bytes_features, 
                self.extract_insn_strings_features, 
                self.extract_insn_offset_features, 
                self.extract_insn_nzxor_characteristic_features,
                self.extract_insn_mnemonic_features, 
                self.extract_insn_peb_access_characteristic_features,
                self.extract_insn_segment_access_features, 
                self.extract_insn_cross_section_cflow, 
                self.extract_function_calls_from,
                self.extract_function_indirect_call_characteristic_features,
                self.extract_insn_obfs_call_plus_5_characteristic_features,
                )
        for extractor_fn in fns:
            yield from extractor_fn(f, bb, insn)


    def extract_insn_number_features(self, f, bb, insn):
        if insn.inner.type in (Instruction.Type.RETURN, Instruction.Type.STACK):
            return
        for i, op in enumerate(insn.inner):
            if op.type == InstructionOperand.Type.CONSTANT:
                if op.value is not None and self.analysis.map.from_virt(op.value) is None:
                    yield Number(op.value), insn.address
                    yield OperandNumber(i, op.value), insn.address

    def extract_insn_bytes_features(self, f, bb, insn):
        if insn.inner.type in (Instruction.Type.CALL,):
            return
        for op in insn.inner:
            if op.value:
                ea = self.analysis.map.from_virt(op.value)
                if ea and not self.analysis.cfg[ea].code:
                    off = self.analysis.map.to_phys(ea)
                    if off:
                        extracted_bytes = self.analysis.file.read(off, MAX_BYTES_FEATURE_SIZE)
                        if extracted_bytes and not capa.features.extractors.helpers.all_zeros(extracted_bytes):
                            yield Bytes(extracted_bytes), insn.address

    def extract_insn_strings_features(self, f, bb, insn):
        for op in insn.inner:
            if op.value:
                ea = self.analysis.map.from_virt(op.value)
                if ea and ea in self.analysis.strings:
                    yield String(self.analysis.strings[ea].text), insn.address

    def extract_insn_offset_features(self, f, bb, insn):
        if not self.is_dotnet:
            for i, op in enumerate(insn.inner):
                if op.type == InstructionOperand.Type.OBJECT and op.value:
                    ea = self.analysis.map.from_virt(op.value)
                    if ea is not None:
                        yield Offset(op.value), insn.address
                        yield OperandOffset(i, op.value), insn.address

    def extract_insn_api_features(self, f, bb, insn):
        api = None
        if insn.inner.type == Instruction.Type.CALL and len(insn.inner) > 0 and insn.inner[0].type in (InstructionOperand.Type.GLOBAL, InstructionOperand.Type.SYMBOL, InstructionOperand.Type.CONSTANT):
            if insn.inner[0].value:
                address = self.analysis.map.from_virt(insn.inner[0].value)
                if address is not None:
                    for symbol in self.analysis.syms[address]:
                        if symbol.type == Symbol.Type.IMPORT:
                            api = symbol.name
                            break
            else:
                address = None
            if not api:
                api = insn.inner[0].symbol

            if not api and address is not None:
                # is it a call to jmp api ?
                bb2 = self.analysis.cfg[address]
                insn2 = self.analysis.asm[address]
                if bb2.code and insn2.type == Instruction.Type.JUMP and len(insn2) > 0 and insn2[0].type in (InstructionOperand.Type.GLOBAL, InstructionOperand.Type.SYMBOL):
                    if insn2[0].value:
                        address = self.analysis.map.from_virt(insn2[0].value)
                    else:
                        address = None
                    api = insn2[0].symbol
                    if not api and address:
                        for symbol in self.analysis.syms[address]:
                            if symbol.type == Symbol.Type.IMPORT:
                                api = symbol.name
                                break
        elif insn.inner.type == Instruction.Type.ASSIGN and len(insn.inner) == 2 and insn.inner[0].type == InstructionOperand.Type.REGISTER and insn.inner[1].type == InstructionOperand.Type.GLOBAL:
            if insn.inner[1].value:
                address = self.analysis.map.from_virt(insn.inner[1].value)
            else:
                address = None
            api = insn.inner[1].symbol
            outgoing = bb.inner.outgoing
            if not api and address and len(outgoing) == 1:
                for symbol in self.analysis.syms[address]:
                    if symbol.type == Symbol.Type.IMPORT:
                        api = symbol.name
                        break
            try:
                next_insn = self.analysis.asm[insn.inner.address + insn.inner.size]
            except:
                next_insn = None
            if next_insn is not None and next_insn.type == Instruction.Type.CALL and next_insn[0].type == InstructionOperand.Type.REGISTER and next_insn[0].register == insn.inner[0].register:
                pass
            else:
                api = None
        if api:
            if self.is_dotnet:
                if self.impl_map_table is not None and address is not None and address in self.impl_map_table:
                    # for native imports, the operand value points to the methoddef table, toherwise it would be to the memberref
                    dll, _, symbol = api.rpartition(".")
                    for name in capa.features.extractors.helpers.generate_symbols(dll, symbol):
                        yield API(name), insn.address
                else:
                    # dnlib syntax
                    api = malcat_to_dnlib_name(api)
                    yield API(api), insn.address
            else:
                dll, _, symbol = api.rpartition(".")
                for name in capa.features.extractors.helpers.generate_symbols(dll, symbol):
                    yield API(name), insn.address

    def extract_insn_nzxor_characteristic_features(self, f, bb, insn):
        if insn.inner.type == Instruction.Type.XOR:   # malcat already checks that it is not a zero-xor or a stack-cookie xor
            yield Characteristic("nzxor"), insn.address

    def extract_insn_mnemonic_features(self, f, bb, insn):
        yield Mnemonic(insn.inner.mnemonic), insn.address

    def extract_insn_obfs_call_plus_5_characteristic_features(self, f, bb, insn):
        if insn.inner.type != Instruction.Type.CALL or insn.inner.address + insn.inner.size != bb.inner.end:
            return
        for edge in bb.inner.outgoing:
            if edge.type == BasicBlockEdge.Type.CALL and edge.address == bb.inner.end:
                yield Characteristic("call $+5"), insn.address

    def extract_insn_peb_access_characteristic_features(self, f, bb, insn):
        if insn.inner.type not in ((Instruction.Type.ASSIGN, Instruction.Type.PUSH)):
            return
        if all(map(lambda op: op.type != InstructionOperand.Type.GLOBAL or op.value not in (0x30, 0x60), insn.inner)):
            return
        disasm = str(insn.inner)
        if "fs:[0x30]" in disasm or "gs:[0x60]" in disasm:
            yield Characteristic("peb access"), insn.address

    def extract_insn_segment_access_features(self, f, bb, insn):
        if all(map(lambda op: op.type != InstructionOperand.Type.GLOBAL, insn.inner)):
            return
        disasm = str(insn.inner)
        if "fs:[" in disasm:
            yield Characteristic("fs access"), insn.address
        if "gs:[" in disasm:
            yield Characteristic("gs access"), insn.address

    def extract_insn_cross_section_cflow(self, f, bb, insn):
        if insn.inner.address + insn.inner.size != bb.inner.end:
            return
        region = self.analysis.map[insn.inner.address]
        if region is None:
            return
        for edge in bb.inner.outgoing:
            if edge.address not in region:
                yield Characteristic("cross section flow"), insn.address


    def extract_function_calls_from(self, f, bb, insn):
        if insn.inner.address + insn.inner.size != bb.inner.end:
            return
        for edge in bb.inner.outgoing:
            if edge.type == BasicBlockEdge.Type.CALL:
                yield Characteristic("calls from"), EffectiveAddress(edge.address)


    def extract_function_indirect_call_characteristic_features(self, f, bb, insn):
        if insn.inner.type == Instruction.Type.CALL and len(insn.inner) >= 1 and insn.inner[0].value:
            # ignore call [import]
            target = self.analysis.map.from_virt(insn.inner[0].value)
            if target and target in self.analysis.syms:
                return
            if len(bb.inner.outgoing) <= 1:
                yield Characteristic("indirect call"), insn.address

    def extract_insn_dotnet_properties(self, f, bb, insn):
        for op in insn.inner:
            if op.type == InstructionOperand.Type.OBJECT:
                name = op.symbol
                if name:
                    name = malcat_to_dnlib_name(name)
                    access = None
                    if op.action in (InstructionOperand.Action.R, InstructionOperand.Action.RW):
                        yield Property(name, access=FeatureAccess.READ), insn.address
                    if op.action in (InstructionOperand.Action.W, InstructionOperand.Action.RW):
                        yield Property(name, access=FeatureAccess.WRITE), insn.address
        if insn.inner.type == Instruction.Type.CALL and len(insn.inner) > 0 and insn.inner[0].type in (InstructionOperand.Type.GLOBAL, InstructionOperand.Type.SYMBOL, InstructionOperand.Type.CONSTANT):
            if insn.inner[0].value:
                address = self.analysis.map.from_virt(insn.inner[0].value)
                if self.member_ref_table is not None and address in self.member_ref_table:
                    name = insn.inner[0].symbol
                    method_name = name
                    idx = name.rfind(".")
                    if idx >= 0:
                        method_name = name[idx+1:]
                        type_name = name[:idx]
                    if method_name.startswith("get_"):
                        #assume it is a property
                        yield Property("{}::{}".format(type_name, method_name[4:]), access=FeatureAccess.READ), insn.address
                    elif method_name.startswith("set_"):
                        #assume it is a property
                        yield Property("{}::{}".format(type_name, method_name[4:]), access=FeatureAccess.WRITE), insn.address

    def extract_insn_class_features(self, f, bb, insn):
        for op in insn.inner:
            names = []
            if op.type in (InstructionOperand.Type.OBJECT, InstructionOperand.Type.SYMBOL) or insn.inner.type == Instruction.Type.CALL and op.type == InstructionOperand.Type.CONSTANT:
                name = op.symbol
                if op.value is not None:
                    address = self.analysis.map.from_virt(op.value)
                else:
                    address = None
                if self.impl_map_table is not None and address is not None and address in self.impl_map_table:
                    # native import
                    continue
                elif op.symbol is None:
                    continue
                elif self.type_def_table is not None and address is not None and address in self.type_def_table:
                    names = op.symbol.split(".")
                elif self.type_ref_table is not None and address is not None and address in self.type_ref_table:
                    names = op.symbol.split(".")
                else:
                    names = op.symbol.split(".")[:-1]
            if names:
                yield Class(".".join(names)), insn.address
                if len(names) > 2:
                    yield Class(".".join(names[:-1])), insn.address
                
    def extract_insn_unmanaged_call_features(self, f, bb, insn):
        if insn.inner.type == Instruction.Type.CALL:
            for op in insn.inner:
                if op.type == InstructionOperand.Type.SYMBOL and op.value is not None:
                    address = self.analysis.map.from_virt(op.value)
                    if self.impl_map_table is not None and address in self.impl_map_table:
                        # native import
                        yield Characteristic("unmanaged call"), insn.address
