- Notifications
You must be signed in to change notification settings - Fork14
An extension library of WMMA API (Tensor Core API)
License
NotificationsYou must be signed in to change notification settings
wmmae/wmma_extension
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
This extension provides features for
- mapping between memory and fragment (primitive functions)
- operationf for vectors
- loading a vector as a fragment
- storing a fragment as a vector
- C++ interface for
mma
instructions [detail] - Error Correction (TCEC) for SGEMM emulation [detail]
- arithmetic operators for fragments (
+, -, *, /, fma
) [detail] - utils [detail]
- etc
without using extra shared memory.
Important
Please specify an appropriate virtual architecture for real GPU.For instance, a program which is compiled with-arch=sm_70
will not work correctly on Ampere GPUs.
- CUDA (10.2 or later)
- C++ (17 or later)
- sm_70: ((16, 16, 16), fp16/fp32)
- sm_75: ((16, 16, 16), fp16/fp32)
- sm_80: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32)
- sm_89: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32)
- sm_90: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32) (
wgmma
instruction is not supported yet)
This function calculates the mapping of the memory and fragment elements.
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,16,16,16, half, nvcuda::wmma::col_major> frag_b;__shared__compute_t matrix[16 *16];mtk::wmma::foreach<decltype(frag_b)>( [&](constunsigned* frag_index_list,constunsigned fragment_index_count,constunsigned mem_index) {constauto m = mem_index %16;constauto n = mem_index /16;for (unsigned i =0; i < fragment_index_count; i++) frag_b.x[frag_index_list[i]] = convert_to<half>(matrix[n *16 + m]); });
This function calculates the mapping of the matrix element position (i,j) and fragment elements.
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,16,16,16, half, nvcuda::wmma::col_major> frag_b;__shared__compute_t matrix[16 *16];mtk::wmma::foreach_ij<decltype(frag_b)>( [&](constunsigned* frag_index_list,constunsigned fragment_index_count,constunsigned i,constunsigned j) {for (unsigned f =0; f < fragment_index_count; f++) frag_b.x[frag_index_list[f]] = convert_to<half>(matrix[j *16 + i]); });
This function calculates the mapping of a given vector and fragment elements.
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,16,16,16, half, nvcuda::wmma::col_major> frag_b;__shared__compute_t vector[16];mtk::wmma::foreach_v<decltype(frag_b)>( [&](constunsigned* frag_index_list,constunsigned fragment_index_count,constunsigned mem_index) {for (unsigned i =0; i < fragment_index_count; i++) frag_b.x[frag_index_list[i]] = convert_to<half>(vector[mem_index]); });// is equivalent to `load_vector`
nvcuda::wmma::fragment<nvcuda::wmma::accumulator,16,16,16,float> frag_c;__shared__compute_t vector[16];mtk::wmma::foreach_v<decltype(frag_c)>(nvcuda::wmma::mem_col_major, [&](constunsigned* frag_index_list,constunsigned fragment_index_count,constunsigned mem_index) {for (unsigned i =0; i < fragment_index_count; i++) vector[mem_index] = convert_to<compute_t>(frag_c.x[frag_index_list[i]]); });// is equivalent to `store_vector`
This function returns the mapping of matrix element (i, j) and fragment element (tid, fid)
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,16,16,16, half, nvcuda::wmma::col_major> frag_b;unsigned tid_list[2];unsigned fid_list[2];unsigned list_size;mtk::wmma::map<decltype(frag_b)>(tid_list, fid_list, list_size, i, j);for (unsigned k =0; k < list_size; k++) {if ((threadIdx.x &0x1f) == tid_list[k]) { frag_b.x[fid_list[k]] =3.0f; }}
#include<mma.h>#include<wmma_extension/wmma_extension.hpp>__global__voidkernel() { nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,16,16,16, half, nvcuda::wmma::col_major> frag_a; nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,16,16,16, half, nvcuda::wmma::col_major> frag_b; nvcuda::wmma::fragment<nvcuda::wmma::accumulator,16,16,16,float> frag_c;__shared__float vec16[16];mtk::wmma::load_vector(frag_a, vec16);mtk::wmma::load_vector(frag_b, vec16);nvcuda::wmma::fill_fragment(frag_c,0.0f);nvcuda::wmma::mma_sync(frag_c, frag_a, frag_b, frag_c);mtk::wmma::store_vector(vec16, frag_c, nvcuda::wmma::mem_col_major);}
- Arguments
- dst_fragment : Destination fragment (
accumulator
) - alpha : diagonal element
- dst_fragment : Destination fragment (
- Argument
- dst_fragment : Destination fragment
This function output the elements of a fragment.
- Arguments
- frag : Target fragment
- name : printing name of fragment (
char*
, optional)
@inproceedings{ootomo_wmmae_2023,author ={Ootomo, Hiroyuki and Yokota, Rio},title ={Reducing Shared Memory Footprint to Leverage High Throughput on Tensor Cores and Its Flexible API Extension Library},year ={2023},series ={HPC Asia '23}}
MIT