Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[pipelines] text-to-audio pipeline standardization#39796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Open
gante wants to merge8 commits intohuggingface:main
base:main
Choose a base branch
Loading
fromgante:tts_pipeline_standardization

Conversation

@gante
Copy link
Contributor

@gantegante commentedJul 30, 2025
edited
Loading

What does this PR do?

⚠️ TODO before merging: after settling on the design, update pipeline usage in model docs accordingly.

This PR standardizestext-to-audio such that the following lines work with most* text-to-audio models:

fromtransformersimportpipelinesynthesiser=pipeline("text-to-audio","facebook/musicgen-large")music=synthesiser("A low-fi song with a strong bassline")synthesiser.save_audio(music,"test.wav")

On most* models where voice control is possible, it is possible to control it through thevoice pipeline argument. The valid values forvoice are model-dependent, and this argument is documented accordingly

fromtransformersimportpipelinetts_pipeline=pipeline("text-to-audio","sesame/csm-1b")audio=tts_pipeline("I just got bamboozled by my cat.",voice="1")tts_pipeline.save_audio(audio,"test.wav")

Core changes

Prior to this PR, recent models withtext-to-audio capabilities had no pipeline support (e.g. CSM, Dia, Qwen2.5 Omni). There was also not a standardized way to control the voice, if the model generates speech.

With this PR,TextToAudioPipeline:

  • Uses aprocessor whenever possible, automatically (as opposed to needing a flag to control it);
  • Takes avoice argument, which a few models can use out of the box. Whether a model can takevoice in the pipeline is specified by properties of the model (if future models have the same properties, they will also havevoice support);
  • Standardizes outputs: ALL models using the pipeline will return{"audio": <np.array with shape (audio_channels, sequence_length)>, "sampling_rate": <int>}. Different models return different array formats, and as a result we can see different saving scripts in the model cards -- the pipeline standardizes it;
  • Adds a function to save the audio, for convenience. This way, users don't need to learn aboutsoundfile or alternatives. Uses the processor'ssave_audio whenever it is available.

Model support

Models with out-of-the-box pipeline support:

  • TTS Models withvoice support:
    • CSM
    • Qwen2.5 Omni
    • Bark
  • TTS Models w/ovoice support:
    • Dia -- voice is set in the prompt; needs chat templates? (we can hardcode it in the pipeline, though 🤔 )
    • FastSpeech2Conformer (model has no voice control)
    • SeamlessM4T and variants (model has no voice control)
    • Vits (model has no voice control)
  • TTA Models:
    • Musicgen and variants

Models that have special requirements:

  • SpeechT5 (requiresspeaker_embeddings argument; we could hide the complexity and acceptvoice: int, but the voice dataset most commonly used to pull the embeddings from isn't compatible withdatasets==4.0.0 💔 )

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, bark, csm, dia, fastspeech2_conformer, qwen2_5_omni, speecht5

@gantegante requested review fromebezzam andeustlbJuly 30, 2025 16:17
@HuggingFaceDocBuilderDev

The docs for this PR livehere. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

audio (`dict[str, Any]` or `list[dict[str, Any]]`):
The audio returned by the pipeline. The dictionary (or each dictionary, if it is a list) should
contain two keys:
- `"audio"`: The audio waveform.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Cool to have saving function 👍

Maybe useful to mention that the audio is expected to have shape(nb_channels, audio_length)? (unless self.processor expects something else...)

And are we sure it isn't a tensor? Otherwise can also add something likethis?

I guesssynthesiser andtts_pipeline try to ensure this dimension order, but maybe useful in case people gets issues saving.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

cfthis comment

ebezzam reacted with thumbs up emoji
Comment on lines +257 to +259
voice (`str`, *optional*):
The voice to use for the generation, if the model is a text-to-speech model that supports multiple
voices. Please refer to the model docs in transformers for model-specific examples.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

How about something more generic likepreset orstyle? At the moment, TTS models mainly make use of this feature (Bark calls itvoice_preset) but maybe music generation models may use this feature at one point?

wdyt@eustlb ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

agree!preset is good IMO

ebezzam and gante reacted with thumbs up emoji
# Or we need to postprocess to get the waveform
raiseValueError(
f"Unexpected keys in the audio output format:{audio.keys()}. Expected one of "
"`waveform` or `audio`"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

and "sequences"

waveform=waveform.unsqueeze(0)
iflen(waveform.shape)==2:# (bsz, audio_length) -> (bsz, audio_channels=1, audio_length)
waveform=waveform.unsqueeze(1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

do we need a condition if self.audio_channels > 1 but missing batch dimension?

Copy link
Contributor

@eustlbeustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Thank you@gante for the good work! 🤗
My comments (open to discussion!!!, more details in the review):

  1. move fromvoice topreset, that should be passed blindly to the processor. This way, we offload all the logic to the processor, with the advantage of reducing maintenance and keeping in general enough to include new approaches in the future
  2. centralize saving function inaudio_utils

gante reacted with thumbs up emoji
self.parent.assertTrue(torch.allclose(output_from_past_slice,output_from_no_past_slice,atol=1e-3))


classBarkModelTester:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Strange this was here, that's an easy unbloat! 😅

ebezzam reacted with hooray emoji

@slow
@require_torch
deftest_csm_model(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

yes yes yes! Thanks for adding it

output_audio:Optional[bool]=False,
**kwargs,
)->Union[GenerateNonBeamOutput,torch.LongTensor]:
)->Union[CsmGenerateOutput,torch.LongTensor,list[torch.FloatTensor]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Nice catch!

Comment on lines +195 to +197
ifaudio_value.ndim==2andaudio_value.shape[0]in (1,2):
# (nb_channels, audio_length) -> (audio_length, nb_channels), as expected by `soundfile`
audio_value=np.transpose(audio_value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I don't get why you're adding this for CSM and Dia. The fact that the models are mono is more of a design choice than a parameter, and such things should not change.

Nevertheless, if we add asave_audio inaudio_utils and rely on this function for saving audio in the processor, it will make sense to have it there, WDYT?

Copy link
Contributor

@ebezzamebezzamJul 31, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Agree with centralizing/standardizing asave_audio function inaudio_utils to avoid code duplication in different processors.

Another idea@eustlb, in several audio libraries I've seen the use of anAudio objects to group together useful properties and methods. WDYT? Here's an example below (quickly drafted with ChatGPT):

importsoundfileassfimportnumpyasnpimportosclassAudio:def__init__(self,waveform:np.ndarray,sample_rate:int,path:str=None):"""        Args:            waveform: NumPy array of shape (channels, samples).            sample_rate: Sampling rate in Hz.            path: Optional original file path.        """ifwaveform.ndim==1:waveform=waveform[np.newaxis, :]# Convert to (1, samples)elifwaveform.shape[0]>waveform.shape[1]:# Assume it's (samples, channels) — transpose to (channels, samples)waveform=waveform.Tself.waveform=waveformself.sample_rate=sample_rateself.path=path@classmethoddefload(cls,path:str):"""        Load audio from file (converted to (channels, samples)).        """ifnotos.path.isfile(path):raiseFileNotFoundError(f"Audio file not found:{path}")data,sample_rate=sf.read(path,always_2d=True)# (samples, channels)waveform=data.T# Transpose to (channels, samples)returncls(waveform,sample_rate,path)defsave(self,path:str):"""        Save audio to file (converted to (samples, channels)).        """data=self.waveform.T# Convert to (samples, channels)sf.write(path,data,self.sample_rate)self.path=pathdefduration_seconds(self)->float:returnself.waveform.shape[1]/self.sample_rate# can have others things for spectrogram, MFCC, augmentations, etc which can call functional versions

It can make working audio with much more convenient (not having to pass sample rate each time and channel dimension consistency). And can be used alongside functional approach

Comment on lines +142 to +146
def_get_sampling_rate(self,sampling_rate:Optional[int]=None)->Optional[int]:
"""
Get the sampling rate from the model config, generation config, processor, or vocoder. Can be overridden by
`sampling_rate` in `__init__`.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I don't think we should be able to override it. Like the number of channels, that's more of a fixed parameter for a certain model than something that users would be able to play with.

Moreover, I don't see a model where it would make sense to have it in generation config (bark has it but it's silent 🤔).

Sampling rate should live in two places:

  1. processor → to prepare the inputs in the correct sampling rate
  2. config → to indicate on which sampling rate the model has been trained

priority order should therefore be model config > processor (+ special case of the vocoder)

ebezzam reacted with thumbs up emoji
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Agree with priority being in model config

Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Agreed with not allowing overwrites! (the option to overwrite it was present before, so I left it -> I will remove it then 👍 )

eustlb reacted with thumbs up emoji
sampling_rate=self.vocoder.config.sampling_rate

ifself.sampling_rateisNone:
ifsampling_rateisNone:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

detail but can be safely removed

Suggested change
if sampling_rate is None:

fromtransformers.models.qwen2_audio.modeling_qwen2_audioimportQwen2AudioEncoderLayer
fromtransformers.models.qwen2_vl.modeling_qwen2_vlimportQwen2VLRotaryEmbedding

from ...cache_utilsimportCache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Trusting you on that one!

Comment on lines +184 to +190
forattrin ["audio_channels","num_audio_channels"]:
audio_channels=getattr(obj,attr,None)
ifaudio_channelsisnotNone:
break
# WARNING: this default may cause issues in the future. We may need a more precise way of detecting the number
# of channels.
audio_channels=audio_channelsor1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

That's ok I think, stereo is exceptional. It's almost always not specified, with convention of being mono in this situation.

gante reacted with thumbs up emoji
Comment on lines +211 to +218
voice=str(kwargs.pop("voice","0"))
ifnotvoice.isdigit():
logger.warning(
f"With{self.model.name_or_path}, the voice pipeline argument must be a digit. Got voice={voice}, "
"using voice=0 instead."
)
voice="0"
conversation= [{"role":voice,"content": [{"type":"text","text":text[0]}]}]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

hum that's a bit misleading for CSM since the voice id does not refer to a predefined voice but rather a voice given a conventional context (cfthis doc example - the model can be fined-tuned to do predefined voices though! cfhere).

Here, setting a voice ID without providing a conversational context will just pick a random voice.

In the future I see multiple different ways of setting this voice, mainly either via a custom way of doing voice cloning, either by using a predefined voice identifier (number/ name/ etc). In both situations that's something that will end up being done in the processor.

What we can do is move fromvoice topreset, which will directly be passed to the processor and actualise current processors__call__ so that they can handle thispreset kwarg.

output_list.append(output_dict)
returnoutput_list

defsave_audio(self,audio:Union[dict[str,Any],list[dict[str,Any]]],path:Union[str,list[str]]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Can we move this toaudio_utils.py?
The best would be to redefineAudioInput type as a dict with keys audio (typed as former AudioInput) andsampling_rate (float) so that we finally move from having audio in the lib as an array with sampling rate. This should not be breaking as it's only type hinting. This might require removingAudioInput in favor of the current type hinting to places where it's currently been used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

See#39796 (comment) for idea on moving away from audio as independent array

Copy link
Contributor

@vasquvasqu left a comment
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Just for me so I can make a proper review after vacay :D

Regarding dia with voice - it does not use the voice (encoded into codebooks) within the encoder but as prefix to the decoder. Since it's an encoder-decoder architecture I'm not sure whether a chat template would really be suitable here. However, the processor should be able to take in audio and text to process it into the correct parameters, e.g. input ids and decoder input ids

@require_torch
deftest_dia_model(self):
"""Tests Dia with the text-to-audio pipeline"""
speech_generator=pipeline(task="text-to-audio",model="buttercrab/dia-v1-1.6b",framework="pt")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Suggested change
speech_generator=pipeline(task="text-to-audio",model="buttercrab/dia-v1-1.6b",framework="pt")
speech_generator=pipeline(task="text-to-audio",model="nari-labs/Dia-1.6B-0626",framework="pt")

Let's use the official ckpt imo

Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

@ebezzamebezzamebezzam left review comments

@vasquvasquvasqu left review comments

@eustlbeustlbeustlb left review comments

At least 1 approving review is required to merge this pull request.

Assignees

No one assigned

Labels

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

5 participants

@gante@HuggingFaceDocBuilderDev@ebezzam@vasqu@eustlb

[8]ページ先頭

©2009-2025 Movatter.jp