arc_lupus
arc_lupus

Reputation: 4114

Static_cast converts to wrong data type, but the result is still correct?

Based on the comments to my last question (Getter-function for derived class in derived class, when using pointer to base class), I've been told that I have to use static_cast when recasting my pointer from the base class to the derived class, in order to access the derived class. I tested that using the following code:

#include <iostream>

class class_data{
public:
    int val_a = 0;
    double val_b = 0.;
};

class overridden_class_data : public class_data{
public:
    int val_c = 0.;
}; 

class overridden_class_data_II : public class_data{
public:
    int val_c = 12.;
}; 

class BaseClass{
public:
    BaseClass(){};

    virtual void print_data() = 0;
    virtual class_data *get_local_data() = 0;

    class_data local_class_data;
};

class DerivedClass : public BaseClass{
public:
    DerivedClass(){
        local_class_data.val_a = 10;
        local_class_data.val_b = 100.;
        local_class_data.val_c = 14;
    };

    void print_data() override{
        std::cout << "Hello World\n";
    }

    class_data * get_local_data() override {
        return &local_class_data;
    }

    overridden_class_data local_class_data;
};

class DerivedClassII : public BaseClass{
public:
    DerivedClassII(){
        local_class_data.val_a = 10;
        local_class_data.val_b = 100.;
    };

    void print_data() override{
        std::cout << "Hello World\n";
    }

    class_data * get_local_data() override {
        return &local_class_data;
    }

    overridden_class_data_II local_class_data;
};

void test_func(BaseClass *class_pointer){
    std::cout << class_pointer->get_local_data()->val_a << '\n';
    std::cout << class_pointer->local_class_data.val_a << '\n';
    class_pointer->local_class_data.val_a = 5;
    std::cout << class_pointer->local_class_data.val_a << '\n';
    std::cout << class_pointer->get_local_data()->val_a << '\n';
    class_pointer->get_local_data()->val_a = 15;
    std::cout << class_pointer->local_class_data.val_a << '\n';
    std::cout << class_pointer->get_local_data()->val_a << '\n';
    std::cout << static_cast<overridden_class_data*>(class_pointer->get_local_data())->val_c << '\n';
}

int main(void){
    std::cout << "From main\n";
    DerivedClass DClass;
    DerivedClassII EClass;
    std::cout << "DClass: \n";
    test_func(&DClass);
    std::cout << "EClass: \n";
    test_func(&EClass);
    return 0;
}

Here I have two derived classes, which use two different derived classes as class variable. To access the data of those classes I have to use static_cast onto the returned base-class pointer to cast it back to the derived class. Still, I do not want to rewrite the function test_func() for both classes, but instead use the same function for them.

Initially, I thought that I had to write the last line of the function twice, recasting the class variable pointer once to overridden_class_data* and once to overridden_class_data_II*, depending on the input class. But after testing I noticed that I do not have to do that, I can recast it to overridden_class_data*, but it still acts as if I recasted it overridden_class_data_II*. Why? Is it because both classes contain the same elements, and therefore the pointer can point to the same spot?

Upvotes: 0

Views: 161

Answers (1)

tangy
tangy

Reputation: 3276

As for your original question, yes this is happening just because (1) the data members of your class are identically setup and (2) static_cast is not safe for such polymorphic casts.

A simple counterexample to break test_func would be(code):

class overridden_class_data : public class_data{
public:
    int val_pad = 0.;
    int val_c = 23.;
}; 

which would incorrectly then print the value 0 instead of 12 for the EClass pointer->get_local_data()->val_c.

A few ways you could go about solving this(making test_fn single use):

  1. Correctly detect the above issue using dynamic_casts, but then test_func would need to be called with appropriate explicit template args.

  2. Forego the casts with compile time safety by making a simple generic test_func and using covariant return types. You mentioned that your concerned about too many templates - Is it code bloat you are worried about?

  3. @churill suggestion of using a virtual getter like get_val_c.

Here is the snippet for the 2nd method suggested - I've marked the changes I made(code):

#include <iostream>

class class_data{
public:
    int val_a = 0;
    double val_b = 0.;
};

class overridden_class_data : public class_data{
public:
    int val_pad = 23;
    int val_c = 0.;
}; 

class overridden_class_data_II : public class_data{
public:
    int val_c = 12.;
};

class BaseClass{
public:
    BaseClass(){};

    virtual void print_data() = 0;
    virtual class_data *get_local_data() = 0;

    class_data local_class_data;
};

class DerivedClass : public BaseClass{
public:
    DerivedClass(){
        local_class_data.val_a = 10;
        local_class_data.val_b = 100.;
        local_class_data.val_c = 14;
    };

    void print_data() override{
        std::cout << "Hello World\n";
    }

    // use covariant return type    
    overridden_class_data * get_local_data() override {
        return &local_class_data;
    }

    overridden_class_data local_class_data;
};

class DerivedClassII : public BaseClass{
public:
    DerivedClassII(){
        local_class_data.val_a = 10;
        local_class_data.val_b = 100.;
    };

    void print_data() override{
        std::cout << "Hello World\n";
    }

    // use covariant return type
    overridden_class_data_II * get_local_data() override {
        return &local_class_data;
    }

    overridden_class_data_II local_class_data;
};

template <typename T>
void test_func(T *class_pointer){ // make generic
    std::cout << class_pointer->get_local_data()->val_a << '\n';
    std::cout << class_pointer->local_class_data.val_a << '\n';
    class_pointer->local_class_data.val_a = 5;
    std::cout << class_pointer->local_class_data.val_a << '\n';
    std::cout << class_pointer->get_local_data()->val_a << '\n';
    class_pointer->get_local_data()->val_a = 15;
    std::cout << class_pointer->local_class_data.val_a << '\n';
    std::cout << class_pointer->get_local_data()->val_a << '\n';
    std::cout << class_pointer->get_local_data()->val_c << '\n';
}

int main(void){
    std::cout << "From main\n";
    DerivedClass DClass;
    DerivedClassII EClass;
    std::cout << "DClass: \n";
    test_func(&DClass);
    std::cout << "EClass: \n";
    test_func(&EClass);
    return 0;
}

Upvotes: 1

Related Questions