std::lower_bound 是一种分区查找算法

起因:每日一题官解看不明白

今天(2024 年 5 月 4 日)做 Leetcode 每日一题又没有做出来,最后抄了答案。题目是这样的:1235. 规划兼职工作

思路是先按照 endTime 排序,然后再 dp,然后 dp 中用二分查找求满足“自己的 endTime 小于等于当前元素 startTime ” 的元素数量。但是官方解答的 std::upper_bound 传参实在是看迷糊了。

二分查找:AoS 和 SoA 的比较

一般数组(Array of Structures 或 AoS)排序好之后,直接用 std::lower_bound / std::upper_bound 做就好了。但如果数组里面存的不是可以直接用 operator< 或者自定义 comp 比较的值,写 comp 就比较伤脑筋。比如这道题里面为了减少数据拷贝(或者内存使用量),只对下标排序:

注:comp 指的是自定义的比较器。

int jobScheduling(vector<int> &startTime, vector<int> &endTime, vector<int> &profit) {
    vector<int> idx(n);
    iota(idx.begin(), idx.end(), 0);
    sort(idx.begin(), idx.end(), [&](int i, int j) -> bool { return endTime[i] < endTime[j]; });
    // ...
}

现在有下标,但是需要拿着下标去访问不同的数组,布局方式是 Structure of Arrays,即 SoA。这个时候解题一般需要手写二分查找。能不能继续利用标准库的二分呢

看了一下 cppreference 上面 std::lower_boundstd::upper_bound 的 possible implementation,发现两者只有 comp 的使用方式有变化。拿默认比较方式 operator< 来理解:

  • std::lower_bound 函数形式是 comp(cur, pivot)。满足条件就 lo = mid + 1,否则 hi = mid。结果是把满足条件的划分在左边,找第一个不满足条件的。
  • std::upper_bound 函数形式是 comp(pivot, cur)。不满足条件就 lo = mid + 1,否则 hi = mid。结果是把不满足条件的划分在左边,找第一个满足条件的。

注意两者的传参方式不同。为了让 std::upper_bound 在当前值小于等于 pivot 的时候向右移动(实际上需要 operator<=),同时又只依赖 operator<,有必要交换两个参数的位置并对返回值逻辑取反,即 a <= b 变成 !(b < a)

这样,在使用 std::lower_bound 的时候,就应该有以下前提:数组分成左右两部分,左边满足性质 A,右边不满足性质 A,最终 std::lower_bound 会停在第一个不满足性质 A 的位置上。使用 std::upper_bound 的时候,也有类似的划分性质,左边不满足条件 B,右边满足条件 B,返回的迭代器指向第一个满足条件 B 的位置。比起来,std::upper_bound 要绕一点,但是两个函数是可以通过调整参数做到等价的。

class Solution {
   public:
    int jobScheduling(vector<int> &startTime, vector<int> &endTime, vector<int> &profit) {
        // 好难,又是看答案的一天。
        int n = startTime.size();
        vector<int> idx(n);
        iota(idx.begin(), idx.end(), 0);
        sort(idx.begin(), idx.end(), [&](int i, int j) -> bool { return endTime[i] < endTime[j]; });
        vector<int> dp(n + 1, 0);  // 前 i 份工作的最大报酬
        for (int i = 1; i <= n; ++i) {
            int j = idx[i - 1];
            /*
             * 拿默认比较方式 operator< 来理解:
             * - lower_bound
             *   - 函数形式是 comp(cur, pivot)。
             *   - 满足条件就 lo = mid + 1,否则 hi = mid。
             *   - 结果是把满足条件的划分在左边(不含停止位置),找第一个不满足条件的。
             * - upper_bound
             *   - 函数形式是 comp(pivot, cur)。
             *   - 不满足条件就 lo = mid + 1,否则 hi = mid。
             *   - 结果是把不满足条件的划分在左边(不含停止位置),找第一个满足条件的。
             */
            // auto it = upper_bound(idx.begin(), idx.begin() + i - 1, startTime[j] /* pivot */,
            //     [&](int pivot, int cur) -> bool { return endTime[cur] > pivot; });
            auto it = lower_bound(idx.begin(), idx.begin() + i - 1, startTime[j] /* pivot */,
                [&](int cur, int pivot) -> bool { return endTime[cur] <= pivot; });
            int k = it - idx.begin();
            dp[i] = max(dp[i - 1], dp[k] + profit[j]);
        }
        return dp[n];
    }
};

Hot or Cold 是 Hunt the thimble 的一个版本,搜索者通过提示人的冷(更远)热(更近)来寻找物体。是不是很像二分搜索?

std::lower_boundstd::partition_point 的特例

我们之前的写法居然根本没有用到数组的全局性质!也就是说,相邻的元素有序这一点是隐含的,并没有被我们直接使用:我们从返回值是一个分区点 + 数组有序这两个条件推断出了返回值一定是一个上界

标准库中 std::lower_bound 对参数的要求其实非常宽松:

Although std::lower_bound only requires [first, last) to be partitioned, this algorithm is usually used in the case where [first, last) is sorted, so that the binary search is valid for any value.

Unlike std::binary_searchstd::lower_bound does not require operator< or comp to be asymmetric (i.e., a < b and b < a always have different results). In fact, it does not even require value < *iter or comp(value, *iter) to be well-formed for any iterator iter in [first, last).

https://en.cppreference.com/w/cpp/algorithm/lower_bound

有了这种认识,我们完全可以用 std::partition_point 来解决 SoA 的场景,因为 std::lower_boundstd::upper_bound 参数中 pivot 的位置是不一样的,很容易弄错。而且 std::partition_point 的参数简单一些:

auto it = partition_point(idx.begin(), idx.begin() + i - 1, [&](int cur) {
    return endTime[cur] <= startTime[j];
});

这和

auto it = upper_bound(idx.begin(), idx.begin() + i - 1,
                      startTime[j] /* pivot */,
                      [&](int pivot, int cur) -> bool { return endTime[cur] > pivot; });

以及

auto it = lower_bound(idx.begin(), idx.begin() + i - 1,
                      startTime[j] /* pivot */,
                      [&](int cur, int pivot) -> bool { return endTime[cur] <= pivot; });

是完全等价的。

This algorithm is a more general form of std::lower_bound, which can be expressed in terms of std::partition_point with the predicate [&](const auto& e) { return e < value; });.

https://en.cppreference.com/w/cpp/algorithm/partition_point