Reputation: 2252
I am currently using tfjs 3.8 to load a segmentation model (loaded as a tf.GraphModel
) on the client side. To create the input Tensor
, I call browser.fromPixels(imageData)
, which creates the Tensor
on CPU from the ImageData
object that is also on CPU. Since I'm using tfjs' webgl
backend, the data is sent to the GPU when calling the model.predict(tensor)
function. All of this works well, excepted that my ImageData
object is created from an image on a canvas with a WebGLRenderingContext
, meaning it comes from the GPU. This GPU->CPU->GPU data transfer is slowing down my process, which I am trying to optimize.
I briefly searched tfjs and could not find a way to create a Tensor
on GPU to prevent the GPU->CPU data transfer. Is there a way I could keep my data on the GPU?
Upvotes: 1
Views: 318
Reputation: 883
Detailed conversation on this topic is in a thread https://github.com/tensorflow/tfjs/issues/5765
Upvotes: 1
Reputation: 2252
The solution is simply to provide the canvas with the webgl context to the browser.fromPixels(canvas)
call. This will create the Tensor directly on the GPU.
Upvotes: 0