Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Training a simple neural network, with tensorflow/datasets data loading

Copyright 2018 The JAX Authors.

Licensed under the Apache License, Version 2.0 (the “License”);

Licensed under the Apache License, Version 2.0 (the “License”);you may not use this file except in compliance with the License.You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, softwaredistributed under the License is distributed on an “AS IS” BASIS,WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.See the License for the specific language governing permissions andlimitations under the License.

Training a simple neural network, with tensorflow/datasets data loading#

Open in ColabOpen in Kaggle

Forked fromneural_network_and_data_loading.ipynb

JAX

Let’s combine everything we showed in thequickstart to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will usetensorflow/datasets data loading API to load images and labels (because it’s pretty great, and the world doesn’t need yet another data loading library :P).

Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won’t use any neural network libraries or special APIs for building our model.

importjax.numpyasjnpfromjaximportgrad,jit,vmapfromjaximportrandom

Hyperparameters#

Let’s get a few bookkeeping items out of the way.

# A helper function to randomly initialize weights and biases# for a dense neural network layerdefrandom_layer_params(m,n,key,scale=1e-2):w_key,b_key=random.split(key)returnscale*random.normal(w_key,(n,m)),scale*random.normal(b_key,(n,))# Initialize all layers for a fully-connected neural network with sizes "sizes"definit_network_params(sizes,key):keys=random.split(key,len(sizes))return[random_layer_params(m,n,k)form,n,kinzip(sizes[:-1],sizes[1:],keys)]layer_sizes=[784,512,512,10]step_size=0.01num_epochs=10batch_size=128n_targets=10params=init_network_params(layer_sizes,random.key(0))

Auto-batching predictions#

Let us first define our prediction function. Note that we’re defining this for asingle image example. We’re going to use JAX’svmap function to automatically handle mini-batches, with no performance penalty.

fromjax.scipy.specialimportlogsumexpdefrelu(x):returnjnp.maximum(0,x)defpredict(params,image):# per-example predictionsactivations=imageforw,binparams[:-1]:outputs=jnp.dot(w,activations)+bactivations=relu(outputs)final_w,final_b=params[-1]logits=jnp.dot(final_w,activations)+final_breturnlogits-logsumexp(logits)

Let’s check that our prediction function only works on single images.

# This works on single examplesrandom_flattened_image=random.normal(random.key(1),(28*28,))preds=predict(params,random_flattened_image)print(preds.shape)
(10,)
# Doesn't work with a batchrandom_flattened_images=random.normal(random.key(1),(10,28*28))try:preds=predict(params,random_flattened_images)exceptTypeError:print('Invalid shapes!')
Invalid shapes!
# Let's upgrade it to handle batches using `vmap`# Make a batched version of the `predict` functionbatched_predict=vmap(predict,in_axes=(None,0))# `batched_predict` has the same call signature as `predict`batched_preds=batched_predict(params,random_flattened_images)print(batched_preds.shape)
(10, 10)

At this point, we have all the ingredients we need to define our neural network and train it. We’ve built an auto-batched version ofpredict, which we should be able to use in a loss function. We should be able to usegrad to take the derivative of the loss with respect to the neural network parameters. Last, we should be able to usejit to speed up everything.

Utility and loss functions#

defone_hot(x,k,dtype=jnp.float32):"""Create a one-hot encoding of x of size k."""returnjnp.array(x[:,None]==jnp.arange(k),dtype)defaccuracy(params,images,targets):target_class=jnp.argmax(targets,axis=1)predicted_class=jnp.argmax(batched_predict(params,images),axis=1)returnjnp.mean(predicted_class==target_class)defloss(params,images,targets):preds=batched_predict(params,images)return-jnp.mean(preds*targets)@jitdefupdate(params,x,y):grads=grad(loss)(params,x,y)return[(w-step_size*dw,b-step_size*db)for(w,b),(dw,db)inzip(params,grads)]

Data loading withtensorflow/datasets#

JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don’t include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let’s just use them instead of reinventing anything. We’ll use thetensorflow/datasets data loader.

importtensorflowastf# Ensure TF does not see GPU and grab all GPU memory.tf.config.set_visible_devices([],device_type='GPU')importtensorflow_datasetsastfdsdata_dir='/tmp/tfds'# Fetch full datasets for evaluation# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpymnist_data,info=tfds.load(name="mnist",batch_size=-1,data_dir=data_dir,with_info=True)mnist_data=tfds.as_numpy(mnist_data)train_data,test_data=mnist_data['train'],mnist_data['test']num_labels=info.features['label'].num_classesh,w,c=info.features['image'].shapenum_pixels=h*w*c# Full train settrain_images,train_labels=train_data['image'],train_data['label']train_images=jnp.reshape(train_images,(len(train_images),num_pixels))train_labels=one_hot(train_labels,num_labels)# Full test settest_images,test_labels=test_data['image'],test_data['label']test_images=jnp.reshape(test_images,(len(test_images),num_pixels))test_labels=one_hot(test_labels,num_labels)
print('Train:',train_images.shape,train_labels.shape)print('Test:',test_images.shape,test_labels.shape)
Train: (60000, 784) (60000, 10)Test: (10000, 784) (10000, 10)

Training loop#

importtimedefget_train_batches():# as_supervised=True gives us the (image, label) as a tuple instead of a dictds=tfds.load(name='mnist',split='train',as_supervised=True,data_dir=data_dir)# You can build up an arbitrary tf.data input pipelineds=ds.batch(batch_size).prefetch(1)# tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arraysreturntfds.as_numpy(ds)forepochinrange(num_epochs):start_time=time.time()forx,yinget_train_batches():x=jnp.reshape(x,(len(x),num_pixels))y=one_hot(y,num_labels)params=update(params,x,y)epoch_time=time.time()-start_timetrain_acc=accuracy(params,train_images,train_labels)test_acc=accuracy(params,test_images,test_labels)print("Epoch{} in{:0.2f} sec".format(epoch,epoch_time))print("Training set accuracy{}".format(train_acc))print("Test set accuracy{}".format(test_acc))
Epoch 0 in 28.30 secTraining set accuracy 0.8400499820709229Test set accuracy 0.8469000458717346Epoch 1 in 14.74 secTraining set accuracy 0.8743667006492615Test set accuracy 0.8803000450134277Epoch 2 in 14.57 secTraining set accuracy 0.8901500105857849Test set accuracy 0.8957000374794006Epoch 3 in 14.36 secTraining set accuracy 0.8991333246231079Test set accuracy 0.903700053691864Epoch 4 in 14.20 secTraining set accuracy 0.9061833620071411Test set accuracy 0.9087000489234924Epoch 5 in 14.89 secTraining set accuracy 0.9113333225250244Test set accuracy 0.912600040435791Epoch 6 in 13.95 secTraining set accuracy 0.9156833291053772Test set accuracy 0.9176000356674194Epoch 7 in 13.32 secTraining set accuracy 0.9192000031471252Test set accuracy 0.9214000701904297Epoch 8 in 13.55 secTraining set accuracy 0.9222500324249268Test set accuracy 0.9241000413894653Epoch 9 in 13.40 secTraining set accuracy 0.9253666996955872Test set accuracy 0.9269000291824341

We’ve now used most of the JAX API:grad for derivatives,jit for speedups andvmap for auto-vectorization.We used NumPy to specify all of our computation, and borrowed the great data loaders fromtensorflow/datasets, and ran the whole thing on the GPU.


[8]ページ先頭

©2009-2025 Movatter.jp