Reputation: 175
I am trying to use triton-lang
to perform a simple element-wise dot product between a column vector and a matrix that both have complex value. I can make the code work if I don't specify block_sizes but I can't figure out how to cut my grid and how to handle my pointers. I somewhat understand the theory on how it should work but I can't make it work.
def cdot(x: torch.Tensor, y: torch.Tensor):
return x * y
def cdot_triton(x: torch.Tensor, y: torch.Tensor, BLOCK_SIZE):
# preallocate the output
z = torch.empty_like(y)
# check arguments
assert x.is_cuda and y.is_cuda and z.is_cuda
# get vector size
N = z.numel()
# 1D launch kernel where each block gets its own program
grid = lambda meta: (N // BLOCK_SIZE, N // BLOCK_SIZE)
# launch the kernel
cdot_kernel[grid](x.real, x.imag, y.real, y.imag, z.real, z.imag, N, BLOCK_SIZE)
return z
@triton.jit
def cdot_kernel(
x_real_ptr,
x_imag_ptr,
y_real_ptr,
y_imag_ptr,
z_real_ptr,
z_imag_ptr,
N: tl.constexpr, # Size of the vector
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
):
row = tl.program_id(0)
col = tl.arange(0, 2*BLOCK_SIZE)
if row < BLOCK_SIZE:
idx = row * BLOCK_SIZE + col
x_real = tl.load(x_real_ptr + 2*row)
x_imag = tl.load(x_imag_ptr + 2*row)
y_real = tl.load(y_real_ptr + 2*idx, mask=col<BLOCK_SIZE, other=0)
y_imag = tl.load(y_imag_ptr + 2*idx, mask=col<BLOCK_SIZE, other=0)
z_real = x_real * y_real - x_imag * y_imag
z_imag = x_real * y_imag + x_imag * y_real
tl.store(z_real_ptr + 2*idx, z_real, mask=col<BLOCK_SIZE)
tl.store(z_imag_ptr + 2*idx, z_imag, mask=col<BLOCK_SIZE)
# ===========================================
# Test kernel
# ===========================================
size = 4
dtype = torch.complex64
x = torch.rand((size, 1), device='cuda', dtype=dtype)
y = torch.rand((size, size), device='cuda', dtype=dtype)
out_dot = cdot(x,y)
out_kernel = cdot_triton(x,y, BLOCK_SIZE=2)
This is the output:
tensor([[-0.1322+1.1461j, -0.1098+0.8015j, 0.2948+1.2155j, -0.1326+0.6076j],
[-0.3687+0.4646j, 0.2349+0.5802j, 0.0568+0.9461j, -0.0457+0.3213j],
[ 0.0523+0.9351j, 0.4409+0.5076j, 0.3956+0.4018j, 0.6230+0.9270j],
[-0.3503+0.7194j, -0.3742+0.2311j, -0.3353+0.3884j, -0.3478+0.6724j]],
device='cuda:0')
tensor([[-0.1322+1.1461j, -0.1098+0.8015j, 0.0617+1.0408j, -0.1988+0.4788j],
[ 0.1147+0.2296j, 0.0686+0.1161j, 0.0647+0.4044j, 0.0795+0.6407j],
[-0.2396+0.6326j, -0.3587+0.5878j, -0.1563+0.4028j, -0.2933+0.3294j],
[-0.1214+0.3678j, 0.0440+0.9951j, 0.3342+1.1360j, 0.6796+0.6590j]],
device='cuda:0')
As you can see only the 2 first values of the top row are accurate.
Any ideas on how I can make this element-wise dot product work?
Many thanks!
Upvotes: 0
Views: 218