baybridge is an in-progress AMD-first kernel DSL and portability layer.
Today it provides:
- a Python frontend with an example-friendly
import baybridge as cutesurface - a portable kernel IR
- textual/reference lowering backends (
mlir_text,gpu_text,gpu_mlir,waveasm_ref,aster_ref,hipkittens_ref, andptx_ref) - executable HIP lowering for a narrow validated subset
- an executable PTX backend for a narrow validated subset via the CUDA driver JIT
- an optional executable HipKittens backend for a narrow BF16 GEMM subset
- a reference runtime for basic examples and
@jitwrapper execution
It does not yet provide:
- PTX-to-AMDGCN translation
- broad executable lowering for the full DSL surface
These paths are usable now:
- basic
@kerneland@jitauthoring - topology builtins such as
arch.thread_idx(),arch.block_idx(),block_dim(),grid_dim(),lane_id() - runtime tensors
- direct scalar indexing in traced kernels:
g_c[row, col] = g_a[row, col] + g_b[row, col] compile(...)from explicitTensorSpecannotationscompile(...)from sample runtime arguments@jitwrappers that launch a single traceable kernel- runtime fallback for simple
@jitwrappers that cannot be traced yet
Current boundary:
- Python control flow on traced dynamic values is not supported.
- If a launched kernel uses
if thread_idx(...) == ...style control flow,compile(...)falls back to the reference runtime for@jitwrappers instead of producing IR. - Direct
@kernelcompilation fails in that case rather than silently degrading.
baybridge is currently tested with a repo-local .venv.
git clone <repo>
cd baybridge
python3 -m venv .venv
. .venv/bin/activate
python -m pip install -U pip
python -m pip install -e '.[dev]'Run the test suite:
PYTHONDONTWRITEBYTECODE=1 PYTHONPATH=src .venv/bin/python -m pytest -q -p no:cacheproviderThis runs immediately through the reference runtime.
import baybridge as cute
@cute.kernel
def hello_world_kernel():
tidx, _, _ = cute.arch.thread_idx()
if tidx == 0:
cute.printf("Hello world")
@cute.jit
def hello_world():
cute.printf("hello world")
hello_world_kernel().launch(grid=(1, 1, 1), block=(32, 1, 1))
hello_world()
compiled = cute.compile(hello_world)
compiled()Notes:
hello_world()executes directly on the reference runtime.compile(hello_world)succeeds, but because the launched kernel uses dynamic Python control flow, the compiled artifact executes through the runtime path rather than producing portable IR.
This kernel is traceable because it avoids dynamic Python if.
import baybridge as cute
@cute.kernel(launch=cute.LaunchConfig(grid=(1, 1, 1), block=(4, 1, 1)))
def indexed_add_kernel(
g_a: cute.Tensor,
g_b: cute.Tensor,
g_c: cute.Tensor,
):
tidx, _, _ = cute.arch.thread_idx()
_, n = g_a.shape
col = tidx % n
row = tidx // n
g_c[row, col] = g_a[row, col] + g_b[row, col]
a = cute.tensor([[1, 2], [3, 4]], dtype="f32")
b = cute.tensor([[10, 20], [30, 40]], dtype="f32")
c = cute.zeros((2, 2), dtype="f32")
artifact = cute.compile(indexed_add_kernel, a, b, c, backend="gpu_text")
print(artifact.ir.to_dict())
print(artifact.lowered_module.text)This produces:
- portable IR containing
thread_idx,constant,mod,floordiv,load,add, andstore gpu_textoutput containingmemref.loadandmemref.store
If a @jit wrapper only specializes and launches a traceable kernel, compile(...) returns the launched kernel IR.
import baybridge as cute
@cute.jit
def indexed_add(m_a: cute.Tensor, m_b: cute.Tensor, m_c: cute.Tensor):
indexed_add_kernel(m_a, m_b, m_c).launch(grid=(1, 1, 1), block=(4, 1, 1))
compiled = cute.compile(indexed_add, a, b, c, backend="portable")
print(compiled.ir.name) # indexed_add_kernel
print(compiled.ir.metadata["wrapped_by"]) # indexed_add- the
kerneldecorator- device-style kernel definition
- direct
compile(...)produces kernel IR when the body is traceable
- the
jitdecorator- host-style specialization/wrapper function
- can execute directly on the reference runtime
- can compile either to launched kernel IR or to a runtime fallback artifact
You can supply tensor information in two ways:
- Explicit annotations
@cute.kernel
def copy_kernel(
src: cute.TensorSpec(shape=(128,), dtype="f16"),
dst: cute.TensorSpec(shape=(128,), dtype="f16"),
):
cute.copy(src, dst)- Sample runtime arguments
a = cute.tensor([[1, 2], [3, 4]], dtype="f32")
b = cute.tensor([[10, 20], [30, 40]], dtype="f32")
c = cute.zeros((2, 2), dtype="f32")
artifact = cute.compile(indexed_add_kernel, a, b, c)The reference runtime uses lightweight CPU-side tensors:
import baybridge as cute
a = cute.tensor([[1, 2, 3], [4, 5, 6]], dtype="f32")
b = cute.zeros((2, 3), dtype="f32")
print(a.shape) # (2, 3)
print(a.dtype) # f32
print(a.tolist()) # [[1, 2, 3], [4, 5, 6]]Useful helpers:
baybridge.tensor(data, dtype=...)baybridge.zeros(shape, dtype=...)baybridge.clone(tensor)baybridge.size(value, mode=...)
For a current backend inventory, validation matrix, and benchmark notes, see docs/backend-status.md.
compile(...) supports:
backend="portable"- stores portable IR only
backend="mlir_text"- lowers portable IR to a baybridge-specific textual form
backend="gpu_text"- lowers portable IR to GPU/ROCDL-flavored textual IR
backend="gpu_mlir"- lowers portable IR to a stricter MLIR GPU/module form suitable for external backend consumption
backend="ptx_ref"- lowers a narrow exact Baybridge subset to driver-JIT-loadable PTX text
- current validated subset:
- canonical indexed rank-1 dense copy:
f32i32i1
- canonical indexed rank-1 dense pointwise:
f32:add/sub/mul/div/max/min/atan2i32:add/sub/mul/div/max/min,bitand,bitor,bitxori1:and/or/xor
- canonical indexed rank-1 unary
neg/abs:f32i32
- canonical indexed rank-1 integer
bitnot:i32i1
- canonical indexed rank-1 compare-to-
i1:f32:cmp_lt/le/gt/ge/eq/nei32:cmp_lt/le/gt/ge/eq/ne
- canonical indexed rank-1 scalar-broadcast compare-to-
i1from:- a scalar kernel parameter
- a rank-1 extent-1 tensor
f32:cmp_lt/le/gt/ge/eq/nei32:cmp_lt/le/gt/ge/eq/ne
- canonical indexed rank-1
selectfromi1predicate tensors:- tensor-vs-tensor branch dtypes:
f32i32i1
- scalar-branch forms from:
- a scalar kernel parameter
- a rank-1 extent-1 tensor
- scalar-branch dtypes:
f32i32i1
- tensor-vs-tensor branch dtypes:
- canonical indexed rank-1 unary:
f32:round/floor/ceil/trunc/sqrt/rsqrt/sin/cos/acos/asin/atan/exp/exp2/log/log2/log10/erf
- canonical indexed rank-1 scalar broadcast from:
- a scalar kernel parameter
- a rank-1 extent-1 tensor
- supported ops:
f32:add/sub/mul/div/max/min/atan2i32:add/sub/mul/div/max/min,bitand,bitor,bitxor
- exact current launch contract for indexed rank-1 families:
grid=(grid.x >= 1,1,1),block=(block.x >= 1,1,1)grid.x * block.x >= extent
- direct
threadIdx.xrank-1 dense copy:f32i32i1
- direct
threadIdx.xrank-1 dense pointwise:f32:add/sub/mul/div/max/min/atan2i32:add/sub/mul/div/max/min,bitand,bitor,bitxori1:and/or/xor
- direct
threadIdx.xrank-1 unaryneg/abs:f32i32
- direct
threadIdx.xrank-1 integerbitnot:i32i1
- direct
threadIdx.xrank-1 compare-to-i1:f32:cmp_lt/le/gt/ge/eq/nei32:cmp_lt/le/gt/ge/eq/ne
- direct
threadIdx.xrank-1 scalar-broadcast compare-to-i1from:- a scalar kernel parameter
- a rank-1 extent-1 tensor
f32:cmp_lt/le/gt/ge/eq/nei32:cmp_lt/le/gt/ge/eq/ne
- direct
threadIdx.xrank-1selectfromi1predicate tensors:- tensor-vs-tensor branch dtypes:
f32i32i1
- scalar-branch forms from:
- a scalar kernel parameter
- a rank-1 extent-1 tensor
- scalar-branch dtypes:
f32i32i1
- tensor-vs-tensor branch dtypes:
- direct
threadIdx.xrank-1 unary:f32:abs,neg,round,floor,ceil,trunc,sqrt/rsqrt,sin,cos,acos,asin,atan,exp,exp2,log,log2,log10,erfi32:abs,neg
- direct
threadIdx.xscalar broadcast from either:- a scalar kernel parameter
- a rank-1 extent-1 tensor
- supported ops:
f32:add/sub/mul/div/max/min/atan2i32:add/sub/mul/div/max/min,bitand,bitor,bitxor
- exact current launch contract for direct rank-1 families:
grid=(1,1,1),block=(block.x >= extent,1,1)
- exact rank-1
copy_reduce:bb.copy(copy_reduce_atom, src, dst)- supported reductions:
addmaxminandorxor
- exact current launch contracts:
- serial:
grid=(1,1,1),block=(1,1,1) - direct single-block:
grid=(1,1,1),block=(block.x >= extent,1,1) - indexed multi-block:
grid=(grid.x >= 1,1,1),block=(block.x >= 1,1,1),grid.x * block.x >= extent
- serial:
- exact 2D
f32/i32/i1dense copy bundles:bb.copy(src, dst)- exact current launch contract:
grid=(1,1,1),block=(1,1,1)
- exact 2D
copy_reducebundles:bb.copy(copy_reduce_atom, src, dst)- supported reductions:
f32:add,max,mini32:add,max,min,and,or,xor
- exact current launch contracts:
- serial:
grid=(1,1,1),block=(1,1,1) - parallel:
grid=(1,1,1),block=(power_of_two,1,1),block.x >= cols
- serial:
- exact 2D
f32/i32scalar-broadcast bundles from scalar kernel parameters:dst.store(src.load() <op> alpha)- supported ops:
f32:add/sub/mul/div/max/min/atan2i32:add/sub/mul/div/max/min,bitand,bitor,bitxor
- exact current launch contract:
grid=(1,1,1),block=(1,1,1)
- exact 2D
f32/i32scalar-broadcast compare-to-i1bundles from scalar kernel parameters:dst.store(src.load() <cmp> alpha)- supported ops:
cmp_lt/le/gt/ge/eq/ne
- exact current launch contracts:
- serial:
grid=(1,1,1),block=(1,1,1) - parallel:
grid=(1,1,1),block=(power_of_two,1,1),block.x >= cols
- serial:
- exact 2D
f32/i32tensor-source scalar-broadcast bundles from rank-1 extent-1 tensors:dst.store(src.load() <op> alpha[0])- supported ops:
f32:add/sub/mul/div/max/min/atan2i32:add/sub/mul/div/max/min,bitand,bitor,bitxor
- exact current launch contracts:
- serial:
grid=(1,1,1),block=(1,1,1) - parallel:
grid=(1,1,1),block=(power_of_two,1,1),block.x >= cols
- serial:
- exact 2D
f32/i32tensor-source scalar-broadcast compare-to-i1bundles from rank-1 extent-1 tensors:dst.store(src.load() <cmp> alpha[0])- supported ops:
cmp_lt/le/gt/ge/eq/ne
- exact current launch contracts:
- serial:
grid=(1,1,1),block=(1,1,1) - parallel:
grid=(1,1,1),block=(power_of_two,1,1),block.x >= cols
- serial:
- exact 2D
f32/i32dense tensor-binary bundles:dst.store(lhs.load() <op> rhs.load())- supported ops:
f32:add,sub,mul,div,max,mini32:add,sub,mul,div,max,min,bitand,bitor,bitxori1:bitand,bitor,bitxor
- exact current launch contract:
grid=(1,1,1),block=(1,1,1)
- exact 2D
f32/i32dense compare-to-i1bundles:dst.store(lhs.load() <cmp> rhs.load())- supported ops:
cmp_lt/le/gt/ge/eq/ne
- exact current launch contracts:
- serial:
grid=(1,1,1),block=(1,1,1) - parallel:
grid=(1,1,1),block=(power_of_two,1,1),block.x >= cols
- serial:
- exact 2D dense
selectbundles fromi1predicate tensors:dst.store(bb.where(pred.load(), lhs.load(), rhs.load()))- currently validated on:
- dense
f32 - broadcast-compatible
i32
- dense
- exact current launch contracts:
- serial:
grid=(1,1,1),block=(1,1,1) - parallel:
grid=(1,1,1),block=(power_of_two,1,1),block.x >= cols
- serial:
- exact 2D scalar-branch
selectbundles fromi1predicate tensors:dst.store(bb.where(pred.load(), src.load(), alpha))dst.store(bb.where(pred.load(), alpha[0], src.load()))- scalar-branch forms from:
- a scalar kernel parameter
- a rank-1 extent-1 tensor
- currently validated on:
f32scalar-kernel-parameter branchesi32extent-1-tensor scalar branches
- exact current launch contracts:
- serial:
grid=(1,1,1),block=(1,1,1) - parallel:
grid=(1,1,1),block=(power_of_two,1,1),block.x >= cols
- serial:
- exact 2D
f32/i32broadcast tensor-binary bundles:dst.store(lhs.load() <op> rhs.load())- broadcast-compatible
lhsandrhsshapes expanded to the destination shape - supported ops:
f32:add,sub,mul,div,max,mini32:add,sub,mul,div,max,min,bitand,bitor,bitxori1:bitand,bitor,bitxor
- exact current launch contract:
grid=(1,1,1),block=(1,1,1)
- exact 2D
f32/i32broadcast compare-to-i1bundles:dst.store(lhs.load() <cmp> rhs.load())- broadcast-compatible
lhsandrhsshapes expanded to the destination shape - supported ops:
cmp_lt/le/gt/ge/eq/ne
- exact current launch contracts:
- serial:
grid=(1,1,1),block=(1,1,1) - parallel:
grid=(1,1,1),block=(power_of_two,1,1),block.x >= cols
- serial:
- exact 2D broadcast
selectbundles fromi1predicate tensors:dst.store(bb.where(pred.load(), lhs.load(), rhs.load()))- broadcast-compatible
lhsandrhsshapes expanded to the predicate/destination shape - currently validated on:
i32
- exact current launch contracts:
- serial:
grid=(1,1,1),block=(1,1,1) - parallel:
grid=(1,1,1),block=(power_of_two,1,1),block.x >= cols
- serial:
- exact 2D unary bundles:
f32:dst.store(abs(src.load()))dst.store(-src.load())dst.store(bb.sqrt(src.load()))dst.store(bb.rsqrt(src.load()))
i32:dst.store(abs(src.load()))dst.store(-src.load())dst.store(~src.load())
i1:dst.store(~src.load())
- exact current launch contract:
grid=(1,1,1),block=(1,1,1)
- exact 2D
f32/i32rowwise or columnwise reduction families:dst_rows.store(src.load().reduce(bb.ReductionOp.<op>, init, reduction_profile=(None, 1)))dst_cols.store(src.load().reduce(bb.ReductionOp.<op>, init, reduction_profile=(1, None)))- supported ops:
f32:reduce_add,reduce_mul,reduce_max,reduce_mini32:reduce_add,reduce_mul,reduce_max,reduce_min,reduce_and,reduce_or,reduce_xor
- exact current launch contracts:
- serial:
grid=(1,1,1),block=(1,1,1) - parallel:
grid=(1,1,1),block=(power_of_two,1,1)
- serial:
- current parallel lowering maps one thread to one output row or column and loops over the reduced axis inside the thread
- exact parallel 2D row-tiled variants of the same tensor families above:
- dense copy for
f32/i32/i1 - scalar-broadcast from scalar kernel parameters:
f32:add/sub/mul/div/max/min/atan2i32:add/sub/mul/div/max/min,bitand,bitor,bitxor
- scalar-broadcast compare-to-
i1from scalar kernel parameters:f32:cmp_lt/le/gt/ge/eq/nei32:cmp_lt/le/gt/ge/eq/ne
- extent-1 tensor-source scalar-broadcast:
f32:add/sub/mul/div/max/min/atan2i32:add/sub/mul/div/max/min,bitand,bitor,bitxor
- extent-1 tensor-source scalar-broadcast compare-to-
i1:f32:cmp_lt/le/gt/ge/eq/nei32:cmp_lt/le/gt/ge/eq/ne
- dense tensor-binary:
f32:add/sub/mul/div/max/min/atan2i32:add/sub/mul/div/max/min,bitand,bitor,bitxori1:bitand/bitor/bitxor
- dense compare-to-
i1:f32:cmp_lt/le/gt/ge/eq/nei32:cmp_lt/le/gt/ge/eq/ne
- dense
selectfromi1predicate tensors:- tensor-vs-tensor branches:
- validated dense
f32 - validated broadcast
i32
- validated dense
- scalar-branch forms:
- validated
f32from scalar kernel params - validated
i32from rank-1 extent-1 tensors
- validated
- tensor-vs-tensor branches:
- broadcast tensor-binary:
f32:add/sub/mul/div/max/min/atan2i32:add/sub/mul/div/max/min,bitand,bitor,bitxori1:bitand/bitor/bitxor
- broadcast compare-to-
i1:f32:cmp_lt/le/gt/ge/eq/nei32:cmp_lt/le/gt/ge/eq/ne
- broadcast
selectfromi1predicate tensors:- validated on
i32
- validated on
- dense unary:
f32:abs,neg,round,floor,ceil,trunc,sqrt/rsqrt,sin,cos,acos,asin,atan,exp,exp2,log,log2,log10,erfi32:abs,neg,bitnoti1:bitnot
- exact current launch contract:
grid=(grid.x, grid.y, 1),block=(power_of_two,1,1)grid.x * block.x >= colsgrid.y >= 1
- current lowering maps each thread to a column tile and advances rows by
grid.y
- dense copy for
- serial rank-1 scalar reductions to
dst[0]:f32:reduce_add,reduce_mul,reduce_max,reduce_mini32:reduce_add,reduce_mul,reduce_max,reduce_min,reduce_and,reduce_or,reduce_xor- exact current launch contract:
grid=(1,1,1),block=(1,1,1)
- exact parallel rank-1 scalar reduction to
dst[0]:- supported ops:
f32:reduce_add,reduce_mul,reduce_max,reduce_mini32:reduce_add,reduce_mul,reduce_max,reduce_min,reduce_and,reduce_or,reduce_xor
- current launch contract:
grid=(1,1,1),block=(power_of_two,1,1) - current lowering uses a single-block shared-memory reduction
- supported ops:
- exact 2D
f32/i32reduction bundle:- scalar reduction to
dst_scalar[0] - row reduction to
dst_rows - column reduction to
dst_cols - supported ops:
f32:reduce_add,reduce_mul,reduce_max,reduce_mini32:reduce_add,reduce_mul,reduce_max,reduce_min,reduce_and,reduce_or,reduce_xor
- exact current launch contract:
grid=(1,1,1),block=(1,1,1)
- scalar reduction to
- exact parallel 2D
f32/i32reduction bundle:- scalar reduction to
dst_scalar[0] - row reduction to
dst_rows - column reduction to
dst_cols - supported ops:
f32:reduce_add,reduce_mul,reduce_max,reduce_mini32:reduce_add,reduce_mul,reduce_max,reduce_min,reduce_and,reduce_or,reduce_xor
- exact current launch contract:
grid=(1,1,1),block=(power_of_two,1,1),block.x >= max(rows, cols) - current lowering uses a single-block shared-memory scalar reduction plus per-thread row/column accumulation
- scalar reduction to
- exact 2D
f32/i32tensor-factory bundle:dst_zero.store(bb.zeros_like(dst_zero))dst_one.store(bb.ones_like(dst_one))dst_full.store(bb.full_like(dst_full, fill_value))- exact current launch contract:
grid=(1,1,1),block=(1,1,1)
- exact parallel 2D row-tiled
f32/i32tensor-factory bundle:- same
zeros_like/ones_like/full_likebundle as above - exact current launch contract:
grid=(grid.x, grid.y, 1),block=(power_of_two,1,1)grid.x * block.x >= colsgrid.y >= 1
- current lowering maps each thread to a column tile and advances rows by
grid.y
- same
- canonical indexed rank-1 dense copy:
- current dtypes:
f32i32i1for exact copy and bitwise tensor families only
backend="ptx_exec"- launches the same exact PTX subset through
libcuda.so.1 - no
nvcc,nvrtc, orptxasis required in the backend path - works with:
- Baybridge
RuntimeTensorvalues through host staging - CUDA
TensorHandlevalues with real device pointers - raw CUDA DLPack-capable tensor objects
- Baybridge
compile(...)auto-selectsptx_execonNvidiaTargetwhen the sample args show a device-resident tensor path- when all tensor sample args are staged
RuntimeTensorvalues,compile(...)now prefersptx_refand warns how to requestbackend=\"ptx_exec\"explicitly if you still want correctness-first staged execution - for PTX reduction-style kernels, any staged
RuntimeTensortensor sample args now keep auto-selection onptx_ref; Baybridge only auto-selectsptx_execthere when tensor sample args are fully device-resident - warns once per built launcher when tensor arguments are staged through host memory
- reuses per-argument device allocations across repeated staged
RuntimeTensorlaunches, but the staged path is still primarily host-copy bound in current measurements - when you pass
stream=..., DLPack-capable inputs are normalized with that launch stream before the PTX backend runs - current measured PTX performance is materially better with direct CUDA
TensorHandleinputs or raw CUDA DLPack objects; stagedRuntimeTensorexecution remains correctness-first and host-copy dominated - current representative raw CUDA DLPack timings are effectively on top of the direct
TensorHandlepath for the same PTX kernels - scalar kernel parameters are supported for the exact PTX scalar-broadcast family
- boolean
i1tensors are now validated on the exact PTX data path for:- rank-1 copy
- rank-1 boolean pointwise
and/or/xor - rank-1 boolean
bitnot - serial, exact single-block parallel, and exact row-tiled multiblock 2D copy
- serial, exact single-block parallel, and exact row-tiled multiblock 2D tensor-binary/broadcast-binary
bitand/bitor/bitxor
- boolean
i1output tensors are now validated for the exact rank-1, serial 2D, and exact parallel 2D row-tiled compare families in bothptx_refandptx_exec, including scalar-broadcast compare-to-i1 - boolean
i1predicate tensors are now also validated as inputs to the exact PTXselectfamily in both rank-1 and 2D forms:- rank-1 direct and canonical indexed forms
- serial and exact parallel 2D row-tiled tensor-select forms:
- dense
f32 - broadcast
i32
- dense
- serial and exact parallel 2D row-tiled scalar-branch select forms:
f32from scalar kernel paramsi32from rank-1 extent-1 tensors
- rank-1 direct and canonical indexed select still remains validated for:
f32i32i1
- current PTX native math stays driver-only and libdevice-free:
- exact rank-1 unary now exists for:
f32:abs,neg,round,floor,ceil,trunc,sqrt,rsqrt,sin,cos,acos,asin,atan,exp,exp2,log,log2,log10,erfi32:abs,neg
- exact 2D serial/single-block-parallel/row-tiled unary now exists for:
f32:abs,neg,round,floor,ceil,trunc,sqrt,rsqrt,sin,cos,acos,asin,atan,exp,exp2,log,log2,log10,erfi32:abs,neg,bitnoti1:bitnot
roundlowers to native PTXcvt.rni.f32.f32floorlowers to native PTXcvt.rmi.f32.f32ceillowers to native PTXcvt.rpi.f32.f32trunclowers to native PTXcvt.rzi.f32.f32sqrtlowers to native PTXsqrt.rn.f32rsqrtlowers to native PTXrsqrt.approx.f32sinlowers to native PTXsin.approx.f32coslowers to native PTXcos.approx.f32acoslowers viasqrt(max(0, 1 - x*x))plus the same explicitatan2core with the square-root term asy; nolibdeviceis usedasinlowers viasqrt(max(0, 1 - x*x))plus the same explicitatan2core; nolibdeviceis usedatanlowers via an explicit reciprocal-plus-piecewise polynomial PTX approximation; nolibdeviceis usedatan2lowers via the same reciprocal-plus-piecewise polynomial core plus explicit quadrant correction; nolibdeviceis usedexplowers via native PTXex2.approx.f32after scaling bylog2(e)exp2lowers to native PTXex2.approx.f32loglowers via native PTXlg2.approx.f32timesln(2)log2lowers to native PTXlg2.approx.f32log10lowers via native PTXlg2.approx.f32timeslog10(2)erflowers via an explicit PTX polynomial-plus-ex2.approx.f32approximation; nolibdeviceis used
- exact rank-1 unary now exists for:
- current PTX reductions stay intentionally exact:
- serial lowerings still exist for the traced 1D scalar-reduction form, the exact traced 2D row/column reduction form, and the exact traced 2D reduction-bundle form
- exact single-block shared-memory lowerings now also exist for the parallel scalar-reduction family and the parallel 2D reduction-bundle family
- exact per-output-thread lowerings now also exist for the parallel 2D row/column reduction family
- that 2D row/column reduction family now also has an exact multiblock output-tiling path when
grid.x * block.x >= output_extent - exact row+column reduction bundles without the scalar output now also exist in serial, single-block parallel, and exact multiblock output-tiled forms
- current PTX
copy_reducestays intentionally exact:- rank-1 serial/direct/indexed
- 2D serial
- 2D single-block parallel
- broader fused NVIDIA-only TMA-reduce semantics are still not claimed here
- broader CUDA math that would require
libdeviceis still intentionally deferred
- launches the same exact PTX subset through
backend="waveasm_ref"- lowers supported kernels to WaveASM-oriented MLIR plus tool invocation hints
- this is a reference backend, not an executable backend
- emits a
*.waveasm_repro/bundle next to the lowered artifact withkernel.mlir,repro.sh, andmanifest.json - the bundle path is exposed on the compiled artifact as
artifact.debug_bundle_dir - you can also generate it directly with
baybridge.emit_waveasm_repro(...)ortools/emit_waveasm_repro.py tools/compare_backends.pycan compile the same kernel through multiple backends, optionally execute supported ones, and report bundle/debug paths in JSON- pass
--include-envtotools/compare_backends.pyto include Python/platform/tool/version metadata for cross-machine diffs - sample factories used with
tools/compare_backends.pymay optionally acceptbackend_nameandtarget_archso they can return backend-specific compile/run args BAYBRIDGE_WAVEASM_ROOTcan point at a Wave checkout to improve tool discovery hints
backend="waveasm_exec"- experimental backend behind
BAYBRIDGE_EXPERIMENTAL_WAVEASM_EXEC=1 - lowers a narrow standard-MLIR subset to WaveASM-consumable MLIR and launches the resulting HSACO through HIP's module API
- experimental backend behind
backend="aster_ref"- lowers supported kernels to ASTER-oriented reference MLIR plus tool/runtime hints
- this is a reference backend, not an executable backend
- emits a
*.aster_repro/bundle next to the lowered artifact withkernel.mlir,repro.sh, andmanifest.json BAYBRIDGE_ASTER_ROOTcan point at an ASTER checkout or install root to improve tool and package discovery- current experimental family:
- single-global-tensor pointwise kernels
- single-global-tensor pointwise math kernels
- single-global-tensor shared-memory staging
- not enabled by default because current upstream WaveASM execution still has correctness issues for Baybridge kernels, including the scalar global
memref.load/storeSRD aliasing bug tracked iniree-org/wave#1117 - does not yet cover reductions, multi-buffer copy, GEMM, or Baybridge custom tensor-SSA ops
- also emits a
*.waveasm_repro/bundle so the exact MLIR/toolchain path can be reproduced outside Baybridge - the bundle path is exposed on the compiled artifact as
artifact.debug_bundle_dir - requires
waveasm-translate,clang++orclang, and prefersld.lldfor final HSACO linking
backend="flydsl_ref"- lowers a supported Baybridge subset to a FlyDSL-oriented reference module
- this is a reference backend, not an executable backend
- if
BAYBRIDGE_FLYDSL_ROOTpoints at a FlyDSL checkout, the artifact includes a direct root/build hint
backend="flydsl_exec"- lowers a validated Baybridge tensor subset to a FlyDSL-flavored Python module and launcher scaffold
- current supported family:
- pointwise tensor kernels
- broadcasted tensor binary ops
- tensor factories via
fill - simple reductions
- shared-memory staging and whole-tensor copy
- requires a built and importable FlyDSL environment, not just a source checkout
BAYBRIDGE_FLYDSL_ROOTshould point at a checkout withbuild-fly/python_packagesorbuild/python_packages- fake/local harness execution can launch with:
baybridge.from_dlpack(...)tensor handles- raw DLPack-capable tensors such as
torch.Tensor - plain
baybridge.RuntimeTensorvalues whentorchis importable
- real upstream FlyDSL execution is currently validated for a narrow subset:
- 1D
f32pointwise binary ops through the copy-atom/register path- hardware-validated today: add, sub, mul, div
- canonical linear indexed pointwise binary ops of the form
block_idx.x * block_dim.x + thread_idx.xare also hardware-validated today:- add, sub, mul, div
- 1D
f32copy - 1D
f32unary math bundle through the copy-atom/register path- hardware-validated today:
- exp, log, cos, erf
- exp2, log2, log10, sqrt
- sin
- rsqrt
- hardware-validated today:
- 2D
f32broadcast binary through the row-slice/copy-atom path whengrid == block == (1, 1, 1)- hardware-validated today: add, sub, mul, div
- 2D
f32reduction bundle through the row-slice/copy-atom path whengrid == block == (1, 1, 1)- hardware-validated today:
- add: full reduction to
(1,)plus row reduction to(M,) - mul: full reduction to
(1,)plus row reduction to(M,)
- add: full reduction to
- hardware-validated today:
- 2D
f32unary math bundle through the row-slice/copy-atom path whengrid == block == (1, 1, 1)- hardware-validated today:
- exp, log, cos, erf
- hardware-validated today:
- 2D
f32tensor-factory bundle through the row-slice/copy-atom path whengrid == block == (1, 1, 1) - 1D
f32shared-memory staging copy when the traced kernel is exactly a shared-memory round-trip andblock.x == extent acos,asin,atan, andatan2remain outside the validated real subset because the current upstream FlyDSL pipeline does not lower the corresponding scalar libcalls cleanly on the active AMD environments
- 1D
- broader real upstream FlyDSL execution is still gated behind:
BAYBRIDGE_EXPERIMENTAL_REAL_FLYDSL_EXEC=1
- reason:
- Baybridge's current generic lowering still does not match FlyDSL's real buffer/tile access model
- the validated subset uses specialized upstream-compatible lowerings
tools/compare_backends.py --executereports that gate asskipped_unvalidated_real_flydsl_exec- intended as the first executable FlyDSL bridge, not full FlyDSL coverage yet
backend="hipkittens_ref"- lowers matched GEMM and attention-family kernels to a HipKittens-oriented C++ reference artifact
- this is a reference backend, not an executable backend
- if
BAYBRIDGE_HIPKITTENS_ROOTpoints at a HipKittens checkout, the artifact includes a direct include-path build hint
backend="hipcc_exec"- builds and runs a narrow executable HIP subset through
hipcc
- builds and runs a narrow executable HIP subset through
backend="hipkittens_exec"- builds and runs a narrow executable HipKittens subset
- current supported kernel family: pure BF16/F16 GEMM composed from HipKittens micro tiles
- currently validated for
gfx950 gfx942executable support is currently limited to the BF16 slice; F16 exec is not wired yet- transposed F16 GEMM stays on
hipkittens_refongfx950; current upstream HipKittens gfx950 MMA templates reject that family - RMSNorm stays on
hipkittens_refongfx950; current upstream HipKittens gfx950 headers do not compile the generated RMSNorm kernel - supported micro tiles:
f16/f16 -> f32bf16/bf16 -> f32(32, 16) x (16, 32) -> (32, 32)(16, 32) x (32, 16) -> (16, 16)
- full tensor shapes may be larger as long as they are exact multiples of one supported micro tile and the reduction dimension is tiled accordingly
- requires
BAYBRIDGE_HIPKITTENS_ROOTto point at a HipKittens checkout
Default backend behavior:
- if the compile target is
cute.NvidiaTarget(...)and no explicit backend is given, Baybridge auto-selects:ptx_execwhen the CUDA driver is available and the traced kernel matches the current exact PTX subset- otherwise
ptx_ref
- if
compile(...)is called without an explicit backend, the traced kernel matcheshipkittens_execongfx950, andBAYBRIDGE_HIPKITTENS_ROOTpoints at a usable checkout, Baybridge auto-prefershipkittens_exec - the same auto-preference path can apply to
gfx942, but only when the local ROCm toolchain satisfies the HipKittenscdna3header requirements - otherwise, if the traced kernel matches a HipKittens tensorop GEMM or attention family, Baybridge auto-prefers
hipkittens_ref - otherwise, if the traced kernel matches the validated
flydsl_execsubset and the active FlyDSL environment is ready, Baybridge auto-prefersflydsl_execonly when:- the inputs match the validated path
- and the kernel falls into either:
- the currently validated real-upstream subset
- or the broader experimental subset with
BAYBRIDGE_EXPERIMENTAL_REAL_FLYDSL_EXEC=1
- otherwise it falls back to the normal default textual backend
- if
BAYBRIDGE_EXEC_ARCHis set, Baybridge uses that architecture as the default compile target
Example:
artifact = cute.compile(indexed_add_kernel, a, b, c, backend="gpu_text")
print(artifact.lowered_module.text)HipKittens reference example:
artifact = cute.compile(gemm_kernel, a, b, c, backend="hipkittens_ref")
print(artifact.lowered_module.dialect) # hipkittens_cpp
print(artifact.lowered_module.text)HipKittens executable example:
artifact = cute.compile(
bf16_gemm_kernel,
a,
b,
c,
backend="hipkittens_exec",
target=cute.AMDTarget(arch="gfx950"),
)
artifact(a, b, c)PTX executable example:
artifact = cute.compile(
indexed_add_kernel,
a,
b,
c,
backend="ptx_exec",
target=cute.NvidiaTarget(sm="sm_80"),
)
artifact(a, b, c)These tests are the best executable documentation today:
tests/test_examples_runtime.py- hello world
- naive elementwise add
tests/test_examples_tracing.py- direct indexed kernel tracing
- traced launch wrapper behavior
tests/test_backend_gpu.pygpu_textlowering surface
tests/test_backend_hipkittens_ref.py- HipKittens reference backend family matching
tests/test_backend_hipkittens_exec.py- HipKittens executable BF16 GEMM path
tests/test_validation.py- explicit unsupported cases and diagnostics
Not implemented yet:
zipped_divide- composition-based tiled tensor views
- slice/view objects with
.load()/.store()semantics - broad NVIDIA executable lowering beyond the current exact PTX subset
Important limitations in the current tracer:
- dynamic Python
if,while, and other control flow on traced scalars are rejected - runtime fallback is only used for
@jitwrappers - direct kernel compilation expects the kernel body itself to be traceable
src/baybridge/frontend.py- user-facing DSL surface
src/baybridge/runtime.py- reference runtime tensors and launch execution
src/baybridge/tracing.py- portable IR builder and traced values
src/baybridge/compiler.py- compile/cache flow and runtime fallback policy
src/baybridge/backends/mlir_text.py- portable textual lowering
src/baybridge/backends/gpu_text.py- GPU/ROCDL-flavored textual lowering
src/baybridge/backends/hipkittens_ref.py- HipKittens-oriented reference lowering for matched kernel families
src/baybridge/backends/hipkittens_exec.py- narrow executable HipKittens lowering for supported BF16 GEMM tiles
baybridge is past the scaffolding stage and is now useful for:
- trying the frontend
- validating portability decisions
- exercising a reference runtime
- inspecting the portable IR and textual AMD-oriented lowerings
It is not yet a drop-in ROCm execution backend for full upstream DSL programs.