Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Making Flux go brrr on GPUs.

NotificationsYou must be signed in to change notification settings

huggingface/flux-fast

Repository files navigation

Making Flux go brrr on GPUs. With simple recipes from this repo, we enabled ~2.5x speedup on Flux.1-Schnell and Flux.1-Dev using (mainly) pure PyTorch code and a beefy GPU like H100. This repo is NOT meant to be a library or an out-of-the-box solution. So, please fork the repo, hack into the code, and share your results 🤗

Check out the accompanying blog posthere.

Updates

July 18, 2025: First caching mechanism influx-fast withcache-dit. Check out the accompanyingPR. Thanks to @DefTruth for the contribution!

July 1, 2025: This repository now supports AMD MI300X GPUs using AITER kernels(PR). The README has been updated to provide instructions on how to run on AMD GPUs. Thanks to @jammm for the contribution!

June 28, 2025: This repository now supportsFlux.1 Kontext Dev. We enabled ~2.5x speedup on it. Check outthis section for more details.

Results

DescriptionImage
Flux.1-Schnellnew_flux_schnell_plot
Flux.1-Devflux_dev_result_plot

Summary of the optimizations:

  • Running with the bfloat16 precision
  • torch.compile
  • Combining q,k,v projections for attention computation
  • torch.channels_last memory format for the decoder output
  • Flash Attention v3 (FA3) with (unscaled) conversion of inputs totorch.float8_e4m3fn
  • Dynamic float8 quantization and quantization of Linear layer weights viatorchao'sfloat8_dynamic_activation_float8_weight
  • Inductor flags:
    • conv_1x1_as_mm = True
    • epilogue_fusion = False
    • coordinate_descent_tuning = True
    • coordinate_descent_check_all_directions = True
  • torch.export + Ahead-of-time Inductor (AOTI) + CUDAGraphs
  • cache acceleration withcache-dit: DBCache

All of the above optimizations are lossless (outside of minor numerical differences sometimesintroduced through the use oftorch.compile /torch.export) EXCEPT FOR dynamic float8 quantization.Disable quantization if you want the same quality results as the baseline while still beingquite a bit faster.

Here are some example outputs with Flux.1-Schnell for prompt"A cat playing with a ball of yarn":

ConfigurationOutput
Baselinebaseline_output
Fully-optimized (with quantization)fast_output

Setup

We rely primarily on pure PyTorch for the optimizations. Currently, a relatively recent nightly version of PyTorch is required.

The numbers reported here were gathered using:

For NVIDIA:

  • torch==2.8.0.dev20250605+cu126 - note that we rely on some fixes since 2.7
  • torchao==0.12.0.dev20250610+cu126 - note that we rely on a fix in the 06/10 nightly
  • diffusers - withthis fix included
  • flash_attn_3==3.0.0b1

For AMD:

  • torch==2.8.0.dev20250605+rocm6.4 - note that we rely on some fixes since 2.7
  • torchao==0.12.0.dev20250610+rocm6.4 - note that we rely on a fix in the 06/10 nightly
  • diffusers - withthis fix included
  • aiter-0.1.4.dev17+gd0384d4

To install deps on NVIDIA:

pip install -U huggingface_hub[hf_xet] accelerate transformerspip install -U diffuserspip install --pre torch==2.8.0.dev20250605+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126pip install --pre torchao==0.12.0.dev20250610+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126

(For NVIDIA) To install flash attention v3, follow the instructions inhttps://github.com/Dao-AILab/flash-attention#flashattention-3-beta-release.

To install deps on AMD:

pip install -U diffuserspip install --pre torch==2.8.0.dev20250605+rocm6.4 --index-url https://download.pytorch.org/whl/nightly/rocm6.4pip install --pre torchao==0.12.0.dev20250610+rocm6.4 --index-url https://download.pytorch.org/whl/nightly/rocm6.4pip install git+https://github.com/ROCm/aiter

(For AMD) Instead of flash attention v3, we use (AITER)[https://github.com/ROCm/aiter]. It provides the required fp8 MHA kernels

For hardware, we used a 96GB 700W H100 GPU and 192GB MI300X GPU. Some of the optimizations applied (BFloat16, torch.compile, Combining q,k,v projections, dynamic float8 quantization) are available on CPU as well.

Run the optimized pipeline

On NVIDIA:

python gen_image.py --prompt"An astronaut standing next to a giant lemon" --output-file output.png --use-cached-model

This will include all optimizations and will attempt to use pre-cached binary modelsgenerated viatorch.export + AOTI. To generate these binaries for subsequent runs, runthe above command without the--use-cached-model flag.

Important

The binaries won't work for hardware that is sufficiently different from the hardware they wereobtained on. For example, if the binaries were obtained on an H100, they won't work on A100.Further, the binaries are currently Linux-only and include dependencies on specific versionsof system libs such as libstdc++; they will not work if they were generated in a sufficientlydifferent environment than the one present at runtime. The PyTorch Compiler team is working onsolutions for more portable binaries / artifact caching.

On AMD:

python gen_image.py --prompt"A cat playing with a ball of yarn" --output-file output.png --compile_export_mode compile

Currently, only torch.export is not working as expected. Instead, usetorch.compile as shown in the above command.

Benchmarking

run_benchmark.py is the main script for benchmarking the different optimization techniques.Usage:

usage: run_benchmark.py [-h] [--ckpt CKPT] [--prompt PROMPT] [--image IMAGE] [--cache-dir CACHE_DIR]                        [--use-cached-model] [--device {cuda,cpu}] [--num_inference_steps NUM_INFERENCE_STEPS]                         [--output-file OUTPUT_FILE] [--seed SEED] [--trace-file TRACE_FILE] [--disable_bf16]                        [--compile_export_mode {compile,export_aoti,disabled}] [--disable_fused_projections]                         [--disable_channels_last] [--disable_fa3] [--disable_quant] [--disable_inductor_tuning_flags]                         [--cache_dit_config CACHE_DIT_CONFIG]options:  -h, --help            show this help message and exit  --ckpt {black-forest-labs/FLUX.1-schnell,black-forest-labs/FLUX.1-dev,black-forest-labs/FLUX.1-Kontext-dev}                        Model checkpoint path (default: black-forest-labs/FLUX.1-schnell)  --prompt PROMPT       Text prompt (default: A cat playing with a ball of yarn)  --image IMAGE         Image to use for Kontext (default: None)  --cache-dir CACHE_DIR                        Cache directory for storing exported models (default: ~/.cache/flux-fast)  --use-cached-model    Attempt to use cached model only (don't re-export) (default: False)  --device {cuda,cpu}   Device to use (default: cuda)  --num_inference_steps NUM_INFERENCE_STEPS                        Number of denoising steps (default: 4)  --output-file OUTPUT_FILE                        Output image file path (default: output.png)  --seed SEED           Random seed to use (default: 42)  --trace-file TRACE_FILE                        Output PyTorch Profiler trace file path (default: None)  --disable_bf16        Disables usage of torch.bfloat16 (default: False)  --compile_export_mode {compile,export_aoti,disabled}                        Configures how torch.compile or torch.export + AOTI are used (default: export_aoti)  --disable_fused_projections                        Disables fused q,k,v projections (default: False)  --disable_channels_last                        Disables usage of torch.channels_last memory format (default: False)  --disable_fa3         Disables use of Flash Attention V3 (default: False)  --disable_quant       Disables usage of dynamic float8 quantization (default: False)  --disable_inductor_tuning_flags                        Disables use of inductor tuning flags (default: False)  --cache_dit_config CACHE_DIT_CONFIG                        Cache options config of cache-dit: DBCache (default: None)

Note that all optimizations are on by default and each can be individually toggled. Example run:

# Run with all optimizations and output a trace file alongside benchmark numberspython run_benchmark.py --trace-file profiler_trace.json.gz

After an experiment has been run, you should expect to seemean / variance times in seconds for 10 benchmarking runs printed to STDOUT, as well as:

  • A.png image file corresponding to the experiment (e.g.output.png). The path can be configured via--output-file.
  • An optional PyTorch profiler trace (e.g.profiler_trace.json.gz). The path can be configured via--trace-file

Important

For benchmarking purposes, we use reasonable defaults. For example, for all the benchmarking experiments, we usethe 1024x1024 resolution. For Schnell, we use 4 denoising steps, and for Dev and Kontext, we use 28.

Flux.1 Kontext Dev

We ran the exact same setup as above onFlux.1 Kontext Dev and obtained the following result:

flux_kontext_plot

Here are some example outputs for prompt"Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors" andthis image:

ConfigurationOutput
Baselinebaseline_output
Fully-optimized (with quantization)fast_output
Notes
  • You need to installdiffusers withthis fix included
  • You need to installtorchao withthis fix included

Improvements, progressively

Baseline

For completeness, we demonstrate a (terrible) baseline here using the defaulttorch.float32 dtype.There's no practical reason do this over loading intorch.bfloat16, and the results are slow enoughthat they ruin the readability of the graph above when included (~7.5 sec).

fromdiffusersimportFluxPipeline# Load the pipeline in full-precision and place its model components on CUDA.pipeline=FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell").to("cuda")prompt="A cat playing with a ball of yarn"image=pipe(prompt,num_inference_steps=4).images[0]
BFloat16
fromdiffusersimportFluxPipeline# Load the pipeline in full-precision and place its model components on CUDA.pipeline=FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell",torch_dtype=torch.bfloat16).to("cuda")prompt="A cat playing with a ball of yarn"image=pipe(prompt,num_inference_steps=4).images[0]
torch.compile
fromdiffusersimportFluxPipeline# Load the pipeline in full-precision and place its model components on CUDA.pipeline=FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell").to("cuda")# Compile the compute-intensive portions of the model: denoising transformer / decoder# "max-autotune" mode tunes kernel hyperparameters and applies CUDAGraphspipeline.transformer=torch.compile(pipeline.transformer,mode="max-autotune",fullgraph=True)pipeline.vae.decode=torch.compile(pipeline.vae.decode,mode="max-autotune",fullgraph=True)# warmup for a few iterations; trigger compilationfor_inrange(3):pipeline("dummy prompt to trigger torch compilation",output_type="pil",num_inference_steps=4,    ).images[0]prompt="A cat playing with a ball of yarn"image=pipe(prompt,num_inference_steps=4).images[0]
Combining attention projection matrices
fromdiffusersimportFluxPipeline# Load the pipeline in full-precision and place its model components on CUDA.pipeline=FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell").to("cuda")# Use channels_last memory formatpipeline.vae=pipeline.vae.to(memory_format=torch.channels_last)# Combine attention projection matrices for (q, k, v)pipeline.transformer.fuse_qkv_projections()pipeline.vae.fuse_qkv_projections()# compilation details omitted (see above)...prompt="A cat playing with a ball of yarn"image=pipe(prompt,num_inference_steps=4).images[0]

Note thattorch.compile is able to perform this fusion automatically, so we do notobserve a speedup from the fusion (outside of noise) whentorch.compile is enabled.

channels_last memory format
fromdiffusersimportFluxPipeline# Load the pipeline in full-precision and place its model components on CUDA.pipeline=FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell").to("cuda")# Use channels_last memory formatpipeline.vae.to(memory_format=torch.channels_last)# compilation details omitted (see above)...prompt="A cat playing with a ball of yarn"image=pipe(prompt,num_inference_steps=4).images[0]
Flash Attention V3 / aiter

Flash Attention V3 is substantially faster on H100s than the previous iteration FA2, duein large part to float8 support. As this kernel isn't quite available yet within PyTorch Core, we implement a customattention processorFlashFusedFluxAttnProcessor3_0 that uses theflash_attn_interfacepython bindings directly. We also ensure proper PyTorch custom op integration so thatthe op integrates well withtorch.compile /torch.export. Inputs are converted to float8 in an unscaled fashion beforekernel invocation and outputs are converted back to the original dtype on the way out.

On AMD GPUs, we useaiter instead, which also provides fp8 MHA kernels.

fromdiffusersimportFluxPipeline# Load the pipeline in full-precision and place its model components on CUDA.pipeline=FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell").to("cuda")# Use channels_last memory formatpipeline.vae.to(memory_format=torch.channels_last)# Combine attention projection matrices for (q, k, v)pipeline.transformer.fuse_qkv_projections()pipeline.vae.fuse_qkv_projections()# Use FA3; reference FlashFusedFluxAttnProcessor3_0 impl for detailspipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())# compilation details omitted (see above)...prompt="A cat playing with a ball of yarn"image=pipe(prompt,num_inference_steps=4).images[0]
float8 quantization
fromdiffusersimportFluxPipeline# Load the pipeline in full-precision and place its model components on CUDA.pipeline=FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell").to("cuda")# Use channels_last memory formatpipeline.vae.to(memory_format=torch.channels_last)# Combine attention projection matrices for (q, k, v)pipeline.transformer.fuse_qkv_projections()pipeline.vae.fuse_qkv_projections()# Use FA3; reference FlashFusedFluxAttnProcessor3_0 impl for detailspipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())# Apply float8 quantization on weights and activationsfromtorchao.quantizationimportquantize_,float8_dynamic_activation_float8_weightquantize_(pipeline.transformer,float8_dynamic_activation_float8_weight(),)# compilation details omitted (see above)...prompt="A cat playing with a ball of yarn"image=pipe(prompt,num_inference_steps=4).images[0]
Inductor tuning flags
fromdiffusersimportFluxPipeline# Load the pipeline in full-precision and place its model components on CUDA.pipeline=FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell").to("cuda")# Use channels_last memory formatpipeline.vae.to(memory_format=torch.channels_last)# Combine attention projection matrices for (q, k, v)pipeline.transformer.fuse_qkv_projections()pipeline.vae.fuse_qkv_projections()# Use FA3; reference FlashFusedFluxAttnProcessor3_0 impl for detailspipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())# Apply float8 quantization on weights and activationsfromtorchao.quantizationimportquantize_,float8_dynamic_activation_float8_weightquantize_(pipeline.transformer,float8_dynamic_activation_float8_weight(),)# Tune Inductor flagsconfig=torch._inductor.configconfig.conv_1x1_as_mm=True# treat 1x1 convolutions as matrix muls# adjust autotuning algorithmconfig.coordinate_descent_tuning=Trueconfig.coordinate_descent_check_all_directions=Trueconfig.epilogue_fusion=False# do not fuse pointwise ops into matmuls# compilation details omitted (see above)...prompt="A cat playing with a ball of yarn"image=pipe(prompt,num_inference_steps=4).images[0]
torch.export + Ahead-Of-Time Inductor (AOTI)

To avoid initial compilation times, we can usetorch.export + Ahead-Of-Time Inductor (AOTI). This willserialize a binary, precompiled form of the model without initial compilation overhead.

# Apply torch.export + AOTI. If serialize=True, writes out the exported models within the cache_dir.# Otherwise, attempts to load previously-exported models from the cache_dir.# This function also applies CUDAGraphs on the loaded models.defuse_export_aoti(pipeline,cache_dir,serialize=False):fromtorch._inductor.packageimportload_package# create cache dir if neededpathlib.Path(cache_dir).mkdir(parents=True,exist_ok=True)def_example_tensor(*shape):returntorch.randn(*shape,device="cuda",dtype=torch.bfloat16)# === Transformer export ===# torch.export requires a representative set of example args to be passed intransformer_kwargs= {"hidden_states":_example_tensor(1,4096,64),"timestep":torch.tensor([1.],device="cuda",dtype=torch.bfloat16),"guidance":None,"pooled_projections":_example_tensor(1,768),"encoder_hidden_states":_example_tensor(1,512,4096),"txt_ids":_example_tensor(512,3),"img_ids":_example_tensor(4096,3),"joint_attention_kwargs": {},"return_dict":False,    }# Possibly serialize model outtransformer_package_path=os.path.join(cache_dir,"exported_transformer.pt2")ifserialize:# Apply exportexported_transformer:torch.export.ExportedProgram=torch.export.export(pipeline.transformer,args=(),kwargs=transformer_kwargs        )# Apply AOTIpath=torch._inductor.aoti_compile_and_package(exported_transformer,package_path=transformer_package_path,inductor_configs={"max_autotune":True,"triton.cudagraphs":True},        )loaded_transformer=load_package(transformer_package_path,run_single_threaded=True    )# warmup before cudagraphingwithtorch.no_grad():loaded_transformer(**transformer_kwargs)# Apply CUDAGraphs. CUDAGraphs are utilized in torch.compile with mode="max-autotune", but# they must be manually applied for torch.export + AOTI.loaded_transformer=cudagraph(loaded_transformer)pipeline.transformer.forward=loaded_transformer# warmup after cudagraphingwithtorch.no_grad():pipeline.transformer(**transformer_kwargs)# hack to get around export's limitationspipeline.vae.forward=pipeline.vae.decodevae_decode_kwargs= {"return_dict":False,    }# Possibly serialize model outdecoder_package_path=os.path.join(cache_dir,"exported_decoder.pt2")ifserialize:# Apply exportexported_decoder:torch.export.ExportedProgram=torch.export.export(pipeline.vae,args=(_example_tensor(1,16,128,128),),kwargs=vae_decode_kwargs        )# Apply AOTIpath=torch._inductor.aoti_compile_and_package(exported_decoder,package_path=decoder_package_path,inductor_configs={"max_autotune":True,"triton.cudagraphs":True},        )loaded_decoder=load_package(decoder_package_path,run_single_threaded=True)# warmup before cudagraphingwithtorch.no_grad():loaded_decoder(_example_tensor(1,16,128,128),**vae_decode_kwargs)loaded_decoder=cudagraph(loaded_decoder)pipeline.vae.decode=loaded_decoder# warmup for a few iterationsfor_inrange(3):pipeline("dummy prompt to trigger torch compilation",output_type="pil",num_inference_steps=4,        ).images[0]returnpipeline

Note that, unlike fortorch.compile, running a model loaded from the torch.export + AOTI workflowdoesn't use CUDAGraphs by default. This was found to result in a ~5% performance decrease vs. torch.compile.To address this discrepancy, we manually record / replay CUDAGraphs over the exported models using the following helper:

# wrapper to automatically handle CUDAGraph record / replay over the given functiondefcudagraph(f):fromtorch.utils._pytreeimporttree_map_only_graphs= {}deff_(*args,**kwargs):key=hash(tuple(tuple(kwargs[a].shape)forainsorted(kwargs.keys())ifisinstance(kwargs[a],torch.Tensor)))ifkeyin_graphs:# use the cached wrapper if one exists. this will perform CUDAGraph replaywrapped,*_=_graphs[key]returnwrapped(*args,**kwargs)# record a new CUDAGraph and cache it for future useg=torch.cuda.CUDAGraph()in_args,in_kwargs=tree_map_only(torch.Tensor,lambdat:t.clone(), (args,kwargs))f(*in_args,**in_kwargs)# stream warmupwithtorch.cuda.graph(g):out_tensors=f(*in_args,**in_kwargs)defwrapped(*args,**kwargs):# note that CUDAGraphs require inputs / outputs to be in fixed memory locations.# inputs must be copied into the fixed input memory locations.            [a.copy_(b)fora,binzip(in_args,args)ifisinstance(a,torch.Tensor)]forkeyinkwargs:ifisinstance(kwargs[key],torch.Tensor):in_kwargs[key].copy_(kwargs[key])g.replay()# clone() outputs on the way out to disconnect them from the fixed output memory# locations. this allows for CUDAGraph reuse without accidentally overwriting memoryreturn [o.clone()foroinout_tensors]# cache function that does CUDAGraph replay_graphs[key]= (wrapped,g,in_args,in_kwargs,out_tensors)returnwrapped(*args,**kwargs)returnf_

Finally, here is the fully-optimized form of the model:

fromdiffusersimportFluxPipeline# Load the pipeline in full-precision and place its model components on CUDA.pipeline=FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell").to("cuda")# Use channels_last memory formatpipeline.vae.to(memory_format=torch.channels_last)# Combine attention projection matrices for (q, k, v)pipeline.transformer.fuse_qkv_projections()pipeline.vae.fuse_qkv_projections()# Use FA3; reference FlashFusedFluxAttnProcessor3_0 impl for detailspipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())# Apply float8 quantization on weights and activationsfromtorchao.quantizationimportquantize_,float8_dynamic_activation_float8_weightquantize_(pipeline.transformer,float8_dynamic_activation_float8_weight(),)# Tune Inductor flagsconfig=torch._inductor.configconfig.conv_1x1_as_mm=True# treat 1x1 convolutions as matrix muls# adjust autotuning algorithmconfig.coordinate_descent_tuning=Trueconfig.coordinate_descent_check_all_directions=Trueconfig.epilogue_fusion=False# do not fuse pointwise ops into matmuls# Apply torch.export + AOTI with CUDAGraphspipeline=use_export_aoti(pipeline,cache_dir=args.cache_dir,serialize=False)prompt="A cat playing with a ball of yarn"image=pipe(prompt,num_inference_steps=4).images[0]
cache acceleration with cache-dit: DBCache

You can usecache-dit to further speedup FLUX model, different configurations of compute blocks (F12B12, etc.) can be customized in cache-dit: DBCache. Please checkcache-dit for more details. For example:

# Install: pip install -U cache-ditfromdiffusersimportFluxPipelinefromcache_dit.cache_factoryimportapply_cache_on_pipe,CacheTypepipeline=FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev",torch_dtype=torch.bfloat16,).to("cuda")# cache-dit: DBCache configscache_options= {"cache_type":CacheType.DBCache,"warmup_steps":0,"max_cached_steps":-1,# -1 means no limit"Fn_compute_blocks":1,# Fn, F1, F12, etc."Bn_compute_blocks":0,# Bn, B0, B12, etc."residual_diff_threshold":0.12,# TaylorSeer options"enable_taylorseer":True,"enable_encoder_taylorseer":True,# Taylorseer cache type cache be hidden_states or residual"taylorseer_cache_type":"residual","taylorseer_kwargs": {"n_derivatives":2,    },}apply_cache_on_pipe(pipeline,**cache_options)

By the way,cache-dit is designed to work compatibly with torch.compile. You can easily usecache-dit with torch.compile to further achieve a better performance. For example:

apply_cache_on_pipe(pipeline,**cache_options)# The cache-dit relies heavily on dynamic Python operations to maintain the cache_context,# so it is necessary to introduce graph breaks at appropriate positions to be compatible# with torch.compile. Thus, we compile the transformer with `max-autotune-no-cudagraphs`# mode if cache-dit is enabled. Otherwise, we compile with `max-autotune` mode.pipeline.transformer=torch.compile(pipeline.transformer,mode="max-autotune-no-cudagraphs",fullgraph=False, )

Under the configuration ofcache-dit + F1B0 + no warmup + TaylorSeer, it only takes 7.42 seconds on NVIDIA L20, with a cumulative speedup of 3.36x (compared to the baseline of 24.94 seconds), while still maintaining high precision with a PSNR of 23.23.

FLUX.1-dev 28 steps, Baseline: BF16 + w/o torch.compile + w/o cache-ditBF16 + compile + qkv projection + channels_last + float8 quant + inductor flagsBF16 + compile + qkv projection + channels_last + float8 quant + inductor flags +cache-dit + F1B0 + no warmup + TaylorSeer
PSNR: infPSNR: 21.77PSNR: 23.23
L20: 24.94sL20: 13.26sL20: 7.42s
outputbf16_compile_qkv_chan_quant_flags_trnbf16_cache_F1B0W0M0_taylorseer_compile_qkv_chan_quant_flags_trn

Important Notes

  1. Please add--cache_dit_config cache_config.yaml flag to use cache-dit. cache-dit doesn't work with torch.export now. cache-dit extends Flux and introduces some Python dynamic operations, so it may not be possible to export the model using torch.export.
  2. Please modify thecache_config.yaml file to change the configuration of cache-dit: DBCache, so as to test the effects and performance under different configurations.

[8]ページ先頭

©2009-2025 Movatter.jp