Training Resnet50 on Cloud TPU with PyTorch

Note: This page applies to the Cloud TPU API. For Ironwood (TPU7x), you must use Google Kubernetes Engine (GKE). For more information, seeAbout TPUs in GKE.

This tutorial shows you how to train the ResNet-50 modelon a Cloud TPU device with PyTorch. You can apply the same pattern toother TPU-optimised image classification models that use PyTorch and theImageNet dataset.

The model in this tutorial is based onDeep Residual Learning for ImageRecognition, which first introducesthe residual network (ResNet) architecture. The tutorial uses the 50-layervariant, ResNet-50, and demonstrates training the model usingPyTorch/XLA.

Warning: This tutorial uses a third-party dataset. Google provides norepresentation, warranty, or other guarantees about the validity, or any otheraspects of this dataset.

Objectives

Costs

In this document, you use the following billable components of Google Cloud:

To generate a cost estimate based on your projected usage, use thepricing calculator.

New Google Cloud users might be eligible for afree trial.

Before you begin

Before starting this tutorial, check that your Google Cloud project is correctlyset up.

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.create permission.Learn how to grant roles.
    Note: If you don't plan to keep the resources that you create in this procedure, create a project instead of selecting an existing project. After you finish these steps, you can delete the project, removing all resources associated with the project.

    Go to project selector

  3. Verify that billing is enabled for your Google Cloud project.

  4. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.create permission.Learn how to grant roles.
    Note: If you don't plan to keep the resources that you create in this procedure, create a project instead of selecting an existing project. After you finish these steps, you can delete the project, removing all resources associated with the project.

    Go to project selector

  5. Verify that billing is enabled for your Google Cloud project.

  6. This walkthrough uses billable components of Google Cloud. Check theCloud TPU pricing page to estimate your costs. Be sure to clean up resources you created when you've finished with them to avoid unnecessary charges.

Create a TPU VM

  1. Open a Cloud Shell window.

    Open Cloud Shell

  2. Create a TPU VM

    gcloudcomputetpustpu-vmcreateyour-tpu-name\--accelerator-type=v3-8\--version=tpu-ubuntu2204-base\--zone=us-central1-a\--project=your-project
    Note: The first time you run a command in a new Cloud Shell VM, anAuthorize Cloud Shell page is displayed. ClickAuthorize at the bottomof the page to allowgcloud to make Google Cloud API calls with yourcredentials.
  3. Connect to your TPU VM using SSH:

    gcloudcomputetpustpu-vmsshyour-tpu-name--zone=us-central1-a
  4. Install PyTorch/XLA on your TPU VM:

    (vm)$pipinstalltorchtorch_xla[tpu]torchvision-fhttps://storage.googleapis.com/libtpu-releases/index.html-fhttps://storage.googleapis.com/libtpu-wheels/index.html
  5. Clone thePyTorch/XLA GitHub repo

    (vm)$gitclone--depth=1https://github.com/pytorch/xla.git
  6. Run the training script with fake data

    (vm)$PJRT_DEVICE=TPUpython3xla/test/test_train_mp_imagenet.py--fake_data--batch_size=256--num_epochs=1

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

  1. Disconnect from the TPU VM:

    (vm)$exit

    Your prompt should now beusername@projectname, showing you are in the Cloud Shell.

  2. Delete your TPU VM.

    $gcloudcomputetpustpu-vmdeleteyour-tpu-name\--zone=us-central1-a

What's next

Except as otherwise noted, the content of this page is licensed under theCreative Commons Attribution 4.0 License, and code samples are licensed under theApache 2.0 License. For details, see theGoogle Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2025-11-25 UTC.