#!/bin/bash

# 日志函数
log.warn() { echo -e "[\e[33mWARN\e[0m]:  \e[1m$*\e[0m"; }
log.error()  { echo -e "[\e[31mERROR\e[0m]: \e[1m$*\e[0m"; }
log.info() { echo -e "[\e[96mINFO\e[0m]:  \e[1m$*\e[0m"; }
log.debug()  { echo -e "[\e[32mDEBUG\e[0m]: \e[1m$*\e[0m"; }

# 检查权限
if [ "$UID" != "0" ]; then
    log.error "需要以root权限运行 Need to be run as root."
    exit 1
fi

# 检查参数
if [ -z "$1" ]; then
    log.error "需要把ace-env所在的路径设置为第一个参数"
    exit 1
fi

# 辅助函数
trim() {
    local str="$1"
    str="${str#"${str%%[![:space:]]*}"}"  # 移除前导空格
    str="${str%"${str##*[![:space:]]}"}"  # 移除尾部空格
    echo "$str"
}

is_regular_or_symlink() {
    local path="$1"
    if [ -f "$path" ] || [ -L "$path" ]; then
        return 0
    fi
    return 1
}

is_directory() {
    local path="$1"
    if [ -d "$path" ]; then
        return 0
    fi
    return 1
}

is_char_or_block_device() {
    local path="$1"
    if [ -c "$path" ] || [ -b "$path" ]; then
        return 0
    fi
    return 1
}

read_driver_version() {
    # 尝试从/sys目录读取
    if [ -f "/sys/module/nvidia/version" ]; then
        local version=$(cat "/sys/module/nvidia/version" 2>/dev/null)
        version=$(trim "$version")
        if [ -n "$version" ]; then
            echo "$version"
            return 0
        fi
    fi
    
    # 尝试从/proc目录读取
    if [ -f "/proc/driver/nvidia/version" ]; then
        local proc_version=$(cat "/proc/driver/nvidia/version" 2>/dev/null)
        # 提取版本号 (格式: 535.86.05)
        if [[ "$proc_version" =~ [0-9]+\.[0-9]+\.[0-9]+ ]]; then
            echo "${BASH_REMATCH[0]}"
            return 0
        fi
    fi
    
    return 1
}

version_from_filename() {
    local filename="$1"
    local prefix="$2"
    
    if [[ "$filename" != "$prefix"* ]]; then
        return 1
    fi
    
    local version="${filename#$prefix}"
    if [[ -z "$version" ]]; then
        return 1
    fi
    
    echo "$version"
    return 0
}

compare_versions() {
    local ver1="$1"
    local ver2="$2"
    
    # 分割版本号
    IFS='.' read -r -a v1_parts <<< "$ver1"
    IFS='.' read -r -a v2_parts <<< "$ver2"
    
    local max_len=$(( ${#v1_parts[@]} > ${#v2_parts[@]} ? ${#v1_parts[@]} : ${#v2_parts[@]} ))
    
    for ((i=0; i<max_len; i++)); do
        local v1=${v1_parts[i]:-0}
        local v2=${v2_parts[i]:-0}
        
        if (( v1 < v2 )); then
            echo "-1"
            return
        elif (( v1 > v2 )); then
            echo "1"
            return
        fi
    done
    
    echo "0"
}

collect_files() {
    local search_dirs=("${!1}")
    local pattern="$2"
    local results=()
    
    for dir in "${search_dirs[@]}"; do
        if [ ! -d "$dir" ]; then
            continue
        fi
        
        # 使用find搜索文件
        while IFS= read -r -d '' file; do
            if is_regular_or_symlink "$file"; then
                results+=("$file")
            fi
        done < <(find "$dir" -name "$pattern" -type f 2>/dev/null | head -100)
    done
    
    # 去重
    local unique_results=()
    declare -A seen
    for file in "${results[@]}"; do
        local realpath=$(readlink -f "$file" 2>/dev/null || echo "$file")
        if [ -z "${seen[$realpath]}" ]; then
            seen["$realpath"]=1
            unique_results+=("$realpath")
        fi
    done
    
    echo "${unique_results[@]}"
}

select_best_versioned_lib() {
    local files=("${!1}")
    local prefix="$2"
    local prefer_version="$3"
    
    local best=""
    local best_version=""
    
    for file in "${files[@]}"; do
        local filename=$(basename "$file")
        
        # 提取版本号
        local version_result=$(version_from_filename "$filename" "$prefix")
        if [ -z "$version_result" ]; then
            continue
        fi
        
        local version="$version_result"
        
        # 如果指定了首选版本，优先匹配
        if [ -n "$prefer_version" ] && [ "$version" != "$prefer_version" ]; then
            continue
        fi
        
        if [ -z "$best" ] || [ "$(compare_versions "$version" "$best_version")" -gt 0 ]; then
            best="$file"
            best_version="$version"
        fi
    done
    
    # 如果没有匹配到首选版本，选择最高版本
    if [ -z "$best" ] && [ -n "$prefer_version" ]; then
        for file in "${files[@]}"; do
            local filename=$(basename "$file")
            local version_result=$(version_from_filename "$filename" "$prefix")
            if [ -z "$version_result" ]; then
                continue
            fi
            
            local version="$version_result"
            
            if [ -z "$best" ] || [ "$(compare_versions "$version" "$best_version")" -gt 0 ]; then
                best="$file"
                best_version="$version"
            fi
        done
    fi
    
    echo "$best"
}

detect_driver_info() {
    local info=""
    
    # 读取驱动版本
    local driver_version=$(read_driver_version)
    if [ -z "$driver_version" ]; then
        driver_version=""
    fi
    
    # 默认库搜索路径
    local default_search_paths=(
        "/usr/lib64"
        "/usr/lib/x86_64-linux-gnu"
        "/usr/lib/i386-linux-gnu"
        "/usr/lib/aarch64-linux-gnu"
        "/usr/lib/x86_64-linux-gnu/nvidia/current"
        "/usr/lib/i386-linux-gnu/nvidia/current"
        "/usr/lib/aarch64-linux-gnu/nvidia/current"
        "/lib64"
        "/lib/x86_64-linux-gnu"
        "/lib/i386-linux-gnu"
        "/lib/aarch64-linux-gnu"
        "/lib/x86_64-linux-gnu/nvidia/current"
        "/lib/i386-linux-gnu/nvidia/current"
        "/lib/aarch64-linux-gnu/nvidia/current"
        "/usr/lib"
        "/lib"
    )
    
    # 搜索libcuda.so
    local cuda_files=($(collect_files default_search_paths[@] "libcuda.so.*"))
    local nvidia_ml_files=($(collect_files default_search_paths[@] "libnvidia-ml.so.*"))
    
    local selected_lib=""
    
    if [ ${#cuda_files[@]} -gt 0 ]; then
        selected_lib=$(select_best_versioned_lib cuda_files[@] "libcuda.so." "$driver_version")
    fi
    
    if [ -z "$selected_lib" ] && [ ${#nvidia_ml_files[@]} -gt 0 ]; then
        selected_lib=$(select_best_versioned_lib nvidia_ml_files[@] "libnvidia-ml.so." "$driver_version")
    fi
    
    local lib_dir=""
    if [ -n "$selected_lib" ]; then
        lib_dir=$(dirname "$selected_lib")
        
        # 如果还没有驱动版本，从文件名提取
        if [ -z "$driver_version" ]; then
            local filename=$(basename "$selected_lib")
            local cuda_version=$(version_from_filename "$filename" "libcuda.so.")
            local ml_version=$(version_from_filename "$filename" "libnvidia-ml.so.")
            
            if [ -n "$cuda_version" ]; then
                driver_version="$cuda_version"
            elif [ -n "$ml_version" ]; then
                driver_version="$ml_version"
            fi
        fi
    fi
    
    # 如果还没找到库目录，使用默认的
    if [ -z "$lib_dir" ]; then
        for path in "${default_search_paths[@]}"; do
            if is_directory "$path"; then
                lib_dir="$path"
                break
            fi
        done
    fi
    
    echo "$driver_version:$lib_dir"
}

read_elf_soname() {
    local file="$1"
    
    # 使用readelf读取SONAME
    if command -v readelf >/dev/null 2>&1; then
        local soname=$(readelf -d "$file" 2>/dev/null | grep -E "SONAME.*\[.*\]" | sed -E 's/.*\[(.*)\].*/\1/')
        if [ -n "$soname" ]; then
            echo "$soname"
            return 0
        fi
    fi
    
    # 使用objdump作为备选
    if command -v objdump >/dev/null 2>&1; then
        local soname=$(objdump -p "$file" 2>/dev/null | grep -E "SONAME" | awk '{print $2}')
        if [ -n "$soname" ]; then
            echo "$soname"
            return 0
        fi
    fi
    
    return 1
}

is_elf32() {
    local file="$1"
    
    if ! [ -f "$file" ]; then
        return 1
    fi
    
    # 检查文件头部
    local header=$(head -c 5 "$file" 2>/dev/null | od -An -t x1 | tr -d ' \n')
    
    # ELF魔法字节: 7f 45 4c 46
    if [[ "$header" == 7f454c46* ]]; then
        # 检查第5个字节：01表示32位，02表示64位
        local class_byte=${header:8:2}
        if [ "$class_byte" = "01" ]; then
            return 0
        fi
    fi
    
    return 1
}

ensure_symlink() {
    local target="$1"
    local link_path="$2"
    
    # 创建父目录
    mkdir -p "$(dirname "$link_path")" 2>/dev/null
    
    # 如果链接已存在且正确，跳过
    if [ -L "$link_path" ]; then
        local current_target=$(readlink -f "$link_path" 2>/dev/null || readlink "$link_path")
        if [ "$current_target" = "$target" ]; then
            return 0
        fi
    fi
    
    # 删除现有文件/链接
    rm -f "$link_path" 2>/dev/null
    
    # 创建符号链接
    ln -sf "$target" "$link_path" 2>/dev/null
    return $?
}

# 主函数
main() {
    ACE_DIR="$1"
    
    # 检查目标目录
    if [[ ! -e "${ACE_DIR}" ]]; then
        log.error "ACE_DIR为空，退出。"
        exit 1
    fi
    
    # 获取驱动信息
    local driver_info=$(detect_driver_info)
    if [ -z "$driver_info" ]; then
#        log.error "无法检测到NVIDIA驱动信息 Cannot detect NVIDIA driver information"
        exit 
    fi
    
    local nvidia_version=$(echo "$driver_info" | cut -d':' -f1)
    local lib_dir=$(echo "$driver_info" | cut -d':' -f2)
    
    if [ -z "$nvidia_version" ]; then
#        log.error "无法获取NVIDIA驱动版本 Cannot determine NVIDIA driver version"
        exit 
    fi
    
    # 检查版本是否已存在且匹配
    if [ -f "$ACE_DIR/amber-ce-tools/nvidia_current_version" ]; then
        local existing_version=$(cat "$ACE_DIR/amber-ce-tools/nvidia_current_version")
        if [ "$existing_version" = "$nvidia_version" ]; then
#            log.info "NVIDIA驱动版本未变化，跳过链接操作 NVIDIA Driver version unchanged, skipping linking."
            exit 0
        else
            log.info "检测到NVIDIA驱动版本变化: $existing_version -> $nvidia_version"
        fi
    fi
    
    log.info "正在链接NVIDIA驱动库和GLX组件 Linking NVIDIA Driver Libs and GLX components"
    
    # 准备目录结构
    mkdir -p "$ACE_DIR/usr/lib" "$ACE_DIR/usr/lib32"
    mkdir -p "$ACE_DIR/orig" "$ACE_DIR/orig/32"
    mkdir -p "$ACE_DIR/etc"
    
    # 清理旧链接
#    find "$ACE_DIR/usr/lib" -type l -name "*.so*" -delete 2>/dev/null
#    find "$ACE_DIR/usr/lib32" -type l -name "*.so*" -delete 2>/dev/null
#    find "$ACE_DIR/orig" -type l -name "*.so*" -delete 2>/dev/null
#    find "$ACE_DIR/orig/32" -type l -name "*.so*" -delete 2>/dev/null
    
    # 默认库搜索路径
    local default_search_paths=(
        "$lib_dir"
        "/usr/lib64"
        "/usr/lib/x86_64-linux-gnu"
        "/usr/lib/i386-linux-gnu"
        "/usr/lib/aarch64-linux-gnu"
        "/usr/lib/x86_64-linux-gnu/nvidia/current"
        "/usr/lib/i386-linux-gnu/nvidia/current"
        "/usr/lib/aarch64-linux-gnu/nvidia/current"
        "/lib64"
        "/lib/x86_64-linux-gnu"
        "/lib/i386-linux-gnu"
        "/lib/aarch64-linux-gnu"
        "/lib/x86_64-linux-gnu/nvidia/current"
        "/lib/i386-linux-gnu/nvidia/current"
        "/lib/aarch64-linux-gnu/nvidia/current"
        "/usr/lib"
        "/lib"
    )
    
    # 1. 首先收集核心NVIDIA库
    log.debug "收集核心NVIDIA库..."
    
    # 核心库列表
    local core_libs=(
        "libnvidia-ml.so.*"
        "libcuda.so.*"
        "libnvidia-ptxjitcompiler.so.*"
        "libnvidia-fatbinaryloader.so.*"
        "libnvidia-opencl.so.*"
        "libnvidia-compiler.so.*"
        "libnvidia-encode.so.*"
        "libnvidia-opticalflow.so.*"
        "libnvcuvid.so.*"
        "libnvidia-cfg.so.*"
        "libnvidia-allocator.so.*"
        "libnvidia-nvvm.so.*"
    )
    
    # 2. 收集图形库（包含GLX）
    log.debug "收集图形库..."
    local graphics_libs=(
        "libGLX_nvidia.so.*"
        "libEGL_nvidia.so.*"
        "libGLESv1_CM_nvidia.so.*"
        "libGLESv2_nvidia.so.*"
        "libnvidia-glcore.so.*"
        "libnvidia-glsi.so.*"
        "libnvidia-tls.so.*"
        "libnvidia-egl-gbm.so.*"
        "libnvidia-egl-wayland.so.*"
        "libnvidia-vulkan-producer.so.*"
        "libEGL.so*"
        "libGL.so*"
        "libGLESv1_CM.so*"
        "libGLESv2.so*"
        "libGLX.so*"
        "libGLdispatch.so*"
        "libOpenCL.so*"
        "libOpenGL.so*"
        "libnvidia-api.so*"
        "libnvidia-egl-xcb.so*"
        "libnvidia-egl-xlib.so*"
    )
    
    # 收集所有库文件
    local all_libs=()
    
    for pattern in "${core_libs[@]}" "${graphics_libs[@]}"; do
        local files=($(collect_files default_search_paths[@] "$pattern"))
        all_libs+=("${files[@]}")
    done
    
    # 去重
    declare -A seen_libs
    local unique_libs=()
    for lib in "${all_libs[@]}"; do
        local realpath=$(readlink -f "$lib" 2>/dev/null || echo "$lib")
        if [ -z "${seen_libs[$realpath]}" ]; then
            seen_libs["$realpath"]=1
            unique_libs+=("$realpath")
        fi
    done
    
    # 创建链接
    local has_32bit=false
    local has_64bit=false
    local has_glx=false
    
    for lib_path in "${unique_libs[@]}"; do
        if [ ! -f "$lib_path" ] && [ ! -L "$lib_path" ]; then
            continue
        fi
        
        local filename=$(basename "$lib_path")
        local is_32bit=false
        
        # 检查是否是32位库
        if is_elf32 "$lib_path"; then
            is_32bit=true
            has_32bit=true
        else
            has_64bit=true
        fi
        
        # 创建原始链接（容器内路径）
        local orig_dest_dir="$ACE_DIR/orig"
        if [ "$is_32bit" = true ]; then
            orig_dest_dir="$ACE_DIR/orig/32"
        fi
        
        # 创建主链接
        local container_target="/host$lib_path"
        local orig_link_path="$orig_dest_dir/$filename"
        
        if ensure_symlink "$container_target" "$orig_link_path"; then
            # 检查是否是GLX库
            if [[ "$filename" == libGLX_nvidia.so.* ]]; then
                has_glx=true
            fi
            
            # 创建SONAME链接
            local soname=$(read_elf_soname "$lib_path")
            if [ -n "$soname" ] && [ "$soname" != "$filename" ]; then
                local soname_link_path="$orig_dest_dir/$soname"
                ensure_symlink "$container_target" "$soname_link_path"
            fi
        fi
    done
    
    # 3. 收集X.Org模块
    log.debug "收集X.Org模块..."
    local xorg_paths=(
        "$lib_dir/nvidia/xorg"
        "$lib_dir/xorg/modules/drivers"
        "$lib_dir/xorg/modules/extensions"
        "$lib_dir/xorg/modules/updates/drivers"
        "$lib_dir/xorg/modules/updates/extensions"
        "/usr/lib/xorg/modules/drivers"
        "/usr/lib/xorg/modules/extensions"
        "/usr/lib/xorg/modules/updates/drivers"
        "/usr/lib/xorg/modules/updates/extensions"
        "/usr/lib64/xorg/modules/drivers"
        "/usr/lib64/xorg/modules/extensions"
        "/usr/lib64/xorg/modules/updates/drivers"
        "/usr/lib64/xorg/modules/updates/extensions"
    )
    
    # 查找X.Org驱动程序
    local xorg_driver=""
    for xorg_dir in "${xorg_paths[@]}"; do
        if [ -f "$xorg_dir/nvidia_drv.so" ]; then
            xorg_driver="$xorg_dir/nvidia_drv.so"
            break
        fi
    done
    
    # 查找GLX服务器模块
    local glx_server=""
    for xorg_dir in "${xorg_paths[@]}"; do
        if [ -f "$xorg_dir/libglxserver_nvidia.so.$nvidia_version" ]; then
            glx_server="$xorg_dir/libglxserver_nvidia.so.$nvidia_version"
            break
        fi
    done
    
    # 如果没有找到特定版本，尝试通配符
    if [ -z "$glx_server" ]; then
        for xorg_dir in "${xorg_paths[@]}"; do
            local found=$(find "$xorg_dir" -name "libglxserver_nvidia.so.*" -type f 2>/dev/null | head -1)
            if [ -n "$found" ]; then
                glx_server="$found"
                break
            fi
        done
    fi
    
    # 创建X.Org文件链接
    if [ -n "$xorg_driver" ]; then
        local dest_dir="$ACE_DIR$(dirname "$xorg_driver")"
        mkdir -p "$dest_dir"
        local container_target="/host$xorg_driver"
        ensure_symlink "$container_target" "$dest_dir/$(basename "$xorg_driver")"
    fi
    
    if [ -n "$glx_server" ]; then
        has_glx=true
        local dest_dir="$ACE_DIR$(dirname "$glx_server")"
        mkdir -p "$dest_dir"
        local container_target="/host$glx_server"
        ensure_symlink "$container_target" "$dest_dir/$(basename "$glx_server")"
    fi
    
    # 4. 复制配置文件和辅助文件
    log.debug "处理配置和辅助文件..."
    
    # Vulkan配置文件
    local vulkan_files=(
        "/usr/share/vulkan/icd.d/nvidia_icd.json"
        "/usr/share/vulkan/icd.d/nvidia_icd.x86_64.json"
        "/usr/share/vulkan/icd.d/nvidia_icd.aarch64.json"
        "/usr/share/vulkan/implicit_layer.d/nvidia_layers.json"
    )
    
    # EGL配置文件
    local egl_files=(
        "/usr/share/egl/egl_external_platform.d/10_nvidia_wayland.json"
        "/usr/share/egl/egl_external_platform.d/15_nvidia_gbm.json"
        "/usr/share/egl/egl_external_platform.d/20_nvidia_xcb.json"
    )
    
    # GLVND配置文件
    local glvnd_files=(
        "/usr/share/glvnd/egl_vendor.d/10_nvidia.json"
    )
    
    # X11配置文件
    local x11_files=(
        "/usr/share/X11/xorg.conf.d/10-nvidia.conf"
        "/usr/share/X11/xorg.conf.d/nvidia-drm-outputclass.conf"
    )
    
    # 处理所有配置文件
    for file in "${vulkan_files[@]}" "${egl_files[@]}" "${glvnd_files[@]}" "${x11_files[@]}"; do
        if [ -f "$file" ]; then
            local dest_dir="$ACE_DIR$(dirname "$file")"
            mkdir -p "$dest_dir"
            local container_target="/host$file"
            ensure_symlink "$container_target" "$dest_dir/$(basename "$file")"
        fi
    done
    
    # 5. 生成ld.so.conf文件
    if [ "$has_64bit" = true ] || [ "$has_32bit" = true ]; then
        echo "/opt/extensions/nvidia/orig" > "$ACE_DIR/etc/ld.so.conf"
        if [ "$has_32bit" = true ]; then
            echo "/opt/extensions/nvidia/orig/32" >> "$ACE_DIR/etc/ld.so.conf"
        fi
    fi
    
    # 6. 标记版本
    echo "$nvidia_version" > "$ACE_DIR/amber-ce-tools/nvidia_current_version"
    
    # 7. 生成环境变量脚本
    cat > "$ACE_DIR/nvidia_env.sh" << EOF
#!/bin/bash
# NVIDIA驱动环境变量

export NVIDIA_DRIVER_VERSION="$nvidia_version"

# 库路径
if [ -d "/opt/extensions/nvidia/orig" ]; then
    export LD_LIBRARY_PATH="/opt/extensions/nvidia/orig:\${LD_LIBRARY_PATH}"
fi
if [ -d "/opt/extensions/nvidia/orig/32" ]; then
    export LD_LIBRARY_PATH="/opt/extensions/nvidia/orig/32:\${LD_LIBRARY_PATH}"
fi

# GLX和EGL配置
if [ "$has_glx" = true ]; then
    export __GLX_VENDOR_LIBRARY_NAME="nvidia"
    export __NV_PRIME_RENDER_OFFLOAD="1"
fi

# Vulkan ICD文件
if [ -f "/opt/extensions/nvidia/usr/share/vulkan/icd.d/nvidia_icd.json" ]; then
    export VK_ICD_FILENAMES="/opt/extensions/nvidia/usr/share/vulkan/icd.d/nvidia_icd.json"
    export VK_ADD_DRIVER_FILES="\${VK_ICD_FILENAMES}"
fi

# EGL外部平台配置
EGL_CONF_DIRS=""
for dir in /opt/extensions/nvidia/usr/share/egl/egl_external_platform.d \
           /usr/share/egl/egl_external_platform.d; do
    if [ -d "\$dir" ]; then
        EGL_CONF_DIRS="\$dir:\${EGL_CONF_DIRS}"
    fi
done
if [ -n "\${EGL_CONF_DIRS}" ]; then
    export EGL_EXTERNAL_PLATFORM_CONFIG_DIRS="\${EGL_CONF_DIRS%:}"
    export __EGL_EXTERNAL_PLATFORM_CONFIG_DIRS="\${EGL_CONF_DIRS%:}"
fi

# EGL供应商库目录
EGL_VENDOR_DIRS=""
for dir in /opt/extensions/nvidia/usr/share/glvnd/egl_vendor.d \
           /usr/share/glvnd/egl_vendor.d; do
    if [ -d "\$dir" ]; then
        EGL_VENDOR_DIRS="\$dir:\${EGL_VENDOR_DIRS}"
    fi
done
if [ -n "\${EGL_VENDOR_DIRS}" ]; then
    export __EGL_VENDOR_LIBRARY_DIRS="\${EGL_VENDOR_DIRS%:}"
fi

export NVIDIA_CTK_LIBCUDA_DIR="/opt/extensions/nvidia/orig"

EOF
    
    chmod +x "$ACE_DIR/nvidia_env.sh"
    
    log.info "NVIDIA驱动库和GLX组件已成功链接 Nvidia Driver Libs and GLX components are successfully linked."
    log.info "驱动版本: $nvidia_version"
    log.info "64位库: $has_64bit, 32位库: $has_32bit, GLX支持: $has_glx"
    log.info "环境变量脚本已生成: $ACE_DIR/nvidia_env.sh"
    
    # 创建设备节点信息文件（供容器运行时使用）
    cat > "$ACE_DIR/devices.info" << EOF
# NVIDIA设备节点
/dev/nvidiactl
/dev/nvidia-uvm
/dev/nvidia-uvm-tools
/dev/nvidia-modeset
/dev/nvidia[0-9]*
/dev/dri/card*
/dev/dri/renderD*
EOF
    
    exit 0
}

# 执行主函数
main "$1"
