// Source : https://leetcode.com/problems/count-of-range-sum/ // Author : Hao Chen // Date : 2016-01-15 /*************************************************************************************** * * Given an integer array nums, return the number of range sums that lie in [lower, * upper] inclusive. * * Range sum S(i, j) is defined as the sum of the elements in nums between indices * i and * j (i ≤ j), inclusive. * * Note: * A naive algorithm of O(n2) is trivial. You MUST do better than that. * * Example: * Given nums = [-2, 5, -1], lower = -2, upper = 2, * Return 3. * The three ranges are : [0, 0], [2, 2], [0, 2] and their respective sums are: -2, -1, 2. * * Credits:Special thanks to @dietpepsi for adding this problem and creating all test * cases. * ***************************************************************************************/ /* * At first of all, we can do preprocess to calculate the prefix sums * * S[i] = S(0, i), then S(i, j) = S[j] - S[i]. * * Note: S(i, j) as the sum of range [i, j) where j exclusive and j > i. * * With these prefix sums, it is trivial to see that with O(n^2) time we can find all S(i, j) * in the range [lower, upper] * * int countRangeSum(vector<int>& nums, int lower, int upper) { * int n = nums.size(); * long[] sums = new long[n + 1]; * for (int i = 0; i < n; ++i) { * sums[i + 1] = sums[i] + nums[i]; * } * int ans = 0; * for (int i = 0; i < n; ++i) { * for (int j = i + 1; j <= n; ++j) { * if (sums[j] - sums[i] >= lower && sums[j] - sums[i] <= upper) { * ans++; * } * } * } * delete []sums; * return ans; * } * * The above solution would get time limit error. * * Recall `count smaller number after self` where we encountered the problem * * count[i] = count of nums[j] - nums[i] < 0 with j > i * * Here, after we did the preprocess, we need to solve the problem * * count[i] = count of a <= S[j] - S[i] <= b with j > i * * In other words, if we maintain the prefix sums sorted, and then are able to find out * - how many of the sums are less than 'lower', say num1, * - how many of the sums are less than 'upper + 1', say num2, * Then 'num2 - num1' is the number of sums that lie within the range of [lower, upper]. * */ class Node{ public: long long val; int cnt; //amount of the nodes Node *left, *right; Node(long long v):val(v), cnt(1), left(NULL), right(NULL) {} }; // a tree stores all of prefix sums class Tree{ public: Tree():root(NULL){ } ~Tree() { freeTree(root); } void Insert(long long val) { Insert(root, val); } int LessThan(long long sum, int val) { return LessThan(root, sum, val, 0); } private: Node* root; //general binary search tree insert algorithm void Insert(Node* &root, long long val) { if (!root) { root = new Node(val); return; } root->cnt++; if (val < root->val ) { Insert(root->left, val); }else if (val > root->val) { Insert(root->right, val); } } //return how many of the sums less than `val` // - `sum` is the new sums which hasn't been inserted // - `val` is the `lower` or `upper+1` int LessThan(Node* root, long long sum, int val, int res) { if (!root) return res; if ( sum - root->val < val) { //if (sum[j, i] < val), which means all of the right branch must be less than `val` //so we add the amounts of sums in right branch, and keep going the left branch. res += (root->cnt - (root->left ? root->left->cnt : 0) ); return LessThan(root->left, sum, val, res); }else if ( sum - root->val > val) { //if (sum[j, i] > val), which means all of left brach must be greater than `val` //so we just keep going the right branch. return LessThan(root->right, sum, val, res); }else { //if (sum[j,i] == val), which means we find the correct place, //so we just return the the amounts of right branch.] return res + (root->right ? root->right->cnt : 0); } } void freeTree(Node* root){ if (!root) return; if (root->left) freeTree(root->left); if (root->right) freeTree(root->right); delete root; } }; class Solution { public: int countRangeSum(vector<int>& nums, int lower, int upper) { Tree tree; tree.Insert(0); long long sum = 0; int res = 0; for (int n : nums) { sum += n; int lcnt = tree.LessThan(sum, lower); int hcnt = tree.LessThan(sum, upper + 1); res += (hcnt - lcnt); tree.Insert(sum); } return res; } };