Reputation: 1242
I have a collection of Tensorflow 2 Bijectors b0, b1, bN
and I have constructed a derived Bijector class that consists of a chain of a bunch of primitive Bijectors, thus:
class MyBijector( Bijector ):
def __init__( self ):
self.bChain = tfb.Chain( [ b0, b1, ... bN ] )
Do I have to define the _forward_log_det_jacobian
for this explicitly or does Tensorflow figure out how to do this for me? If I have to define it, can someone please
remind me how the "chain rule" works in this case?
Upvotes: 0
Views: 94
Reputation: 5079
I really have two questions: First, assuming I just say "b = tfb.Chain([b1,b2])" (that is, without defining a new class) does Tensorflow 2 "just know" how to calculate "call" or do I have to define it myself. Second, if I do have to define it myself, how do I define the "forward_log_det_jacobian", that is, how do I use the "chain rule" to do this?
MyBijector
will have __call__
method which you can easily get a chain instance (i.e: bij1(bij2)
).
In other words, TensorFlow knows how do it itself, you don't need to re-define it.
Upvotes: 1