LLM distributed supervised fine-tuning with JAX
Contents
LLM distributed supervised fine-tuning with JAX#

25 Jan, 2024 by .
In this article, we review the process for fine-tuning a Bidirectional Encoder Representationsfrom Transformers (BERT)-based large language model (LLM) using JAX for a text classification task. Weexplore techniques for parallelizing this fine-tuning procedure across multiple AMD GPUs, thenevaluate our model’s performance on a holdout dataset. For this, we use a(BERT)-base-cased transformer model with a GeneralLanguage Understanding Evaluation (GLUE) benchmark dataset on multiple AMD GPUs.
We focus on twoSingle Program, Multiple Data (SPMD)parallelism methods in JAX. These are:
Using a
pmapfunction for straightforward data distribution over a single leading axis.Using
jit,Mesh, andmesh_utilsfunctions to shard data across devices, providing greater controlover parallelization.
Our emphasis is on the first method, and we provide details on the second method in thefinal section.
In developing this article, we referencedthis tutorial,which we highly recommend.
What is supervised fine-tuning?#
In the era of artificial intelligence (AI), transformer architecture-based models like BERT, GPT-3, andtheir successors have provided a sturdy foundation for achieving cutting-edge performance acrossvarious natural language processing (NLP) tasks, including text classification, text generation, andsentiment analysis. Nonetheless, when applied in isolation to these specific tasks, these large,pre-trained models often exhibit limitations. Supervised fine-tuning (SFT) provides a solution to theselimitations.
Unlike pre-trained models, which undergo broad, unsupervised training on massive and diversedatasets, SFT adopts a focused and resource-efficient approach. Typically, this requires a relativelycompact, high-quality dataset that is precisely tailored to the given task. SFT can improve modelperformance to a state-of-the-art level without the need for protracted training periods, as it is able toleverage the extensive knowledge acquired by pre-trained models.
The SFT process consists of fine-tuning the model’s existing weights or adding extra parameters toensure alignment with the intricacies of the designated task. Often, this adaptation incorporatestask-specific layers, such as the addition of a softmax layer for classification, which enhances themodel’s ability to address supervised tasks.
What is JAX?#
JAX is a high-performance numerical computation library for Python. In contrast to traditional machinelearning frameworks, such as TensorFlow and PyTorch, JAX has remarkable speed and efficiency. JAXutilizes Just-in-Time (JIT) compilation, seamless automatic differentiation, and an inherent capabilityto efficiently vectorize and parallelize code, which allows for simple adaptation for AI accelerators(GPUs and TPUs).
Why use AMD GPUs?#
AMD GPUs stand out for their robust open-source support–featuring tools like ROCm andHIP–making them easily adaptable to AI workflows. AMD’s competitive price-to-performance ratiocaters to anyone seeking cost-effective solutions for AI and deep learning tasks. As AMD’s presence inthe market grows, more machine learning libraries and frameworks are adding AMD GPU support.
Hardware requirements and running environment#
To harness the computational capabilities required for this task, we leverage theAMD Accelerator Cloud (AAC). AAC is a platform that offers on-demand cloudcomputing resources and APIs on a pay-as-you-go basis. Specifically, we use aJAX docker container with 8 GPUs (on AAC) to utilize the fullpotential of cutting-edge GPU parallel computing.
This article is hardware-agnostic, meaning that access to AAC isnot a requirement for successfullyrunning the code examples provided. As long as you have access to accelerator devices, such as GPUsor TPUs, you should be able to run the code examples with minimal code modifications. If you’re usingAMD GPUs, make sure you have ROCm and its compatible versions of JAX and Jaxlib installed correctly.Refer to the following tutorials for installation instructions:
JAX and Jaxlib installation: You can alsodirectly pull a JAX Docker image in the link.
Code example on SFT of a transformer model#
For this demonstration, we fine-tune a transformer-based LLM(bert-base-cased) using a General Language UnderstandingEvaluation (GLUE)benchmark dataset, Quora Question Pairs (QQP). Thisdataset consists of over 400,000 pairs of questions, each accompanied by a binary annotation thatindicates if the two questions are paraphrases of each other. The input variables are the sentences ofthe two questions, while the output variable is a binary indicator denoting whether the questions sharethe same meaning.
Installation#
First, install the required packages (%%capture is acell magic that will suppress the output of the cell).
%%capture!pipinstalldatasets!pipinstallgit+https://github.com/huggingface/transformers.git!pipinstallflax!pipinstallgit+https://github.com/deepmind/optax.git!pipinstallevaluate!pipinstallipywidgets!pipinstallblackisort# Jupyter Notebook code formatter; optional
Import the remaining packages and functionalities.
importosfromitertoolsimportchainfromtypingimportCallableimportevaluateimportflaximportjaximportjax.numpyasjnpimportoptaximportpandasaspdfromdatasetsimportload_datasetfromflaximporttraverse_utilfromflax.trainingimporttrain_statefromflax.training.common_utilsimportget_metrics,onehot,shard,shard_prng_keyfromipywidgetsimportIntProgressasIProgressfromtqdm.notebookimporttqdmfromtransformersimport(AutoConfig,AutoTokenizer,FlaxAutoModelForSequenceClassification,)os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
JAX pre-allocates 75% of total GPU memory to reduce overhead and fragmentation when running thefirst JAX operation, but may trigger out-of-memory (OOM) errors. To avoid the OOM issues, suppressthe default behavior by setting theXLA_PYTHON_CLIENT_PREALLOCATE flag to false.
Check if the GPU devices are detectable by JAX. If not, you may need to re-install ROCm, JAX, andJaxlib. If JAX is installed correctly, you can see all the GPU devices you requested, which in our case is 8GPUs.
jax.local_devices()
[gpu(id=0),gpu(id=1),gpu(id=2),gpu(id=3),gpu(id=4),gpu(id=5),gpu(id=6),gpu(id=7)]
Get the fine-tuning dataset and pre-trained model checkpoint#
Specify the settings for your fine-tuning process: the dataset, the pre-trained model, and how manysamples you want processed per batch and per device.
task="qqp"model_checkpoint="bert-base-cased"per_device_batch_size=64
Load the dataset and evaluation metric module.
raw_dataset=load_dataset("glue",task)metric=evaluate.load("glue",task)
The next few code blocks show how to tokenize the text data with the model-specific tokenizer andload the tokenized training and validation data. Using the same tokenizer as used in the pre-trainedmodel ensures that the same words will be converted to the same embedding vector in the fine-tuningprocess.
It’s important to highlight that we’ve performed a 10% subsampling on the training and evaluationdatasets from the original training data. Despite this reduction, the QQP dataset still provides sufficientdata for achieving commendable performance and allows us to observe metric improvements aftereach epoch. This subsampling approach also expedites our training process for illustration.
Process the training and evaluation datasets using the data preprocessing function and the mapwrapper’s batch and parallel processing features. You can view the tokenized dataset in the followingoutput.
tokenizer=AutoTokenizer.from_pretrained(model_checkpoint)
defpreprocess_function(examples):texts=(examples["question1"],examples["question2"])processed=tokenizer(*texts,padding="max_length",max_length=128,truncation=True)processed["labels"]=examples["label"]returnprocessed
# Details about how to handle and process huggingface dataset:# https://huggingface.co/docs/datasets/processdata=raw_dataset["train"].shuffle(seed=0)train_data=data.select(list(range(int(data.shape[0]*0.1))))eval_data=data.select(list(range(int(data.shape[0]*0.1),int(data.shape[0]*0.2))))print(f"Shape of the original training dataset is:{data.shape}")print(f"Shape of the current training dataset is:{train_data.shape}")print(f"Shape of the current evaluation dataset is:{eval_data.shape}")
Shapeoftheoriginaltrainingdatasetis:(363846,4)Shapeofthecurrenttrainingdatasetis:(36384,4)Shapeofthecurrentevaluationdatasetis:(36385,4)
train_dataset=train_data.map(preprocess_function,batched=True,remove_columns=train_data.column_names)eval_dataset=eval_data.map(preprocess_function,batched=True,remove_columns=eval_data.column_names)
# You can view the tokenized dataset with the output of this cell.pd.DataFrame(train_dataset[:3])
Download the pre-trained model configurations and checkpoint from Hugging Face. Note that you’llsee a warning message stating that some of the model weights weren’t used. This is expected becausethe BERT model checkpoint is aPreTraining model class and you’re initializing aSequenceClassification model. The warning message states:You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference..This is what we’ll focus on in the rest of this blog.
num_labels=2seed=0config=AutoConfig.from_pretrained(model_checkpoint,num_labels=num_labels)model=FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint,config=config,seed=seed)
Someweightsofthemodelcheckpointatbert-base-casedwerenotusedwheninitializingFlaxBertForSequenceClassification:{('cls','predictions','bias'),('cls','predictions','transform','dense','kernel'),('cls','predictions','transform','LayerNorm','bias'),('cls','predictions','transform','LayerNorm','scale'),('cls','predictions','transform','dense','bias')}-ThisISexpectedifyouareinitializingFlaxBertForSequenceClassificationfromthecheckpointofamodeltrainedonanothertaskorwithanotherarchitecture(e.g.initializingaBertForSequenceClassificationmodelfromaBertForPreTrainingmodel).-ThisISNOTexpectedifyouareinitializingFlaxBertForSequenceClassificationfromthecheckpointofamodelthatyouexpecttobeexactlyidentical(initializingaBertForSequenceClassificationmodelfromaBertForSequenceClassificationmodel).SomeweightsofFlaxBertForSequenceClassificationwerenotinitializedfromthemodelcheckpointatbert-base-casedandarenewlyinitialized:{('classifier','kernel'),('classifier','bias'),('bert','pooler','dense','kernel'),('bert','pooler','dense','bias')}YoushouldprobablyTRAINthismodelonadown-streamtasktobeabletouseitforpredictionsandinference.
Define the state of your fine-tuning model#
The following code blocks show you how to set up training parameters, such as number of trainingepochs and initial learning rate. A learning rate schedule is needed in order to have the learning ratelinearly decay as the training progresses, which ensures learning efficiency and stability.
num_train_epochs=6learning_rate=2e-5
total_batch_size=per_device_batch_size*jax.local_device_count()print("The overall batch size (both for training and eval) is",total_batch_size)
Theoverallbatchsize(bothfortrainingandeval)is512
num_train_steps=len(train_dataset)//total_batch_size*num_train_epochslearning_rate_function=optax.linear_schedule(init_value=learning_rate,end_value=0,transition_steps=num_train_steps)
Next, you’ll need to establish the training state, encompassing the optimizer and loss functionresponsibilities, and oversee the update of the model’s parameters throughout the training process.
With the state object, initialize and update the models. When invoking the model, provide the state asinput, the model then returns the updated state by adding information from the new batch of datawhile preserving the model instance.
Flax offers a user-friendly class (flax.training.train_state.TrainState) that takes in the model parameters,the loss function, and the optimizer. When supplied with data, it can update the model parametersusing theapply_gradients function.
The following code blocks show how to define and establish the training state, optimizer, and lossfunction.
classTrainState(train_state.TrainState):logits_function:Callable=flax.struct.field(pytree_node=False)loss_function:Callable=flax.struct.field(pytree_node=False)
# Create a decay_mask_fn function to make sure that weight decay is not applied to any bias or# LayerNorm weights, as it may not improve model performance and even be harmful.defdecay_mask_fn(params):flat_params=traverse_util.flatten_dict(params)flat_mask={path:(path[-1]!="bias"andpath[-2:]!=("LayerNorm","scale"))forpathinflat_params}returntraverse_util.unflatten_dict(flat_mask)
# Standard Adam optimizer with weight decaydefadamw(weight_decay):returnoptax.adamw(learning_rate=learning_rate_function,b1=0.9,b2=0.999,eps=1e-6,weight_decay=weight_decay,mask=decay_mask_fn,)
defloss_function(logits,labels):xentropy=optax.softmax_cross_entropy(logits,onehot(labels,num_classes=num_labels))returnjnp.mean(xentropy)defeval_function(logits):returnlogits.argmax(-1)
# Instantiate the TrainStatestate=TrainState.create(apply_fn=model.__call__,params=model.params,tx=adamw(weight_decay=0.01),logits_function=eval_function,loss_function=loss_function,)
Define how to train, evaluate the model, and enable parallelization#
Thetrain_step andeval_step parameters define how the model should be trained and evaluated. Thetrain step follows the standard training process:
Calculate the loss with the current weights.
Calculate the gradients of the loss function with respect to the weights.
Update the weights with the gradients and learning rate.
Repeat the above steps until the stopping criteria has been met.
It’s important to highlight that thelax.pmean function computes the mean of gradients from databatches across all 8 GPU devices. This crucial step guarantees the synchronization of model parametersacross all GPU devices.
deftrain_step(state,batch,dropout_rng):targets=batch.pop("labels")dropout_rng,new_dropout_rng=jax.random.split(dropout_rng)defloss_function(params):logits=state.apply_fn(**batch,params=params,dropout_rng=dropout_rng,train=True)[0]loss=state.loss_function(logits,targets)returnlossgrad_function=jax.value_and_grad(loss_function)loss,grad=grad_function(state.params)grad=jax.lax.pmean(grad,"batch")new_state=state.apply_gradients(grads=grad)metrics=jax.lax.pmean({"loss":loss,"learning_rate":learning_rate_function(state.step)},axis_name="batch",)returnnew_state,metrics,new_dropout_rng
defeval_step(state,batch):logits=state.apply_fn(**batch,params=state.params,train=False)[0]returnstate.logits_function(logits)
Next, apply thejax.pmap function to the definedtrain_step andeval_step functions. Applyingpmap() to a function compiles that function with XLA (similar tojit()), then runs it in parallel on XLAdevices, such as multiple GPUs or multiple TPU cores. Simply put, this step sends the training andevaluation functions to all GPU devices. You’ll also need to send the training state to all GPU devicesviaflax.jax_utils.replicate. These steps ensure you’re updating the state, via distributed training, on allGPU devices.
parallel_train_step=jax.pmap(train_step,axis_name="batch",donate_argnums=(0,))parallel_eval_step=jax.pmap(eval_step,axis_name="batch")state=flax.jax_utils.replicate(state)
Define the data loader functions that return a data batch generator. A new batch of data is fed intoeach step of the final training and evaluation loops.
defglue_train_data_loader(rng,dataset,batch_size):steps_per_epoch=len(dataset)//batch_sizeperms=jax.random.permutation(rng,len(dataset))perms=perms[:steps_per_epoch*batch_size]# Skip incomplete batch.perms=perms.reshape((steps_per_epoch,batch_size))forperminperms:batch=dataset[perm]batch={k:jnp.array(v)fork,vinbatch.items()}batch=shard(batch)yieldbatch
defglue_eval_data_loader(dataset,batch_size):foriinrange(len(dataset)//batch_size):batch=dataset[i*batch_size:(i+1)*batch_size]batch={k:jnp.array(v)fork,vinbatch.items()}batch=shard(batch)yieldbatch
A pseudo-random number generator (PRNG) key is generated based on an integer seed, and is thensplit into 8 new keys so that each GPU device gets a different key. Then run the training steps toupdate thestate based on the pre-defined training parameters, such as number of epochs andtotal_batch_size. After finishing each epoch, run the evaluation step on the eval dataset to see theaccuracy and f1 metrics. Because you used a smaller dataset than the original training dataset in thebenchmark, you can see that the eval metrics (train loss and eval accuracy) steadily improved in thefirst few epochs.
rng=jax.random.PRNGKey(seed)dropout_rngs=jax.random.split(rng,jax.local_device_count())
fori,epochinenumerate(tqdm(range(1,num_train_epochs+1),desc=f"Epoch ...",position=0,leave=True)):rng,input_rng=jax.random.split(rng)# trainwithtqdm(total=len(train_dataset)//total_batch_size,desc="Training...",leave=True)asprogress_bar_train:forbatchinglue_train_data_loader(input_rng,train_dataset,total_batch_size):state,train_metrics,dropout_rngs=parallel_train_step(state,batch,dropout_rngs)progress_bar_train.update(1)# evaluatewithtqdm(total=len(eval_dataset)//total_batch_size,desc="Evaluating...",leave=False)asprogress_bar_eval:forbatchinglue_eval_data_loader(eval_dataset,total_batch_size):labels=batch.pop("labels")predictions=parallel_eval_step(state,batch)metric.add_batch(predictions=list(chain(*predictions)),references=list(chain(*labels)))progress_bar_eval.update(1)eval_metric=metric.compute()loss=round(flax.jax_utils.unreplicate(train_metrics)["loss"].item(),3)eval_score1=round(list(eval_metric.values())[0],3)metric_name1=list(eval_metric.keys())[0]eval_score2=round(list(eval_metric.values())[1],3)metric_name2=list(eval_metric.keys())[1]print(f"{i+1}/{num_train_epochs} | Train loss:{loss} | Eval{metric_name1}:{eval_score1},{metric_name2}:{eval_score2}")
Epoch...:0%||0/6[00:00<?,?it/s]Training...:0%||0/71[00:00<?,?it/s]Evaluating...:0%||0/71[00:00<?,?it/s]1/6|Trainloss:0.475|Evalaccuracy:0.799,f1:0.762Training...:0%||0/71[00:00<?,?it/s]Evaluating...:0%||0/71[00:00<?,?it/s]2/6|Trainloss:0.369|Evalaccuracy:0.834,f1:0.789Training...:0%||0/71[00:00<?,?it/s]Evaluating...:0%||0/71[00:00<?,?it/s]3/6|Trainloss:0.299|Evalaccuracy:0.846,f1:0.797Training...:0%||0/71[00:00<?,?it/s]Evaluating...:0%||0/71[00:00<?,?it/s]4/6|Trainloss:0.239|Evalaccuracy:0.846,f1:0.806Training...:0%||0/71[00:00<?,?it/s]Evaluating...:0%||0/71[00:00<?,?it/s]5/6|Trainloss:0.252|Evalaccuracy:0.849,f1:0.802Training...:0%||0/71[00:00<?,?it/s]Evaluating...:0%||0/71[00:00<?,?it/s]6/6|Trainloss:0.212|Evalaccuracy:0.849,f1:0.805
Using JAX device mesh to achieve parallelism#
fromjax.experimentalimportmesh_utilsfromjax.shardingimportMesh,NamedShardingfromjax.shardingimportPartitionSpecasP
config=AutoConfig.from_pretrained(model_checkpoint,num_labels=num_labels)model=FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint,config=config,seed=seed)state=TrainState.create(apply_fn=model.__call__,params=model.params,tx=adamw(weight_decay=0.01),logits_function=eval_function,loss_function=loss_function,)
Someweightsofthemodelcheckpointatbert-base-casedwerenotusedwheninitializingFlaxBertForSequenceClassification:{('cls','predictions','bias'),('cls','predictions','transform','dense','kernel'),('cls','predictions','transform','LayerNorm','bias'),('cls','predictions','transform','LayerNorm','scale'),('cls','predictions','transform','dense','bias')}-ThisISexpectedifyouareinitializingFlaxBertForSequenceClassificationfromthecheckpointofamodeltrainedonanothertaskorwithanotherarchitecture(e.g.initializingaBertForSequenceClassificationmodelfromaBertForPreTrainingmodel).-ThisISNOTexpectedifyouareinitializingFlaxBertForSequenceClassificationfromthecheckpointofamodelthatyouexpecttobeexactlyidentical(initializingaBertForSequenceClassificationmodelfromaBertForSequenceClassificationmodel).SomeweightsofFlaxBertForSequenceClassificationwerenotinitializedfromthemodelcheckpointatbert-base-casedandarenewlyinitialized:{('classifier','kernel'),('classifier','bias'),('bert','pooler','dense','kernel'),('bert','pooler','dense','bias')}YoushouldprobablyTRAINthismodelonadown-streamtasktobeabletouseitforpredictionsandinference.
@jax.jitdeftrain_step(state,batch,dropout_rng):targets=batch.pop("labels")dropout_rng,new_dropout_rng=jax.random.split(dropout_rng)defloss_function(params):logits=state.apply_fn(**batch,params=params,dropout_rng=dropout_rng,train=True)[0]loss=state.loss_function(logits,targets)returnlossgrad_function=jax.value_and_grad(loss_function)loss,grad=grad_function(state.params)new_state=state.apply_gradients(grads=grad)metrics={"loss":loss,"learning_rate":learning_rate_function(state.step)}returnnew_state,metrics,new_dropout_rng
@jax.jitdefeval_step(state,batch):logits=state.apply_fn(**batch,params=state.params,train=False)[0]returnstate.logits_function(logits)
num_devices=len(jax.local_devices())devices=mesh_utils.create_device_mesh((num_devices,))# Data will be split along the batch axisdata_mesh=Mesh(devices,axis_names=("batch",))# naming axes of the meshdata_sharding=NamedSharding(data_mesh,P("batch",),)# naming axes of the sharded partitiondefglue_train_data_loader(rng,dataset,batch_size):steps_per_epoch=len(dataset)//batch_sizeperms=jax.random.permutation(rng,len(dataset))perms=perms[:steps_per_epoch*batch_size]# Skip incomplete batch.perms=perms.reshape((steps_per_epoch,batch_size))forperminperms:batch=dataset[perm]batch={k:jax.device_put(jnp.array(v),data_sharding)fork,vinbatch.items()}yieldbatchdefglue_eval_data_loader(dataset,batch_size):foriinrange(len(dataset)//batch_size):batch=dataset[i*batch_size:(i+1)*batch_size]batch={k:jax.device_put(jnp.array(v),data_sharding)fork,vinbatch.items()}yieldbatch
# Replicate the model and optimizer variable on all devicesdefget_replicated_train_state(devices,state):# All variables will be replicated on all devicesvar_mesh=Mesh(devices,axis_names=("_"))# In NamedSharding, axes not mentioned are replicated (all axes here)var_replication=NamedSharding(var_mesh,P())# Apply the distribution settings to the model variablesstate=jax.device_put(state,var_replication)returnstatestate=get_replicated_train_state(devices,state)
rng=jax.random.PRNGKey(seed)dropout_rng=jax.random.PRNGKey(seed)
fori,epochinenumerate(tqdm(range(1,num_train_epochs+1),desc=f"Epoch ...",position=0,leave=True)):rng,input_rng=jax.random.split(rng)# trainwithtqdm(total=len(train_dataset)//total_batch_size,desc="Training...",leave=True)asprogress_bar_train:forbatchinglue_train_data_loader(input_rng,train_dataset,total_batch_size):state,train_metrics,dropout_rng=train_step(state,batch,dropout_rng)progress_bar_train.update(1)# evaluatewithtqdm(total=len(eval_dataset)//total_batch_size,desc="Evaluating...",leave=False)asprogress_bar_eval:forbatchinglue_eval_data_loader(eval_dataset,total_batch_size):labels=batch.pop("labels")predictions=eval_step(state,batch)metric.add_batch(predictions=list(predictions),references=list(labels))progress_bar_eval.update(1)eval_metric=metric.compute()loss=round(train_metrics["loss"].item(),3)eval_score1=round(list(eval_metric.values())[0],3)metric_name1=list(eval_metric.keys())[0]eval_score2=round(list(eval_metric.values())[1],3)metric_name2=list(eval_metric.keys())[1]print(f"{i+1}/{num_train_epochs} | Train loss:{loss} | Eval{metric_name1}:{eval_score1},{metric_name2}:{eval_score2}")
Epoch...:0%||0/6[00:00<?,?it/s]Training...:0%||0/71[00:00<?,?it/s]Evaluating...:0%||0/71[00:00<?,?it/s]1/6|Trainloss:0.469|Evalaccuracy:0.796,f1:0.759Training...:0%||0/71[00:00<?,?it/s]Evaluating...:0%||0/71[00:00<?,?it/s]2/6|Trainloss:0.376|Evalaccuracy:0.833,f1:0.788Training...:0%||0/71[00:00<?,?it/s]Evaluating...:0%||0/71[00:00<?,?it/s]3/6|Trainloss:0.296|Evalaccuracy:0.844,f1:0.795Training...:0%||0/71[00:00<?,?it/s]Evaluating...:0%||0/71[00:00<?,?it/s]4/6|Trainloss:0.267|Evalaccuracy:0.846,f1:0.805Training...:0%||0/71[00:00<?,?it/s]Evaluating...:0%||0/71[00:00<?,?it/s]5/6|Trainloss:0.263|Evalaccuracy:0.848,f1:0.804Training...:0%||0/71[00:00<?,?it/s]Evaluating...:0%||0/71[00:00<?,?it/s]6/6|Trainloss:0.222|Evalaccuracy:0.849,f1:0.805