- Notifications
You must be signed in to change notification settings - Fork85
Fast implementation of BERT inference directly on NVIDIA (CUDA, CUBLAS) and Intel MKL
License
zhihu/cuBERT
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Highly customized and optimized BERT inference directly on NVIDIA (CUDA,CUBLAS) or Intel MKL,without tensorflow and its framework overhead.
ONLY BERT (Transformer) is supported.
- Tesla P4
- 28 * Intel(R) Xeon(R) CPU E5-2680 v4 @ 2.40GHz
- Debian GNU/Linux 8 (jessie)
- gcc (Debian 4.9.2-10+deb8u1) 4.9.2
- CUDA: release 9.0, V9.0.176
- MKL: 2019.0.1.20181227
- tensorflow: 1.12.0
- BERT: seq_length = 32
batch size | 128 (ms) | 32 (ms) |
---|---|---|
tensorflow | 255.2 | 70.0 |
cuBERT | 184.6 | 54.5 |
batch size | 128 (ms) | 1 (ms) |
---|---|---|
tensorflow | 1504.0 | 69.9 |
mklBERT | 984.9 | 24.0 |
Note: MKL should be run underOMP_NUM_THREADS=?
to control its threadnumber. Other environment variables and their possible values includes:
KMP_BLOCKTIME=0
KMP_AFFINITY=granularity=fine,verbose,compact,1,0
cuBERT can be accelerated byTensor CoreandMixed Precisionon NVIDIA Volta and Turing GPUs. We support mixed precision as variablesstored in fp16 with computation taken in fp32. The typical accuracy erroris less than 1% compared with single precision inference, while the speedachieves more than 2x acceleration.
We support following 2 pooling method.
- The standard BERT pooler, which is defined as:
withtf.variable_scope("pooler"):# We "pool" the model by simply taking the hidden state corresponding# to the first token. We assume that this has been pre-trainedfirst_token_tensor=tf.squeeze(self.sequence_output[:,0:1, :],axis=1)self.pooled_output=tf.layers.dense(first_token_tensor,config.hidden_size,activation=tf.tanh,kernel_initializer=create_initializer(config.initializer_range))
- Simple average pooler:
self.pooled_output=tf.reduce_mean(self.sequence_output,axis=1)
Following outputs are supported:
cuBERT_OutputType | python code |
---|---|
cuBERT_LOGITS | model.get_pooled_output() * output_weights + output_bias |
cuBERT_PROBS | probs = tf.nn.softmax(logits, axis=-1) |
cuBERT_POOLED_OUTPUT | model.get_pooled_output() |
cuBERT_SEQUENCE_OUTPUT | model.get_sequence_output() |
cuBERT_EMBEDDING_OUTPUT | model.get_embedding_output() |
mkdir build&&cd build# if build with CUDAcmake -DCMAKE_BUILD_TYPE=Release -DcuBERT_ENABLE_GPU=ON -DCUDA_ARCH_NAME=Common ..# or build with MKLcmake -DCMAKE_BUILD_TYPE=Release -DcuBERT_ENABLE_MKL_SUPPORT=ON ..make -j4# install to /usr/local# it will also install MKL if -DcuBERT_ENABLE_MKL_SUPPORT=ONsudo make install
If you would like to run tfBERT_benchmark for performance comparison,please first install tensorflow C API fromhttps://www.tensorflow.org/install/lang_c.
Download BERT test modelbert_frozen_seq32.pb
andvocab.txt
fromDropbox,and put them under dirbuild
before runmake test
or./cuBERT_test
.
We provide simple Python wrapper by Cython, and it can be built andinstalled after C++ building as follows:
cd pythonpython setup.py bdist_wheel# installpip install dist/cuBERT-xxx.whl# testpython cuBERT_test.py
Please check the Python API usage and examples atcuBERT_test.pyfor more details.
Java wrapper is implemented throughJNA. After installing maven and C++ building, it can be built as follows:
cd javamvn clean package# -DskipTests
When using Java JAR, you need to specifyjna.library.path
to thelocation oflibcuBERT.so
if it is not installed to the system path.Andjna.encoding
should be set to UTF8 as-Djna.encoding=UTF8
in the JVM start-up script.
Please check the Java API usage and example atModelTest.javafor more details.
Pre-built python binary package (currently only with MKL on Linux) canbe installed as follows:
Download and installMKLto system path.
Download the wheel package and
pip install cuBERT-xxx-linux_x86_64.whl
run
python -c 'import libcubert'
to verify your installation.
cuBERT is built withprotobuf-c toavoid version and code conflicting with tensorflow protobuf.
Libraries compiled by CUDA with different versions are not compatible.
MKL is dynamically linked. We install both cuBERT and MKL insudo make install
.
We assume the typical usage case of cuBERT is for online serving, whereconcurrent requests of different batch_size should be served as fast aspossible. Thus, throughput and latency should be balanced, especially inpure CPU environment.
As the vanillaclass Bert is not thread-safebecause of its internal buffers for computation, a wrapperclass BertMis written to hold locks of differentBert
instances for thread safety.BertM
will choose one underlyingBert
instance by a round-robinmanner, and consequence requests of the sameBert
instance might bequeued by its corresponding lock.
OneBert
is placed on one GPU card. The maximum concurrent requests isthe number of usable GPU cards on one machine, which can be controlledbyCUDA_VISIBLE_DEVICES
if it is specified.
For pure CPU environment, it is more complicate than GPU. There are 2level of parallelism:
Request level. Concurrent requests will compete CPU resource if theonline server itself is multi-threaded. If the server is single-threaded(for example some server implementation in Python), things will be mucheasier.
Operation level. The matrix operations are parallelized by OpenMP andMKL. The maximum parallelism is controlled by
OMP_NUM_THREADS
,MKL_NUM_THREADS
, and many other environment variables. We refer ourusers to first readUsing Threaded Intel® MKL in Multi-Thread ApplicationandRecommended settings for calling Intel MKL routines from multi-threaded applications.
Thus, we introduceCUBERT_NUM_CPU_MODELS
for better control of requestlevel parallelism. This variable specifies the number ofBert
instancescreated on CPU/memory, which acts same likeCUDA_VISIBLE_DEVICES
forGPU.
If you have limited number of CPU cores (old or desktop CPUs, or inDocker), it is not necessary to use
CUBERT_NUM_CPU_MODELS
. For example4 CPU cores, a request-level parallelism of 1 and operation-levelparallelism of 4 should work quite well.But if you have many CPU cores like 40, it might be better to try withrequest-level parallelism of 5 and operation-level parallelism of 8.
In summary,OMP_NUM_THREADS
orMKL_NUM_THREADS
defines how many threadsone model could use, andCUBERT_NUM_CPU_MODELS
defines how many models intotal.
Again, the per request latency and overall throughput should be balanced,and it diffs from modelseq_length
,batch_size
, your CPU cores, yourserver QPS, and many many other things. You should take a lot benchmarkto achieve the best trade-off. Good luck!
- fanliwen
- wangruixin
- fangkuan
- sunxian
About
Fast implementation of BERT inference directly on NVIDIA (CUDA, CUBLAS) and Intel MKL