https://leetcode.cn/problems/maximum-balanced-subsequence-sum/description/
这题最开始的想法是维护一个排序结构 `st = [(a, b)]`
- a 表示 `i-nums[i]`
- b 表示满足 `i-nums[i]` 是a的情况下的最大值
- a做升序,b则做降序
其实这个结构挺好的,唯一的问题就在于更新,维护这个数据结构的成本比较高:需要将某些元素从st里面删除掉,这个成本有点高。 我花了比较长时间进行调试,但是还有有一些test cases出错了。 **错误的思路,糟糕的数据结构,从一开始就会让整个调试成本增加。**
后面换了一个思路,就是其实 `i-nums[i]` 这个值是可以枚举出来的。如果我们将这些值映射成为偏移量的话,那么其实我们想求解的就是某个区间内的最大值, 并且每次只是更新其中一个点,这个数据结构正好就是上周使用的线段树。
#!/usr/bin/env python
# coding:utf-8
# Copyright (C) dirlt
from typing import List
class Solution:
def maxBalancedSubsequenceSum(self, nums: List[int]) -> int:
# dp[i] = max(dp[j]) + nums[i] if (i - j) <= nums[i] - nums[j]
# i - nums[i] <= j - nums[j]
diff = [i - nums[i] for i in range(len(nums))]
diff.sort(reverse=True)
pos = {}
for d in diff:
if d not in pos:
pos[d] = len(pos)
N = len(pos)
INF = 1 << 63
SZ = 1
while SZ < N:
SZ = SZ * 2
MAX = [-INF] * (2 * SZ)
def update_max(p, v):
k = p + SZ
MAX[k] = max(MAX[k], v)
while k != 1:
p = k // 2
MAX[p] = max(MAX[2 * p], MAX[2 * p + 1])
k = p
def query_max(p):
def do(i, j, k, s, sz):
if i <= s <= (s + sz - 1) <= j:
return MAX[k]
mid = s + sz // 2
res = -INF
if i < mid:
a = do(i, j, 2 * k, s, sz // 2)
res = max(res, a)
if j >= mid:
a = do(i, j, 2 * k + 1, mid, sz // 2)
res = max(res, a)
return res
return do(0, p, 1, 0, SZ)
ans = -INF
for i in range(len(nums)):
d = i - nums[i]
p = pos[d]
value = nums[i]
last = query_max(p)
value += max(last, 0)
update_max(p, value)
ans = max(ans, value)
return ans