- Notifications
You must be signed in to change notification settings - Fork149
kapre: Keras Audio Preprocessors
License
keunwoochoi/kapre
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
Keras Audio Preprocessors - compute STFT, ISTFT, Melspectrogram, and others on GPU real-time.
Tested on Python 3.8+, with type hints for better development experience
- You can optimize DSP parameters
- Your model deployment becomes much simpler and consistent.
- Your code and model has less dependencies
- Quick and easy!
- Consistent with 1D/2D tensorflow batch shapes
- Data format agnostic (
channels_firstandchannels_last) - Less error prone - Kapre layers are tested against Librosa (stft, decibel, etc) - which is (trust me)trickier than you think.
- Kapre layers have some extended APIs from the default
tf.signalsimplementation such as..- A perfectly invertible
STFTandInverseSTFTpair - Mel-spectrogram with more options
- A perfectly invertible
- Reproducibility - Kapre is available on pip with versioning
- Preprocess your audio dataset. Resample the audio to the right sampling rate and store the audio signals (waveforms).
- In your ML model, add Kapre layer e.g.
kapre.time_frequency.STFT()as the first layer of the model. - The data loader simply loads audio signals and feed them into the model
- In your hyperparameter search, include DSP parameters like
n_fftto boost the performance. - When deploying the final model, all you need to remember is the sampling rate of the signal. No dependency or preprocessing!
pip install kapre
Kapre includes comprehensive type hints for better IDE support and development experience.
Run type checking with our included script:
python scripts/check_types.py
Or use your preferred type checker:
# With mypypip install mypymypy kapre/# With pyrightpip install pyrightpyright kapre/
# Install development dependenciespip install -e".[dev]"# Run testspytest tests/# Run type checkingpython scripts/check_types.py# Format codeblack kapre/ tests/# Lint codeflake8 kapre/ tests/
Please refer to Kapre API Documentation athttps://kapre.readthedocs.io
fromtensorflow.keras.modelsimportSequentialfromtensorflow.keras.layersimportConv2D,BatchNormalization,ReLU,GlobalAveragePooling2D,Dense,SoftmaxfromkapreimportSTFT,Magnitude,MagnitudeToDecibelfromkapre.composedimportget_melspectrogram_layer,get_log_frequency_spectrogram_layer# 6 channels (!), maybe 1-sec audio signal, for an example.input_shape= (44100,6)sr=44100model=Sequential()# A STFT layermodel.add(STFT(n_fft=2048,win_length=2018,hop_length=1024,window_name=None,pad_end=False,input_data_format='channels_last',output_data_format='channels_last',input_shape=input_shape))model.add(Magnitude())model.add(MagnitudeToDecibel())# these three layers can be replaced with get_stft_magnitude_layer()# Alternatively, you may want to use a melspectrogram layer# melgram_layer = get_melspectrogram_layer()# or log-frequency layer# log_stft_layer = get_log_frequency_spectrogram_layer()# add more layers as you wantmodel.add(Conv2D(32, (3,3),strides=(2,2)))model.add(BatchNormalization())model.add(ReLU())model.add(GlobalAveragePooling2D())model.add(Dense(10))model.add(Softmax())# Compile the modelmodel.compile('adam','categorical_crossentropy')# if single-label classification# train it with raw audio sample inputs# for example, you may have functions that load your data as below.x=load_x()# e.g., x.shape = (10000, 6, 44100)y=load_y()# e.g., y.shape = (10000, 10) if it's 10-class classification# then..model.fit(x,y)# Done!
- See the Jupyter notebook at theexample folder
TheSTFT layer is not tflite compatible (due totf.signal.stft). To create a tflitecompatible model, first train using the normalkapre layers then create a newmodel replacingSTFT andMagnitude withSTFTTflite,MagnitudeTflite.Tflite compatible layers are restricted to a batch size of 1 which prevents useof them during training.
# assumes you have run the one-shot example above.fromkapreimportSTFTTflite,MagnitudeTflitemodel_tflite=Sequential()model_tflite.add(STFTTflite(n_fft=2048,win_length=2018,hop_length=1024,window_name=None,pad_end=False,input_data_format='channels_last',output_data_format='channels_last',input_shape=input_shape))model_tflite.add(MagnitudeTflite())model_tflite.add(MagnitudeToDecibel())model_tflite.add(Conv2D(32, (3,3),strides=(2,2)))model_tflite.add(BatchNormalization())model_tflite.add(ReLU())model_tflite.add(GlobalAveragePooling2D())model_tflite.add(Dense(10))model_tflite.add(Softmax())# load the trained weights into the tflite compatible model.model_tflite.set_weights(model.get_weights())
Please cite this paper if you use Kapre for your work.
@inproceedings{choi2017kapre, title={Kapre: On-GPU Audio Preprocessing Layers for a Quick Implementation of Deep Neural Network Models with Keras}, author={Choi, Keunwoo and Joo, Deokjin and Kim, Juho}, booktitle={Machine Learning for Music Discovery Workshop at 34th International Conference on Machine Learning}, year={2017}, organization={ICML}}About
kapre: Keras Audio Preprocessors
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.
Contributors14
Uh oh!
There was an error while loading.Please reload this page.