Skip to content

Python API Reference

The wave_gpu module provides Python bindings for WAVE GPU compute. It exposes array creation, kernel compilation, device intrinsics, and a kernel decorator for writing GPU kernels in Python.

Terminal window
pip install wave-gpu

wave_gpu.array(data, dtype="f32") -> WaveArray

Section titled “wave_gpu.array(data, dtype="f32") -> WaveArray”

Create a device array from a Python list or iterable.

import wave_gpu
a = wave_gpu.array([1.0, 2.0, 3.0])
b = wave_gpu.array([1, 2, 3, 4], dtype="u32")

Parameters:

ParameterTypeDefaultDescription
datalist or iterable(required)Source data to upload to the device.
dtypestr"f32"Element type. One of: "f16", "f32", "f64", "i32", "u32".

Returns: WaveArray - a handle to a device-resident buffer.

Raises: ValueError if dtype is not a supported type. TypeError if elements cannot be converted to the target type.


wave_gpu.zeros(n, dtype="f32") -> WaveArray

Section titled “wave_gpu.zeros(n, dtype="f32") -> WaveArray”

Create a device array of n zeros.

buf = wave_gpu.zeros(1024)
buf_int = wave_gpu.zeros(256, dtype="i32")

Parameters:

ParameterTypeDefaultDescription
nint(required)Number of elements.
dtypestr"f32"Element type.

Returns: WaveArray


wave_gpu.ones(n, dtype="f32") -> WaveArray

Section titled “wave_gpu.ones(n, dtype="f32") -> WaveArray”

Create a device array of n ones.

buf = wave_gpu.ones(512, dtype="f64")

Parameters:

ParameterTypeDefaultDescription
nint(required)Number of elements.
dtypestr"f32"Element type.

Returns: WaveArray


A handle to a device-resident buffer. WaveArray objects are returned by all array creation functions and by kernel outputs.

PropertyTypeDescription
datamemoryviewRaw byte view of the device buffer contents (triggers a device-to-host copy).
dtypestrElement type string (e.g., "f32", "u32").

Copy the buffer contents back to the host and return as a Python list.

a = wave_gpu.array([1.0, 2.0, 3.0])
print(a.to_list()) # [1.0, 2.0, 3.0]

Return the number of elements in the buffer.

a = wave_gpu.zeros(128)
len(a) # 128

Access a single element by index. Triggers a device-to-host copy of that element.

a = wave_gpu.array([10.0, 20.0, 30.0])
a[1] # 20.0

Negative indexing is supported. Slicing is not currently supported.


Decorator that marks a Python function as a GPU kernel. The function body is compiled to WAVE binary format at decoration time.

@wave_gpu.kernel
def vector_add(a: wave_gpu.f32, b: wave_gpu.f32, out: wave_gpu.f32):
tid = wave_gpu.thread_id()
out[tid] = a[tid] + b[tid]

The decorated function becomes a callable that accepts WaveArray arguments and dispatch parameters:

a = wave_gpu.array([1.0, 2.0, 3.0, 4.0])
b = wave_gpu.array([5.0, 6.0, 7.0, 8.0])
out = wave_gpu.zeros(4)
vector_add(a, b, out, grid=(4, 1, 1), workgroup=(4, 1, 1))
print(out.to_list()) # [6.0, 8.0, 10.0, 12.0]

Dispatch parameters (keyword arguments):

ParameterTypeDefaultDescription
gridtuple[int, int, int](required)Global grid dimensions (x, y, z).
workgrouptuple[int, int, int](required)Workgroup dimensions (x, y, z).

Query the detected GPU device.

dev = wave_gpu.device()
print(dev.vendor) # "AMD", "NVIDIA", "Intel", or "Unknown"
print(dev.name) # e.g., "AMD Radeon RX 7900 XTX"

Returns: DeviceInfo with the following attributes:

AttributeTypeDescription
vendorstrGPU vendor string.
namestrGPU device name.

The following intrinsics are available inside @wave_gpu.kernel functions. They must not be called outside of a kernel context.

Returns the global thread index of the current invocation.

Returns the workgroup index of the current invocation.

Returns the lane index within the current wave (0 to wave_width() - 1).

Returns the wave width (number of lanes per wave). Typically 32 or 64 depending on hardware.

Synchronize all threads in the current workgroup. All threads must reach the barrier before any thread proceeds past it. Must not be called inside divergent control flow.


Type annotations used in kernel function signatures to declare buffer element types.

TypeDescriptionSize
wave_gpu.f1616-bit IEEE 754 half-precision float2 bytes
wave_gpu.f3232-bit IEEE 754 single-precision float4 bytes
wave_gpu.f6464-bit IEEE 754 double-precision float8 bytes
wave_gpu.i3232-bit signed integer4 bytes
wave_gpu.u3232-bit unsigned integer4 bytes

These types are also valid values for the dtype string parameter: "f16", "f32", "f64", "i32", "u32".