Matrix Multiply (GEMM)¶
The API below provides transformation functions for Generic Matrix Multiplies (GEMMs) for complex and real values, and batching support for tensors of range 3 and 4.
Cached API¶
-
template<typename TensorTypeC, typename TensorTypeA, typename TensorTypeB, MatXMatMulProvider_t PROV = PROVIDER_TYPE_CUBLASLT>
void matx::matmul(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b, cudaStream_t stream = 0, float alpha = 1.0, float beta = 0.0)¶ Run a GEMM without a plan
Creates a new GEMM plan in the cache if none exists, and uses that to execute the GEMM. This function is preferred over creating a plan directly for both efficiency and simpler code. Since it only uses the signature of the GEMM to decide if a plan is cached, it may be able to reused plans for different A/B/C matrices as long as they were configured with the same dimensions.
- Template Parameters
T1 – Data type of C matrix
T2 – Data type of A matrix
T3 – Data type of B matrix
RANK – Rank of A/B/C matrices
PROV – Provider type chosen from MatXMatMulProvider_t type
- Parameters
c – C matrix view
a – A matrix view
b – B matrix view
stream – CUDA stream
alpha – Scalar multiplier to apply to matrix A
beta – Scalar multiplier to apply to matrix C on input
Non-Cached API¶
-
template<typename TensorTypeC, typename TensorTypeA, typename TensorTypeB, MatXMatMulProvider_t PROV = PROVIDER_TYPE_CUBLASLT>
class matx::matxMatMulHandle_t¶ Public Functions
-
inline matxMatMulHandle_t(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b)¶
Construct a GEMM handle
Creates a GEMM handle for the view shapes and provider type given. The view shapres are used to create the underlying metadata used for the GEMM, so a handle should only be used for views of identical sizes. The provider chooses the underlying library used to perform the GEMM. Certain providers have more features than others and may perform differently than others. At the moment, it is recommended to try different providers for a given matrix size until the optimal provider is found. Different providers may also be used by creating multiple handles.
- Template Parameters
T1 – Type of C matrix
T2 – Type of A matrix
T3 – Type of B matrix
PROV – Provider type chosen from MatXMatMulProvider_t type
- Parameters
c – C matrix view
a – A matrix view
b – B matrix view
-
inline ~matxMatMulHandle_t()¶
GEMM handle destructor
Destroys any helper data used for provider type and any workspace memory created
-
inline void Exec(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b, cudaStream_t stream, float alpha = 1.0f, float beta = 0.0f)¶
Execute a Matrix multiply (GEMM)
Execute a matrix multiply operation on two rank=2 input tensors into an output tensor. Using BLAS notation, tensor A has dimensions MxK, B is KxN, and C is MxN. Concretely:
\(\textbf{C} = \alpha\textbf{A}\textbf{B} + \beta\textbf{C}\)
MatX will perform runtime checks ensuring that the dimension constraints are met on all views. Unlike BLAS GEMMS, most parameters of the GEMM call are deduced from the view itself; there is no need to specify dimensions or transpose operations. MatX will attempt to perform the GEMM in the most efficient way possible given the knowledge of the view.
While GEMMs are strictly rank=2 functions, rank 3 and higher tensors may be passed to this function, which has the effect of batching across the higher dimensions.
Note
views being passed to matxGemm must not be permuted and must have a contigous stride currently.
- Template Parameters
T1 – Type of beta
T2 – Type of alpha
- Parameters
c – Output tensor C
a – Input tensor A
b – Input tensor B
stream – CUDA stream
alpha – Alpha value
beta – Beta value
-
inline matxMatMulHandle_t(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b)¶
-
enum matx::MatXMatMulProvider_t¶
Defines a provider type for a GEMM. The provider is directly tied to the underlying library used for the gemm, and certain providers provide capabilities that others may not have.
Values:
-
enumerator PROVIDER_TYPE_CUTLASS¶
CUTLASS library.
-
enumerator PROVIDER_TYPE_CUBLASLT¶
cuBLASLt library
-
enumerator PROVIDER_TYPE_AUTO¶
Automatically select.
-
enumerator PROVIDER_TYPE_SENTINEL¶
Sentinel value. Do not use.
-
enumerator PROVIDER_TYPE_CUTLASS¶