Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

An extension library of WMMA API (Tensor Core API)

License

NotificationsYou must be signed in to change notification settings

wmmae/wmma_extension

Repository files navigation

WMMA API Extension

This extension provides features for

  • mapping between memory and fragment (primitive functions)
  • operationf for vectors
    • loading a vector as a fragment
    • storing a fragment as a vector
  • C++ interface formma instructions [detail]
  • Error Correction (TCEC) for SGEMM emulation [detail]
  • arithmetic operators for fragments (+, -, *, /, fma) [detail]
  • utils [detail]
  • etc

without using extra shared memory.

Important

Please specify an appropriate virtual architecture for real GPU.For instance, a program which is compiled with-arch=sm_70 will not work correctly on Ampere GPUs.

Requirements

  • CUDA (10.2 or later)
  • C++ (17 or later)

Supported architectures / fragment

  • sm_70: ((16, 16, 16), fp16/fp32)
  • sm_75: ((16, 16, 16), fp16/fp32)
  • sm_80: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32)
  • sm_89: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32)
  • sm_90: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32) (wgmma instruction is not supported yet)

Functions

Primitive functions

foreach

This function calculates the mapping of the memory and fragment elements.

nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,16,16,16, half, nvcuda::wmma::col_major> frag_b;__shared__compute_t matrix[16 *16];mtk::wmma::foreach<decltype(frag_b)>(        [&](constunsigned* frag_index_list,constunsigned fragment_index_count,constunsigned mem_index) {constauto m = mem_index %16;constauto n = mem_index /16;for (unsigned i =0; i < fragment_index_count; i++)                frag_b.x[frag_index_list[i]] = convert_to<half>(matrix[n *16 + m]);        });

foreach_ij

This function calculates the mapping of the matrix element position (i,j) and fragment elements.

nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,16,16,16, half, nvcuda::wmma::col_major> frag_b;__shared__compute_t matrix[16 *16];mtk::wmma::foreach_ij<decltype(frag_b)>(        [&](constunsigned* frag_index_list,constunsigned fragment_index_count,constunsigned i,constunsigned j) {for (unsigned f =0; f < fragment_index_count; f++)                frag_b.x[frag_index_list[f]] = convert_to<half>(matrix[j *16 + i]);        });

foreach_v

For matrix A/B

This function calculates the mapping of a given vector and fragment elements.

nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,16,16,16, half, nvcuda::wmma::col_major> frag_b;__shared__compute_t vector[16];mtk::wmma::foreach_v<decltype(frag_b)>(        [&](constunsigned* frag_index_list,constunsigned fragment_index_count,constunsigned mem_index) {for (unsigned i =0; i < fragment_index_count; i++)                frag_b.x[frag_index_list[i]] = convert_to<half>(vector[mem_index]);        });// is equivalent to `load_vector`

For accumulator

nvcuda::wmma::fragment<nvcuda::wmma::accumulator,16,16,16,float> frag_c;__shared__compute_t vector[16];mtk::wmma::foreach_v<decltype(frag_c)>(nvcuda::wmma::mem_col_major,        [&](constunsigned* frag_index_list,constunsigned fragment_index_count,constunsigned mem_index) {for (unsigned i =0; i < fragment_index_count; i++)                vector[mem_index] = convert_to<compute_t>(frag_c.x[frag_index_list[i]]);        });// is equivalent to `store_vector`

map

This function returns the mapping of matrix element (i, j) and fragment element (tid, fid)

nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,16,16,16, half, nvcuda::wmma::col_major> frag_b;unsigned tid_list[2];unsigned fid_list[2];unsigned list_size;mtk::wmma::map<decltype(frag_b)>(tid_list, fid_list, list_size, i, j);for (unsigned k =0; k < list_size; k++) {if ((threadIdx.x &0x1f) == tid_list[k]) {    frag_b.x[fid_list[k]] =3.0f;  }}

Functions for vector

Sample

#include<mma.h>#include<wmma_extension/wmma_extension.hpp>__global__voidkernel() {    nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,16,16,16, half, nvcuda::wmma::col_major> frag_a;    nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,16,16,16, half, nvcuda::wmma::col_major> frag_b;    nvcuda::wmma::fragment<nvcuda::wmma::accumulator,16,16,16,float> frag_c;__shared__float vec16[16];mtk::wmma::load_vector(frag_a, vec16);mtk::wmma::load_vector(frag_b, vec16);nvcuda::wmma::fill_fragment(frag_c,0.0f);nvcuda::wmma::mma_sync(frag_c, frag_a, frag_b, frag_c);mtk::wmma::store_vector(vec16, frag_c, nvcuda::wmma::mem_col_major);}

Other functions

make_identity_matrix / add_eye

load_matrix

  • Arguments
    • dst_fragment : Destination fragment (accumulator)
    • alpha : diagonal element

fill_zero

  • Argument
    • dst_fragment : Destination fragment

Debugging functions

print_fragment

This function output the elements of a fragment.

  • Arguments
    • frag : Target fragment
    • name : printing name of fragment (char*, optional)

Publication

@inproceedings{ootomo_wmmae_2023,author ={Ootomo, Hiroyuki and Yokota, Rio},title ={Reducing Shared Memory Footprint to Leverage High Throughput on Tensor Cores and Its Flexible API Extension Library},year ={2023},series ={HPC Asia '23}}

LICENSE

MIT


[8]ページ先頭

©2009-2025 Movatter.jp