Note
Go to the endto download the full example code.
Mosaic: Memory Profiling for PyTorch#
Author:Basil Wong
How to capture and analyze PyTorch memory snapshots
Identify memory savings from activation checkpointing
Debug unexpected memory usage from abandoned code
Integrate memory analysis into training pipelines
PyTorch v2.0.0 or later
CUDA-capable GPU
Basic understanding of PyTorch training loops
This tutorial demonstrates how to useMosaic, a post-processing memorysnapshot analysis tool for PyTorch. Mosaic helps analyze GPU memory usage indistributed deep learning, providing detailed insights into memory allocations,peak usage, and memory imbalances across parallel workers.
Mosaic was instrumental in debugging OOM issues during the405B LLaMA trainingand is now open source.
Introduction to Mosaic#
Overview#
In distributed deep learning, understanding GPU memory usage is criticalfor optimizing training efficiency and debugging Out-of-Memory (OOM) errors.Mosaic is a post-analysis tool for memory usage designed to work withlarge-scale jobs. It helps analyze PyTorch memory snapshots captured duringthe execution of PyTorch training jobs, providing detailed insights intomemory allocations, peak usage, and memory imbalances across parallel workers.
Getting Started#
Clone the mosaic repository and install from the mosaic directory:
gitclonehttps://github.com/facebookresearch/mosaiccdmosaicpython3-mvenvvenvsourcevenv/bin/activatepip3install-rrequirements.txtpip3install-e.
Alternatively, install directly via pip:
pipinstallgit+https://github.com/facebookresearch/mosaic.git
Simple Usage Examples#
1. Peak Memory Usage Analysis
When addressing memory problems like OOM errors, focusing on peak memoryusage is crucial. Themosaic_get_memory_usage_peak command presents astack trace of the memory allocations that contributed to the peak memoryusage:
mosaic_get_memory_usage_peak--snapshot<pathtosnapshot>
2. Categorical Memory Profiling
Mosaic classifies allocations into categories (activation, backward,optimizer, etc.):
Activation Memory: Tensors saved for backward pass
Gradient Memory: Gradients computed during backpropagation
Optimizer State: Adam/SGD momentum and variance buffers
Parameter Memory: Model weights
mosaic_get_memory_profile--snapshot<path>--out-path<html>\--profilecategoriesAn example HTML output looks like:

Categorical memory profiling showing memory breakdown by type(activation, gradient, optimizer, etc.)#
To maintain allocation order for the categories, add--preserve-allocation-order:
mosaic_get_memory_profile--snapshot<path>--out-path<html>\--profilecategories--preserve-allocation-order
Categorical profiling with--preserve-allocation-order shows memoryallocations in chronological order#
3. Custom Dictionary Profiling
For targeted analysis via regex pattern matching:
mosaic_get_memory_profile--snapshot<path>--profilecustom\--custom-profile'{"ncclx": "ncclx"}'
This is invaluable for tracking specific kernels, optimizers, or custom code patterns:

Custom profiling with regex patterns to track specific operations likeNCCL communications#
Dependencies and Imports#
Let’s set up the required dependencies and imports for this tutorial.
importsubprocessimportsysimportshutilfromcontextlibimportcontextmanagerimportpickle# Fix for sphinx-gallery environment where __main__.__file__ may not exist# This is needed for transformers library compatibilityimportosifnothasattr(sys.modules["__main__"],"__file__"):# Use this file's path as a fallback, or a dummy path if __file__ is not availabletry:sys.modules["__main__"].__file__=os.path.abspath(__file__)exceptNameError:# __file__ not available, use transformers modeling file as fallbackimporttransformers.modeling_utilssys.modules["__main__"].__file__=transformers.modeling_utils.__file__importtorchfromtorch.utils.dataimportDataLoader,Dataset# Install dependencies if neededtry:fromtransformersimportGPT2LMHeadModel,GPT2Tokenizerfromtransformers.modeling_outputsimportCausalLMOutputWithCrossAttentionsexceptImportError:subprocess.check_call([sys.executable,"-m","pip","install","-q","transformers"])fromtransformersimportGPT2LMHeadModel,GPT2Tokenizerfromtransformers.modeling_outputsimportCausalLMOutputWithCrossAttentionstry:frommosaic.libmosaic.analyzer.memory_abstractimportMemoryAbstractexceptImportError:subprocess.check_call([sys.executable,"-m","pip","install","-q","git+https://github.com/facebookresearch/mosaic.git",])frommosaic.libmosaic.analyzer.memory_abstractimportMemoryAbstractprint(f"PyTorch version:{torch.__version__}")print(f"CUDA available:{torch.cuda.is_available()}")iftorch.cuda.is_available():print(f"GPU:{torch.cuda.get_device_name(0)}")
Shared Utilities#
These helper classes and functions are used throughout the tutorial.
classRandomTokenDataset(Dataset):"""Generates random token sequences for training. This dataset creates random input sequences suitable for language model training, simulating real training data without requiring actual text. """def__init__(self,vocab_size,seq_length=512,num_samples=100,seed=None):self.vocab_size=vocab_sizeself.seq_length=seq_lengthself.num_samples=num_samplesself.generator=NoneifseedisnotNone:self.generator=torch.Generator().manual_seed(seed)def__len__(self):returnself.num_samplesdef__getitem__(self,idx):# noqa: ARG002ifself.generatorisnotNone:input_ids=torch.randint(0,self.vocab_size,(self.seq_length,),generator=self.generator)else:input_ids=torch.randint(0,self.vocab_size,(self.seq_length,))return{"input_ids":input_ids,"labels":input_ids.clone()}@contextmanagerdefcapture_memory_snapshot(output_path):"""Context manager to capture and save PyTorch CUDA memory snapshots. This captures all GPU memory allocations during the context and saves them to a pickle file for later analysis with Mosaic. Args: output_path: Path to save the memory snapshot pickle file. """torch.cuda.memory._record_memory_history(max_entries=100000)try:yieldfinally:snapshot=torch.cuda.memory._snapshot()torch.cuda.memory._record_memory_history(enabled=None)withopen(output_path,"wb")asf:pickle.dump(snapshot,f)print(f"✓ Memory snapshot saved to{output_path}")
Case 1: Understanding Memory Differences with Activation Checkpointing#
This section demonstrates how to use Mosaic to analyze and compare GPUmemory usage between different model configurations.
What we’ll do:
Train GPT-2 and capture a memory snapshot (baseline)
Enable activation checkpointing and train again (modified)
Use Mosaic to identify exactly where memory savings occur
Training Function for Activation Checkpointing Comparison#
defrun_training_ac(activation_checkpointing:bool,snapshot_path:str,batch_size:int=4,seq_length:int=512,num_steps:int=5,):"""Run training loop and capture memory snapshot. Args: activation_checkpointing: Whether to enable gradient checkpointing. snapshot_path: Path to save the memory snapshot. batch_size: Training batch size. seq_length: Sequence length for input tokens. num_steps: Number of training steps to run. Returns: Peak GPU memory usage in GB. """# Clear any previous memorytorch.cuda.empty_cache()torch.cuda.reset_peak_memory_stats()device=torch.device("cuda")# Load modelprint(f"Loading GPT-2 (activation_checkpointing={activation_checkpointing})...")model=GPT2LMHeadModel.from_pretrained("gpt2")ifactivation_checkpointing:model.gradient_checkpointing_enable()print("Activation checkpointing is ENABLED")else:print("Activation checkpointing is DISABLED")model=model.to(device)model.train()# Create dataset and dataloadertokenizer=GPT2Tokenizer.from_pretrained("gpt2")dataset=RandomTokenDataset(vocab_size=tokenizer.vocab_size,seq_length=seq_length,num_samples=100,)dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=True)# Setup optimizeroptimizer=torch.optim.AdamW(model.parameters(),lr=1e-5)# Training loop with memory captureprint(f"Running{num_steps} training steps...")withcapture_memory_snapshot(snapshot_path):forstep,batchinenumerate(dataloader):ifstep>=num_steps:breakbatch={k:v.to(device)fork,vinbatch.items()}optimizer.zero_grad()outputs=model(input_ids=batch["input_ids"],labels=batch["labels"])loss=outputs.lossloss.backward()optimizer.step()print(f" Step{step+1}/{num_steps}, Loss:{loss.item():.4f}")peak_memory_gb=torch.cuda.max_memory_allocated()/(1024**3)print(f"✓ Peak GPU memory:{peak_memory_gb:.2f} GB")# Cleanupdelmodel,optimizertorch.cuda.empty_cache()returnpeak_memory_gb
Run Baseline Training (Without Activation Checkpointing)#
Note
This tutorial requires a CUDA-capable GPU. If you’re running inGoogle Colab, make sure to select a GPU runtime:Runtime → Change runtime type → Hardware accelerator → GPU
ifnottorch.cuda.is_available():print("="*60)print("WARNING: No CUDA GPU detected!")print("="*60)print("\nThis tutorial requires a CUDA-capable GPU for memory profiling.")print("\nIf you're running in Google Colab:")print(" 1. Go to Runtime → Change runtime type")print(" 2. Set Hardware accelerator to 'GPU'")print(" 3. Click 'Save' and re-run the notebook")print("\nSkipping GPU memory profiling examples...")HAS_CUDA=Falseelse:HAS_CUDA=True# Check if Mosaic CLI is availableHAS_MOSAIC_CLI=shutil.which("mosaic_get_memory_profile")isnotNoneifHAS_CUDAandnotHAS_MOSAIC_CLI:print("Note: Mosaic CLI not found. Install Mosaic to generate HTML profiles.")print(" pip install git+https://github.com/facebookresearch/mosaic.git")ifHAS_CUDA:print("="*60)print("BASELINE: Training WITHOUT Activation Checkpointing")print("="*60)baseline_memory=run_training_ac(activation_checkpointing=False,snapshot_path="snapshot_baseline.pickle",batch_size=4,seq_length=512,num_steps=5,)
Run Modified Training (With Activation Checkpointing)#
ifHAS_CUDA:print("\n"+"="*60)print("MODIFIED: Training WITH Activation Checkpointing")print("="*60)ac_memory=run_training_ac(activation_checkpointing=True,snapshot_path="snapshot_with_ac.pickle",batch_size=4,seq_length=512,num_steps=5,)# Summaryprint("\n"+"="*60)print("MEMORY COMPARISON SUMMARY")print("="*60)print(f"Baseline (no AC):{baseline_memory:.2f} GB")print(f"With AC:{ac_memory:.2f} GB")ifbaseline_memory>0:saved_pct=100*(baseline_memory-ac_memory)/baseline_memoryprint(f"Memory Saved:{baseline_memory-ac_memory:.2f} GB ({saved_pct:.1f}%)")
Generate Categorical Memory Profiles with Mosaic#
Use Mosaic to generate HTML profiles for both snapshots.
ifHAS_CUDAandHAS_MOSAIC_CLI:print("\n"+"="*60)print("MOSAIC: Categorical Memory Profiling")print("="*60)# Generate HTML profiles using subprocessprint("\nGenerating baseline profile...")result1=subprocess.run(["mosaic_get_memory_profile","--snapshot","snapshot_baseline.pickle","--out-path","profile_baseline.html","--profile","categories","--preserve-allocation-order","--plotter_sampling_rate","20",],capture_output=True,text=True,)print(result1.stdout)ifresult1.stderr:print(result1.stderr)print("\nGenerating activation checkpointing profile...")result2=subprocess.run(["mosaic_get_memory_profile","--snapshot","snapshot_with_ac.pickle","--out-path","profile_with_ac.html","--profile","categories","--preserve-allocation-order","--plotter_sampling_rate","20",],capture_output=True,text=True,)print(result2.stdout)ifresult2.stderr:print(result2.stderr)ifresult1.returncode==0andresult2.returncode==0:print("\nGenerated profile_baseline.html")print("Generated profile_with_ac.html")print("\nDownload these files to view the interactive memory profiles.")else:print("\nNote: Mosaic profile generation encountered issues.")print("This may happen if running in an environment without full Mosaic support.")
Download Generated Files (Google Colab)#
If running in Google Colab, uncomment the following lines to downloadthe generated snapshot and profile files:
# from google.colab import files## print("Downloading memory snapshots and profiles...")# files.download('snapshot_baseline.pickle')# files.download('snapshot_with_ac.pickle')# files.download('profile_baseline.html')# files.download('profile_with_ac.html')
Results Interpretation: Activation Checkpointing#
The generated HTML profiles visualize memory usage over time, withallocations colored by category. Here’s what the profiles look like:

Baseline (without activation checkpointing): Notice the largeactivation memory (shown in one color) that persists throughoutthe forward pass.#

With activation checkpointing: Activation memory is significantlyreduced as intermediate activations are discarded and recomputedduring the backward pass.#
What We Observed#
Based on the Mosaic categorical profiling results:
Metric | Baseline | With Activation Checkpointing | Difference |
|---|---|---|---|
Total Peak Memory | 4.62 GB | 2.55 GB | 2.07 GB (45% reduction) |
Activation Memory | 2.93 GB | 872.79 MB | 2.08 GB saved (71% reduction) |
Backward/Gradient Memory | 793.39 MB | 785.27 MB | 8 MB (minimal change) |
Optimizer State | 949.4 MB | 949.4 MB | No change |
Unknown | 32 KB | 32 KB | No change |
Key Insights#
Primary Finding: Activation memory dropped from2.93 GB → 872 MB(71% reduction), which accounts for nearly all the total memory savings.
Why Does This Happen?#
Activation checkpointing is a memory optimization technique that:
Without AC (Baseline): All intermediate activations from the forwardpass are stored in memory for use during backpropagation. GPT-2 has 12transformer layers, each storing multiple activations (attention outputs,MLP outputs, etc.). For batch_size=4, seq_length=512, this adds up quickly.
With AC (Optimized): Only activations at checkpoint boundaries arestored; intermediate activations are recomputed during the backward pass.This dramatically reduces activation memory (71% in our case) while othermemory categories remain unchanged.
How Mosaic Helped#
Mosaic’s categorical profiling immediately identified:
Activation memory is the category with the largest difference (2.08 GB saved)
Backward/Gradient memory stayed nearly constant (793 MB → 785 MB)
Optimizer state remained unchanged (949 MB) - expected since modelparameters don’t change
Without Mosaic: You would need to manually instrument your code, trackallocations, and categorize them yourself.
With Mosaic: You get instant categorical breakdowns with exact numbers,making it trivial to identify/quantify memory optimizations.
Case 2: Debugging Unexpected Memory Usage#
This section demonstrates how to use Mosaic to debug when your model isusing more memory than expected and you’re not sure why.
What we’ll do:
Train GPT-2 and capture a memory snapshot.
Train GPT-2 with a bug that introduces additional memory and capturea memory snapshot.
Use Mosaic to identify potential culprits introducing additional memory.
The Buggy Model#
This model hasabandoned debug code that creates unnecessary GPU memoryoverhead. Someone added projection layers to “analyze hidden states” duringdebugging, but forgot to remove them before training.
classGPT2WithDebugOverhead(torch.nn.Module):"""GPT2 wrapper with abandoned 'feature analysis' code that bloats peak memory. This wrapper adds extra projection layers that consume memory but serve no purpose - simulating abandoned debug code that was never cleaned up. """def__init__(self,base_model):super().__init__()self.base_model=base_modelconfig=base_model.config# BUG: Large projection layers from an abandoned experimentself.debug_projections=torch.nn.ModuleList([torch.nn.Linear(config.n_embd,config.n_embd*4)for_inrange(config.n_layer)])debug_params=sum(p.numel()forpinself.debug_projections.parameters())print(f" [DEBUG] Added{config.n_layer} debug projection layers")print(f" [DEBUG] Extra parameters:{debug_params:,}")defforward(self,input_ids=None,labels=None,**kwargs):# Run normal GPT-2 forward with hidden statesoutputs=self.base_model(input_ids=input_ids,labels=labels,output_hidden_states=True,**kwargs,)# BUG: Project all hidden states through debug layersprojected=[]for_layer_idx,(hidden,proj)inenumerate(zip(outputs.hidden_states[1:],self.debug_projections)):proj_hidden=proj(hidden)projected.append(proj_hidden)# Tie to loss so gradients flow throughdebug_regularization=sum(p.mean()forpinprojected)*1e-10returnCausalLMOutputWithCrossAttentions(loss=outputs.loss+debug_regularization,logits=outputs.logits,)
Training Functions for Debug Comparison#
defrun_training_clean(snapshot_path,num_steps=3):"""Training with the normal model."""torch.cuda.empty_cache()torch.cuda.reset_peak_memory_stats()device=torch.device("cuda")print("Loading clean model (no debug overhead)...")model=GPT2LMHeadModel.from_pretrained("gpt2").to(device)model.train()tokenizer=GPT2Tokenizer.from_pretrained("gpt2")dataset=RandomTokenDataset(vocab_size=tokenizer.vocab_size,seq_length=512,seed=42)dataloader=DataLoader(dataset,batch_size=4,shuffle=False)optimizer=torch.optim.AdamW(model.parameters(),lr=1e-5)print("Running training (should contain no debug overhead)...")withcapture_memory_snapshot(snapshot_path):forstep,batchinenumerate(dataloader):ifstep>=num_steps:breakbatch={k:v.to(device)fork,vinbatch.items()}optimizer.zero_grad()outputs=model(input_ids=batch["input_ids"],labels=batch["labels"])loss=outputs.lossloss.backward()optimizer.step()print(f" Step{step+1}, Loss:{loss.item():.4f}")peak_memory=torch.cuda.max_memory_allocated()/1024**3print(f"Peak GPU memory:{peak_memory:.2f} GB")delmodel,optimizertorch.cuda.empty_cache()returnpeak_memorydefrun_training_with_bug(snapshot_path,num_steps=3):"""Training with the buggy model."""torch.cuda.empty_cache()torch.cuda.reset_peak_memory_stats()device=torch.device("cuda")print("Loading buggy model with debug overhead...")# Load pretrained GPT-2 and wrap it with the debug overheadbase_model=GPT2LMHeadModel.from_pretrained("gpt2")model=GPT2WithDebugOverhead(base_model).to(device)model.train()tokenizer=GPT2Tokenizer.from_pretrained("gpt2")dataset=RandomTokenDataset(vocab_size=tokenizer.vocab_size,seq_length=512,seed=42)dataloader=DataLoader(dataset,batch_size=4,shuffle=False)optimizer=torch.optim.AdamW(model.parameters(),lr=1e-5)print("Running training (WITH debug overhead bug)...")withcapture_memory_snapshot(snapshot_path):forstep,batchinenumerate(dataloader):ifstep>=num_steps:breakbatch={k:v.to(device)fork,vinbatch.items()}optimizer.zero_grad()outputs=model(input_ids=batch["input_ids"],labels=batch["labels"])loss=outputs.lossloss.backward()optimizer.step()print(f" Step{step+1}, Loss:{loss.item():.4f}")peak_memory=torch.cuda.max_memory_allocated()/1024**3print(f"Peak GPU memory:{peak_memory:.2f} GB")delmodel,optimizertorch.cuda.empty_cache()returnpeak_memory
Run Training for Baseline (Clean Model)#
ifHAS_CUDA:print("\n"+"="*60)print("Training with baseline model")print("="*60)baseline_memory_debug=run_training_clean("snapshot_debug_baseline.pickle",num_steps=3)
Run Training WITH the Bug#
ifHAS_CUDA:print("\n"+"="*60)print("Training with debug projection overhead (BUG)")print("="*60)buggy_memory=run_training_with_bug("snapshot_with_bug.pickle",num_steps=3)
Use Mosaic to Find the Problem#
Analyze both snapshots to identify the source of extra memory usage.We’ll run Mosaic’s peak memory analysis on each snapshot separately.
Analyze the Baseline (Clean) Snapshot#
ifHAS_CUDAandHAS_MOSAIC_CLI:print("="*60)print("MOSAIC: Analyzing the Baseline Snapshot")print("="*60)result=subprocess.run(["mosaic_get_memory_usage_peak","--snapshot","snapshot_debug_baseline.pickle"],capture_output=True,text=True,)print(result.stdout)ifresult.stderr:print(result.stderr)
Analyze the Buggy Snapshot#
ifHAS_CUDAandHAS_MOSAIC_CLI:print("="*60)print("MOSAIC: Analyzing the Buggy Snapshot")print("="*60)result=subprocess.run(["mosaic_get_memory_usage_peak","--snapshot","snapshot_with_bug.pickle"],capture_output=True,text=True,)print(result.stdout)ifresult.stderr:print(result.stderr)
Analyzing The Mosaic Output#
When you run Mosaic’s peak memory analysis, it shows stack traces for eachmemory allocation. Let’s look at how to find abandoned or unnecessary codethat’s bloating the memory.
1. Optimizer State Allocations Delta
In the buggy snapshot output, we can see that the first two stack tracesrepresent theoptimizer state allocations (likezeros_like for Adamoptimizer state). Seetorch/optim/adam.py in the stack trace.
In the snapshot of the buggy model we can see around a total of 0.21 GBmore memory:
Version | Stack Trace Position | Calls | Memory (per trace) |
|---|---|---|---|
Buggy model | 1st and 2nd | 172 calls | 0.569 GB + 0.569 GB |
Baseline | 2nd and 3rd | 148 calls | 0.464 GB + 0.464 GB |
What this tells us: The optimizer is tracking more tensors! This is yourfirst clue that there are extra parameters or tensors in the computation graph.
2. Additional Activation Allocations
The buggy version showsextra allocations that don’t appear in thebaseline model. Scrolling down the Mosaic output of the buggy model we cansee additional stack traces which contain:
torch::autograd::Engine::evaluate_function: We’re in the backward passAddmmBackward0::apply: Computing gradients for an addmm operationempty_cudaat the bottom: Allocating a new CUDA tensor to storethe gradient
0.176 GB from matrix multiply gradients (
AddmmBackward0,mm_mat1_backward)
Memory Total Explanation#
Total Peak Dynamic Memory Usage: This is the peak memory that changesduring execution, measured relative to the starting point of the snapshot.It tracks memory allocations that occur during the traced execution timeline.
Total Static Memory Usage: This is the “starting memory” or baselinememory that exists before tracing begins. It’s estimated by the PyTorchvisualizer and remains constant throughout the snapshot (doesn’t come withstack traces).
Note
In the snapshots you may observe differences in totalstatic memoryusage, which accounts for the remaining difference.
Total Overall Peak Memory Usage: Dynamic + Static
ifHAS_CUDA:print("\n"+"="*60)print("COMPARISON")print("="*60)print(f"Baseline (clean model):{baseline_memory_debug:.2f} GB")print(f"With bug (debug projections):{buggy_memory:.2f} GB")print(f"Extra memory from bug:{buggy_memory-baseline_memory_debug:.2f} GB")
Case 3: Integrating Memory Analysis into Your Training Pipeline#
This section demonstrates how to use Mosaic to automatically capture memorysnapshots during training, get structured memory breakdown data formonitoring/dashboards, and build automated memory monitoring for large-scaletraining using Mosaicprogrammatically (as a Python dependency).
Mosaic integrates memory analysis directly into your training pipeline.
Training with Automatic Memory Capture#
defrun_training_with_memory_capture(batch_size=4,seq_length=512,num_steps=5,snapshot_path="training_snapshot.pickle",):"""Run training and automatically capture memory snapshot."""torch.cuda.empty_cache()torch.cuda.reset_peak_memory_stats()device=torch.device("cuda")model=GPT2LMHeadModel.from_pretrained("gpt2").to(device)model.train()tokenizer=GPT2Tokenizer.from_pretrained("gpt2")dataset=RandomTokenDataset(tokenizer.vocab_size,seq_length)dataloader=DataLoader(dataset,batch_size=batch_size)optimizer=torch.optim.AdamW(model.parameters(),lr=1e-5)print(f"Running{num_steps} training steps with memory capture...")withcapture_memory_snapshot(snapshot_path):forstep,batchinenumerate(dataloader):ifstep>=num_steps:breakbatch={k:v.to(device)fork,vinbatch.items()}optimizer.zero_grad()outputs=model(input_ids=batch["input_ids"],labels=batch["labels"])outputs.loss.backward()optimizer.step()print(f" Step{step+1}/{num_steps}, Loss:{outputs.loss.item():.4f}")peak_memory_gb=torch.cuda.max_memory_allocated()/1024**3print(f"✓ PyTorch reported peak memory:{peak_memory_gb:.3f} GB")delmodel,optimizertorch.cuda.empty_cache()returnsnapshot_pathifHAS_CUDA:print("\n"+"="*60)print("CASE 3: Pipeline Integration")print("="*60)pipeline_snapshot_path=run_training_with_memory_capture(batch_size=4,seq_length=512)
Mosaic Memory Analysis via Python API#
Instead of using CLI commands, we can use Mosaic’s Python API directlyfor programmatic integration.
ifHAS_CUDA:print("\n"+"="*60)print("MOSAIC MEMORY ANALYSIS (via Python API)")print("="*60)# Load and analyze the memory snapshotmemory_abstract=MemoryAbstract(memory_snapshot_file=pipeline_snapshot_path)memory_abstract.load_memory_snapshot()# Analyze peak memory usagememory_abstract.memory_snapshot.analyze_memory_snapshot(opt="memory_peak")# Get resultsdynamic_peak=memory_abstract.memory_snapshot.dynamic_memory_peakstatic_memory=memory_abstract.memory_snapshot.static_memoryoverall_peak=dynamic_peak+static_memoryprint(f"Peak dynamic memory:{dynamic_peak/1024**3:.3f} GiB")print(f"Static memory:{static_memory/1024**3:.3f} GiB")print(f"Overall peak memory:{overall_peak/1024**3:.3f} GiB")print("✓ Analysis complete using Mosaic Python API")
Reusable Memory Analysis Function#
Create a reusable function for analyzing training memory snapshots.
defanalyze_training_memory(snapshot_path):"""Analyze a memory snapshot using Mosaic's Python API. Returns a structured dictionary with memory breakdown. Args: snapshot_path: Path to the memory snapshot pickle file. Returns: Dictionary containing memory analysis results. """# Load snapshotmemory_abstract=MemoryAbstract(memory_snapshot_file=snapshot_path)memory_abstract.load_memory_snapshot()# Analyze peak memorymemory_abstract.memory_snapshot.analyze_memory_snapshot(opt="memory_peak")# Extract resultsdynamic_peak=memory_abstract.memory_snapshot.dynamic_memory_peakstatic_memory=memory_abstract.memory_snapshot.static_memoryoverall_peak=dynamic_peak+static_memoryreturn{"snapshot_path":snapshot_path,"dynamic_peak_memory_bytes":dynamic_peak,"static_memory_bytes":static_memory,"overall_peak_memory_bytes":overall_peak,"dynamic_peak_memory_gib":dynamic_peak/1024**3,"static_memory_gib":static_memory/1024**3,"overall_peak_memory_gib":overall_peak/1024**3,}ifHAS_CUDA:analysis=analyze_training_memory(pipeline_snapshot_path)print("\nMemory Analysis Result:")forkey,valueinanalysis.items():print(f"{key}:{value}")
Complete Training Pipeline with Memory Monitoring#
This demonstrates a production-ready training pipeline with integratedMosaic memory monitoring that can be used in CI/CD, monitoring dashboards,or capacity planning.
deftraining_pipeline_with_memory_monitoring(model_name:str,batch_size:int,seq_length:int,num_steps:int=5,snapshot_path:str="pipeline_snapshot.pickle",)->dict:"""Complete training pipeline with integrated Mosaic memory monitoring. Can be integrated into CI/CD, monitoring dashboards, or capacity planning. Args: model_name: HuggingFace model name to use. batch_size: Training batch size. seq_length: Sequence length for input tokens. num_steps: Number of training steps. snapshot_path: Path to save the memory snapshot. Returns: Dictionary containing training and memory analysis report. """device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")# Setupprint(f"Loading model:{model_name}")model=GPT2LMHeadModel.from_pretrained(model_name).to(device)model.train()optimizer=torch.optim.AdamW(model.parameters(),lr=1e-5)tokenizer=GPT2Tokenizer.from_pretrained(model_name)torch.cuda.empty_cache()torch.cuda.reset_peak_memory_stats()# Training with memory captureprint(f"Running{num_steps} training steps...")withcapture_memory_snapshot(snapshot_path):forstepinrange(num_steps):input_ids=torch.randint(0,tokenizer.vocab_size,(batch_size,seq_length)).to(device)outputs=model(input_ids=input_ids,labels=input_ids)outputs.loss.backward()optimizer.step()optimizer.zero_grad()print(f" Step{step+1}/{num_steps}, Loss:{outputs.loss.item():.4f}")pytorch_peak_gb=torch.cuda.max_memory_allocated()/1024**3# Mosaic analysis using Python APIprint("Analyzing memory with Mosaic...")memory_abstract=MemoryAbstract(memory_snapshot_file=snapshot_path)memory_abstract.load_memory_snapshot()memory_abstract.memory_snapshot.analyze_memory_snapshot(opt="memory_peak")dynamic_peak=memory_abstract.memory_snapshot.dynamic_memory_peakstatic_memory=memory_abstract.memory_snapshot.static_memoryoverall_peak=dynamic_peak+static_memoryreport={"model":model_name,"config":{"batch_size":batch_size,"seq_length":seq_length,"num_steps":num_steps,},"pytorch_peak_memory_gb":pytorch_peak_gb,"mosaic_analysis":{"dynamic_peak_gib":dynamic_peak/1024**3,"static_memory_gib":static_memory/1024**3,"overall_peak_gib":overall_peak/1024**3,},"snapshot_path":snapshot_path,}delmodel,optimizertorch.cuda.empty_cache()returnreport# Run the pipelineifHAS_CUDA:report=training_pipeline_with_memory_monitoring("gpt2",batch_size=4,seq_length=512,num_steps=5)print("\n"+"="*60)print("PIPELINE REPORT")print("="*60)print(f"Model:{report['model']}")print(f"Config:{report['config']}")print(f"PyTorch Peak Memory:{report['pytorch_peak_memory_gb']:.3f} GB")print(f"Mosaic Dynamic Peak:{report['mosaic_analysis']['dynamic_peak_gib']:.3f} GiB")print(f"Mosaic Overall Peak:{report['mosaic_analysis']['overall_peak_gib']:.3f} GiB")
CI/CD and Dashboard Integration Patterns#
These patterns show how to integrate Mosaic analysis into automatedworkflows.
importjson
Pattern 1: CI/CD Memory Regression Testing#
defcheck_memory_regression(report,threshold_gib=5.0):"""Check if memory usage exceeds threshold for CI/CD pipelines. Args: report: Memory analysis report from training_pipeline_with_memory_monitoring. threshold_gib: Maximum allowed memory in GiB. Raises: AssertionError: If memory exceeds threshold. """peak=report["mosaic_analysis"]["overall_peak_gib"]assertpeak<threshold_gib,(f"Memory regression!{peak:.2f} GiB >{threshold_gib} GiB")print(f"Memory check passed:{peak:.2f} GiB <{threshold_gib} GiB threshold")
Pattern 2: Export to JSON for Dashboards#
ifHAS_CUDA:check_memory_regression(report,threshold_gib=8.0)withopen("memory_report.json","w")asf:json.dump(report,f,indent=2,default=str)print("Memory report exported to memory_report.json")
Conclusion#
This tutorial demonstrated three key use cases for Mosaic memory profiling:
Case 1: Activation Checkpointing Analysis
Used Mosaic to compare memory usage between baseline and optimized models
Identified that activation checkpointing reduced activation memory by 71%
Mosaic’s categorical profiling made it trivial to pinpoint memory savings
Case 2: Debugging Unexpected Memory Usage
Created a “buggy” model with abandoned debug code
Used
mosaic_get_memory_usage_peakto identify extra allocationsStack traces revealed optimizer state tracking extra parameters
Case 3: Pipeline Integration
Demonstrated programmatic usage via Mosaic’s Python API
Showed integration patterns for CI/CD and dashboards with structured reports