Mark Lavin
Mark Lavin

Reputation: 1242

Tensorflow 2: How to compute forward jacobian for chain of bijectors

I have a collection of Tensorflow 2 Bijectors b0, b1, bN and I have constructed a derived Bijector class that consists of a chain of a bunch of primitive Bijectors, thus:

class MyBijector( Bijector ):
    def __init__( self ):
        self.bChain = tfb.Chain( [ b0, b1, ... bN ] )

Do I have to define the _forward_log_det_jacobian for this explicitly or does Tensorflow figure out how to do this for me? If I have to define it, can someone please remind me how the "chain rule" works in this case?

Upvotes: 0

Views: 94

Answers (1)

Frightera
Frightera

Reputation: 5079

I really have two questions: First, assuming I just say "b = tfb.Chain([b1,b2])" (that is, without defining a new class) does Tensorflow 2 "just know" how to calculate "call" or do I have to define it myself. Second, if I do have to define it myself, how do I define the "forward_log_det_jacobian", that is, how do I use the "chain rule" to do this?

MyBijector will have __call__ method which you can easily get a chain instance (i.e: bij1(bij2)).

In other words, TensorFlow knows how do it itself, you don't need to re-define it.

Upvotes: 1

Related Questions