diff -ur cuda-12.9.org/targets/x86_64-linux/include/cublas_api.h cuda-12.9/targets/x86_64-linux/include/cublas_api.h --- cuda-12.9.org/targets/x86_64-linux/include/cublas_api.h 2026-04-30 07:09:48.874819190 +0000 +++ cuda-12.9/targets/x86_64-linux/include/cublas_api.h 2026-04-30 07:11:20.581260507 +0000 @@ -3372,7 +3372,7 @@ void* C, cudaDataType Ctype, int ldc, - cublasComputeType_t computeType, + cudaDataType computeType, cublasGemmAlgo_t algo); CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGemmEx_64(cublasHandle_t handle, @@ -4879,7 +4879,7 @@ cudaDataType Ctype, int ldc, int batchCount, - cublasComputeType_t computeType, + cudaDataType computeType, cublasGemmAlgo_t algo); CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGemmBatchedEx_64(cublasHandle_t handle, @@ -4924,7 +4924,7 @@ int ldc, long long int strideC, int batchCount, - cublasComputeType_t computeType, + cudaDataType computeType, cublasGemmAlgo_t algo); CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGemmStridedBatchedEx_64(cublasHandle_t handle, @@ -5670,190 +5670,6 @@ #if defined(__cplusplus) } - -static inline cublasStatus_t cublasMigrateComputeType(cublasHandle_t handle, - cudaDataType_t dataType, - cublasComputeType_t* computeType) { - cublasMath_t mathMode = CUBLAS_DEFAULT_MATH; - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - status = cublasGetMathMode(handle, &mathMode); - if (status != CUBLAS_STATUS_SUCCESS) { - return status; - } - - bool isPedantic = ((mathMode & 0xf) == CUBLAS_PEDANTIC_MATH); - - switch (dataType) { - case CUDA_R_32F: - case CUDA_C_32F: - *computeType = isPedantic ? CUBLAS_COMPUTE_32F_PEDANTIC : CUBLAS_COMPUTE_32F; - return CUBLAS_STATUS_SUCCESS; - case CUDA_R_64F: - case CUDA_C_64F: - *computeType = isPedantic ? CUBLAS_COMPUTE_64F_PEDANTIC : CUBLAS_COMPUTE_64F; - return CUBLAS_STATUS_SUCCESS; - case CUDA_R_16F: - *computeType = isPedantic ? CUBLAS_COMPUTE_16F_PEDANTIC : CUBLAS_COMPUTE_16F; - return CUBLAS_STATUS_SUCCESS; - case CUDA_R_32I: - *computeType = isPedantic ? CUBLAS_COMPUTE_32I_PEDANTIC : CUBLAS_COMPUTE_32I; - return CUBLAS_STATUS_SUCCESS; - default: - return CUBLAS_STATUS_NOT_SUPPORTED; - } -} -/* wrappers to accept old code with cudaDataType computeType when referenced from c++ code */ -static inline cublasStatus_t cublasGemmEx(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const void* alpha, /* host or device pointer */ - const void* A, - cudaDataType Atype, - int lda, - const void* B, - cudaDataType Btype, - int ldb, - const void* beta, /* host or device pointer */ - void* C, - cudaDataType Ctype, - int ldc, - cudaDataType computeType, - cublasGemmAlgo_t algo) { - cublasComputeType_t migratedComputeType = CUBLAS_COMPUTE_32F; - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - status = cublasMigrateComputeType(handle, computeType, &migratedComputeType); - if (status != CUBLAS_STATUS_SUCCESS) { - return status; - } - - return cublasGemmEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A, - Atype, - lda, - B, - Btype, - ldb, - beta, - C, - Ctype, - ldc, - migratedComputeType, - algo); -} - -static inline cublasStatus_t cublasGemmBatchedEx(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const void* alpha, /* host or device pointer */ - const void* const Aarray[], - cudaDataType Atype, - int lda, - const void* const Barray[], - cudaDataType Btype, - int ldb, - const void* beta, /* host or device pointer */ - void* const Carray[], - cudaDataType Ctype, - int ldc, - int batchCount, - cudaDataType computeType, - cublasGemmAlgo_t algo) { - cublasComputeType_t migratedComputeType; - cublasStatus_t status; - status = cublasMigrateComputeType(handle, computeType, &migratedComputeType); - if (status != CUBLAS_STATUS_SUCCESS) { - return status; - } - - return cublasGemmBatchedEx(handle, - transa, - transb, - m, - n, - k, - alpha, - Aarray, - Atype, - lda, - Barray, - Btype, - ldb, - beta, - Carray, - Ctype, - ldc, - batchCount, - migratedComputeType, - algo); -} - -static inline cublasStatus_t cublasGemmStridedBatchedEx(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const void* alpha, /* host or device pointer */ - const void* A, - cudaDataType Atype, - int lda, - long long int strideA, /* purposely signed */ - const void* B, - cudaDataType Btype, - int ldb, - long long int strideB, - const void* beta, /* host or device pointer */ - void* C, - cudaDataType Ctype, - int ldc, - long long int strideC, - int batchCount, - cudaDataType computeType, - cublasGemmAlgo_t algo) { - cublasComputeType_t migratedComputeType; - cublasStatus_t status; - status = cublasMigrateComputeType(handle, computeType, &migratedComputeType); - if (status != CUBLAS_STATUS_SUCCESS) { - return status; - } - - return cublasGemmStridedBatchedEx(handle, - transa, - transb, - m, - n, - k, - alpha, - A, - Atype, - lda, - strideA, - B, - Btype, - ldb, - strideB, - beta, - C, - Ctype, - ldc, - strideC, - batchCount, - migratedComputeType, - algo); -} #endif /* __cplusplus */ #endif /* !defined(CUBLAS_API_H_) */