Sampath
Sampath

Reputation: 11

Triton Naive Matrix multiplication implementation

I am trying to implement naive matrix multiplication in triton but was getting error

**IndexError: map::at**
    214         if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
    215             passes.llvmir.add_di_scope(pm)
--> 216         pm.run(mod)
    217         # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
    218         llvm.init_targets()

Much appreciate any suggestions on how to resolve this issue

Below is the code

import triton
import triton.language as tl
import torch
import numpy as np

#matmul(MxK, KxN) = MxN
# M = 16
# K = 16
# N = 16
# ROW_BLOCK_SIZE = 16
# COL_BLOCK_SIZE = 16
# K_SIZE = 16

@triton.jit
def mat_mul_gpu(mat1, mat2, result, M, N, K, ROW_BLOCK_SIZE : tl.constexpr, COL_BLOCK_SIZE : tl.constexpr, K_SIZE: tl.constexpr):
  row_block_id = tl.program_id(axis=0)
  col_block_id = tl.program_id(axis=1)

  row = tl.arange(0, ROW_BLOCK_SIZE) [:, None]
  col = tl.arange(0, COL_BLOCK_SIZE) [None, :]
  tmp = tl.arange(0, K_SIZE)
  mat1_idx = (row * K) + tmp[None, :]
  mat2_idx = (tmp * N)[:, None] + col

  # Masks to prevent out-of-bound memory access
  mat1_mask = (row < M) & (tmp[None, :] < K)
  mat2_mask = (tmp[:, None] < K) & (col < N)

  # Load matrix values and compute dot product
  accum = tl.zeros((ROW_BLOCK_SIZE, COL_BLOCK_SIZE), dtype=tl.float32)
  a = tl.load(mat1 + mat1_idx, mask = mat1_mask)
  b = tl.load(mat2 + mat2_idx, mask = mat2_mask)
  accum += tl.dot(a, b)

  result_mask = (row < M) & (col < N)
  result_ptr = result + (N * row) + col
  tl.store(result_ptr, accum, mask=result_mask)


mat1_shape = (16, 16)
mat2_shape = (16, 16)
result_shape = (16, 16)
torch.manual_seed(0)

mat1 = torch.rand(mat1_shape, device="cuda", dtype=torch.float)
mat2 = torch.rand(mat2_shape, device="cuda", dtype=torch.float)
d_result = torch.empty(result_shape, device="cuda", dtype=torch.float)
result = torch.empty(result_shape, dtype=torch.float)

ROW_BLOCK_SIZE = 16
COL_BLOCK_SIZE = 16
num_row_blocks = triton.cdiv(result_shape[0], ROW_BLOCK_SIZE)
num_col_blocks = triton.cdiv(result_shape[1], COL_BLOCK_SIZE)
grid = (num_row_blocks, num_col_blocks)
mat_mul_gpu[grid](mat1, mat2, d_result, result_shape[0], result_shape[1], mat1_shape[1], ROW_BLOCK_SIZE, COL_BLOCK_SIZE, mat1_shape[1])

Upvotes: 0

Views: 53

Answers (1)

Sampath
Sampath

Reputation: 11

Fixed the issue by changing from F32 to F16 I was running it in google colab on T4 GPU and looks like there is some issue with F32 support on T4 https://github.com/triton-lang/triton/issues/5557

Upvotes: 1

Related Questions