import torch
import torch.nn as nn
import numpy as np
import time
import random
import os
# Ensure consistent results (optional, but good for demonstrations)
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
# Set default matmul precision (good practice for modern GPUs)
torch.set_float32_matmul_precision('high')PyTorch as a GPU Computation Library
Introduction
PyTorch is widely known as a leading deep learning framework. However, it’s far more than that! PyTorch is also a powerful and versatile library for general-purpose, high-performance numerical computation. In this notebook, we’ll explore how to leverage PyTorch for GPU-accelerated computing beyond deep learning. We’ll cover fundamental tensor operations, object-oriented structuring with nn.Module, and the significant performance boosts offered by torch.compile (PyTorch’s JIT compiler).
PyTorch as a Linear Algebra Library: Tensors and GPU Offloading
Let’s begin with the foundation: tensors and how to harness the power of the GPU.
# --- Tensor Creation ---
# Create tensors from lists
a = torch.tensor([1, 2, 3])
b = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
print(f"Tensor a: {a}")
print(f"Tensor b:\n{b}")# Create tensors with specific data types
c = torch.tensor([1, 2, 3], dtype=torch.float64)
print(f"Tensor c: {c}")# Create tensors filled with zeros, ones, or random numbers
zeros_tensor = torch.zeros(2, 3) # 2x3 tensor of zeros
ones_tensor = torch.ones(5) # 1D tensor of ones (length 5)
rand_uniform_tensor = torch.rand(3, 4) # 3x4, uniform random [0, 1)
rand_normal_tensor = torch.randn(2, 2) # 2x2, standard normal
print(f"Zeros Tensor:\n{zeros_tensor}")
print(f"Ones Tensor:\n{ones_tensor}")
print(f"Uniform Random Tensor:\n{rand_uniform_tensor}")
print(f"Standard Normal Random Tensor:\n{rand_normal_tensor}")# Create a tensor from a NumPy array
numpy_array = np.array([[1, 2], [3, 4]])
tensor_from_numpy = torch.from_numpy(numpy_array)
print(f"Tensor from NumPy Array:\n{tensor_from_numpy}")GPU Acceleration: Moving Tensors to the GPU
The real power of PyTorch lies in its seamless GPU integration. We can move tensors to the GPU using the .to() method. Let’s check for GPU availability and demonstrate moving tensors.
# --- GPU Check and Tensor Movement ---
# Check for CUDA availability
if torch.cuda.is_available():
device = torch.device('cuda')
print("CUDA is available! Using GPU.")
else:
device = torch.device('cpu')
print("CUDA not available. Using CPU.")# Create a tensor on the CPU, this is a better practice than creating on CPU and moving to GPU
x = torch.randn(5, 5)
print(f"x is initially on device: {x.device}")# Move the tensor to the GPU (if available)
x = x.to(device)
print(f"x is now on device: {x.device}")# Perform a simple operation (element-wise multiplication)
y = x * 2
print(f"y (result of x * 2) is on device: {y.device}")# Move the result back to the CPU (for printing, etc.)
y = y.to('cpu')
print(f"y is now on device: {y.device}")
print(f"y:\n{y}")CPU vs. GPU: Matrix Multiplication Benchmark
To truly appreciate the speedup offered by GPUs, let’s compare the execution time of a matrix multiplication on both the CPU and GPU. We’ll use large matrices to make the difference dramatic. Crucially, we use torch.cuda.synchronize() to ensure the GPU operation completes before we stop the timer.
# --- Matrix Multiplication Benchmark ---
size = 10000 # Large matrix size
# Create tensors on CPU and GPU
a_cpu = torch.randn(size, size)
b_cpu = torch.randn(size, size)
a_gpu = a_cpu.to(device) # Move to GPU if available
b_gpu = b_cpu.to(device)# CPU timing
start_time = time.perf_counter()
c_cpu = torch.matmul(a_cpu, b_cpu)
end_time = time.perf_counter()
cpu_time = end_time - start_time
print(f"CPU time: {cpu_time:.4f} seconds")# GPU timing (with synchronization)
start_time = time.perf_counter()
c_gpu = torch.matmul(a_gpu, b_gpu)
torch.cuda.synchronize() # Ensure GPU operation completes!
end_time = time.perf_counter()
gpu_time = end_time - start_time
print(f"GPU time: {gpu_time:.4f} seconds")print(f"Speedup: {cpu_time / gpu_time:.2f}x")Other Useful Linear Algebra Operations
PyTorch has a rich set of functions beyond matrix multiplication. Here are a few examples:
# --- Other Linear Algebra Operations ---
# Element-wise operations
a = torch.tensor([1, 2, 3], device=device)
b = torch.tensor([4, 5, 6], device=device)
print(f"Element-wise addition: {a + b}")
print(f"Element-wise subtraction: {a - b}")
print(f"Element-wise multiplication: {a * b}")
print(f"Element-wise division: {a / b}")# Reductions
x = torch.randn(3, 4, device=device)
print(f"Sum of all elements: {x.sum()}")
print(f"Mean of all elements: {x.mean()}")
print(f"Max element: {x.max()}")
print(f"Min element: {x.min()}")# Reshaping
y = x.view(12) # Reshape to a 1D tensor
print(f"Reshaped tensor (view): {y.shape}")
z = x.reshape(2, 6) # Reshape to a 2x6 tensor
print(f"Reshaped tensor (reshape): {z.shape}")# Slicing and indexing (similar to NumPy)
print(f"First row of x: {x[0, :]}")
print(f"Second column of x: {x[:, 1]}")PyTorch as an OOP Library: Introduction to nn.Module
While we won’t be building neural networks, PyTorch’s nn.Module class is incredibly useful for structuring any computation that involves parameters (values you want to manage or update). Think of nn.Module as a container for your operations and their associated data.
Defining a Custom nn.Module
Let’s create a simple module that performs a linear transformation (y = Ax + b). This will demonstrate the basic structure of an nn.Module. This example performs a projection onto the column space of a matrix A.
# --- Custom nn.Module: Projection onto Column Space ---
class ProjectIntoColumnSpace(nn.Module):
def __init__(self, A):
super().__init__()
# Use nn.Parameter to register A as a parameter of the module.
self.A = nn.Buffer(A)
# Lazily initialize P
def forward(self, x):
if not hasattr(self, 'P'):
# Compute the projection matrix
self.P = nn.Buffer(self.A @ torch.linalg.pinv(self.A))
# Project x onto the column space of A
return self.P @ x
# Create a random matrix A
A = torch.randn(10000, 1000)
# Initialize the module and move it to the GPU
model = ProjectIntoColumnSpace(A).to(device)print('Tensor A is on device:', model.A.device)# Create a random input tensor
x = torch.randn(10000, device=device)
# Perform the projection and time it
start_time = time.perf_counter()
y = model(x)
torch.cuda.synchronize() # synchronize before and after operation.
end_time = time.perf_counter()
print(f"Projection time (GPU): {end_time - start_time:.4f} seconds")
print(f"Output shape: {y.shape}")start_time = time.perf_counter()
y = model(x)
torch.cuda.synchronize() # synchronize before and after operation.
end_time = time.perf_counter()
print(f"Projection time (GPU): {end_time - start_time:.4f} seconds")
print(f"Output shape: {y.shape}")# Perform projection on CPU for comparison
model_cpu = model.to('cpu') # Move the model to the CPU
x_cpu = x.to('cpu') # Move the input to the CPU
start_time = time.perf_counter()
y_cpu = model_cpu(x_cpu)
end_time = time.perf_counter()
print(f"Projection time (CPU): {end_time - start_time:.4f} seconds")
print(f"Output shape: {y_cpu.shape}")Key Benefits of nn.Module:
- Organization: Keeps parameters and computation logic together.
- Parameter Management: Easy access to all parameters (e.g.,
model.parameters()). - Device Management: Moving the module to the GPU (e.g.,
.to(device)) automatically moves all its parameters and buffers. - Buffers: Buffers are like parameters but are not optimized.
PyTorch as a JIT Language: Leveraging torch.compile
While GPU acceleration provides a significant boost, PyTorch’s default “eager” execution mode can have overhead. Each operation is executed individually. torch.compile addresses this by acting as a Just-In-Time (JIT) compiler, optimizing your code for even greater performance.
Understanding Python Overhead
# --- Pure Python Implementation ---
def monte_carlo_pi_python(n_samples):
inside_circle = 0
for _ in range(n_samples):
x = random.uniform(-1, 1) # Random x coordinate
y = random.uniform(-1, 1) # Random y coordinate
if x**2 + y**2 <= 1: # Check if inside the circle
inside_circle += 1
return 4 * inside_circle / n_samples
n_samples = 10000000
# expect this to be slow
start_time = time.perf_counter()
pi_python = monte_carlo_pi_python(n_samples)
end_time = time.perf_counter()
print(f"Pure Python Pi Estimate: {pi_python:.6f}, Time: {end_time - start_time:.4f} seconds")# --- PyTorch Implementation (Vectorized) ---
def monte_carlo_pi_pytorch(n_samples, device):
x = torch.rand(n_samples, device=device) * 2 - 1 # Range [-1, 1]
y = torch.rand(n_samples, device=device) * 2 - 1
inside_circle = (x**2 + y**2 <= 1).sum() # Count points inside
return 4 * inside_circle.float() / n_samples # Return to float
start_time = time.perf_counter()
pi_pytorch = monte_carlo_pi_pytorch(n_samples, device)
torch.cuda.synchronize() # This is important for the CPU as well.
end_time = time.perf_counter()
print(f"PyTorch (Uncompiled) Pi Estimate: {pi_pytorch:.6f}, Time: {end_time - start_time:.4f} seconds")Introducing torch.compile
Vectorization reduces the overhead of Python, but you can go even further with torch.compile. Let’s see how it works and the performance benefits it offers.
torch.compile analyzes your PyTorch code and generates optimized code, often using the Triton kernel language for GPUs. This can lead to:
- Operator Fusion: Combining multiple operations into a single kernel.
- Kernel Specialization: Generating code tailored to specific tensor shapes and data types.
- Reduced Overhead: Minimizing the communication between the CPU and GPU.
Basic torch.compile Usage
The simplest way to use torch.compile is to pass your function to torch.compile. Let’s see how this works with a simple example.
# PyTorch (Compiled) - Some compilation is done when the function is passed to torch.compile
start_time = time.perf_counter()
monte_carlo_pi_pytorch_compiled = torch.compile(monte_carlo_pi_pytorch)
print(f'Compiled Time: {time.perf_counter() - start_time:.4f} seconds')# PyTorch (Compiled) - The first run will also do some compilation
start_time = time.perf_counter()
pi_compiled = monte_carlo_pi_pytorch_compiled(n_samples, device)
torch.cuda.synchronize()
end_time = time.perf_counter()
print(f"Time for warm-up run: {pi_compiled:.6f}, Time: {end_time - start_time:.4f} seconds")# PyTorch (Compiled) - Subsequent runs will be faster
start_time = time.perf_counter()
pi_compiled = monte_carlo_pi_pytorch_compiled(n_samples, device)
torch.cuda.synchronize()
end_time = time.perf_counter()
print(f"PyTorch (Compiled) Pi Estimate: {pi_compiled:.6f}, Time: {end_time - start_time:.4f} seconds")torch.compile provides significant speedups in this Monte Carlo simulation primarily by eliminating Python interpreter overhead and fusing operations. The pure Python version is slow because each iteration of the loop involves many individual Python operations. The uncompiled PyTorch version is faster due to vectorized operations, but still launches separate kernels for random number generation, squaring, addition, comparison, and summation. torch.compile, however, analyzes the entire function and generates a single, optimized kernel that performs all these steps in one go, drastically reducing the communication between CPU and GPU and minimizing kernel launch overhead. This is a classic example of operator fusion, a key optimization technique employed by torch.compile.
However, torch.compile is not a silver bullet. If the overhead of the Python interpreter is negligible compared to the computation, you may not see a significant speedup, and the compilation cost may outweigh the benefits.
# --- When compiling doesn't help ---
def element_wise_mult(A, B):
return A * B
A = torch.randn(10000, 10000, device=device)
B = torch.randn(10000, 10000, device=device)
# Uncompiled version
start_time = time.perf_counter()
result_uncompiled = element_wise_mult(A, B)
torch.cuda.synchronize()
end_time = time.perf_counter()
print(f"Uncompiled function time: {end_time - start_time:.4f} seconds")# Compiled version
compiled_element_wise_mult = torch.compile(element_wise_mult)
# Warm-up (important for JIT compilers)
start_time = time.perf_counter()
compiled_element_wise_mult(A, B)
torch.cuda.synchronize()
print(f'Time for warm-up run: {time.perf_counter() - start_time:.4f} seconds')# After warm-up
start_time = time.perf_counter()
result_compiled = compiled_element_wise_mult(A, B)
torch.cuda.synchronize()
end_time = time.perf_counter()
print(f"Compiled function time: {end_time - start_time:.4f} seconds")Recompilation and Data Types
Like in Julia, torch.compile recompiles functions when the input types change. This can introduce overhead, so it’s important to be aware of when recompilation occurs.
# --- Recompilation Examples ---
# 1. Recompilation due to Data Type Change
x_int = torch.ones(10000, 10000, device=device, dtype=torch.int64)
y_int = torch.ones(10000, 10000, device=device, dtype=torch.int64)
start_time = time.perf_counter()
compiled_element_wise_mult(x_int, y_int) # Compile for int64
torch.cuda.synchronize()
end_time = time.perf_counter()
print(f"First call (int64) time: {end_time - start_time:.4f} seconds")start_time = time.perf_counter()
compiled_element_wise_mult(x_int, y_int) # Reuse compiled int64 version
torch.cuda.synchronize()
end_time = time.perf_counter()
print(f"Second call (int64) time: {end_time - start_time:.4f} seconds")# 2. Recompilation due to Device Change
x_cpu = x_int.cpu()
y_cpu = y_int.cpu()
start_time = time.perf_counter()
compiled_element_wise_mult(x_cpu, y_cpu) # Recompile for CPU
# No need for torch.cuda.synchronize() on CPU
end_time = time.perf_counter()
print(f"Call with CPU tensors time: {end_time - start_time:.4f} seconds")Recompilation and Dynamic Shapes
Unlike in Julia, by default, torch.compile tries to be smart about input tensor sizes. The first time it compiles a function, it generates a specialized kernel that’s optimized for the specific input sizes it encountered. However, if it sees inputs of different sizes later, it will attempt to recompile with a more dynamic kernel that can handle a range of sizes, avoiding further recompilations (within limits). This behavior can be controlled using the dynamic argument to torch.compile (though we won’t dive into that here).
# --- Recompilation and Dynamic Shapes ---
# We'll use compiled_element_wise_mult from before:
# @torch.compile
# def compiled_element_wise_mult(A, B):
# return A * B
# We have already compiled this function with tensors of size 10000x10000.
# Different Size (Triggers Recompilation with Dynamic Kernel)
x_different = torch.ones(2000, 2000, device=device, dtype=torch.int64)
y_different = torch.ones(2000, 2000, device=device, dtype=torch.int64)
start_time = time.perf_counter()
compiled_element_wise_mult(x_different, y_different) # Will try to recompilation for dynamic shapes
torch.cuda.synchronize()
end_time = time.perf_counter()
print(f"First call (different size) time: {end_time - start_time:.4f} seconds")# Smaller Size (Uses Dynamic Kernel - No Recompilation)
x_smaller = torch.ones(500, 500, device=device, dtype=torch.int64)
y_smaller = torch.ones(500, 500, device=device, dtype=torch.int64)
start_time = time.perf_counter()
compiled_element_wise_mult(x_smaller, y_smaller) # Uses dynamic kernel
torch.cuda.synchronize()
end_time = time.perf_counter()
print(f"First call (smaller size) time: {end_time - start_time:.4f} seconds")# Larger Size (Uses Dynamic Kernel - No Recompilation)
x_larger = torch.ones(20000, 20000, device=device, dtype=torch.int64)
y_larger = torch.ones(20000, 20000, device=device, dtype=torch.int64)
start_time = time.perf_counter()
compiled_element_wise_mult(x_larger, y_larger) # Uses dynamic kernel
torch.cuda.synchronize()
end_time = time.perf_counter()
print(f"First call (larger size) time: {end_time - start_time:.4f} seconds")Key Takeaways:
torch.compileaims to balance specialization (for best performance) and dynamism (to avoid excessive recompilation).- The
dynamic=True/False/Noneargument totorch.compilegives you more control over this behavior, but the default (None) usually works well.
Compiling Different if-else Branches
Python is not a native JIT language, so torch.compile can’t handle arbitrary Python control flow. However, it can handle if-else statements where the condition depends on static values (like function arguments), not on the tensor data itself. It compiles a separate version for each branch.
# --- Compiling Different if-else Branches ---
# Another approach to using torch.compile is as a decorator
@torch.compile
def if_func(x, y, mode):
if mode == 0:
return x + y
else:
return x - y
x = torch.randn(1000, 1000, device=device)
y = torch.randn(1000, 1000, device=device)
# Call with mode=0 (compiles the addition branch)
start_time = time.perf_counter()
result_0 = if_func(x, y, 0)
torch.cuda.synchronize()
end_time = time.perf_counter()
print(f"Time (mode=0, first call): {end_time - start_time:.4f} seconds")# Call again with mode=0 (reuses the compiled addition branch)
start_time = time.perf_counter()
result_0_again = if_func(x, y, 0)
torch.cuda.synchronize()
end_time = time.perf_counter()
print(f"Time (mode=0, second call): {end_time - start_time:.4f} seconds")# Call with mode=1 (compiles the subtraction branch)
start_time = time.perf_counter()
result_1 = if_func(x, y, 1)
torch.cuda.synchronize()
end_time = time.perf_counter()
print(f"Time (mode=1, first call): {end_time - start_time:.4f} seconds")# Call again with mode=1 (reuses the compiled subtraction branch)
start_time = time.perf_counter()
result_1_again = if_func(x, y, 1)
torch.cuda.synchronize()
end_time = time.perf_counter()
print(f"Time (mode=1, second call): {end_time - start_time:.4f} seconds")Recompilation Summary
torch.compile is a powerful tool for optimizing PyTorch code, but it’s somewhat a “hack” on top of Python, so it has very complex behavior on recompilation. Some rules of thumb are:
- Recompilation is triggered when the input types or devices change. (Same as in Julia)
- A “static” kernel for a specific input size is generated by default, but a dynamic kernel that can handle a range of sizes is generated it sees inputs of different sizes later. (Different from Julia)
- Each branch of an
if-elsestatement is compiled separately. (Different from Julia) torch.compileis designed to be user-friendly, but it has many complex behaviors that support this goal. See best practices at https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html
Graph Breaks
torch.compile aims to capture the whole function in a single computation graph for full optimization. However, it can’t always do this. When it encounters code that can’t be traced, a “graph break” occurs, and torch.compile compiles codes before and after the break in separate graphs. This prevents optimization through the entire function and should be avoided when possible.
Graph breaks occur on things like:
- Data-dependent if-statements
- Many Python built-in functions
- C functions
# --- Graph Break Example ---
# Data-dependent branching is not supported
def data_dependent_branch(x, y):
if x.sum() > 0:
x = x + y
return x
compiled_data_dependent_branch = torch.compile(data_dependent_branch, fullgraph=True) # Force full graph compilation
x = torch.ones(10000, device=device)
y = torch.ones(10000, device=device)
# Compilation will fail here
result_pos = compiled_data_dependent_branch(x, y)# We can compile this function with fullgraph=False, but it will lose the maximum performance benefit
compiled_data_dependent_branch = torch.compile(data_dependent_branch, fullgraph=False) # Allow partial graph compilation
start_time = time.perf_counter()
result_pos = compiled_data_dependent_branch(x, y)
torch.cuda.synchronize()
print(f"Time (first call): {time.perf_counter() - start_time:.4f} seconds")# not supported Python built-in functions
# time.time() is not supported
@torch.compile
def compiled_unsupported_func(x):
x = x * 2
return x + time.time()
# Will raise a warning, but still usable
start_time = time.perf_counter()
result = compiled_unsupported_func(x)
torch.cuda.synchronize()
print(f"Time: {time.perf_counter() - start_time:.4f} seconds")Inspecting Compiled Code
You can inspect the compiled code using the torch._dynamo.explain. This is helpful for understanding how torch.compile is transforming your code, but it can be quite verbose.
# --- Inspecting the Compiled Graph ---
'''
def data_dependent_branch(x, y):
if x.sum() > 0:
x = x + y
return x
'''
# Print the compiled graph
print(torch._dynamo.explain(data_dependent_branch)(x, y))'''
def compiled_unsupported_func(x):
x = x * 2
return x + time.time()
'''
# Print the compiled graph
print(torch._dynamo.explain(compiled_unsupported_func)(x))