roz
roz

Reputation: 39

How to use triton.language.device_print for numbers?

I am using Triton 3.1.0 I need to debug a Triton program by printing something in device code. However, I find that device_print does not accept numbers. Below is a simple example.

import triton
import triton.language as tl

@triton.jit
def kernel():
    pid = tl.program_id(0)
    tl.device_print(pid)

kernel[(1,)]()

It complains "AssertionError: int32[] is not string".

I tried using strings, e.g., "hello", instead of pid in the example. It works.

I tried str() to convert the number to a string. But it fails with "NameError('str is not defined')".

Printing floating numbers, e.g., 3.0, does not work either.

How can we print numbers?

Upvotes: 2

Views: 54

Answers (1)

mehdi si-mohammed
mehdi si-mohammed

Reputation: 1

Change tl.device_print(pid) to tl.device_print("pid",pid)

Upvotes: 0

Related Questions