Skip to content

Python SDK

The wave_gpu Python package lets you write GPU kernels in Python and run them on any supported GPU backend. This guide walks through installation, device setup, array management, kernel authoring, and error handling.

Install from PyPI:

Terminal window
pip install wave-gpu

Import the package in your code:

import wave_gpu

Before launching kernels you need to confirm that a GPU is available. wave_gpu.device() returns a DeviceInfo object describing the first detected device:

dev = wave_gpu.device()
print(dev.vendor) # e.g. "Apple"
print(dev.name) # e.g. "Apple M2 Max"

If no GPU is found the call raises a RuntimeError.

WaveArray is the primary buffer type. It lives in device-accessible memory. Create one with the factory helpers:

# From existing Python data
a = wave_gpu.array([1.0, 2.0, 3.0], dtype="f32")
# Pre-filled buffers
z = wave_gpu.zeros(1024, dtype="f32")
o = wave_gpu.ones(1024, dtype="f32")

Supported dtypes: f16, f32, f64, i32, u32.

WaveArray exposes a few useful properties and methods:

MemberDescription
dataRaw underlying data
dtypeElement type string
len(arr)Number of elements (__len__)
arr[i]Element access (__getitem__)
arr.to_list()Copy contents back to a Python list

Decorate a plain Python function with @wave_gpu.kernel to mark it as a GPU kernel. Inside the kernel body you can use WAVE intrinsics to determine the current thread’s position and synchronize:

@wave_gpu.kernel
def vector_add(a, b, out, n):
tid = wave_gpu.thread_id()
if tid < n:
out[tid] = a[tid] + b[tid]
IntrinsicDescription
thread_id()Global thread index
workgroup_id()Workgroup (block) index
lane_id()Lane within the current wave/warp
wave_width()Number of lanes per wave
barrier()Workgroup-level synchronization barrier

Kernel parameters must be either WaveArray objects (for buffer access) or plain int / float values (for scalar uniforms).

Call the decorated function directly. WAVE compiles the kernel on first invocation and caches the result. You can optionally specify the dispatch grid and workgroup size:

n = 1024
a = wave_gpu.array([float(i) for i in range(n)], dtype="f32")
b = wave_gpu.ones(n, dtype="f32")
out = wave_gpu.zeros(n, dtype="f32")
# Simple launch - WAVE infers a 1-D grid from the first buffer
vector_add(a, b, out, n)
# Explicit grid and workgroup dimensions
vector_add(a, b, out, n, grid=(n // 256, 1, 1), workgroup=(256, 1, 1))

After a kernel completes, read data back to the host with to_list():

result = out.to_list()
print(result[:8]) # [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]

You can also index individual elements for quick spot-checks:

assert out[0] == 1.0

WAVE surfaces two main exception types:

  • RuntimeError - raised when kernel compilation fails (e.g. unsupported intrinsic, invalid shader generation) or when no device is found.
  • TypeError - raised when kernel arguments do not match the expected signature (wrong type or count).

Handle them with standard try / except:

try:
broken_kernel(a, b, out, n)
except RuntimeError as e:
print(f"Compilation or device error: {e}")
except TypeError as e:
print(f"Bad kernel arguments: {e}")

See the full Python API Reference for detailed method signatures and advanced options.