import subprocess
import time
import argparse
import os
import re

def check_system_status():
    """Check PCIe device, XDMA driver, and Gen3 x4 link status."""
    print("Checking system status...")
    
    # Check lspci for Xilinx device and Gen3 x4
    try:
        lspci_output = subprocess.run("lspci -d 10ee: -vv", shell=True, capture_output=True, text=True, check=True).stdout
        print(f"PCIe Devices:\n{lspci_output.strip()}")
        
        # Check for Gen3 x4 link
        link_status = re.search(r"LnkSta:\s*Speed (\S+), Width (\S+)", lspci_output)
        if link_status:
            speed, width = link_status.groups()
            if speed != "8GT/s" or width != "x4":
                print(f"Warning: PCIe link is not Gen3 x4 (Speed: {speed}, Width: {width}). Expected 8GT/s, x4.")
        else:
            print("Warning: Could not parse PCIe link status.")
        
        # Check BAR mapping
        bar_status = re.search(r"Region 0: Memory at \S+ \(64-bit, non-prefetchable\)", lspci_output)
        if not bar_status:
            print("Warning: PCIe BAR0 not properly mapped.")
    except subprocess.CalledProcessError as e:
        print(f"Error checking lspci: {e}\nStderr: {e.stderr}")
    
    # Check dmesg for XDMA logs
    try:
        dmesg_output = subprocess.run("dmesg | grep xdma", shell=True, capture_output=True, text=True, check=True).stdout
        if dmesg_output:
            print(f"XDMA Driver Logs:\n{dmesg_output.strip()}")
        else:
            print("No XDMA driver logs found in dmesg.")
    except subprocess.CalledProcessError as e:
        print(f"Error checking dmesg: {e}\nStderr: {e.stderr}")
    
    # Check device nodes
    h2c_exists = os.path.exists("/dev/xdma0_h2c_0")
    c2h_exists = os.path.exists("/dev/xdma0_c2h_0")
    if not (h2c_exists and c2h_exists):
        print("Error: Xilinx DMA device nodes not found.")
        if not h2c_exists:
            print(" - /dev/xdma0_h2c_0 missing")
        if not c2h_exists:
            print(" - /dev/xdma0_c2h_0 missing")
        return False
    return True

def run_transfer(transfer_type, transfer_size, iterations):
    """Run H2C or C2H transfer and calculate throughput."""
    device = f"/dev/xdma0_{transfer_type}_0"
    total_time = 0
    total_bytes = transfer_size * iterations
    
    print(f"\nStarting {transfer_type.upper()} transfer: {iterations} iterations, {transfer_size / (1024 * 1024):.2f} MB each")
    
    for i in range(iterations):
        start_time = time.time()
        try:
            if transfer_type == 'h2c':
                cmd = f"dd if=/dev/zero of={device} bs={transfer_size} count=1"
            else:
                cmd = f"dd if={device} of=/dev/null bs={transfer_size} count=1"
            
            result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
            elapsed = time.time() - start_time
            total_time += elapsed
            
            throughput = (transfer_size / elapsed) / (1024 * 1024)  # MB/s
            print(f"Iteration {i+1}/{iterations}: {throughput:.2f} MB/s")
            if result.stderr:
                print(f"Command stderr: {result.stderr.strip()}")
                
        except subprocess.CalledProcessError as e:
            print(f"Error during {transfer_type.upper()} transfer: {e}")
            print(f"Command: {e.cmd}")
            print(f"Return code: {e.returncode}")
            print(f"Stderr: {e.stderr}")
            return False
    
    avg_throughput = (total_bytes / total_time) / (1024 * 1024)  # MB/s
    print(f"\nAverage {transfer_type.upper()} Throughput: {avg_throughput:.2f} MB/s")
    
    # Check if throughput is below Gen3 x4 expectations
    expected_min = 3000  # MB/s (minimum for Gen3 x4)
    if avg_throughput < expected_min * 0.7:
        print(f"Warning: Throughput significantly below expected ~{expected_min} MB/s for PCIe Gen3 x4.")
    
    return True

def main():
    parser = argparse.ArgumentParser(description="PCIe Gen3 x4 Throughput Test for Numato Aller Artix UltraScale+")
    parser.add_argument("--size", type=float, default=100.0, help="Transfer size in MB (default: 100)")
    parser.add_argument("--iterations", type=int, default=10, help="Number of iterations (default: 10)")
    parser.add_argument("--type", choices=['h2c', 'c2h', 'both'], default='both', help="Transfer type: h2c, c2h, or both (default: both)")
    args = parser.parse_args()
    
    # Convert size from MB to bytes
    transfer_size = int(args.size * 1024 * 1024)
    
    # Check system status
    if not check_system_status():
        print("System check failed. Please ensure XDMA driver is loaded and FPGA is configured for PCIe Gen3 x4.")
        return
    
    # Run transfers
    if args.type in ['h2c', 'both']:
        if not run_transfer('h2c', transfer_size, args.iterations):
            print("H2C transfer failed.")
            return
    
    if args.type in ['c2h', 'both']:
        if not run_transfer('c2h', transfer_size, args.iterations):
            print("C2H transfer failed.")
            return

if __name__ == "__main__":
    main()