fml  0.1-0
Fused Matrix Library
gpublas.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_GPU_ARCH_HIP_GPUBLAS_H
6 #define FML_GPU_ARCH_HIP_GPUBLAS_H
7 #pragma once
8 
9 
10 #include <rocblas.h>
11 
12 
13 namespace gpublas
14 {
15  namespace err
16  {
17  inline std::string get_rocblas_error_msg(cublasStatus_t check)
18  {
19  if (check == rocblas_status_success)
20  return "";
21  else if (check == rocblas_status_invalid_handle)
22  return "invalid handle";
23  else if (check == rocblas_status_not_implemented)
24  return "function not implemented";
25  else if (check == rocblas_status_invalid_pointer)
26  return "invalid data";
27  else if (check == rocblas_status_invalid_size)
28  return "invalid size";
29  else if (check == rocblas_status_memory_error)
30  return "failed internal memory operation";
31  else if (check == rocblas_status_internal_error)
32  return "internal library failure";
33  else if (check == rocblas_status_perf_degraded)
34  return "performance degraded from low device memory";
35  else if (check == rocblas_status_size_query_mismatch)
36  return "unmatched start/stop size query";
37  else if (check == rocblas_status_size_increased)
38  return "queried device memory size increased";
39  else if (check == rocblas_status_size_unchanged)
40  return "queried device memory size unchanged";
41  else if (check == rocblas_status_invalid_value)
42  return "invalid paramater";
43  else if (check == rocblas_status_continue)
44  return "nothing preventing function to proceed";
45  else
46  return "unknown rocblas error occurred";
47  }
48 
49  inline void check_ret(rocblas_status check, std::string op)
50  {
51  if (check != rocblas_status_success)
52  {
53  std::string msg = "rocblas " + op + "() failed with error: " + get_rocblas_error_msg(check);
54  throw std::runtime_error(msg);
55  }
56  }
57  }
58 
59 
60 
61  inline rocblas_status gemm(rocblas_handle handle, rocblas_operation transa,
62  rocblas_operation transb, int m, int n, int k, const __half alpha,
63  const __half *A, int lda, const __half *B, int ldb, const __half beta,
64  __half *C, int ldc)
65  {
66  return rocblas_hgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb,
67  &beta, C, ldc);
68  }
69 
70  inline rocblas_status gemm(rocblas_handle handle, rocblas_operation transa,
71  rocblas_operation transb, int m, int n, int k, const float alpha,
72  const float *A, int lda, const float *B, int ldb, const float beta,
73  float *C, int ldc)
74  {
75  return rocblas_sgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb,
76  &beta, C, ldc);
77  }
78 
79  inline rocblas_status gemm(rocblas_handle handle, rocblas_operation transa,
80  rocblas_operation transb, int m, int n, int k, const double alpha,
81  const double *A, int lda, const double *B, int ldb, const double beta,
82  double *C, int ldc)
83  {
84  return rocblas_dgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb,
85  &beta, C, ldc);
86  }
87 
88 
89 
90  inline rocblas_status syrk(rocblas_handle handle, cublasFillMode_t uplo,
91  rocblas_operation trans, int n, int k, const float alpha, const float *A,
92  int lda, const float beta, float *C, int ldc)
93  {
94  return rocblas_ssyrk(handle, uplo, trans, n, k, &alpha, A, lda, &beta, C, ldc);
95  }
96 
97  inline rocblas_status syrk(rocblas_handle handle, cublasFillMode_t uplo,
98  rocblas_operation trans, int n, int k, const double alpha, const double *A,
99  int lda, const double beta, double *C, int ldc)
100  {
101  return rocblas_dsyrk(handle, uplo, trans, n, k, &alpha, A, lda, &beta, C, ldc);
102  }
103 
104 
105 
106  inline rocblas_status geam(rocblas_handle handle, rocblas_operation transa,
107  rocblas_operation transb, int m, int n, const float alpha, const float *A,
108  int lda, const float beta, const float *B, int ldb, float *C, int ldc)
109  {
110  return rocblas_sgeam(handle, transa, transb, m, n, &alpha, A, lda, &beta, B,
111  ldb, C, ldc);
112  }
113 
114  inline rocblas_status geam(rocblas_handle handle, rocblas_operation transa,
115  rocblas_operation transb, int m, int n, const double alpha, const double *A,
116  int lda, const double beta, const double *B, int ldb, double *C, int ldc)
117  {
118  return rocblas_dgeam(handle, transa, transb, m, n, &alpha, A, lda, &beta, B,
119  ldb, C, ldc);
120  }
121 }
122 
123 
124 #endif