- Notifications
You must be signed in to change notification settings - Fork29
GSoC'2021 | TensorFlow implementation of Wav2Vec2
License
thevasudevgupta/gsoc-wav2vec2
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
This repository presents an implementation of theWav2Vec2 model [1] inTensorFlow 2.0 as a part ofGoogle Summer of Code.
For a quick demo, please check outthis. You can find the final report of the projecthere.
The repository comes with shiny Colab Notebooks. Below you can find a list of them. Spin them up and don't forget to have fun!
Notebook | Description |
---|---|
tensorflow/hub | This notebook gives you a template to fine-tune a pre-trained Wav2Vec2 SavedModel |
This notebook demonstrates conversion of TF Wav2Vec2 model to ONNX and compares the latency of ONNX exported model & TF model on CPU | |
This notebook demonstrates Wav2Vec2 evaluation (without any padding) on LibriSpeech data | |
This notebook demonstrates Wav2Vec2 SavedModel evaluation (with constant padding upto 246000 length) on LibriSpeech data | |
This notebook shows a small demo of how to use Wav2Vec2 for inference for ASR task |
Below is a summary of checkpoints obtained during the project:
🤗Hub Checkpoint | TFHubSavedModel | Description |
---|---|---|
gsoc-wav2vec2 | wav2vec2 | This checkpoint is TensorFlow's equivalent ofpre-trained Wav2Vec2 by Facebook. PyTorch weights are converted into TensorFlow usingconvert_torch_to_tf.py |
gsoc-wav2vec2-960h | wav2vec2-960h | This checkpoint is TensorFlow's equivalent offine-tuned Wav2Vec2 by Facebook. PyTorch weights are converted into TensorFlow usingconvert_torch_to_tf.py |
finetuned-wav2vec2-960h | - | This checkpoint is obtained by fine-tuning Wav2Vec2 model on 960h of LibriSpeech dataset during my GSoC tenure. You can reproduce training by runningmain.py on TPU v3-8 |
gsoc-wav2vec2-robust | wav2vec2-robust | This checkpoint is TensorFlow's equivalent ofpre-trained Wav2Vec2-robust by Facebook. PyTorch weights are converted into TensorFlow usingconvert_torch_to_tf.py |
gsoc-wav2vec2-xlsr-53 | wav2vec2-xlsr-53 | This checkpoint is TensorFlow's equivalent ofpre-trained Wav2Vec2-xlsr-53 by Facebook. PyTorch weights are converted into TensorFlow usingconvert_torch_to_tf.py |
To know more about the process of obtaining the first two checkpoints, please check outthis section and to know about the process of getting the last checkpoint, please check outthis section.
InstallWav2Vec2
model from this repository using thepip
command:
# this will install the wav2vec2 packagepip3 install git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main
You can use the fine-tuned checkpoints (from 🤗 Hub) like this:
fromwav2vec2importWav2Vec2ForCTC,Wav2Vec2Configconfig=Wav2Vec2Config()model=Wav2Vec2ForCTC(config)# now use this model like any other TF model# incase you are interested in already trained model, use `.from_pretrained` methodmodel_id="finetuned-wav2vec2-960h"model=Wav2Vec2ForCTC.from_pretrained(model_id)
Additionally, you can use theSavedModel
from TFHub like this:
importtensorflow_hubasHubmodel_url="https://tfhub.dev/vasudevgupta7/wav2vec2-960h/1"model=hub.KerasLayer(model_url)# use this `model`, just like any other TF SavedModel
Please checkout the notebooks referred to in this repository for more information on using theWav2Vec2
model.
# install & setup TensorFlow firstpip3 install tensorflow==2.5# Only `TF==2.5` is tested for now!# install other requirements of this project using the following command:pip3 install -qr requirements.txtsudo apt-get install libsndfile1-dev# switch to code directory for further stepscd src
For using TPUs, it's essential to store model weights and datasets in the GCS bucket so that TPU can access them directly from there. Hence we will create 2 GCS buckets - one for checkpointing and the other for storing LibriSpeech tfrecords.
# these bucket names will be required to run the training script laterexport DATA_BUCKET_NAME="gsoc-librispeech-us"export CKPT_BUCKET_NAME="gsoc-checkpoints-us"# create GCS bucketsgsutil mb gs://${DATA_BUCKET_NAME}gsutil mb gs://${CKPT_BUCKET_NAME}
Now we will download the LibriSpeech dataset from the official website & convert them into tfrecords usingmake_tfrecords.py
. Finally, we will export all the tfrecords to the GCS bucket.
# possible values are `dev-clean`, `train-clean-100`, `train-clean-360`, `train-other-500`, `test-clean`# you will have to follow same steps for all the configurations (specified above).export DATA_SPLIT=dev-cleanwget https://www.openslr.org/resources/12/${DATA_SPLIT}.tar.gztar -xf${DATA_SPLIT}.tar.gzpython3 make_tfrecords.py --data_dir LibriSpeech/${DATA_SPLIT} -d${DATA_SPLIT} -n 50# transfer tfrecords to GCS bucketgsutil cp -r${DATA_SPLIT} gs://<DATA_BUCKET_NAME>/${DATA_SPLIT}
Now your GCS bucket (DATA_BUCKET_NAME
) should look like this:
.|- ${DATA_SPLIT} |- ${DATA_SPLIT}-0.tfrecord |- ${DATA_SPLIT}-1.tfrecord . .
Follow the above steps for all other data splits. You need to change theDATA_SPLIT
environment variable.
Since you have installed everything and GCS buckets are configured, we need to run one command to initiate training.
Note: Following commands assume that you already have exportedDATA_BUCKET_NAME
&CKPT_BUCKET_NAME
environment variables.
The following command will fine-tune the wav2vec2 model on single/multiple GPUs or Colab/Kaggle TPUs:
python3 main.py
For training on Cloud TPUs, run the following command:
# export `TPU_NAME` environment variable first# this flag will ensure that your VM connects to the specified TPUs & TPUs become visible to TensorFlowTPU_NAME=<tpu-name> python3 main.py
You can convert original PyTorch checkpoints (from Facebook) using the conversion script available in this repository.
python3 convert_torch_to_tf.py \--hf_model_id facebook/wav2vec2-base\# HuggingFace Hub ID of the model you want to convert--with_lm_head# Whether to use `Wav2Vec2ForCTC` or `Wav2Vec2Model` from this repository
# first install `torch` & `transformers`pip3 install torch transformers# run this from the root of this repositorypytest -sv tests
- Sayak Paul,Morgan Roff,Jaeyoun Kim for mentoring me throughout the project.
- TensorFlow team &TRC for providing access to TPUs during my GSoC tenure.
[1] Baevski, Alexei, et al. "Wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations." ArXiv:2006.11477 [Cs, Eess], Oct. 2020. arXiv.org,http://arxiv.org/abs/2006.11477.
Please create an issue if you encounter any problems while using this project. Don't forget to 🌟 this repository if you like this work.