Skip to content

Rust API Reference

The wave_sdk crate provides Rust bindings for WAVE GPU compute. It exposes device buffer management, device detection, and kernel compilation and dispatch.

Add to your Cargo.toml:

[dependencies]
wave_sdk = "0.1"

Buffer creation and device memory management.

pub enum ElementType {
F16,
F32,
F64,
I32,
U32,
}

Represents the element type of a device buffer.

pub struct DeviceBuffer { /* opaque */ }

A handle to a device-resident buffer. Created via the free functions in this module.

Fields (read-only):

FieldTypeDescription
countusizeNumber of elements in the buffer.

Methods:

Copy buffer contents to a host Vec<f32>. Panics if the buffer element type is not F32.

let buf = wave_sdk::array::from_f32(&[1.0, 2.0, 3.0]);
let host: Vec<f32> = buf.to_f32();
assert_eq!(host, vec![1.0, 2.0, 3.0]);

Copy buffer contents to a host Vec<u32>. Panics if the buffer element type is not U32.

let buf = wave_sdk::array::from_u32(&[10, 20, 30]);
let host: Vec<u32> = buf.to_u32();
assert_eq!(host, vec![10, 20, 30]);

Create a device buffer from a host f32 slice. Copies the data to device memory.

let buf = wave_sdk::array::from_f32(&[1.0, 2.0, 3.0, 4.0]);
assert_eq!(buf.count, 4);

Create a device buffer of n zero-valued f32 elements.

let buf = wave_sdk::array::zeros_f32(1024);
assert_eq!(buf.count, 1024);

Create a device buffer from a host u32 slice. Copies the data to device memory.

let buf = wave_sdk::array::from_u32(&[0, 1, 2, 3]);
assert_eq!(buf.count, 4);

Create a device buffer of n zero-valued u32 elements.

let buf = wave_sdk::array::zeros_u32(256);
assert_eq!(buf.count, 256);

GPU device detection and information.

pub enum GpuVendor {
AMD,
NVIDIA,
Intel,
Unknown,
}
pub struct Device {
pub vendor: GpuVendor,
pub name: String,
}

Information about a detected GPU device.

Detect the first available GPU. Returns None if no supported GPU is found.

use wave_sdk::device;
if let Some(dev) = device::detect() {
println!("Found GPU: {} ({:?})", dev.name, dev.vendor);
} else {
eprintln!("No GPU detected");
}

Kernel compilation and dispatch.

pub enum Language {
Python,
Rust,
Cpp,
TypeScript,
}

Source language for kernel compilation.

pub struct CompiledKernel { /* opaque */ }

A compiled WAVE kernel ready for dispatch.

compile(source: &str, lang: Language) -> Result<CompiledKernel, String>

Section titled “compile(source: &str, lang: Language) -> Result<CompiledKernel, String>”

Compile a kernel from source code. Returns a CompiledKernel on success, or an error string describing the compilation failure.

use wave_sdk::kernel::{compile, Language};
let source = r#"
@wave_gpu.kernel
def add(a: f32, b: f32, out: f32):
tid = thread_id()
out[tid] = a[tid] + b[tid]
"#;
let kernel = compile(source, Language::Python).expect("compilation failed");

launch(&self, device: &Device, buffers: &[&DeviceBuffer], scalars: &[u32], grid: [u32; 3], workgroup: [u32; 3]) -> Result<(), String>

Section titled “launch(&self, device: &Device, buffers: &[&DeviceBuffer], scalars: &[u32], grid: [u32; 3], workgroup: [u32; 3]) -> Result<(), String>”

Dispatch the kernel on the given device.

use wave_sdk::{array, device, kernel};
let dev = device::detect().expect("no GPU");
let a = array::from_f32(&[1.0, 2.0, 3.0, 4.0]);
let b = array::from_f32(&[5.0, 6.0, 7.0, 8.0]);
let out = array::zeros_f32(4);
let k = kernel::compile(src, kernel::Language::Python).unwrap();
k.launch(
&dev,
&[&a, &b, &out],
&[], // no scalar arguments
[4, 1, 1], // grid dimensions
[4, 1, 1], // workgroup dimensions
).expect("kernel launch failed");
let result = out.to_f32();
assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);

Parameters:

ParameterTypeDescription
device&DeviceTarget GPU device.
buffers&[&DeviceBuffer]Device buffers bound as kernel arguments, in declaration order.
scalars&[u32]Scalar values passed as kernel arguments (pushed as 32-bit constants).
grid[u32; 3]Global grid dimensions [x, y, z].
workgroup[u32; 3]Workgroup dimensions [x, y, z].

Returns: Ok(()) on success. Err(String) if dispatch fails (e.g., invalid grid dimensions, device error).