https://leetcode.com/problems/reverse-pairs/
之所以想写这题,主要是因为自己在上面做了很多种尝试,虽然这些尝试在时间复杂度面前都是无效的。
LC上只要你的时间复杂度是错的,那么基本上就是TLE,没有办法在代码层面或者是特殊case上优化来AC. 说说这道题目我试过的算法吧。
首先我觉得这个问题是一个线段树,树节点定义如下:
- se 是一个tuple (start, end), 闭区间,表示这个树所覆盖的范围
- max_value 表示在这个范围内最大值
- min_value 表示在这个范围内最小值
- left, right 左右节点. 左边是 (s, (s+e)/2), 右边是 ((s+e) 2+1, e)
class Tree:
def __init__(self, se, max_value, min_value):
self.se = se
self.max_value = max_value
self.min_value = min_value
self.left = None
self.right = None
如果我们想查询从 `idx` 开始到数组结尾有多少个值是 `<=v` 的话,可以使用下面过程:
- 如果当前范围 `s >=idx` 并且最大值比v小的话,那么整个范围都符合条件
- 如果当前范围内最小值比v大的话,那么说明这个范围不符合条件
- 如果没有办法立刻决定,就需要搜索左右两棵子树
def query_tree(t: Tree, idx, v, ctx):
if t is None:
return 0
(s, e) = t.se
sz = (e - s) + 1
if idx <= s and v >= t.max_value:
return sz
if t.min_value > v:
return 0
m = (s + e) // 2
if idx > m:
a = 0
else:
a = query_tree(t.left, idx, v, ctx)
b = query_tree(t.right, idx, v, ctx)
return a + b
上面这段搜索代码要求搜索范围是一边固定(搜索到数组结尾)一边可变,但是改成两边可变的范围也不是什么难事。
这样整个外围代码就是这样的
# note(yan): 不知道这种区间树是否正确足够高效
class Solution:
def reversePairs(self, nums: List[int]) -> int:
if len(nums) < 2: return 0
t = build_tree(nums, 0, len(nums) - 1)
ans = 0
for i in range(len(nums) - 1):
j = i + 1
max_exp = (nums[i] - 1) // 2
res = query_tree(t, j, max_exp)
ans += res
return ans
但是搜索线段树远没有我想像的那么高效,因为这个强烈取决于里面的值分布。如果你运气好,访问第一棵树就可以停止, 如果运气不好,那么就需要搜索整棵树了,访问节点的数量就是线性的。这样下来最坏情况还是O(n^2).
但是我忽视这个时间复杂度,尝试继续从代码上或者是特殊路径上做优化,是否可以避免最坏的情况呢?一个优化点是这样的。
考虑数组 `[10 10 9(X) 4 9(Y) 4 4 4 4 3]` 如果我们从后向前遍历,当我们知道9(Y)后面有5个满足条件点的话, 那么当访问9(X)的话,其实只需要判断两个点 `[4 9(Y)]`, 而其余后面的点肯定是都满足的。
建立这个索引不复杂,时间空间复杂度都是O(n). 这个优化上了之后依然是TLE, 因为这个优化同样很依赖数据的分布。 比如对于 `[8 7 6 5 4 3 2 1]` 没有任何优化效果。
当意识到这个方向走不下去的时候,最好的方法还是从头思考问题,把原来的一些想法全部丢弃掉。如果我们考虑从后往前遍历, 每看到一个点就将这个点保存到数据库S中,我们还是有办法立刻查询到 “在这个数据库S中,有多少个点它的值是小于某个值v的”. 我们可以很容易地想到使用二叉树当这个数据库,这样程序大体框架就是这样的了。
class Solution:
def reversePairs(self, nums: List[int]) -> int:
if len(nums) < 2: return 0
t = None
ans = 0
for i in reversed(range(len(nums))):
v = nums[i]
max_exp = (v - 1) // 2
ans += query_tree(t, max_exp)
t = insert_tree(t, v)
return ans
不过简单的二叉树还是不能避免顺序数组造成的最坏情况,所以最好配合AVL来保持平衡。虽然这样的确可以通过,但是时间也比较长, 最重要的是代码也非常多。 *肯定有比AVL更好的方法,LC肯定不会让你去手写旋转树的*。如果要写旋转树,一定就是还有更好的方法。
最后就是使用归并排序算法。归并排序算法时间复杂度有保障,可以处理任何数据分布。而且根据我的经验,很适合解决这类需要重复扫描区间的计数问题。
class Solution:
def reversePairs(self, nums: List[int]) -> int:
self.ans = 0 # 直接把全局变量记在这里
def msort(lst):
n = len(lst)
if n <= 1:
return lst
return merge(msort(lst[:n // 2]), msort(lst[n // 2:]))
def merge(a, b):
i, j = 0, 0
while i < len(a) and j < len(b):
if a[i] <= 2 * b[j]:
i += 1
else:
self.ans += len(a) - i
j += 1
return sorted(a + b)
msort(nums)
return self.ans