2163: Minimum-Difference-in-Sums-After-Removal-of-Elements
Hard


table of contents

The intuition to this algorithm is that we should be able to place a “separator” to separate the first n elements and last n elements. Since the first n elements must be between 0 and 2n, (else, there will not be at least n elements left for the last n elements) and the last n elements must be between n and 3n (a similar reason), that implies that the “separator” must be an index, i, such that n < i <= 2n.

Then, for each possible “separator” index, we need to calculate the minimum sum of any n elements before i and then calculate the maximum sum of any n elements after i. This will give us 2n + 1 possible minimum differences which we can thus iterate through to find the ans.

Since finding the minimum sum of any n elements before i is just the same as finding the maximum sum of any n elements after i except you flip the operator, an explanation to find just the minimum sum will suffice.

First, iterate through the first n elements, adding each one into a maximum heap and calculating a running sum. Then, at each index, you can check if the top element of the maximum heap is greater than the current element. If so, you can replace the top element of the maximum heap inside the running sum and maximum heap with the current element instead. This will thus give us the minimum sum of any n elements before i, which we can store inside of an array.

We can then repeat similar steps to find an array full of the maximum sum of any n elements after i, which we can store inside of another array.

Finally, we can just iterate through both simulatenously to return the minimum difference possible between the sums of the two parts after the removal of n elements.

code

class Solution {
public:
    long long minimumDifference(vector<int>& nums) {
        int n = nums.size()/3;

        priority_queue<int> mxHp;
        vector<long long> mx;
        vector<long long> mn;
        long long currentSum = 0;

        for (int i = 0; i < n; ++i) {
            mxHp.push(nums[i]);
            currentSum += nums[i];
        }
        mx.push_back(currentSum);

        for (int i = n; i < 2*n; ++i) 
        {
            long long top = mxHp.top();
            if (top > nums[i]) {
                currentSum -= top;
                mxHp.pop();
                mxHp.push(nums[i]);
                currentSum += nums[i];
            }
            mx.push_back(currentSum);
        }

        currentSum = 0;
        priority_queue<int, vector<int>, greater<int>> mnHp;
        for (int i = 3*n-1; i >= 2*n; --i) {
            mnHp.push(nums[i]);
            currentSum += nums[i];
        }
        mn.push_back(currentSum);

        for (int i = 2*n-1; i >= n; --i) 
        {
            long long top = mnHp.top();
            if (top < nums[i]) {
                currentSum -= top;
                mnHp.pop();
                mnHp.push(nums[i]);
                currentSum += nums[i];
            }
            mn.push_back(currentSum);
        }

        long long ans = LLONG_MAX;
        for (int i = 0; i < mx.size(); ++i) {
            ans = min(ans, mx[i] - mn[mn.size()-i-1]);
        }
        return ans;
    }
};

complexity

time taken