user14430587
user14430587

Reputation: 13

Why is this merge sort implementation not working properly?

The following is my merge sort implementation, what it does is it takes an array of pairs of integers and sorts them according to second element in the pair (ties are broken by first element in pair and all elements are distinct)

void mer(vector<pair<int, int>> a, int l, int m, int r, vector<pair<int, int>> res) {
    int i = l;
    int j = m;
    int k = l;
    while (i < m && j < r) {
        if (a[i].second < a[j].second) {
            res[k] = a[i];
            i++;
            k++;
        } else
        if (a[i].second > a[j].second) {
            res[k] = a[j];
            j++;
            k++;
        } else {
            if (a[i].first > a[j].first) {
                res[k] = a[j];
                k++;
                j++;
            } else {
                res[k] = a[i];
                i++;
                k++;
            }
        }
    }
    while (i < m) {
        res[k] = a[i];
        k++;
        i++;
    }
    while (j < r) {
        res[k] = a[j];
        k++;
        j++;
    }
    for (int i = l; i < r; i++) {
        a[i] = res[i];
    }
}
    
void solve(vector<pair<int, int>> a, int l, int r, vector<pair<int, int>> res) {
    if (l < r) {
        int m = (l + r) / 2;
        solve(a, l, m, res);
        solve(a, m + 1, r, res);
        mer(a, l, m, r, res);
    }
}

But when I run my code with main:

int main() {
    int n;
    cin >> n;
    map<int, int> a;

    for (int i = 0; i < n; i++) {
        int u;
        cin >> u;
        a[u]++;
    }
    vector<pair<int, int>> b;
    for (auto i : a) {
        b.push_back(i);
    }
    vector<pair<int, int>> res(b.size());
    solve(b, 0, b.size(), res);
}

Consider my input to be:

10
1 1 1 1 1 1 2 2 3 3
it outputs 
1 6
2 2
3 2

That is what ever the input is output is same . I have spent a lot of time looking for the problem. I am not able to fix it.

Upvotes: 1

Views: 69

Answers (1)

chqrlie
chqrlie

Reputation: 144695

There are multiple problems in your code:

  • the vector objects should be passed by reference.

  • res is not the destination of the solve() function, but rather an auxiliary vector to act as temporary storage for mer. Calling it tmp would be less confusing.

  • the r index in solve() is excluded, so the test fo a slice with less than 2 elements and the recursive calls should be:

    void mer(vector<pair<int, int>>& a, int l, int m, int r, vector<pair<int, int>>& res) {
        int i = l;
        int j = m;
        int k = l;
        while (i < m && j < r) {
            if (a[i].second < a[j].second) {
                res[k] = a[i];
                i++;
                k++;
            } else
            if (a[i].second > a[j].second) {
                res[k] = a[j];
                j++;
                k++;
            } else {
                if (a[i].first > a[j].first) {
                    res[k] = a[j];
                    k++;
                    j++;
                } else {
                    res[k] = a[i];
                    i++;
                    k++;
                }
            }
        }
        while (i < m) {
            res[k] = a[i];
            k++;
            i++;
        }
        while (j < r) {
            res[k] = a[j];
            k++;
            j++;
        }
        for (int i = l; i < r; i++) {
            a[i] = res[i];
        }
    }
    
    void solve(vector<pair<int, int>>& a, int l, int r, vector<pair<int, int>>& tmp) {
        if (r - l >= 2) {
            int m = (l + r) / 2;
            solve(a, l, m, tmp);
            solve(a, m, r, tmp);
            mer(a, l, m, r, tmp);
        }
    }
    
  • in the main() function, you do not store the input values into the vector of pairs and you should also print the sorted pairs:

    int main() {
        int n;
        cin >> n;
    
        vector<pair<int, int>> a(n);
    
        for (int i = 0; i < n; i++) {
            a[i].first = i + 1;
            cin >> a[i].second;
        }
        vector<pair<int, int>> tmp(n);
        solve(a, 0, n, tmp);
        for (int i = 0; i < n; i++) {
            printf("%d %d\n", a[i].first, a[i].second);
        }
        return 0
    }
    

Upvotes: 1

Related Questions