Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

GSoC'2021 | TensorFlow implementation of Wav2Vec2

License

NotificationsYou must be signed in to change notification settings

thevasudevgupta/gsoc-wav2vec2

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

89 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GSoC

This repository presents an implementation of theWav2Vec2 model [1] inTensorFlow 2.0 as a part ofGoogle Summer of Code.

For a quick demo, please check outthis. You can find the final report of the projecthere.

Notebooks

The repository comes with shiny Colab Notebooks. Below you can find a list of them. Spin them up and don't forget to have fun!

NotebookDescription
tensorflow/hubThis notebook gives you a template to fine-tune a pre-trained Wav2Vec2 SavedModel
Open In ColabThis notebook demonstrates conversion of TF Wav2Vec2 model to ONNX and compares the latency of ONNX exported model & TF model on CPU
Open In ColabThis notebook demonstrates Wav2Vec2 evaluation (without any padding) on LibriSpeech data
Open In ColabThis notebook demonstrates Wav2Vec2 SavedModel evaluation (with constant padding upto 246000 length) on LibriSpeech data
Open In ColabThis notebook shows a small demo of how to use Wav2Vec2 for inference for ASR task

Checkpoints

Below is a summary of checkpoints obtained during the project:

🤗Hub CheckpointTFHubSavedModelDescription
gsoc-wav2vec2wav2vec2This checkpoint is TensorFlow's equivalent ofpre-trained Wav2Vec2 by Facebook. PyTorch weights are converted into TensorFlow usingconvert_torch_to_tf.py
gsoc-wav2vec2-960hwav2vec2-960hThis checkpoint is TensorFlow's equivalent offine-tuned Wav2Vec2 by Facebook. PyTorch weights are converted into TensorFlow usingconvert_torch_to_tf.py
finetuned-wav2vec2-960h-This checkpoint is obtained by fine-tuning Wav2Vec2 model on 960h of LibriSpeech dataset during my GSoC tenure. You can reproduce training by runningmain.py on TPU v3-8
gsoc-wav2vec2-robustwav2vec2-robustThis checkpoint is TensorFlow's equivalent ofpre-trained Wav2Vec2-robust by Facebook. PyTorch weights are converted into TensorFlow usingconvert_torch_to_tf.py
gsoc-wav2vec2-xlsr-53wav2vec2-xlsr-53This checkpoint is TensorFlow's equivalent ofpre-trained Wav2Vec2-xlsr-53 by Facebook. PyTorch weights are converted into TensorFlow usingconvert_torch_to_tf.py

To know more about the process of obtaining the first two checkpoints, please check outthis section and to know about the process of getting the last checkpoint, please check outthis section.

Using this Repository

InstallWav2Vec2 model from this repository using thepip command:

# this will install the wav2vec2 packagepip3 install git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main

You can use the fine-tuned checkpoints (from 🤗 Hub) like this:

fromwav2vec2importWav2Vec2ForCTC,Wav2Vec2Configconfig=Wav2Vec2Config()model=Wav2Vec2ForCTC(config)# now use this model like any other TF model# incase you are interested in already trained model, use `.from_pretrained` methodmodel_id="finetuned-wav2vec2-960h"model=Wav2Vec2ForCTC.from_pretrained(model_id)

Additionally, you can use theSavedModel from TFHub like this:

importtensorflow_hubasHubmodel_url="https://tfhub.dev/vasudevgupta7/wav2vec2-960h/1"model=hub.KerasLayer(model_url)# use this `model`, just like any other TF SavedModel

Please checkout the notebooks referred to in this repository for more information on using theWav2Vec2 model.

Reproducing this project

Setting Up

# install & setup TensorFlow firstpip3 install tensorflow==2.5# Only `TF==2.5` is tested for now!# install other requirements of this project using the following command:pip3 install -qr requirements.txtsudo apt-get install libsndfile1-dev# switch to code directory for further stepscd src

For using TPUs, it's essential to store model weights and datasets in the GCS bucket so that TPU can access them directly from there. Hence we will create 2 GCS buckets - one for checkpointing and the other for storing LibriSpeech tfrecords.

# these bucket names will be required to run the training script laterexport DATA_BUCKET_NAME="gsoc-librispeech-us"export CKPT_BUCKET_NAME="gsoc-checkpoints-us"# create GCS bucketsgsutil mb gs://${DATA_BUCKET_NAME}gsutil mb gs://${CKPT_BUCKET_NAME}

Preparing dataset

Now we will download the LibriSpeech dataset from the official website & convert them into tfrecords usingmake_tfrecords.py. Finally, we will export all the tfrecords to the GCS bucket.

# possible values are `dev-clean`, `train-clean-100`, `train-clean-360`, `train-other-500`, `test-clean`# you will have to follow same steps for all the configurations (specified above).export DATA_SPLIT=dev-cleanwget https://www.openslr.org/resources/12/${DATA_SPLIT}.tar.gztar -xf${DATA_SPLIT}.tar.gzpython3 make_tfrecords.py --data_dir LibriSpeech/${DATA_SPLIT} -d${DATA_SPLIT} -n 50# transfer tfrecords to GCS bucketgsutil cp -r${DATA_SPLIT} gs://<DATA_BUCKET_NAME>/${DATA_SPLIT}

Now your GCS bucket (DATA_BUCKET_NAME) should look like this:

.|- ${DATA_SPLIT}    |- ${DATA_SPLIT}-0.tfrecord    |- ${DATA_SPLIT}-1.tfrecord    .    .

Follow the above steps for all other data splits. You need to change theDATA_SPLIT environment variable.

Model training

Since you have installed everything and GCS buckets are configured, we need to run one command to initiate training.

Note: Following commands assume that you already have exportedDATA_BUCKET_NAME &CKPT_BUCKET_NAME environment variables.

The following command will fine-tune the wav2vec2 model on single/multiple GPUs or Colab/Kaggle TPUs:

python3 main.py

For training on Cloud TPUs, run the following command:

# export `TPU_NAME` environment variable first# this flag will ensure that your VM connects to the specified TPUs & TPUs become visible to TensorFlowTPU_NAME=<tpu-name> python3 main.py

Running Conversion script

You can convert original PyTorch checkpoints (from Facebook) using the conversion script available in this repository.

python3 convert_torch_to_tf.py \--hf_model_id facebook/wav2vec2-base\# HuggingFace Hub ID of the model you want to convert--with_lm_head# Whether to use `Wav2Vec2ForCTC` or `Wav2Vec2Model` from this repository

Running tests

# first install `torch` & `transformers`pip3 install torch transformers# run this from the root of this repositorypytest -sv tests

Acknowledgement

References

[1] Baevski, Alexei, et al. "Wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations." ArXiv:2006.11477 [Cs, Eess], Oct. 2020. arXiv.org,http://arxiv.org/abs/2006.11477.

End Notes

Please create an issue if you encounter any problems while using this project. Don't forget to 🌟 this repository if you like this work.


[8]ページ先頭

©2009-2025 Movatter.jp