5 #ifndef FML_GPU_ARCH_HIP_GPUBLAS_H
6 #define FML_GPU_ARCH_HIP_GPUBLAS_H
17 inline std::string get_rocblas_error_msg(cublasStatus_t check)
19 if (check == rocblas_status_success)
21 else if (check == rocblas_status_invalid_handle)
22 return "invalid handle";
23 else if (check == rocblas_status_not_implemented)
24 return "function not implemented";
25 else if (check == rocblas_status_invalid_pointer)
26 return "invalid data";
27 else if (check == rocblas_status_invalid_size)
28 return "invalid size";
29 else if (check == rocblas_status_memory_error)
30 return "failed internal memory operation";
31 else if (check == rocblas_status_internal_error)
32 return "internal library failure";
33 else if (check == rocblas_status_perf_degraded)
34 return "performance degraded from low device memory";
35 else if (check == rocblas_status_size_query_mismatch)
36 return "unmatched start/stop size query";
37 else if (check == rocblas_status_size_increased)
38 return "queried device memory size increased";
39 else if (check == rocblas_status_size_unchanged)
40 return "queried device memory size unchanged";
41 else if (check == rocblas_status_invalid_value)
42 return "invalid paramater";
43 else if (check == rocblas_status_continue)
44 return "nothing preventing function to proceed";
46 return "unknown rocblas error occurred";
49 inline void check_ret(rocblas_status check, std::string op)
51 if (check != rocblas_status_success)
53 std::string msg =
"rocblas " + op +
"() failed with error: " + get_rocblas_error_msg(check);
54 throw std::runtime_error(msg);
61 inline rocblas_status gemm(rocblas_handle handle, rocblas_operation transa,
62 rocblas_operation transb,
int m,
int n,
int k,
const __half alpha,
63 const __half *A,
int lda,
const __half *B,
int ldb,
const __half beta,
66 return rocblas_hgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb,
70 inline rocblas_status gemm(rocblas_handle handle, rocblas_operation transa,
71 rocblas_operation transb,
int m,
int n,
int k,
const float alpha,
72 const float *A,
int lda,
const float *B,
int ldb,
const float beta,
75 return rocblas_sgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb,
79 inline rocblas_status gemm(rocblas_handle handle, rocblas_operation transa,
80 rocblas_operation transb,
int m,
int n,
int k,
const double alpha,
81 const double *A,
int lda,
const double *B,
int ldb,
const double beta,
84 return rocblas_dgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb,
90 inline rocblas_status syrk(rocblas_handle handle, cublasFillMode_t uplo,
91 rocblas_operation trans,
int n,
int k,
const float alpha,
const float *A,
92 int lda,
const float beta,
float *C,
int ldc)
94 return rocblas_ssyrk(handle, uplo, trans, n, k, &alpha, A, lda, &beta, C, ldc);
97 inline rocblas_status syrk(rocblas_handle handle, cublasFillMode_t uplo,
98 rocblas_operation trans,
int n,
int k,
const double alpha,
const double *A,
99 int lda,
const double beta,
double *C,
int ldc)
101 return rocblas_dsyrk(handle, uplo, trans, n, k, &alpha, A, lda, &beta, C, ldc);
106 inline rocblas_status geam(rocblas_handle handle, rocblas_operation transa,
107 rocblas_operation transb,
int m,
int n,
const float alpha,
const float *A,
108 int lda,
const float beta,
const float *B,
int ldb,
float *C,
int ldc)
110 return rocblas_sgeam(handle, transa, transb, m, n, &alpha, A, lda, &beta, B,
114 inline rocblas_status geam(rocblas_handle handle, rocblas_operation transa,
115 rocblas_operation transb,
int m,
int n,
const double alpha,
const double *A,
116 int lda,
const double beta,
const double *B,
int ldb,
double *C,
int ldc)
118 return rocblas_dgeam(handle, transa, transb, m, n, &alpha, A, lda, &beta, B,