import struct import time from capstone import * from capstone.x86 import * DATA_DIR = 'D:\\work\\2019\\pediy_Q4\\6\\' def hex2bin(s): return s.decode('hex') def bin2hex(s): return s.encode('hex') def load_file(filename): f = open(filename, 'rb') s = f.read() f.close() return s def save_file(filename, s): f = open(filename, 'wb') f.write(s) f.close() return def rol8(v, n): n &= 7 if n == 0: return v return ((v << n) | (v >> (8 - n))) & 0xFF def ror8(v, n): n &= 7 if n == 0: return v return ((v >> n) | (v << (8 - n))) & 0xFF def unwrap_u32(s, offset): return struct.unpack('<I', s[offset:offset+4])[0] def unwrap_u8(s, offset): return struct.unpack('<B', s[offset:offset+1])[0] def va_to_offset(va): return va - (0x401000 - 0x400) def offset_to_va(offset): return offset + (0x401000 - 0x400) class Pattern(object): def __init__(self, offset, size, pattern=''): self.offset = offset self.size = size self.pattern = pattern def place_holder(self): return self.pattern == '' def __str__(self): if not self.place_holder(): return 'PatternInfo(offset:%d, hex:%s)' % (self.offset, bin2hex(self.pattern)) return 'PatternInfo(offset:%d, size:%d)' % (self.offset, self.size) def __repr__(self): return str(self) class PatternObject(object): def __init__(self, pattern=''): self.patterns = [] offset = 0 while True: i = pattern.find('*', offset) if i == -1: # right right_pattern = hex2bin(pattern[offset:]) self.patterns.append(Pattern(offset / 2, len(right_pattern), right_pattern)) break # left left_pattern = hex2bin(pattern[offset:i]) self.patterns.append(Pattern(offset / 2, len(left_pattern), left_pattern)) # middle k = i + 1 while pattern[k] == '*': k += 1 self.patterns.append(Pattern(i / 2, (k - i) / 2)) offset = k def first_pattern(self): return self.patterns[0].pattern def match(self, buf, offset): i = offset for pat in self.patterns: if pat.place_holder(): i += pat.size elif pat.pattern == buf[i:i+pat.size]: i += pat.size else: return False return True def find_pattern(buf, pattern, offset=0): pat = PatternObject(pattern) size = len(buf) while offset < size: # find first pattern i = buf.find(pat.first_pattern(), offset, offset + size) if i == -1: return -1 if pat.match(buf, i): return i offset += 1 return -1 class Instruction(object): def __init__(self, address=0, imm=0): self.address = address self.imm = imm return def get_address(self): return self.address def get_imm(self): return self.imm def do(self, v, counter): # type:(int, int) -> int return 0 class Instruction8(Instruction): def get_imm(self): return self.imm & 0xFF class InstructionMov8(Instruction8): def __init__(self, address, imm): Instruction.__init__(self, address=address, imm=imm) def __repr__(self): return '%08X: mov al,%02x' % (self.get_address(), self.get_imm()) def do(self, v, counter): return self.get_imm() class InstructionNeg8(Instruction8): def __init__(self, address): Instruction.__init__(self, address=address) def __repr__(self): return '%08X: neg al' % self.get_address() def do(self, v, counter): return 0 - v class InstructionNot8(Instruction8): def __init__(self, address): Instruction.__init__(self, address=address) def __repr__(self): return '%08X: not al' % self.get_address() def do(self, v, counter): return ~v class InstructionSub8(Instruction8): def __init__(self, address, imm): Instruction.__init__(self, address=address, imm=imm) def __repr__(self): return '%08X: sub al,%02X' % (self.get_address(), self.get_imm()) def do(self, v, counter): return v - self.get_imm() class InstructionSub8Counter(Instruction8): def __init__(self, address): Instruction.__init__(self, address=address) def __repr__(self): return '%08X: sub al,cl' % self.get_address() def do(self, v, counter): return v - counter class InstructionAdd8(Instruction8): def __init__(self, address, imm): Instruction.__init__(self, address=address, imm=imm) def __repr__(self): return '%08X: add al,%02X' % (self.get_address(), self.get_imm()) def do(self, v, counter): return v + self.get_imm() class InstructionAdd8Counter(Instruction8): def __init__(self, address): Instruction.__init__(self, address=address) def __repr__(self): return '%08X: add al,cl' % self.get_address() def do(self, v, counter): return v + counter class InstructionXor8(Instruction8): def __init__(self, address, imm): Instruction.__init__(self, address=address, imm=imm) def __repr__(self): return '%08X: xor al,%02X' % (self.get_address(), self.get_imm()) def do(self, v, counter): return v ^ self.get_imm() class InstructionXor8Counter(Instruction8): def __init__(self, address): Instruction.__init__(self, address=address) def __repr__(self): return '%08X: xor al,cl' % self.get_address() def do(self, v, counter): return v ^ counter class InstructionMul8Counter(Instruction8): def __init__(self, address): Instruction.__init__(self, address=address) def __repr__(self): return '%08X: mul cl' % self.get_address() def do(self, v, counter): return v * counter class InstructionXor8Expression(Instruction8): def __init__(self, address, expr): Instruction.__init__(self, address=address) self.expr = expr def __repr__(self): s = '' s += '%08X: xor al,bl' % self.get_address() for ins in self.expr: s += '\n\t%s' % ins return s def do(self, v, counter): t = 0 for ins in self.expr: # type:Instruction t = ins.do(t, counter) & 0xFF return v ^ t class InstructionRor8(Instruction8): def __init__(self, address, imm): Instruction.__init__(self, address=address, imm=imm) def __repr__(self): return '%08X: ror al,%02X' % (self.get_address(), self.get_imm()) def do(self, v, counter): return ror8(v, self.get_imm()) class InstructionRor8Counter(Instruction8): def __init__(self, address): Instruction.__init__(self, address=address) def __repr__(self): return '%08X: ror al,cl' % self.get_address() def do(self, v, counter): return ror8(v, counter) class InstructionRol8(Instruction8): def __init__(self, address, imm): Instruction.__init__(self, address=address, imm=imm) def __repr__(self): return '%08X: rol al,%02X' % (self.get_address(), self.get_imm()) def do(self, v, counter): return rol8(v, self.get_imm()) class InstructionRol8Counter(Instruction8): def __init__(self, address): Instruction.__init__(self, address=address) def __repr__(self): return '%08X: rol al,cl' % self.get_address() def do(self, v, counter): return rol8(v, counter) def pe_get_code_partial(ary, va, size=0x800): offset = va_to_offset(va) buf = bytes(ary[offset:offset+size]) return buf def simplify_inst(ins_ary): # return ins_ary if len(ins_ary) == 0: return ins_ary simplified_ary = [] for old_ins in ins_ary: if len(simplified_ary) == 0: simplified_ary.append(old_ins) continue ins = simplified_ary[-1] if isinstance(ins, InstructionAdd8) and isinstance(old_ins, InstructionAdd8): simplified_ary[-1] = InstructionAdd8(ins.get_address(), ins.get_imm() + old_ins.get_imm()) continue if isinstance(ins, InstructionAdd8) and isinstance(old_ins, InstructionSub8): simplified_ary[-1] = InstructionAdd8(ins.get_address(), ins.get_imm() - old_ins.get_imm()) continue if isinstance(ins, InstructionSub8) and isinstance(old_ins, InstructionAdd8): simplified_ary[-1] = InstructionSub8(ins.get_address(), ins.get_imm() - old_ins.get_imm()) continue if isinstance(ins, InstructionSub8) and isinstance(old_ins, InstructionSub8): simplified_ary[-1] = InstructionSub8(ins.get_address(), ins.get_imm() + old_ins.get_imm()) continue if isinstance(ins, InstructionXor8) and isinstance(old_ins, InstructionXor8): simplified_ary[-1] = InstructionXor8(ins.get_address(), ins.get_imm() ^ old_ins.get_imm()) continue if isinstance(ins, InstructionRol8) and isinstance(old_ins, InstructionRol8): simplified_ary[-1] = InstructionRol8(ins.get_address(), ins.get_imm() + old_ins.get_imm()) continue if isinstance(ins, InstructionRor8) and isinstance(old_ins, InstructionRor8): simplified_ary[-1] = InstructionRor8(ins.get_address(), ins.get_imm() + old_ins.get_imm()) continue if isinstance(ins, InstructionRol8) and isinstance(old_ins, InstructionRor8): v1 = ins.get_imm() v2 = old_ins.get_imm() if v1 > v2: simplified_ary[-1] = InstructionRol8(ins.get_address(), v1 - v2) else: simplified_ary[-1] = InstructionRor8(ins.get_address(), v2 - v1) continue if isinstance(ins, InstructionRor8) and isinstance(old_ins, InstructionRol8): v1 = ins.get_imm() v2 = old_ins.get_imm() if v1 > v2: simplified_ary[-1] = InstructionRor8(ins.get_address(), v1 - v2) else: simplified_ary[-1] = InstructionRol8(ins.get_address(), v2 - v1) continue simplified_ary.append(old_ins) return simplified_ary def pe_decrypt_code_partial(ary, va, size, ins_ary, counter_type=0): offset = va_to_offset(va) # print('counter_type: %d' % counter_type) for ins in ins_ary: print('%s' % ins) simplified_ary = simplify_inst(ins_ary) # print('simplified') # for ins in simplified_ary: # print('%s' % ins) for i in range(size): v = ary[offset + i] if counter_type == 0: counter = size - i elif counter_type == 1: counter = i + 1 elif counter_type == 2: counter = size - 1 - i else: counter = 0 counter &= 0xFF for ins in simplified_ary: # type: Instruction v = ins.do(v, counter) v &= 0xFF ary[offset + i] = v return patch_enc_va_info = dict() def pe_smc_decrypt(ary, enc_va, enc_size, initial_va, counter_type): va = initial_va ins_ary = [] md = Cs(CS_ARCH_X86, CS_MODE_32) md.detail = True # search for decrypt start t = -1 for i in md.disasm(pe_get_code_partial(ary, va, size=0x80), va): # type:CsInsn if i.id == X86_INS_LODSB: # print('%08X: lodsb' % i.address) t = i.address + i.size break if t == -1: return -1 va = t # begin(inclusive), end jmp_ary = [] # push ecx; mov cl,0xA7; xor al,cl; pop ecx ecx_pushed = False ecx_popped = False ecx_imm = 0 # TODO need inst_ary? like ebx? # push eax; op eax; ...; mov ebx,eax; pop eax eax_pushed = False eax_popped = False ebx_inst_ary = [] while va < (initial_va + 0x800): for i in md.disasm(pe_get_code_partial(ary, va, size=0x80), va): # type:CsInsn va = i.address + i.size # skip code in jmp area in_jmp_area = False for begin_address, end_address in jmp_ary: if begin_address <= i.address < end_address: in_jmp_area = True break if in_jmp_area: continue # print("%x:\t%s\t%s" % (i.address, i.mnemonic, i.op_str)) if len(i.operands) == 1: op0 = i.operands[0] # type: X86Op if i.id == X86_INS_JMP and op0.type == X86_OP_IMM: # print('jmp: %08X-%08X' % (i.address + i.size, op0.imm)) jmp_ary.append((i.address + i.size, op0.imm)) continue if len(i.operands) == 0: if i.id == X86_INS_PUSHAL: eax_pushed = True elif i.id == X86_INS_POPAL: eax_popped = True if len(i.operands) == 1: op0 = i.operands[0] # type: X86Op if i.id == X86_INS_PUSH and op0.type == X86_OP_REG: if op0.reg == X86_REG_EAX: eax_pushed = True elif op0.reg == X86_REG_ECX: ecx_pushed = True continue if i.id == X86_INS_POP and op0.type == X86_OP_REG: if op0.reg == X86_REG_EAX: eax_popped = True elif op0.reg == X86_REG_ECX: ecx_popped = True continue if eax_pushed and (not eax_popped): if len(i.operands) == 2: op0 = i.operands[0] # type: X86Op op1 = i.operands[1] # type:X86Op if op0.type == X86_OP_REG and (op0.reg in (X86_REG_AL, X86_REG_EAX)): # OP EAX/AL, IMM/REG if op1.type == X86_OP_IMM: if i.id == X86_INS_MOV: ebx_inst_ary.append(InstructionMov8(i.address, op1.imm)) elif i.id == X86_INS_SUB: ebx_inst_ary.append(InstructionSub8(i.address, op1.imm)) elif i.id == X86_INS_ADD: ebx_inst_ary.append(InstructionAdd8(i.address, op1.imm)) elif i.id == X86_INS_XOR: ebx_inst_ary.append(InstructionXor8(i.address, op1.imm)) elif i.id == X86_INS_ROL: ebx_inst_ary.append(InstructionRol8(i.address, op1.imm)) elif i.id == X86_INS_ROR: ebx_inst_ary.append(InstructionRor8(i.address, op1.imm)) # OP EAX/AL, ECX/CL elif op1.type == X86_OP_REG and (op1.reg in (X86_REG_CL, X86_REG_ECX)): if i.id == X86_INS_SUB: ebx_inst_ary.append(InstructionSub8Counter(i.address)) elif i.id == X86_INS_ADD: ebx_inst_ary.append(InstructionAdd8Counter(i.address)) elif i.id == X86_INS_XOR: ebx_inst_ary.append(InstructionXor8Counter(i.address)) elif i.id == X86_INS_ROL: ebx_inst_ary.append(InstructionRol8Counter(i.address)) elif i.id == X86_INS_ROR: ebx_inst_ary.append(InstructionRor8Counter(i.address)) elif len(i.operands) == 1: op0 = i.operands[0] # type: X86Op if op0.type == X86_OP_REG and (op0.reg in (X86_REG_AL, X86_REG_EAX)): if i.id == X86_INS_NEG: ebx_inst_ary.append(InstructionNeg8(i.address)) elif i.id == X86_INS_NOT: ebx_inst_ary.append(InstructionNot8(i.address)) if i.id == X86_INS_MUL and op0.type == X86_OP_REG and (op0.reg in (X86_REG_CL, X86_REG_ECX)): ebx_inst_ary.append(InstructionMul8Counter(i.address)) continue # until we meet pop eax if len(i.operands) == 2: op0 = i.operands[0] # type: X86Op op1 = i.operands[1] # type:X86Op # mov [esi-1],al (CLD) # mov [esi+1],al (STD) if i.id == X86_INS_MOV and op0.type == X86_OP_MEM and op0.mem.base == X86_REG_ESI and op1.type == X86_OP_REG and op1.reg == X86_REG_AL: # collect decrypt instruction finished if len(ins_ary) == 0: return -1 if op0.mem.disp == 1: counter_type = 1 tmp_ins = ins_ary[-1] if isinstance(tmp_ins, InstructionRor8) and tmp_ins.get_imm() == 0xD8: if enc_va in patch_enc_va_info: patch_enc_va_info[enc_va] += 1 else: patch_enc_va_info[enc_va] = 1 if enc_va in patch_enc_va_info and patch_enc_va_info[enc_va] == 3: ary[va_to_offset(enc_va)] = 0xE8 # ugly hack return va pe_decrypt_code_partial(ary, enc_va, enc_size, ins_ary, counter_type=counter_type) return va elif i.id == X86_INS_MOV and op0.type == X86_OP_REG and op0.reg == X86_REG_CL: # mov cl, A7 # mov cl, ah if not ecx_pushed: continue if op1.type == X86_OP_IMM: ecx_imm = op1.imm elif op1.type == X86_OP_REG and op1.reg == X86_REG_AH: ecx_imm = 0 elif op0.type == X86_OP_REG and (op0.reg in (X86_REG_AL, X86_REG_EAX)): # OP EAX/AL, IMM/REG if op1.type == X86_OP_IMM: if i.id == X86_INS_SUB: ins_ary.append(InstructionSub8(i.address, op1.imm)) elif i.id == X86_INS_ADD: ins_ary.append(InstructionAdd8(i.address, op1.imm)) elif i.id == X86_INS_XOR: ins_ary.append(InstructionXor8(i.address, op1.imm)) elif i.id == X86_INS_ROL: ins_ary.append(InstructionRol8(i.address, op1.imm)) elif i.id == X86_INS_ROR: ins_ary.append(InstructionRor8(i.address, op1.imm)) # OP EAX/AL, EBX/BL elif op1.type == X86_OP_REG and (op1.reg in (X86_REG_BL,)): if i.id == X86_INS_XOR: ins_ary.append(InstructionXor8Expression(i.address, ebx_inst_ary)) # OP EAX/AL, ECX/CL elif op1.type == X86_OP_REG and (op1.reg in (X86_REG_CL, X86_REG_ECX)): if ecx_pushed and (not ecx_popped): if i.id == X86_INS_SUB: ins_ary.append(InstructionSub8(i.address, ecx_imm)) elif i.id == X86_INS_ADD: ins_ary.append(InstructionAdd8(i.address, ecx_imm)) elif i.id == X86_INS_XOR: ins_ary.append(InstructionXor8(i.address, ecx_imm)) elif i.id == X86_INS_ROL: ins_ary.append(InstructionRol8(i.address, ecx_imm)) elif i.id == X86_INS_ROR: ins_ary.append(InstructionRor8(i.address, ecx_imm)) else: if i.id == X86_INS_SUB: ins_ary.append(InstructionSub8Counter(i.address)) elif i.id == X86_INS_ADD: ins_ary.append(InstructionAdd8Counter(i.address)) elif i.id == X86_INS_XOR: ins_ary.append(InstructionXor8Counter(i.address)) elif i.id == X86_INS_ROL: ins_ary.append(InstructionRol8Counter(i.address)) elif i.id == X86_INS_ROR: ins_ary.append(InstructionRor8Counter(i.address)) elif len(i.operands) == 1: op0 = i.operands[0] # type: X86Op if op0.type == X86_OP_REG and op0.reg in (X86_REG_AL, X86_REG_EAX): if i.id == X86_INS_NEG: ins_ary.append(InstructionNeg8(i.address)) elif i.id == X86_INS_NOT: ins_ary.append(InstructionNot8(i.address)) return -1 def get_current_time(): now = int(time.time()) t = time.localtime(now) return time.strftime("%Y-%m-%d %H:%M:%S", t) def pe_smc_decrypt_repeated(ary, va_start, va_end=0x506000): # type:(str, int, int) -> None decrypted = 0 va = va_start while va < va_end: # it's repeated-smc, no point to get full code buf = pe_get_code_partial(ary, va) # call $+5; pop esi; sub esi,imm1; add esi,imm2; mov ecx,imm3 # call $+5; pop esi; sub esi,imm1; add esi,imm2; jmp xx; mov ecx,imm3 pat = 'E8000000005E81EE******0081C6******00' i = find_pattern(buf, pat) if i == -1: break jmp_offset = i + len(pat)/2 if buf[jmp_offset] == hex2bin('EB'): # jmp $+XX mov_offset = jmp_offset + unwrap_u8(buf, jmp_offset+1) + 2 else: mov_offset = jmp_offset # mov ecx,imm3 if buf[mov_offset] != hex2bin('B9') or buf[mov_offset+4] != hex2bin('00'): va += jmp_offset continue ip = i + 5 enc_va = (va + ip) + unwrap_u32(buf, ip + 9) - unwrap_u32(buf, ip + 3) enc_size = unwrap_u32(buf, mov_offset + 1) # sub ecx,1 sub_ecx_offset = mov_offset + 5 while buf[sub_ecx_offset] == hex2bin('EB'): sub_ecx_offset = sub_ecx_offset + unwrap_u8(buf, sub_ecx_offset + 1) + 2 if buf[sub_ecx_offset:sub_ecx_offset+3] == hex2bin('83E901'): counter_type = 2 else: counter_type = 0 # position independent code last_pic_va = va + i print('[%s] pic_va: %08x, enc_va: %08x, enc_size: %08x' % (get_current_time(), last_pic_va, enc_va, enc_size)) new_va = pe_smc_decrypt(ary, enc_va, enc_size, last_pic_va, counter_type) print('[%s] decrypt done, new_va: %08x' % (get_current_time(), new_va)) if new_va == -1: break va = new_va decrypted += 1 return def test_decrypt(): va = 0x00401629 buf = load_file(DATA_DIR + 'CrackMe.exe') ary = bytearray(buf) pe_smc_decrypt_repeated(ary, va) save_file(DATA_DIR + 'CM_fix.exe', bytes(ary)) return test_decrypt()
2. 提取函数
4. z3求解