Frank
Frank

Reputation: 169

Conv2d not accepting tensor as input, saying its not tensor

I want to pass a tensor through a convolutional 2 layer. I am not able to execute it as I am getting a type error even though I have converted my numpy array to a tensor.

I tried using tf.convert_to_tensor() to solve this problem. Didn't work

import numpy as np
import tensorflow as tf

class Generator():

  def __init__(self):

    self.conv1 = nn.Conv2d(1, 28, kernel_size=3, stride=1, padding=1)
    self.pool1 = nn.MaxPool2d(kernel_size=3, stride=0, padding=1)

    self.fc1 = nn.Linear(100, 10)
    self.fc2 = nn.Linear(10, 5)

  def forward_pass(self, x):                                                                       #Why do we pass the object itself in every method?

    x = self.conv1(x)
    print(x)
    x = self.pool1(x)
    print(x)

    x = self.fc1(x)
    print(x)
    x = self.fc2(x)
    print(x)

    return x

arr = tf.convert_to_tensor(np.random.random((3,28,28)))

gen = Generator()
gen.forward_pass(arr)


Error message -

TypeError                                 Traceback (most recent call last)

<ipython-input-31-9fa8e764dcdb> in <module>()
      1 gen = Generator()
----> 2 gen.forward_pass(arr)

2 frames

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in forward(self, input)
    336                             _pair(0), self.dilation, self.groups)
    337         return F.conv2d(input, self.weight, self.bias, self.stride,
--> 338                         self.padding, self.dilation, self.groups)
    339 
    340 

TypeError: conv2d(): argument 'input' (position 1) must be Tensor, not Tensor

Upvotes: 1

Views: 4004

Answers (1)

Brennan Vincent
Brennan Vincent

Reputation: 10666

You are trying to pass a TensorFlow tensor to a PyTorch function. TensorFlow and PyTorch are separate projects with different data structures which, in general, cannot be used interchangeably in this way.

To convert a NumPy array to a PyTorch tensor, you can use:

import torch
arr = torch.from_numpy(np.random.random((3,28,28)))

Upvotes: 2

Related Questions