Matrix Multiply (GEMM)

The API below provides transformation functions for Generic Matrix Multiplies (GEMMs) for complex and real values, and batching support for tensors of range 3 and 4.

Cached API

template<typename TensorTypeC, typename TensorTypeA, typename TensorTypeB, MatXMatMulProvider_t PROV = PROVIDER_TYPE_CUBLASLT>
void matx::matmul(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b, cudaStream_t stream = 0, float alpha = 1.0, float beta = 0.0)

Run a GEMM without a plan

Creates a new GEMM plan in the cache if none exists, and uses that to execute the GEMM. This function is preferred over creating a plan directly for both efficiency and simpler code. Since it only uses the signature of the GEMM to decide if a plan is cached, it may be able to reused plans for different A/B/C matrices as long as they were configured with the same dimensions.

Template Parameters
  • T1 – Data type of C matrix

  • T2 – Data type of A matrix

  • T3 – Data type of B matrix

  • RANK – Rank of A/B/C matrices

  • PROV – Provider type chosen from MatXMatMulProvider_t type

Parameters
  • c – C matrix view

  • a – A matrix view

  • b – B matrix view

  • stream – CUDA stream

  • alpha – Scalar multiplier to apply to matrix A

  • beta – Scalar multiplier to apply to matrix C on input

Non-Cached API

template<typename TensorTypeC, typename TensorTypeA, typename TensorTypeB, MatXMatMulProvider_t PROV = PROVIDER_TYPE_CUBLASLT>
class matx::matxMatMulHandle_t

Public Functions

inline matxMatMulHandle_t(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b)

Construct a GEMM handle

Creates a GEMM handle for the view shapes and provider type given. The view shapres are used to create the underlying metadata used for the GEMM, so a handle should only be used for views of identical sizes. The provider chooses the underlying library used to perform the GEMM. Certain providers have more features than others and may perform differently than others. At the moment, it is recommended to try different providers for a given matrix size until the optimal provider is found. Different providers may also be used by creating multiple handles.

Template Parameters
  • T1 – Type of C matrix

  • T2 – Type of A matrix

  • T3 – Type of B matrix

  • PROV – Provider type chosen from MatXMatMulProvider_t type

Parameters
  • c – C matrix view

  • a – A matrix view

  • b – B matrix view

inline ~matxMatMulHandle_t()

GEMM handle destructor

Destroys any helper data used for provider type and any workspace memory created

inline void Exec(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b, cudaStream_t stream, float alpha = 1.0f, float beta = 0.0f)

Execute a Matrix multiply (GEMM)

Execute a matrix multiply operation on two rank=2 input tensors into an output tensor. Using BLAS notation, tensor A has dimensions MxK, B is KxN, and C is MxN. Concretely:

\(\textbf{C} = \alpha\textbf{A}\textbf{B} + \beta\textbf{C}\)

MatX will perform runtime checks ensuring that the dimension constraints are met on all views. Unlike BLAS GEMMS, most parameters of the GEMM call are deduced from the view itself; there is no need to specify dimensions or transpose operations. MatX will attempt to perform the GEMM in the most efficient way possible given the knowledge of the view.

While GEMMs are strictly rank=2 functions, rank 3 and higher tensors may be passed to this function, which has the effect of batching across the higher dimensions.

Note

views being passed to matxGemm must not be permuted and must have a contigous stride currently.

Template Parameters
  • T1 – Type of beta

  • T2 – Type of alpha

Parameters
  • c – Output tensor C

  • a – Input tensor A

  • b – Input tensor B

  • stream – CUDA stream

  • alpha – Alpha value

  • beta – Beta value

enum matx::MatXMatMulProvider_t

Defines a provider type for a GEMM. The provider is directly tied to the underlying library used for the gemm, and certain providers provide capabilities that others may not have.

Values:

enumerator PROVIDER_TYPE_CUTLASS

CUTLASS library.

enumerator PROVIDER_TYPE_CUBLASLT

cuBLASLt library

enumerator PROVIDER_TYPE_AUTO

Automatically select.

enumerator PROVIDER_TYPE_SENTINEL

Sentinel value. Do not use.