2019KCTF总决赛 第四题:西部乐园 WP
2019-12-17 15:38:34 Author: bbs.pediy.com(查看原文) 阅读量:119 收藏

import struct
import time
from capstone import *
from capstone.x86 import *


DATA_DIR = 'D:\\work\\2019\\pediy_Q4\\6\\'


def hex2bin(s):
    return s.decode('hex')


def bin2hex(s):
    return s.encode('hex')


def load_file(filename):
    f = open(filename, 'rb')
    s = f.read()
    f.close()
    return s


def save_file(filename, s):
    f = open(filename, 'wb')
    f.write(s)
    f.close()
    return


def rol8(v, n):
    n &= 7
    if n == 0:
        return v
    return ((v << n) | (v >> (8 - n))) & 0xFF


def ror8(v, n):
    n &= 7
    if n == 0:
        return v
    return ((v >> n) | (v << (8 - n))) & 0xFF


def unwrap_u32(s, offset):
    return struct.unpack('<I', s[offset:offset+4])[0]


def unwrap_u8(s, offset):
    return struct.unpack('<B', s[offset:offset+1])[0]


def va_to_offset(va):
    return va - (0x401000 - 0x400)


def offset_to_va(offset):
    return offset + (0x401000 - 0x400)


class Pattern(object):
    def __init__(self, offset, size, pattern=''):
        self.offset = offset
        self.size = size
        self.pattern = pattern

    def place_holder(self):
        return self.pattern == ''

    def __str__(self):
        if not self.place_holder():
            return 'PatternInfo(offset:%d, hex:%s)' % (self.offset, bin2hex(self.pattern))
        return 'PatternInfo(offset:%d, size:%d)' % (self.offset, self.size)

    def __repr__(self):
        return str(self)


class PatternObject(object):
    def __init__(self, pattern=''):
        self.patterns = []
        offset = 0
        while True:
            i = pattern.find('*', offset)
            if i == -1:
                # right
                right_pattern = hex2bin(pattern[offset:])
                self.patterns.append(Pattern(offset / 2, len(right_pattern), right_pattern))
                break
            # left
            left_pattern = hex2bin(pattern[offset:i])
            self.patterns.append(Pattern(offset / 2, len(left_pattern), left_pattern))
            # middle
            k = i + 1
            while pattern[k] == '*':
                k += 1
            self.patterns.append(Pattern(i / 2, (k - i) / 2))
            offset = k

    def first_pattern(self):
        return self.patterns[0].pattern

    def match(self, buf, offset):
        i = offset
        for pat in self.patterns:
            if pat.place_holder():
                i += pat.size
            elif pat.pattern == buf[i:i+pat.size]:
                i += pat.size
            else:
                return False
        return True


def find_pattern(buf, pattern, offset=0):
    pat = PatternObject(pattern)
    size = len(buf)
    while offset < size:
        # find first pattern
        i = buf.find(pat.first_pattern(), offset, offset + size)
        if i == -1:
            return -1
        if pat.match(buf, i):
            return i
        offset += 1
    return -1


class Instruction(object):
    def __init__(self, address=0, imm=0):
        self.address = address
        self.imm = imm
        return

    def get_address(self):
        return self.address

    def get_imm(self):
        return self.imm

    def do(self, v, counter):  # type:(int, int) -> int
        return 0


class Instruction8(Instruction):
    def get_imm(self):
        return self.imm & 0xFF


class InstructionMov8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: mov al,%02x' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return self.get_imm()


class InstructionNeg8(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: neg al' % self.get_address()

    def do(self, v, counter):
        return 0 - v


class InstructionNot8(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: not al' % self.get_address()

    def do(self, v, counter):
        return ~v


class InstructionSub8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: sub al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return v - self.get_imm()


class InstructionSub8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: sub al,cl' % self.get_address()

    def do(self, v, counter):
        return v - counter


class InstructionAdd8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: add al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return v + self.get_imm()


class InstructionAdd8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: add al,cl' % self.get_address()

    def do(self, v, counter):
        return v + counter


class InstructionXor8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: xor al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return v ^ self.get_imm()


class InstructionXor8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: xor al,cl' % self.get_address()

    def do(self, v, counter):
        return v ^ counter


class InstructionMul8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: mul cl' % self.get_address()

    def do(self, v, counter):
        return v * counter


class InstructionXor8Expression(Instruction8):
    def __init__(self, address, expr):
        Instruction.__init__(self, address=address)
        self.expr = expr

    def __repr__(self):
        s = ''
        s += '%08X: xor al,bl' % self.get_address()
        for ins in self.expr:
            s += '\n\t%s' % ins
        return s

    def do(self, v, counter):
        t = 0
        for ins in self.expr:  # type:Instruction
            t = ins.do(t, counter) & 0xFF
        return v ^ t


class InstructionRor8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: ror al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return ror8(v, self.get_imm())


class InstructionRor8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: ror al,cl' % self.get_address()

    def do(self, v, counter):
        return ror8(v, counter)


class InstructionRol8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: rol al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return rol8(v, self.get_imm())


class InstructionRol8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: rol al,cl' % self.get_address()

    def do(self, v, counter):
        return rol8(v, counter)


def pe_get_code_partial(ary, va, size=0x800):
    offset = va_to_offset(va)
    buf = bytes(ary[offset:offset+size])
    return buf


def simplify_inst(ins_ary):
    # return ins_ary
    if len(ins_ary) == 0:
        return ins_ary
    simplified_ary = []
    for old_ins in ins_ary:
        if len(simplified_ary) == 0:
            simplified_ary.append(old_ins)
            continue
        ins = simplified_ary[-1]
        if isinstance(ins, InstructionAdd8) and isinstance(old_ins, InstructionAdd8):
            simplified_ary[-1] = InstructionAdd8(ins.get_address(), ins.get_imm() + old_ins.get_imm())
            continue
        if isinstance(ins, InstructionAdd8) and isinstance(old_ins, InstructionSub8):
            simplified_ary[-1] = InstructionAdd8(ins.get_address(), ins.get_imm() - old_ins.get_imm())
            continue
        if isinstance(ins, InstructionSub8) and isinstance(old_ins, InstructionAdd8):
            simplified_ary[-1] = InstructionSub8(ins.get_address(), ins.get_imm() - old_ins.get_imm())
            continue
        if isinstance(ins, InstructionSub8) and isinstance(old_ins, InstructionSub8):
            simplified_ary[-1] = InstructionSub8(ins.get_address(), ins.get_imm() + old_ins.get_imm())
            continue
        if isinstance(ins, InstructionXor8) and isinstance(old_ins, InstructionXor8):
            simplified_ary[-1] = InstructionXor8(ins.get_address(), ins.get_imm() ^ old_ins.get_imm())
            continue
        if isinstance(ins, InstructionRol8) and isinstance(old_ins, InstructionRol8):
            simplified_ary[-1] = InstructionRol8(ins.get_address(), ins.get_imm() + old_ins.get_imm())
            continue
        if isinstance(ins, InstructionRor8) and isinstance(old_ins, InstructionRor8):
            simplified_ary[-1] = InstructionRor8(ins.get_address(), ins.get_imm() + old_ins.get_imm())
            continue
        if isinstance(ins, InstructionRol8) and isinstance(old_ins, InstructionRor8):
            v1 = ins.get_imm()
            v2 = old_ins.get_imm()
            if v1 > v2:
                simplified_ary[-1] = InstructionRol8(ins.get_address(), v1 - v2)
            else:
                simplified_ary[-1] = InstructionRor8(ins.get_address(), v2 - v1)
            continue
        if isinstance(ins, InstructionRor8) and isinstance(old_ins, InstructionRol8):
            v1 = ins.get_imm()
            v2 = old_ins.get_imm()
            if v1 > v2:
                simplified_ary[-1] = InstructionRor8(ins.get_address(), v1 - v2)
            else:
                simplified_ary[-1] = InstructionRol8(ins.get_address(), v2 - v1)
            continue
        simplified_ary.append(old_ins)
    return simplified_ary


def pe_decrypt_code_partial(ary, va, size, ins_ary, counter_type=0):
    offset = va_to_offset(va)
    # print('counter_type: %d' % counter_type)
    for ins in ins_ary:
        print('%s' % ins)

    simplified_ary = simplify_inst(ins_ary)
    # print('simplified')
    # for ins in simplified_ary:
    #     print('%s' % ins)

    for i in range(size):
        v = ary[offset + i]
        if counter_type == 0:
            counter = size - i
        elif counter_type == 1:
            counter = i + 1
        elif counter_type == 2:
            counter = size - 1 - i
        else:
            counter = 0
        counter &= 0xFF
        for ins in simplified_ary:  # type: Instruction
            v = ins.do(v, counter)
            v &= 0xFF
        ary[offset + i] = v
    return


patch_enc_va_info = dict()


def pe_smc_decrypt(ary, enc_va, enc_size, initial_va, counter_type):
    va = initial_va
    ins_ary = []
    md = Cs(CS_ARCH_X86, CS_MODE_32)
    md.detail = True

    # search for decrypt start
    t = -1
    for i in md.disasm(pe_get_code_partial(ary, va, size=0x80), va):  # type:CsInsn
        if i.id == X86_INS_LODSB:
            # print('%08X: lodsb' % i.address)
            t = i.address + i.size
            break
    if t == -1:
        return -1
    va = t

    # begin(inclusive), end
    jmp_ary = []

    # push ecx; mov cl,0xA7; xor al,cl; pop ecx
    ecx_pushed = False
    ecx_popped = False
    ecx_imm = 0  # TODO need inst_ary? like ebx?

    # push eax; op eax; ...; mov ebx,eax; pop eax
    eax_pushed = False
    eax_popped = False
    ebx_inst_ary = []

    while va < (initial_va + 0x800):
        for i in md.disasm(pe_get_code_partial(ary, va, size=0x80), va):  # type:CsInsn
            va = i.address + i.size
            # skip code in jmp area
            in_jmp_area = False
            for begin_address, end_address in jmp_ary:
                if begin_address <= i.address < end_address:
                    in_jmp_area = True
                    break
            if in_jmp_area:
                continue

            # print("%x:\t%s\t%s" % (i.address, i.mnemonic, i.op_str))

            if len(i.operands) == 1:
                op0 = i.operands[0]  # type: X86Op
                if i.id == X86_INS_JMP and op0.type == X86_OP_IMM:
                    # print('jmp: %08X-%08X' % (i.address + i.size, op0.imm))
                    jmp_ary.append((i.address + i.size, op0.imm))
                    continue

            if len(i.operands) == 0:
                if i.id == X86_INS_PUSHAL:
                    eax_pushed = True
                elif i.id == X86_INS_POPAL:
                    eax_popped = True

            if len(i.operands) == 1:
                op0 = i.operands[0]  # type: X86Op
                if i.id == X86_INS_PUSH and op0.type == X86_OP_REG:
                    if op0.reg == X86_REG_EAX:
                        eax_pushed = True
                    elif op0.reg == X86_REG_ECX:
                        ecx_pushed = True
                    continue
                if i.id == X86_INS_POP and op0.type == X86_OP_REG:
                    if op0.reg == X86_REG_EAX:
                        eax_popped = True
                    elif op0.reg == X86_REG_ECX:
                        ecx_popped = True
                    continue

            if eax_pushed and (not eax_popped):
                if len(i.operands) == 2:
                    op0 = i.operands[0]  # type: X86Op
                    op1 = i.operands[1]  # type:X86Op
                    if op0.type == X86_OP_REG and (op0.reg in (X86_REG_AL, X86_REG_EAX)):
                        # OP EAX/AL, IMM/REG
                        if op1.type == X86_OP_IMM:
                            if i.id == X86_INS_MOV:
                                ebx_inst_ary.append(InstructionMov8(i.address, op1.imm))
                            elif i.id == X86_INS_SUB:
                                ebx_inst_ary.append(InstructionSub8(i.address, op1.imm))
                            elif i.id == X86_INS_ADD:
                                ebx_inst_ary.append(InstructionAdd8(i.address, op1.imm))
                            elif i.id == X86_INS_XOR:
                                ebx_inst_ary.append(InstructionXor8(i.address, op1.imm))
                            elif i.id == X86_INS_ROL:
                                ebx_inst_ary.append(InstructionRol8(i.address, op1.imm))
                            elif i.id == X86_INS_ROR:
                                ebx_inst_ary.append(InstructionRor8(i.address, op1.imm))
                        # OP EAX/AL, ECX/CL
                        elif op1.type == X86_OP_REG and (op1.reg in (X86_REG_CL, X86_REG_ECX)):
                            if i.id == X86_INS_SUB:
                                ebx_inst_ary.append(InstructionSub8Counter(i.address))
                            elif i.id == X86_INS_ADD:
                                ebx_inst_ary.append(InstructionAdd8Counter(i.address))
                            elif i.id == X86_INS_XOR:
                                ebx_inst_ary.append(InstructionXor8Counter(i.address))
                            elif i.id == X86_INS_ROL:
                                ebx_inst_ary.append(InstructionRol8Counter(i.address))
                            elif i.id == X86_INS_ROR:
                                ebx_inst_ary.append(InstructionRor8Counter(i.address))
                elif len(i.operands) == 1:
                    op0 = i.operands[0]  # type: X86Op
                    if op0.type == X86_OP_REG and (op0.reg in (X86_REG_AL, X86_REG_EAX)):
                        if i.id == X86_INS_NEG:
                            ebx_inst_ary.append(InstructionNeg8(i.address))
                        elif i.id == X86_INS_NOT:
                            ebx_inst_ary.append(InstructionNot8(i.address))
                    if i.id == X86_INS_MUL and op0.type == X86_OP_REG and (op0.reg in (X86_REG_CL, X86_REG_ECX)):
                        ebx_inst_ary.append(InstructionMul8Counter(i.address))
                continue  # until we meet pop eax

            if len(i.operands) == 2:
                op0 = i.operands[0]  # type: X86Op
                op1 = i.operands[1]  # type:X86Op
                # mov [esi-1],al (CLD)
                # mov [esi+1],al (STD)
                if i.id == X86_INS_MOV and op0.type == X86_OP_MEM and op0.mem.base == X86_REG_ESI and op1.type == X86_OP_REG and op1.reg == X86_REG_AL:
                    # collect decrypt instruction finished
                    if len(ins_ary) == 0:
                        return -1
                    if op0.mem.disp == 1:
                        counter_type = 1
                    tmp_ins = ins_ary[-1]
                    if isinstance(tmp_ins, InstructionRor8) and tmp_ins.get_imm() == 0xD8:
                        if enc_va in patch_enc_va_info:
                            patch_enc_va_info[enc_va] += 1
                        else:
                            patch_enc_va_info[enc_va] = 1
                    if enc_va in patch_enc_va_info and patch_enc_va_info[enc_va] == 3:
                        ary[va_to_offset(enc_va)] = 0xE8  # ugly hack
                        return va
                    pe_decrypt_code_partial(ary, enc_va, enc_size, ins_ary, counter_type=counter_type)
                    return va
                elif i.id == X86_INS_MOV and op0.type == X86_OP_REG and op0.reg == X86_REG_CL:
                    # mov cl, A7
                    # mov cl, ah
                    if not ecx_pushed:
                        continue
                    if op1.type == X86_OP_IMM:
                        ecx_imm = op1.imm
                    elif op1.type == X86_OP_REG and op1.reg == X86_REG_AH:
                        ecx_imm = 0
                elif op0.type == X86_OP_REG and (op0.reg in (X86_REG_AL, X86_REG_EAX)):
                    # OP EAX/AL, IMM/REG
                    if op1.type == X86_OP_IMM:
                        if i.id == X86_INS_SUB:
                            ins_ary.append(InstructionSub8(i.address, op1.imm))
                        elif i.id == X86_INS_ADD:
                            ins_ary.append(InstructionAdd8(i.address, op1.imm))
                        elif i.id == X86_INS_XOR:
                            ins_ary.append(InstructionXor8(i.address, op1.imm))
                        elif i.id == X86_INS_ROL:
                            ins_ary.append(InstructionRol8(i.address, op1.imm))
                        elif i.id == X86_INS_ROR:
                            ins_ary.append(InstructionRor8(i.address, op1.imm))
                    # OP EAX/AL, EBX/BL
                    elif op1.type == X86_OP_REG and (op1.reg in (X86_REG_BL,)):
                        if i.id == X86_INS_XOR:
                            ins_ary.append(InstructionXor8Expression(i.address, ebx_inst_ary))
                    # OP EAX/AL, ECX/CL
                    elif op1.type == X86_OP_REG and (op1.reg in (X86_REG_CL, X86_REG_ECX)):
                        if ecx_pushed and (not ecx_popped):
                            if i.id == X86_INS_SUB:
                                ins_ary.append(InstructionSub8(i.address, ecx_imm))
                            elif i.id == X86_INS_ADD:
                                ins_ary.append(InstructionAdd8(i.address, ecx_imm))
                            elif i.id == X86_INS_XOR:
                                ins_ary.append(InstructionXor8(i.address, ecx_imm))
                            elif i.id == X86_INS_ROL:
                                ins_ary.append(InstructionRol8(i.address, ecx_imm))
                            elif i.id == X86_INS_ROR:
                                ins_ary.append(InstructionRor8(i.address, ecx_imm))
                        else:
                            if i.id == X86_INS_SUB:
                                ins_ary.append(InstructionSub8Counter(i.address))
                            elif i.id == X86_INS_ADD:
                                ins_ary.append(InstructionAdd8Counter(i.address))
                            elif i.id == X86_INS_XOR:
                                ins_ary.append(InstructionXor8Counter(i.address))
                            elif i.id == X86_INS_ROL:
                                ins_ary.append(InstructionRol8Counter(i.address))
                            elif i.id == X86_INS_ROR:
                                ins_ary.append(InstructionRor8Counter(i.address))
            elif len(i.operands) == 1:
                op0 = i.operands[0]  # type: X86Op
                if op0.type == X86_OP_REG and op0.reg in (X86_REG_AL, X86_REG_EAX):
                    if i.id == X86_INS_NEG:
                        ins_ary.append(InstructionNeg8(i.address))
                    elif i.id == X86_INS_NOT:
                        ins_ary.append(InstructionNot8(i.address))
    return -1


def get_current_time():
    now = int(time.time())
    t = time.localtime(now)
    return time.strftime("%Y-%m-%d %H:%M:%S", t)


def pe_smc_decrypt_repeated(ary, va_start, va_end=0x506000):  # type:(str, int, int) -> None
    decrypted = 0
    va = va_start
    while va < va_end:
        # it's repeated-smc, no point to get full code
        buf = pe_get_code_partial(ary, va)
        # call $+5; pop esi; sub esi,imm1; add esi,imm2; mov ecx,imm3
        # call $+5; pop esi; sub esi,imm1; add esi,imm2; jmp xx; mov ecx,imm3
        pat = 'E8000000005E81EE******0081C6******00'
        i = find_pattern(buf, pat)
        if i == -1:
            break
        jmp_offset = i + len(pat)/2
        if buf[jmp_offset] == hex2bin('EB'):  # jmp $+XX
            mov_offset = jmp_offset + unwrap_u8(buf, jmp_offset+1) + 2
        else:
            mov_offset = jmp_offset
        # mov ecx,imm3
        if buf[mov_offset] != hex2bin('B9') or buf[mov_offset+4] != hex2bin('00'):
            va += jmp_offset
            continue
        ip = i + 5
        enc_va = (va + ip) + unwrap_u32(buf, ip + 9) - unwrap_u32(buf, ip + 3)
        enc_size = unwrap_u32(buf, mov_offset + 1)

        # sub ecx,1
        sub_ecx_offset = mov_offset + 5
        while buf[sub_ecx_offset] == hex2bin('EB'):
            sub_ecx_offset = sub_ecx_offset + unwrap_u8(buf, sub_ecx_offset + 1) + 2
        if buf[sub_ecx_offset:sub_ecx_offset+3] == hex2bin('83E901'):
            counter_type = 2
        else:
            counter_type = 0

        # position independent code
        last_pic_va = va + i
        print('[%s] pic_va: %08x, enc_va: %08x, enc_size: %08x' % (get_current_time(), last_pic_va, enc_va, enc_size))
        new_va = pe_smc_decrypt(ary, enc_va, enc_size, last_pic_va, counter_type)
        print('[%s] decrypt done, new_va: %08x' % (get_current_time(), new_va))
        if new_va == -1:
            break
        va = new_va
        decrypted += 1
    return


def test_decrypt():
    va = 0x00401629
    buf = load_file(DATA_DIR + 'CrackMe.exe')
    ary = bytearray(buf)
    pe_smc_decrypt_repeated(ary, va)
    save_file(DATA_DIR + 'CM_fix.exe', bytes(ary))
    return


test_decrypt()

2. 提取函数

4. z3求解


文章来源: https://bbs.pediy.com/thread-256696.htm
如有侵权请联系:admin#unsafe.sh