bsaoptima
bsaoptima

Reputation: 175

Triton.Lang How to handle Block sizes

I am trying to use triton-lang to perform a simple element-wise dot product between a column vector and a matrix that both have complex value. I can make the code work if I don't specify block_sizes but I can't figure out how to cut my grid and how to handle my pointers. I somewhat understand the theory on how it should work but I can't make it work.

def cdot(x: torch.Tensor, y: torch.Tensor):
    return x * y

def cdot_triton(x: torch.Tensor, y: torch.Tensor, BLOCK_SIZE):
    # preallocate the output
    z = torch.empty_like(y)

    # check arguments
    assert x.is_cuda and y.is_cuda and z.is_cuda

    # get vector size
    N = z.numel()

    # 1D launch kernel where each block gets its own program
    grid = lambda meta: (N // BLOCK_SIZE, N // BLOCK_SIZE)

    # launch the kernel
    cdot_kernel[grid](x.real, x.imag, y.real, y.imag, z.real, z.imag, N, BLOCK_SIZE)

    return z

@triton.jit
def cdot_kernel(
    x_real_ptr,
    x_imag_ptr,
    y_real_ptr,
    y_imag_ptr,
    z_real_ptr,
    z_imag_ptr,
    N: tl.constexpr,  # Size of the vector
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process
):
    row = tl.program_id(0)
    col = tl.arange(0, 2*BLOCK_SIZE)


    if row < BLOCK_SIZE:
        idx = row * BLOCK_SIZE + col
        x_real = tl.load(x_real_ptr + 2*row)
        x_imag = tl.load(x_imag_ptr + 2*row)
        y_real = tl.load(y_real_ptr + 2*idx, mask=col<BLOCK_SIZE, other=0)
        y_imag = tl.load(y_imag_ptr + 2*idx, mask=col<BLOCK_SIZE, other=0)

        z_real = x_real * y_real - x_imag * y_imag
        z_imag = x_real * y_imag + x_imag * y_real

        tl.store(z_real_ptr + 2*idx, z_real, mask=col<BLOCK_SIZE)
        tl.store(z_imag_ptr + 2*idx, z_imag, mask=col<BLOCK_SIZE)
        
# ===========================================
# Test kernel
# ===========================================

size = 4
dtype = torch.complex64
x = torch.rand((size, 1), device='cuda', dtype=dtype)
y = torch.rand((size, size), device='cuda', dtype=dtype)


out_dot = cdot(x,y)
out_kernel = cdot_triton(x,y, BLOCK_SIZE=2)

This is the output:

tensor([[-0.1322+1.1461j, -0.1098+0.8015j,  0.2948+1.2155j, -0.1326+0.6076j],
        [-0.3687+0.4646j,  0.2349+0.5802j,  0.0568+0.9461j, -0.0457+0.3213j],
        [ 0.0523+0.9351j,  0.4409+0.5076j,  0.3956+0.4018j,  0.6230+0.9270j],
        [-0.3503+0.7194j, -0.3742+0.2311j, -0.3353+0.3884j, -0.3478+0.6724j]],
       device='cuda:0')
tensor([[-0.1322+1.1461j, -0.1098+0.8015j,  0.0617+1.0408j, -0.1988+0.4788j],
        [ 0.1147+0.2296j,  0.0686+0.1161j,  0.0647+0.4044j,  0.0795+0.6407j],
        [-0.2396+0.6326j, -0.3587+0.5878j, -0.1563+0.4028j, -0.2933+0.3294j],
        [-0.1214+0.3678j,  0.0440+0.9951j,  0.3342+1.1360j,  0.6796+0.6590j]],
       device='cuda:0')

As you can see only the 2 first values of the top row are accurate.

Any ideas on how I can make this element-wise dot product work?

Many thanks!

Upvotes: 0

Views: 218

Answers (0)

Related Questions