Reputation: 3
I have a single class called FloatTensor. I have overloaded operators for + and * in that. Here is the code.
class FloatTensor {
public:
float val; // value of tensor
float grad; // value of grad
Operation *frontOp =NULL, *backOp =NULL;
FloatTensor* two;
FloatTensor() {
// default
}
FloatTensor(float val) {
this->val = val;
}
FloatTensor(float val, Operation* op) {
this->val = val;
this->backOp = op;
}
void backward(float grad) {
this->grad = grad;
if(this->backOp != NULL) {
this->backOp->backward(grad);
}
}
FloatTensor exp() {
this->frontOp = new ExponentOperation(this);
return this->frontOp->compute();
}
FloatTensor operator * (FloatTensor &two) {
this->frontOp = new MultiplyOperation(this, &two);
return this->frontOp->compute();
}
FloatTensor operator + (FloatTensor &two) {
this->frontOp = new AddOperation(this, &two);
return this->frontOp->compute();
}
FloatTensor operator / (FloatTensor &two) {
this->frontOp = new DivideOperation(this, &two);
return this->frontOp->compute();
}
};
IN my main function when I try simple overloading, things work great
int main() {
// X
FloatTensor x1(200); // heap declaration
FloatTensor x2(300);
// Weights
FloatTensor w1(222);
FloatTensor w2(907);
FloatTensor temp = (x1*w1);
}
However when I try to overload this formula with more operators like this
int main() {
// X
FloatTensor x1(200); // heap declaration
FloatTensor x2(300);
// Weights
FloatTensor w1(222);
FloatTensor w2(907);
FloatTensor temp = (x1*w1) + (x2*w2);
}
I get this error:
no operator "+" matches these operands -- operand types are: FloatTensor + FloatTensor
I would be very grateful if someone can explain why this is happening. I observed that this works:
x1*w1*x2*x1;
x1*w1 + x2;
But x1*w1 + x2*w2
does not.
Very strange..
Upvotes: 0
Views: 41
Reputation: 93264
Your operators accept a non-const
lvalue reference as an argument. Temporaries do not bind to non-const
lvalue references. To accept temporaries, use:
FloatTensor operator + (const FloatTensor &two)
Upvotes: 3