PyTorch XLA 워크로드 프로파일링
성능 최적화는 효율적인 머신러닝 모델을 빌드하는 데 중요한 부분입니다. XProf 프로파일링 도구를 사용하여 머신러닝 워크로드의 성능을 측정할 수 있습니다. XProf를 사용하면 XLA 기기에서 모델 실행의 자세한 trace를 캡처할 수 있습니다. 이러한 trace를 사용하면 성능 병목 현상을 식별하고, 기기 활용도를 파악하고, 코드를 최적화할 수 있습니다.
이 가이드에서는 PyTorch XLA 스크립트에서 trace를 프로그래매틱 방식으로 캡처하고 XProf 및 Tensorboard를 사용하여 시각화하는 프로세스를 설명합니다.
trace 캡처
기존 학습 스크립트에 몇 줄의 코드를 추가하여 trace를 캡처할 수 있습니다. trace를 캡처하는 기본 도구는 torch_xla.debug.profiler 모듈이며, 이 모듈은 일반적으로 xp 별칭으로 가져옵니다.
1. 프로파일러 서버 시작
trace를 캡처하려면 먼저 프로파일러 서버를 시작해야 합니다. 이 서버는 스크립트의 백그라운드에서 실행되며 trace 데이터를 수집합니다. 기본 실행 블록의 시작 부분 근처에서 xp.start_server()를 호출하여 시작할 수 있습니다.
2. trace 기간 정의
프로파일링할 코드를 xp.start_trace() 및 xp.stop_trace() 호출 내에 래핑합니다. start_trace 함수는 trace 파일이 저장되는 디렉터리의 경로를 사용합니다.
가장 관련성 높은 작업을 포착하기 위해 기본 학습 루프를 래핑하는 것이 일반적입니다.
# The directory where the trace files are stored. log_dir = '/root/logs/' # Start tracing xp.start_trace(log_dir) # ... your training loop or other code to be profiled ... train_mnist() # Stop tracing xp.stop_trace() 3. 맞춤 trace 라벨 추가
기본적으로 캡처되는 trace는 하위 수준의 PyTorch XLA 함수이며 탐색하기 어려울 수 있습니다. xp.Trace() 컨텍스트 관리자를 사용하여 코드의 특정 섹션에 맞춤 라벨을 추가할 수 있습니다. 이러한 라벨은 프로파일러의 타임라인 뷰에 이름이 지정된 블록으로 표시되므로 데이터 준비, 순방향 패스 또는 옵티마이저 단계와 같은 특정 작업을 훨씬 쉽게 식별할 수 있습니다.
다음 예시에서는 학습 단계의 여러 부분에 컨텍스트를 추가하는 방법을 보여줍니다.
def forward(self, x): # This entire block will be labeled 'forward' in the trace with xp.Trace('forward'): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2(x), 2)) x = x.view(-1, 7*7*64) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1) # You can also nest context managers for more granular detail for batch_idx, (data, target) in enumerate(train_loader): with torch_xla.step(): with xp.Trace('train_step_data_prep_and_forward'): optimizer.zero_grad() data, target = data.to(device), target.to(device) output = model(data) with xp.Trace('train_step_loss_and_backward'): loss = loss_fn(output, target) loss.backward() with xp.Trace('train_step_optimizer_step_host'): optimizer.step() 전체 예시
다음 예시에서는 mnist_xla.py 파일을 기반으로 PyTorch XLA 스크립트에서 trace를 캡처하는 방법을 보여줍니다.
import torch import torch.optim as optim from torchvision import datasets, transforms # PyTorch/XLA specific imports import torch_xla import torch_xla.core.xla_model as xm import torch_xla.debug.profiler as xp def train_mnist(): # ... (model definition and data loading code) ... print("Starting training...") # ... (training loop as defined in the previous section) ... print("Training finished!") if __name__ == '__main__': # 1. Start the profiler server server = xp.start_server(9012) # 2. Start capturing the trace and define the output directory xp.start_trace('/root/logs/') # Run the training function that contains custom trace labels train_mnist() # 3. Stop the trace xp.stop_trace() trace 시각화
스크립트가 완료되면 trace 파일이 지정한 디렉터리(예: /root/logs/)에 저장됩니다. XProf 및 TensorBoard를 사용하여 이 trace를 시각화할 수 있습니다.
TensorBoard를 설치합니다.
pip install tensorboard_plugin_profile tensorboard
TensorBoard를 실행합니다. TensorBoard가
xp.start_trace()에서 사용한 로그 디렉터리를 가리키도록 합니다.tensorboard --logdir /root/logs/
프로필을 확인합니다. 웹브라우저에서 TensorBoard가 제공한 URL(일반적으로
http://localhost:6006)을 엽니다. PyTorch XLA - Profile 탭으로 이동하여 대화형 trace를 확인합니다. 만든 맞춤 라벨을 확인하고 모델의 여러 부분의 실행 시간을 분석할 수 있습니다.
Google Cloud 를 사용하여 워크로드를 실행하는 경우 cloud-diagnostics-xprof 도구를 사용하는 것이 좋습니다. Tensorboard 및 XProf를 실행하는 VM을 사용하여 간소화된 프로필 수집 및 보기 환경을 제공합니다.