"""
name: Latrodectus
category: config extractors
author: malcat

Decrypt AES-encrypted string for Latrodectus samples
"""
import json
import struct
import malcat

malcat.setup()  # Add Malcat's data/ and bindings/ directories to sys.path when called in headless mode
from transforms.block import AesDecrypt



############################ config extraction

def find_decryption_function_and_key(a):
    """
    find a function with high number of incoming refs and a large stack string
    """
    for fn in a.fns:
        if len(fn.inrefs) > 100:
            for s in a.strings[fn.start:fn.end]:
                if s.type == malcat.FoundString.Type.DYNAMIC and s.size >= 32:
                    return fn, s.bytes[-32:]
    return None, b""

def iter_encrypted_strings(a, decrypt_fn):
    """
    Returns a tuple of (offset, IV, ciphertext) for every encrypted string found
    """
    for inref in decrypt_fn.inrefs:
        bb = a.cfg[inref.address]
        for instr in bb:
            if len(instr) == 2 and instr[0].type == malcat.InstructionOperand.Type.REGISTER and len(instr.outrefs) == 1:
                ea = instr.outrefs[0].address
                offset = a.a2p(ea)
                if offset:
                    #print(f"Found encrypted string at {a.ppa(ea)}")
                    size, = struct.unpack("<H", a.file[offset:offset+2])
                    yield offset, a.file[offset+2: offset+2+16], a.file[offset+2+16:offset+2+16+size]

def latrodectus_decrypt_strings(a, in_place=False):
    res = []

    decrypt_fn, key = find_decryption_function_and_key(a)
    if not decrypt_fn:
        raise ValueError("Could not locate decryption function")
    print(f"Found decryption function: {a.ppa(decrypt_fn)}, key: {key.hex()}")

    decryptor = AesDecrypt()
    for offset, iv, ciphertext in iter_encrypted_strings(a, decrypt_fn):
        decrypted = decryptor.run(ciphertext, mode="ctr", iv=iv, key=key)
        if in_place:
            # patch the decrypted string directly
            totalsz = len(ciphertext) + 2 + 16
            a.file[offset:offset+len(decrypted)] = decrypted
            if len(decrypted) < totalsz:
                a.file[offset+len(decrypted):offset+totalsz] = b"\x00" * (totalsz - len(decrypted))

        if len(decrypted) > 2 and decrypted[1] == 0:
            try:
                decrypted = decrypted.decode("utf-16le")
            except: pass
        else:
            try:
                decrypted = decrypted.decode("ascii")
            except: pass
        if type(decrypted) == str and decrypted.endswith("\x00"):
            # remove null byte terminator
            decrypted = decrypted[:-1]
        res.append(decrypted)
    return res 

def latrodectus_patch_api_calls(a):
    for cst in analysis.constants:
        if cst.category == "apihash":
            apiname = cst.name[5:-1]    # remove the "hash(" and ")"
            instr_address = analysis.cfg.align(cst.address)   # the constant is in the middle of an instruction, find the start of the instruction
            count_lea = 0
            for instr in analysis.asm[instr_address: instr_address+40]:    # iterate over the next instructions
                if instr.opcode == "lea" and instr[1].type == malcat.InstructionOperand.Type.GLOBAL and len(instr.outrefs) == 1:    # lea xx, [global address] ?
                    count_lea += 1
                    if count_lea == 2:  # first one is dll, second one stores the resolved API address
                        ref = instr.outrefs[0]
                        analysis.syms[ref.address] = apiname    # add user label
                        break

################################ MAIN
    
if __name__ == "__main__":
    decstringsets = []
    if "analysis" in globals():
        # called from the gui, analysis object is already instanciated with the current file
        with analysis.history.group():
            decstringsets.append(latrodectus_decrypt_strings(analysis, in_place=True))
            latrodectus_patch_api_calls(analysis)
        print("Strings have been decrypted in-place, api calls have been patched in place")
    else:
        # called in headless mode, we need to analyse a file first
        import optparse
        usage = "usage: %prog <file1> [file2] ... [fileN]"
        parser = optparse.OptionParser(usage=usage, description="""Extract strings for (unpacked) Latrodectus samples""")
        options, args = parser.parse_args()
        if len(args) < 1:
            parser.error("Please give path to a file")

        for fname in args:
            a = malcat.analyse(fname)
            decstringsets.append(latrodectus_decrypt_strings(a, in_place=False))

    for decstringset in decstringsets:
        print("\nLATRODECTUS_strings = ", end="")
        print(json.dumps(decstringset, indent=4))
