- Notifications
You must be signed in to change notification settings - Fork16
An extension library of WMMA API (Tensor Core API)
License
wmmae/wmma_extension
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
This extension provides features for
- mapping between memory and fragment (primitive functions)
- operationf for vectors
- loading a vector as a fragment
- storing a fragment as a vector
- C++ interface for
mma
instructions [detail] - Error Correction (TCEC) for SGEMM emulation [detail]
- arithmetic operators for fragments (
+, -, *, /, fma
) [detail] - utils [detail]
- etc
without using extra shared memory.
Important
Please specify an appropriate virtual architecture for real GPU.For instance, a program which is compiled with-arch=sm_70
will not work correctly on Ampere GPUs.
- CUDA (10.2 or later)
- C++ (17 or later)
- sm_70: ((16, 16, 16), fp16/fp32)
- sm_75: ((16, 16, 16), fp16/fp32)
- sm_80: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32)
- sm_89: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32)
- sm_90: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32) (
wgmma
instruction is not supported yet)
This function calculates the mapping of the memory and fragment elements.
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,16,16,16, half, nvcuda::wmma::col_major> frag_b;__shared__compute_t matrix[16 *16];mtk::wmma::foreach<decltype(frag_b)>( [&](constunsigned* frag_index_list,constunsigned fragment_index_count,constunsigned mem_index) {constauto m = mem_index %16;constauto n = mem_index /16;for (unsigned i =0; i < fragment_index_count; i++) frag_b.x[frag_index_list[i]] = convert_to<half>(matrix[n *16 + m]); });
This function calculates the mapping of the matrix element position (i,j) and fragment elements.
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,16,16,16, half, nvcuda::wmma::col_major> frag_b;__shared__compute_t matrix[16 *16];mtk::wmma::foreach_ij<decltype(frag_b)>( [&](constunsigned* frag_index_list,constunsigned fragment_index_count,constunsigned i,constunsigned j) {for (unsigned f =0; f < fragment_index_count; f++) frag_b.x[frag_index_list[f]] = convert_to<half>(matrix[j *16 + i]); });
This function calculates the mapping of a given vector and fragment elements.
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,16,16,16, half, nvcuda::wmma::col_major> frag_b;__shared__compute_t vector[16];mtk::wmma::foreach_v<decltype(frag_b)>( [&](constunsigned* frag_index_list,constunsigned fragment_index_count,constunsigned mem_index) {for (unsigned i =0; i < fragment_index_count; i++) frag_b.x[frag_index_list[i]] = convert_to<half>(vector[mem_index]); });// is equivalent to `load_vector`
nvcuda::wmma::fragment<nvcuda::wmma::accumulator,16,16,16,float> frag_c;__shared__compute_t vector[16];mtk::wmma::foreach_v<decltype(frag_c)>(nvcuda::wmma::mem_col_major, [&](constunsigned* frag_index_list,constunsigned fragment_index_count,constunsigned mem_index) {for (unsigned i =0; i < fragment_index_count; i++) vector[mem_index] = convert_to<compute_t>(frag_c.x[frag_index_list[i]]); });// is equivalent to `store_vector`
This function returns the mapping of matrix element (i, j) and fragment element (tid, fid)
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,16,16,16, half, nvcuda::wmma::col_major> frag_b;unsigned tid_list[2];unsigned fid_list[2];unsigned list_size;mtk::wmma::map<decltype(frag_b)>(tid_list, fid_list, list_size, i, j);for (unsigned k =0; k < list_size; k++) {if ((threadIdx.x &0x1f) == tid_list[k]) { frag_b.x[fid_list[k]] =3.0f; }}
#include<mma.h>#include<wmma_extension/wmma_extension.hpp>__global__voidkernel() { nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,16,16,16, half, nvcuda::wmma::col_major> frag_a; nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,16,16,16, half, nvcuda::wmma::col_major> frag_b; nvcuda::wmma::fragment<nvcuda::wmma::accumulator,16,16,16,float> frag_c;__shared__float vec16[16];mtk::wmma::load_vector(frag_a, vec16);mtk::wmma::load_vector(frag_b, vec16);nvcuda::wmma::fill_fragment(frag_c,0.0f);nvcuda::wmma::mma_sync(frag_c, frag_a, frag_b, frag_c);mtk::wmma::store_vector(vec16, frag_c, nvcuda::wmma::mem_col_major);}
- Arguments
- dst_fragment : Destination fragment (
accumulator
) - alpha : diagonal element
- dst_fragment : Destination fragment (
- Argument
- dst_fragment : Destination fragment
This function output the elements of a fragment.
- Arguments
- frag : Target fragment
- name : printing name of fragment (
char*
, optional)
@inproceedings{ootomo_wmmae_2023,author ={Ootomo, Hiroyuki and Yokota, Rio},title ={Reducing Shared Memory Footprint to Leverage High Throughput on Tensor Cores and Its Flexible API Extension Library},year ={2023},series ={HPC Asia '23}}
MIT
About
An extension library of WMMA API (Tensor Core API)
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Releases
Packages0
Uh oh!
There was an error while loading.Please reload this page.