#!/usr/bin/env python3
"""
GenX UDS Flashing Script
ANCIT Consulting - SmartWheels GenX

Flashes a compiled .hex file to an ECU using the ANCIT UDS Bootloader
via the standard UDS programming sequence over CAN.

UDS Sequence:
    1.  DiagnosticSessionControl  (0x10 0x02) - enter programming session
    2.  SecurityAccess            (0x27 0x05) - seed/key via GenerateKeyEx.dll
    3.  RoutineControl - Erase    (0x31 0x01 0xF2 0x00) - erase flash
    4.  RequestDownload           (0x34)      - specify address and size
    5.  TransferData              (0x36)      - send image in blocks
    6.  RequestTransferExit       (0x37)      - end transfer
    7.  RoutineControl - CRC      (0x31 0x01 0x02 0x01) - post-program CRC check
    8.  RoutineControl - GetResult(0x31 0x03 0x02 0x01) - get CRC result
    9.  ECUReset                  (0x11 0x01) - reset ECU

Requirements:
    pip install python-can udsoncan==1.25.2 can-isotp intelhex pyserial
    For gs_usb: pip install gs-usb

Usage:
    python genx_uds_flash.py
"""

import ctypes
import os
import time

import can
import isotp
import udsoncan
import udsoncan.configs
from udsoncan.client import Client
from udsoncan.connections import PythonIsoTpConnection
from udsoncan.services import DiagnosticSessionControl, ECUReset
from udsoncan import MemoryLocation
from intelhex import IntelHex

# =============================================================================
# CONFIGURATION - edit these values for your setup
# =============================================================================

CAN_INTERFACE = "pcan"
CAN_CHANNEL   = "PCAN_USBBUS1"
CAN_BITRATE   = 500000

# ISO-TP addressing (ECU physical addressing)
ECU_TX_ID = 0x7E0         # Tester -> ECU
ECU_RX_ID = 0x7E8         # ECU -> Tester

# File paths
HEX_FILE_PATH     = r"E:\\WorkSpace\\nxp_New_Demo\\JSW_UDS\\Debug_FLASH\\JSW_UDS.hex"
SECURITY_DLL_PATH = r"E:\\WorkSpace\\nxp_New_Demo\\JSW_UDS\\GenerateKeyEx_64.dll"

# UDS Security Access level (must match bootloader configuration)
SECURITY_LEVEL = 0x05     # service 0x27 subfunction 0x05

# Transfer block size in bytes per TransferData request
BLOCK_SIZE = 256

# =============================================================================


def create_can_bus():
    """Create CAN bus — handles all supported interface types."""
    if CAN_INTERFACE == "gs_usb":
        from gs_usb.gs_usb import GsUsb
        devs = GsUsb.scan()
        if not devs:
            raise RuntimeError("gs_usb device not found - check USB connection")
        return can.interface.Bus(bustype="gs_usb", channel=devs[0], index=0, bitrate=CAN_BITRATE)
    else:
        return can.interface.Bus(channel=CAN_CHANNEL, bustype=CAN_INTERFACE, bitrate=CAN_BITRATE)


def load_security_dll(dll_path):
    """Load GenerateKeyEx.dll and configure its function signature."""
    if not os.path.isfile(dll_path):
        raise FileNotFoundError(f"Security DLL not found: {dll_path}")
    dll = ctypes.CDLL(dll_path)
    dll.GenerateKeyEx.argtypes = [
        ctypes.POINTER(ctypes.c_ubyte),  # ipSeedArray
        ctypes.c_uint,                   # iSeedArraySize
        ctypes.c_uint,                   # iSecurityLevel
        ctypes.c_char_p,                 # iVariant
        ctypes.POINTER(ctypes.c_ubyte),  # iopKeyArray (out)
        ctypes.c_uint,                   # iMaxKeyArraySize
        ctypes.POINTER(ctypes.c_uint),   # oActualKeyArraySize (out)
    ]
    dll.GenerateKeyEx.restype = ctypes.c_int
    return dll


def compute_key(dll, seed_bytes, level):
    """Compute the security key from seed using GenerateKeyEx algorithm."""
    seed_arr = (ctypes.c_ubyte * len(seed_bytes))(*seed_bytes)
    key_buf  = (ctypes.c_ubyte * 64)()
    key_size = ctypes.c_uint(0)
    result = dll.GenerateKeyEx(
        seed_arr, len(seed_bytes),
        level, b"",
        key_buf, 64,
        ctypes.byref(key_size)
    )
    if result != 0:
        raise RuntimeError(f"GenerateKeyEx returned error code {result}")
    return bytes(key_buf[:key_size.value])


def load_hex_chunks(hex_path, block_size):
    """Parse Intel HEX file, return list of (address, data) chunks."""
    if not os.path.isfile(hex_path):
        raise FileNotFoundError(f"HEX file not found: {hex_path}")
    ih = IntelHex(hex_path)
    segments = ih.segments()
    chunks = []
    for start, end in segments:
        data = ih.tobinarray(start=start, size=(end - start))
        for offset in range(0, len(data), block_size):
            chunk = bytes(data[offset:offset + block_size])
            chunks.append((start + offset, chunk))
    total = sum(end - start for start, end in segments)
    return chunks, ih.minaddr(), total


def crc16(data, poly=0x1021, init_crc=0xFFFF, xor_out=0x0000):
    """Calculate CRC16-CCITT checksum."""
    crc = init_crc
    for byte in data:
        crc ^= byte << 8
        for _ in range(8):
            if crc & 0x8000:
                crc = (crc << 1) ^ poly
            else:
                crc <<= 1
            crc &= 0xFFFF
    return crc ^ xor_out


def flash(hex_path=HEX_FILE_PATH, dll_path=SECURITY_DLL_PATH):
    print("=" * 60)
    print("  ANCIT UDS Bootloader - Flashing Sequence")
    print("=" * 60)
    print(f"  HEX  : {hex_path}")
    print(f"  DLL  : {dll_path}")
    print(f"  CAN  : {CAN_INTERFACE} / {CAN_CHANNEL if CAN_INTERFACE != 'gs_usb' else 'gs_usb (auto)'} @ {CAN_BITRATE} bps")

    dll = load_security_dll(dll_path)
    chunks, start_addr, total_bytes = load_hex_chunks(hex_path, BLOCK_SIZE)

    ih = IntelHex(hex_path)
    all_hex_bytes = ih.tobinarray(start=start_addr, size=total_bytes)

    print(f"  Image: {total_bytes} bytes, base 0x{start_addr:08X}, {len(chunks)} blocks")
    print()

    bus = create_can_bus()
    bus.set_filters([{"can_id": ECU_RX_ID, "can_mask": 0x7FF, "extended": False}])
    tp_addr = isotp.Address(isotp.AddressingMode.Normal_11bits, txid=ECU_TX_ID, rxid=ECU_RX_ID)
    stack = isotp.CanStack(bus, address=tp_addr, params={"stmin": 0, "blocksize": 0})
    conn  = PythonIsoTpConnection(stack)

    config = udsoncan.configs.default_client_config.copy()
    config["data_identifiers"] = {}

    try:
        with Client(conn, request_timeout=5, config=config) as client:

            # Step 1: Programming Session
            print("[1/9] DiagnosticSessionControl - Programming session (0x10 0x02)")
            client.change_session(DiagnosticSessionControl.Session.programmingSession)
            time.sleep(0.1)

            # Step 2: Security Access (AES 128, level 0x05)
            print("[2/9] SecurityAccess - Requesting seed (0x27 0x05)")
            seed_resp = client.request_seed(SECURITY_LEVEL)
            seed = seed_resp.service_data.seed
            key  = compute_key(dll, seed, SECURITY_LEVEL)
            print(f"      Seed : {seed.hex().upper()}")
            print(f"      Key  : {key.hex().upper()}")
            client.send_key(SECURITY_LEVEL, key)
            print("      Security access granted")

            # Step 3: Routine Control - Erase Flash
            print("[3/9] RoutineControl - Erase flash (0x31 0x01 0xF2 0x00)")
            client.start_routine(0xF200)
            time.sleep(5.0)
            print("      Erase complete")

            # Step 4: Request Download
            print(f"[4/9] RequestDownload - 0x{start_addr:08X}, {total_bytes} bytes (0x34)")
            mem_loc = MemoryLocation(
                address=start_addr,
                memorysize=total_bytes,
                address_format=32,
                memorysize_format=32
            )
            client.request_download(mem_loc)

            # Step 5: Transfer Data
            print(f"[5/9] TransferData - {len(chunks)} blocks (0x36)")
            for idx, (addr, data) in enumerate(chunks, start=1):
                client.transfer_data(idx & 0xFF, data)
                pct = idx * 100 // len(chunks)
                bar = "#" * (pct // 5) + "-" * (20 - pct // 5)
                print(f"\r      [{bar}] {pct:3d}%  block {idx}/{len(chunks)}", end="", flush=True)
            print("\n      Transfer complete")

            # Step 6: Transfer Exit
            print("[6/9] RequestTransferExit (0x37)")
            client.request_transfer_exit()

            # Step 7: Routine Control - CRC Validation
            # Payload: 4 bytes start address + 4 bytes end address + 2 bytes CRC16
            print("[7/9] RoutineControl - CRC validation (0x31 0x01 0x02 0x01)")
            end_addr = start_addr + total_bytes - 1
            checksum = crc16(bytes(all_hex_bytes))
            print(f"      Start : 0x{start_addr:08X}")
            print(f"      End   : 0x{end_addr:08X}")
            print(f"      CRC16 : 0x{checksum:04X}")
            crc_payload = (
                start_addr.to_bytes(4, "big") +
                end_addr.to_bytes(4, "big")   +
                bytes([(checksum >> 8) & 0xFF,
                       checksum & 0xFF])
            )
            client.start_routine(0x0201, data=crc_payload)
            print("      CRC sent")

            # Step 8: Get CRC Result
            print("[8/9] RoutineControl - Get CRC result (0x31 0x03 0x02 0x01)")
            get_crc_response = client.get_routine_result(0x0201)
            status = get_crc_response.service_data.routine_status_record
            print(f"      Status: {status.hex().upper() if status else 'empty'}")
            if status and status[0] == 0x01:
                print("      CRC Verification SUCCESS")
            else:
                raise RuntimeError(
                    f"CRC Verification FAILED - aborting before reset. "
                    f"Status: {status.hex().upper() if status else 'no data'}"
                )

            # Step 9: ECU Reset
            print("[9/9] ECUReset - Hard reset (0x11 0x01)")
            client.ecu_reset(ECUReset.ResetType.hardReset)
            print("      ECU reset complete")

    finally:
        bus.shutdown()

    print()
    print("=" * 60)
    print("  Flashing complete.")
    print("=" * 60)


if __name__ == "__main__":
    flash()
