mq_123
mq_123

Reputation: 13

Calling an initialized function from a list inside a jitted JAX function

Given is a jitted function, which is calling another function that maps over a batch, which again calls a function, i.e. inner_function, to compute a certain property. Also given is a list of initialized functions intialized_functions_dic, from which we want to call the proper initialized function based on some information passed as argument, e.g. info_1. Is there a way to make this work? Thanks in advance.

initialized_functions_dic = {1:init_function1, 2:init_function_2, 3:init_function_3}


def inner_function(info_1, info_2, info_3):
    return 5 + outside_dic[info_1]

Calling outside_dic[info_1] will throw an error due to trying to access a dictionary with a traced value.

Trying to pass info_1 as static_argnums also fails due to info_1 being an unhashable type 'ArrayImpl'.

Upvotes: 1

Views: 242

Answers (1)

jakevdp
jakevdp

Reputation: 86328

It sounds like you're looking for jax.lax.switch, which will switch between entries in a list of functions given an index:

initialized_functions = [init_function_1, init_function_2, init_function_3]

def inner_function(info_1, info_2, info_3):
    idx = info_1 - 1  # lists are zero-indexed
    args = (info_2, info_3) # tuple of arguments to pass to the function
    return 5 + lax.switch(idx, initialized_functions, *args)

Upvotes: 1

Related Questions