Treinar um modelo usando a TPU v5e
Com uma pegada menor de 256 chips por pod, a TPU v5e é otimizada para treinamento, ajuste e disponibilização de transformadores, conversão de texto em imagem e redes neurais convolucionais (CNNs). Para saber como usar o Cloud TPU v5e para disponibilização, consulte Inferência com a v5e.
Para mais informações sobre o hardware e as configurações de TPU do Cloud TPU v5e, consulte TPU v5e.
Introdução
Confira as seções a seguir para saber como começar a usar a TPU v5e.
Solicitação de cotas
Você precisa de cota para usar a TPU v5e em jobs de treinamento. Há diferentes tipos de cota para TPUs on demand, TPUs reservadas e VMs spot de TPU. Outras cotas são necessárias ao usar a TPU v5e para inferência. Para mais informações sobre cotas, consulte Cotas. Para solicitar cota da TPU v5e, entre em contato com a equipe de vendas do Cloud.
Criar uma conta e um projeto do Google Cloud
Você precisa ter uma conta e um projeto do Google Cloud para usar o Cloud TPU. Para mais informações, consulte Configurar um ambiente do Cloud TPU.
Criar um Cloud TPU
A prática recomendada é provisionar Cloud TPUs v5e como recursos em fila usando o comando queued-resource create. Para mais informações, consulte Gerenciar recursos em fila.
Você também pode usar a API Create Node (gcloud compute tpus tpu-vm create) para provisionar Cloud TPUs v5e. Para mais informações, consulte Gerenciar recursos de TPU.
Para mais informações sobre as configurações da v5e disponíveis para treinamento, consulte Tipos de Cloud TPU v5e para treinamento.
Configuração de framework
Esta seção descreve o processo geral de configuração para o treinamento de modelos personalizados usando o JAX ou o PyTorch com a TPU v5e.
Para instruções sobre a configuração da inferência, consulte Introdução à inferência na v5e.
Defina algumas variáveis de ambiente:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id
Configuração para JAX
Se você tiver frações com mais de oito chips, vai ter várias VMs em uma fração. Nesse caso, use a flag --worker=all para executar em uma única etapa a instalação em todas as VMs de TPU. Assim, você não precisa usar SSH para fazer login em cada uma delas:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html' Descrições de flags de comando
| Variável | Descrição |
| TPU_NAME | O ID de texto atribuído pelo usuário da TPU criado quando a solicitação de recurso em fila é alocada. |
| PROJECT_ID | Nome do projeto doGoogle Cloud . Use um projeto atual ou crie outro em Configurar o projeto do Google Cloud . |
| ZONE | Consulte o documento Regiões e zonas de TPU para saber quais são as zonas disponíveis. |
| worker | A VM de TPU que tem acesso às TPUs. |
Execute o comando a seguir para verificar o número de dispositivos. As saídas mostradas aqui foram produzidas com uma fração da v5litepod-16. Esse código testa se tudo está instalado corretamente. Para isso, ele verifica se o JAX reconhece os TensorCores do Cloud TPU e pode executar operações básicas:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"' A saída será assim:
SSH: Attempting to connect to worker 0... SSH: Attempting to connect to worker 1... SSH: Attempting to connect to worker 2... SSH: Attempting to connect to worker 3... 16 4 16 4 16 4 16 4 jax.device_count() mostra o número total de chips na fração especificada. jax.local_device_count() indica a contagem de chips acessíveis por uma única VM na fração.
# Check the number of chips in the given slice by summing the count of chips # from all VMs through the # jax.local_device_count() API call. gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"' A saída será assim:
SSH: Attempting to connect to worker 0... SSH: Attempting to connect to worker 1... SSH: Attempting to connect to worker 2... SSH: Attempting to connect to worker 3... [16. 16. 16. 16.] [16. 16. 16. 16.] [16. 16. 16. 16.] [16. 16. 16. 16.] Confira os tutoriais do JAX neste documento para fazer treinamentos na v5e usando o JAX.
Configuração para PyTorch
A v5e só aceita o ambiente de execução PJRT, e o PyTorch usa o PJRT a partir da versão 2.1 como ambiente de execução padrão para todas as versões de TPU.
Esta seção descreve como usar o PJRT na v5e com o PyTorch/XLA e fornece comandos para todos os workers.
Instalar dependências
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip install mkl mkl-include pip install tf-nightly tb-nightly tbp-nightly pip install numpy sudo apt-get install libopenblas-dev -y pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Substitua PYTORCH_VERSION pela versão do PyTorch que você quer usar. PYTORCH_VERSION é usado para especificar a mesma versão do PyTorch/XLA. A 2.6.0 é a recomendação.
Para mais informações sobre as versões do PyTorch e do PyTorch/XLA, consulte PyTorch: introdução e Versões do PyTorch/XLA.
Para mais informações sobre a instalação do PyTorch/XLA, consulte Instalação do PyTorch/XLA.
Se você receber um erro ao instalar os wheels para torch, torch_xla ou torchvision, como pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end or semicolon (after name and no valid version specifier) torch==nightly+20230222, faça downgrade da versão com este comando:
pip3 install setuptools==62.1.0 Executar um script com PJRT
unset LD_PRELOAD Confira um exemplo que usa um script Python para fazer um cálculo em uma VM v5e:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/ export PJRT_DEVICE=TPU export PT_XLA_DEBUG=0 export USE_TORCH=ON unset LD_PRELOAD export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"' Isso gera um resultado como este:
SSH: Attempting to connect to worker 0... SSH: Attempting to connect to worker 1... xla:0 tensor([[ 1.8611, -0.3114, -2.4208], [-1.0731, 0.3422, 3.1445], [ 0.5743, 0.2379, 1.1105]], device='xla:0') xla:0 tensor([[ 1.8611, -0.3114, -2.4208], [-1.0731, 0.3422, 3.1445], [ 0.5743, 0.2379, 1.1105]], device='xla:0') Confira os tutoriais do PyTorch neste documento para fazer treinamentos na v5e usando o PyTorch.
Exclua a TPU e o recurso em fila no fim da sessão. Para excluir um recurso em fila, você precisa primeiro excluir a fração e depois o recurso:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet Você também pode seguir essas duas etapas para remover solicitações de recursos em fila que estão no estado FAILED.
Exemplos de JAX/FLAX
As seções a seguir descrevem exemplos de como treinar modelos do JAX e do FLAX em TPUs v5e.
Treinar o ImageNet na v5e
Neste tutorial, descrevemos como treinar o ImageNet na v5e usando dados de entrada falsos. Se você quiser usar dados reais, consulte o arquivo README no GitHub.
Configuração
Crie variáveis de ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrições de variáveis de ambiente
Variável Descrição PROJECT_IDO ID do projeto do Google Cloud . Use um projeto atual ou crie um novo. TPU_NAMEO nome da TPU. ZONEA zona em que a VM de TPU será criada. Para mais informações sobre as zonas disponíveis, consulte Zonas e regiões de TPU. ACCELERATOR_TYPEO tipo de acelerador especifica a versão e o tamanho do Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores aceitos por cada versão de TPU, consulte Versões de TPU. RUNTIME_VERSIONA versão do software do Cloud TPU. SERVICE_ACCOUNTO endereço de e-mail da conta de serviço. Para encontrá-lo, acesse a página Contas de serviço no console do Google Cloud . Por exemplo:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.comQUEUED_RESOURCE_IDO ID de texto atribuído pelo usuário da solicitação de recurso em fila. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}Quando o recurso em fila estiver no estado
ACTIVE, será possível acessar por SSH a VM de TPU:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}Quando o QueuedResource estiver no estado
ACTIVE, a saída será parecida com esta:state: ACTIVE Instale a versão mais recente do JAX e do jaxlib:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'Clone o modelo do ImageNet e instale os requisitos correspondentes:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"Para gerar dados falsos, o modelo precisa de informações sobre as dimensões do conjunto de dados. Para coletá-las, acesse os metadados do conjunto de dados do ImageNet:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
Treinar o modelo
Depois de concluir todas as etapas, você pode treinar o modelo.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py" Excluir a TPU e o recurso em fila
Exclua a TPU e o recurso em fila no fim da sessão.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet Modelos FLAX do Hugging Face
Os modelos do Hugging Face implementados em FLAX funcionam imediatamente no Cloud TPU v5e. Esta seção fornece instruções para executar modelos conhecidos.
Treinar o ViT no Imagenette
Neste tutorial, mostramos como treinar o modelo Vision Transformer (ViT) do Hugging Face usando o conjunto de dados Imagenette da Fast AI no Cloud TPU v5e.
O modelo ViT foi o primeiro a treinar um codificador Transformer no ImageNet com resultados excelentes em comparação com as redes convolucionais. Para mais informações, consulte Visão geral do ViT.
Configuração
Crie variáveis de ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrições de variáveis de ambiente
Variável Descrição PROJECT_IDO ID do projeto do Google Cloud . Use um projeto atual ou crie um novo. TPU_NAMEO nome da TPU. ZONEA zona em que a VM de TPU será criada. Para mais informações sobre as zonas disponíveis, consulte Zonas e regiões de TPU. ACCELERATOR_TYPEO tipo de acelerador especifica a versão e o tamanho do Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores aceitos por cada versão de TPU, consulte Versões de TPU. RUNTIME_VERSIONA versão do software do Cloud TPU. SERVICE_ACCOUNTO endereço de e-mail da conta de serviço. Para encontrá-lo, acesse a página Contas de serviço no console do Google Cloud . Por exemplo:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.comQUEUED_RESOURCE_IDO ID de texto atribuído pelo usuário da solicitação de recurso em fila. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}Será possível acessar por SSH a VM de TPU quando o recurso em fila estiver no estado
ACTIVE:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}Quando o recurso em fila estiver no estado
ACTIVE, a saída será assim:state: ACTIVE Instale o JAX e a biblioteca dele:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'Faça o download do repositório do Hugging Face e instale os requisitos:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'Faça o download do conjunto de dados Imagenette:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
Treinar o modelo
Treine o modelo com um buffer pré-mapeado de 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3' Excluir a TPU e o recurso em fila
Exclua a TPU e o recurso em fila no fim da sessão.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet Resultados do comparativo do ViT
O script de treinamento foi executado em v5litepod-4, v5litepod-16 e v5litepod-64. A tabela a seguir mostra a capacidade de processamento com diferentes tipos de aceleradores.
| Tipo de acelerador | v5litepod-4 | v5litepod-16 | v5litepod-64 |
| Período | 3 | 3 | 3 |
| Tamanho global do lote | 32 | 128 | 512 |
| Capacidade de processamento (exemplos/segundo) | 263,40 | 429,34 | 470,71 |
Treinar a difusão no Pokémon
Neste tutorial, mostramos como treinar o modelo Stable Diffusion do Hugging Face usando o conjunto de dados Pokémon no Cloud TPU v5e.
O modelo Stable Diffusion é um modelo de conversão de texto em imagem baseado em espaço latente que gera imagens realistas com base em qualquer entrada de texto. Para saber mais, acesse estes recursos:
Configuração
Defina uma variável de ambiente para o nome do bucket de armazenamento:
export GCS_BUCKET_NAME=your_bucket_name
Configure um bucket de armazenamento para a saída do modelo:
gcloud storage buckets create gs://GCS_BUCKET_NAME \ --project=your_project \ --location=us-west1
Crie variáveis de ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west1-c export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrições de variáveis de ambiente
Variável Descrição PROJECT_IDO ID do projeto do Google Cloud . Use um projeto atual ou crie um novo. TPU_NAMEO nome da TPU. ZONEA zona em que a VM de TPU será criada. Para mais informações sobre as zonas disponíveis, consulte Zonas e regiões de TPU. ACCELERATOR_TYPEO tipo de acelerador especifica a versão e o tamanho do Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores aceitos por cada versão de TPU, consulte Versões de TPU. RUNTIME_VERSIONA versão do software do Cloud TPU. SERVICE_ACCOUNTO endereço de e-mail da conta de serviço. Para encontrá-lo, acesse a página Contas de serviço no console do Google Cloud . Por exemplo:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.comQUEUED_RESOURCE_IDO ID de texto atribuído pelo usuário da solicitação de recurso em fila. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}Quando o recurso em fila estiver no estado
ACTIVE, será possível acessar por SSH a VM de TPU:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}Quando o recurso em fila estiver no estado
ACTIVE, a saída será assim:state: ACTIVE Instale o JAX e a biblioteca dele.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'Faça o download do repositório do Hugging Face e instale os requisitos.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
Treinar o modelo
Treine o modelo com um buffer pré-mapeado de 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command=" git clone https://github.com/google/maxdiffusion cd maxdiffusion pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip3 install -r requirements.txt pip3 install . pip3 install gcsfs export LIBTPU_INIT_ARGS='' python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \ jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \ per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \ output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash" Limpeza
Exclua a TPU, o recurso em fila e o bucket do Cloud Storage no final da sessão.
Exclua a TPU:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quietExclua o recurso em fila:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quietExclua o bucket do Cloud Storage:
gcloud storage rm -r gs://${GCS_BUCKET_NAME}
Resultados do comparativo para difusão
O script de treinamento foi executado na v5litepod-4, na v5litepod-16 e na v5litepod-64. A tabela a seguir mostra a capacidade de processamento.
| Tipo de acelerador | v5litepod-4 | v5litepod-16 | v5litepod-64 |
| Etapa do treinamento | 1500 | 1500 | 1500 |
| Tamanho global do lote | 32 | 64 | 128 |
| Capacidade de processamento (exemplos/segundo) | 36,53 | 43,71 | 49,36 |
PyTorch/XLA
As seções a seguir descrevem exemplos de como treinar modelos do PyTorch/XLA em TPUs v5e.
Treinar o ResNet usando o ambiente de execução do PJRT
O PyTorch/XLA está migrando do XRT para o PJRT a partir do PyTorch 2.0. Confira as instruções atualizadas para configurar a v5e para as cargas de trabalho de treinamento do PyTorch/XLA.
Configuração
Crie variáveis de ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrições de variáveis de ambiente
Variável Descrição PROJECT_IDO ID do projeto do Google Cloud . Use um projeto atual ou crie um novo. TPU_NAMEO nome da TPU. ZONEA zona em que a VM de TPU será criada. Para mais informações sobre as zonas disponíveis, consulte Zonas e regiões de TPU. ACCELERATOR_TYPEO tipo de acelerador especifica a versão e o tamanho do Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores aceitos por cada versão de TPU, consulte Versões de TPU. RUNTIME_VERSIONA versão do software do Cloud TPU. SERVICE_ACCOUNTO endereço de e-mail da conta de serviço. Para encontrá-lo, acesse a página Contas de serviço no console do Google Cloud . Por exemplo:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.comQUEUED_RESOURCE_IDO ID de texto atribuído pelo usuário da solicitação de recurso em fila. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}Quando o QueuedResource estiver no estado
ACTIVE, será possível acessar por SSH a VM de TPU:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}Quando o recurso em fila estiver no estado
ACTIVE, a saída será assim:state: ACTIVE Instalar dependências específicas do PyTorch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Substitua
PYTORCH_VERSIONpela versão do PyTorch que você quer usar.PYTORCH_VERSIONé usado para especificar a mesma versão do PyTorch/XLA. A 2.6.0 é a recomendação.Para mais informações sobre as versões do PyTorch e do PyTorch/XLA, consulte PyTorch: introdução e Versões do PyTorch/XLA.
Para mais informações sobre a instalação do PyTorch/XLA, consulte Instalação do PyTorch/XLA.
Treinar o modelo do ResNet
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' date export PJRT_DEVICE=TPU export PT_XLA_DEBUG=0 export USE_TORCH=ON export XLA_USE_BF16=1 export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so git clone https://github.com/pytorch/xla.git cd xla/ git checkout release-r2.6 python3 test/test_train_mp_imagenet.py --model=resnet50 --fake_data --num_epochs=1 --num_workers=16 --log_steps=300 --batch_size=64 --profile' Excluir a TPU e o recurso em fila
Exclua a TPU e o recurso em fila no fim da sessão.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet Resultado do comparativo
A tabela a seguir mostra a comparação entre as capacidades de processamento.
| Tipo de acelerador | Capacidade de processamento (exemplos/segundo) |
| v5litepod-4 | 4.240 ex/s |
| v5litepod-16 | 10.810 ex/s |
| v5litepod-64 | 46.154 ex/s |
Treinar o ViT na v5e
Neste tutorial, explicamos como executar o VIT na v5e usando o repositório do Hugging Face no PyTorch/XLA no conjunto de dados cifar10.
Configuração
Crie variáveis de ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descrições de variáveis de ambiente
Variável Descrição PROJECT_IDO ID do projeto do Google Cloud . Use um projeto atual ou crie um novo. TPU_NAMEO nome da TPU. ZONEA zona em que a VM de TPU será criada. Para mais informações sobre as zonas disponíveis, consulte Zonas e regiões de TPU. ACCELERATOR_TYPEO tipo de acelerador especifica a versão e o tamanho do Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores aceitos por cada versão de TPU, consulte Versões de TPU. RUNTIME_VERSIONA versão do software do Cloud TPU. SERVICE_ACCOUNTO endereço de e-mail da conta de serviço. Para encontrá-lo, acesse a página Contas de serviço no console do Google Cloud . Por exemplo:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.comQUEUED_RESOURCE_IDO ID de texto atribuído pelo usuário da solicitação de recurso em fila. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}Quando o QueuedResource estiver no estado
ACTIVE, será possível acessar por SSH a VM de TPU:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}Quando o recurso em fila estiver no estado
ACTIVE, a saída será assim:state: ACTIVE Instalar dependências do PyTorch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
Substitua
PYTORCH_VERSIONpela versão do PyTorch que você quer usar.PYTORCH_VERSIONé usado para especificar a mesma versão do PyTorch/XLA. A 2.6.0 é a recomendação.Para mais informações sobre as versões do PyTorch e do PyTorch/XLA, consulte PyTorch: introdução e Versões do PyTorch/XLA.
Para mais informações sobre a instalação do PyTorch/XLA, consulte Instalação do PyTorch/XLA.
Faça o download do repositório do Hugging Face e instale os requisitos.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=" git clone https://github.com/suexu1025/transformers.git vittransformers; \ cd vittransformers; \ pip3 install .; \ pip3 install datasets; \ wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
Treinar o modelo
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export PT_XLA_DEBUG=0 export USE_TORCH=ON export TF_CPP_MIN_LOG_LEVEL=0 export XLA_USE_BF16=1 export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so cd vittransformers python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \ --remove_unused_columns=False \ --label_names=pixel_values \ --mask_ratio=0.75 \ --norm_pix_loss=True \ --do_train=true \ --do_eval=true \ --base_learning_rate=1.5e-4 \ --lr_scheduler_type=cosine \ --weight_decay=0.05 \ --num_train_epochs=3 \ --warmup_ratio=0.05 \ --per_device_train_batch_size=8 \ --per_device_eval_batch_size=8 \ --logging_strategy=steps \ --logging_steps=30 \ --evaluation_strategy=epoch \ --save_strategy=epoch \ --load_best_model_at_end=True \ --save_total_limit=3 \ --seed=1337 \ --output_dir=MAE \ --overwrite_output_dir=true \ --logging_dir=./tensorboard-metrics \ --tpu_metrics_debug=true' Excluir a TPU e o recurso em fila
Exclua a TPU e o recurso em fila no fim da sessão.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet Resultado do comparativo
A tabela a seguir mostra a comparação entre as capacidades de processamento de diferentes tipos de aceleradores.
| v5litepod-4 | v5litepod-16 | v5litepod-64 | |
| Período | 3 | 3 | 3 |
| Tamanho global do lote | 32 | 128 | 512 |
| Capacidade de processamento (exemplos/segundo) | 201 | 657 | 2.844 |