Python语言基础


基本数据结构

Python 提供了丰富的内置数据结构,如 listdequedictset 等。以下是一些常用数据结构的介绍及其使用方法。

列表 list(动态数组)

list 是 Python 的可变序列类型,可以用作动态数组。

初始化方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
# 初始化一个空列表
nums = []

# 初始化一个包含元素 1, 3, 5 的列表
nums = [1, 3, 5]

# 初始化大小为 n,元素都为0的列表
n = 10
nums = [0] * n

# 二维列表,m 行 n 列,元素都为 1
m, n = 3, 4
matrix = [[1] * n for _ in range(m)]

Python 列表的常用方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
nums = [0] * 10

# 输出:False
print(len(nums) == 0)
# 输出:10
print(len(nums))

# 在列表尾部插入一个元素 20
nums.append(20)
# 输出:11
print(len(nums))

# 得到列表最后一个元素,输出:20
print(nums[-1])

# 删除列表的最后一个元素
nums.pop()
# 输出:10
print(len(nums))

# 索引访问与修改
nums[0] = 11
# 输出:11
print(nums[0])

# 在索引 3 处插入一个元素 99
nums.insert(3, 99)

# 删除索引 2 处的元素
nums.pop(2)

# 交换 nums[0] 和 nums[1]
nums[0], nums[1] = nums[1], nums[0]

# 遍历列表
# 输出示例:
# 0 11 99 0 0 0 0 0 0 0
for num in nums:
print(num, end=" ")
print()

双端队列 deque

deque 是 collections 模块提供的双端队列,可以高效地在两端插入和删除元素。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from collections import deque

# 初始化双端队列
lst = deque([1, 2, 3, 4, 5])

# 检查是否为空,输出:False
print(len(lst) == 0)

# 获取大小,输出:5
print(len(lst))

# 在头部插入 0,尾部插入 6
lst.appendleft(0)
lst.append(6)

# 获取头部和尾部元素,输出:0 6
print(lst[0], lst[-1])

# 删除头部和尾部元素
lst.popleft()
lst.pop()

# 在索引 2 处插入 99
lst.insert(2, 99)

# 删除索引 1 处的元素
del lst[1]

# 遍历双端队列
# 输出:1 99 3 4 5
for val in lst:
print(val, end=" ")
print()

队列 Queue

队列是一种操作受限制的数据结构:只允许在队尾插入元素,在队头删除元素。

Python 没有专门的队列类型,但可以使用 deque 来模拟队列,append 相当于入队,popleft 相当于出队。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from collections import deque

# 初始化队列
q = deque()

# 向队尾插入元素
q.append(10)
q.append(20)
q.append(30)

# 是否为空,输出:False
print(len(q) == 0)

# 大小,输出:3
print(len(q))

# 获取队头元素,不出队,输出:10
print(q[0])

# 队头元素出队
q.popleft()

# 新的队头元素,输出:20
print(q[0])

栈 Stack

虽然 Python 没有专门的栈类型,但可以使用 list 或 deque 来模拟栈。append 相当于压栈,pop 相当于出栈。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 使用 list 作为栈
s = []

# 压栈
s.append(10)
s.append(20)
s.append(30)

# 是否为空,输出:False
print(len(s) == 0)

# 大小,输出:3
print(len(s))

# 栈顶元素,输出:30
print(s[-1])

# 出栈
s.pop()

# 新的栈顶元素,输出:20
print(s[-1])

字典 dict(哈希表)

dict 是 Python 的哈希表实现,通过键值对存储数据,查找、插入和删除操作平均时间复杂度为 O(1)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# 初始化字典
hashmap = {1: "one", 2: "two", 3: "three"}

# 是否为空,输出:False
print(len(hashmap) == 0)

# 大小,输出:3
print(len(hashmap))

# 查找键
# 输出:Key 2 -> two
if 2 in hashmap:
print(f"Key 2 -> {hashmap[2]}")
else:
print("Key 2 not found.")

# 获取键对应的值,不存在则返回 None
# 输出:None
print(hashmap.get(4))

# 插入新键值对
hashmap[4] = "four"

# 获取新插入的值,输出:four
print(hashmap[4])

# 删除键 3
del hashmap[3]

# 检查删除后
if 3 in hashmap:
print(f"Key 3 -> {hashmap[3]}")
else:
print("Key 3 not found.")
# 输出:Key 3 not found.

# 遍历字典
# 输出:
# 1 -> one
# 2 -> two
# 4 -> four
for k, v in hashmap.items():
print(f"{k} -> {v}")

集合 set(哈希集合)

set 是 Python 的哈希集合,用于存储不重复元素,常用于去重和快速查询元素是否存在。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# 初始化集合
hashset = {1, 2, 3, 4}

# 是否为空,输出:False
print(len(hashset) == 0)

# 大小,输出:4
print(len(hashset))

# 查找元素
if 3 in hashset:
print("Element 3 found.")
else:
print("Element 3 not found.")
# 输出:Element 3 found.

# 插入新元素
hashset.add(5)

# 删除元素 2
# discard 不存在的元素不会报错
hashset.discard(2)

# 检查删除后
if 2 in hashset:
print("Element 2 found.")
else:
print("Element 2 not found.")
# 输出:Element 2 not found.

# 遍历集合,输出:
# 1
# 3
# 4
# 5
for element in hashset:
print(element)

时间空间复杂度入门


1、时空复杂度用 Big O 表示法表示(类似  等)。它们都是估计值,不需要精确计算,常数项和低增长项都可以忽略,仅需保留最高增长项

比方说  等同于  等同于 

2、我们分析算法复杂度时,分析的是最坏情况的复杂度。这一点会在下面的示例中体现。

3、时间复杂度用来衡量一个算法的执行效率,空间复杂度用来衡量算法的内存消耗,它们都是越小越好。

比方说时间复杂度 的算法比  的算法执行效率高,空间复杂度  的算法比  的算法内存消耗小。

当然,一般我们要说明这个 代表什么,比如 代表输入的数组的长度。

4、如何估算?现在你可以简单理解:时间复杂度大部分情况下就是看 for 循环的最大嵌套层数;空间复杂度就看算法申请了多少空间来存储数据

注意

以上的分析方法中,有些细节并不严谨:

1、按照 for 循环的嵌套层数来估算时间复杂度是简化的方法,其实不完全准确。

2、大部分时候我们是分析最坏情况下的复杂度,但是对于数据结构 API 的复杂度衡量,我们会分析平均复杂度。

完善的复杂度分析方法会在 算法时空复杂度分析实用指南具体介绍,以上估算方法对于学习本章内容足够了。

举几个例子来说比较直观。

时间/空间复杂度案例分析

**示例一,时间复杂度 ,空间复杂度 **:

1
2
3
4
5
6
# 输入一个整数数组,返回所有元素的和
def getSum(nums: List[int]) -> int:
sum = 0
for i in range(len(nums)):
sum += nums[i]
return sum

算法包含一个 for 循环遍历 nums 数组,所以时间复杂度是 ,其中 n 代表 nums 数组的长度。

我们的算法只使用了一个 sum 变量,这个 nums 是题目给的输入,不算在我们算法的空间复杂度里面,所以空间复杂度是 

**示例二,时间复杂度 ,空间复杂度 **:

1
2
3
4
5
6
7
8
# 当 n 是 10 的倍数时,计算累加和,否则返回 -1
def sum(n: int) -> int:
if n % 10 != 0:
return -1
sum = 0
for i in range(n + 1):
sum += i
return sum

其实只有当 n 是 10 的倍数时,算法才会执行 for 循环,时间复杂度是 。其他情况下算法会直接返回,时间复杂度是 

但是算法复杂度只考察最坏情况,所以这个算法的时间复杂度是 ,空间复杂度是 

**示例三,时间复杂度 ,空间复杂度 **:

1
2
3
4
5
6
7
# 数组是否存在两个数,它们的和为 target?
def hasTargetSum(nums: List[int], target: int) -> bool:
for i in range(len(nums)):
for j in range(i + 1, len(nums)):
if nums[i] + nums[j] == target:
return True
return False

算法嵌套了两层 for 循环,所以时间复杂度是 ,其中 代表 nums 数组的长度。

我们的算法只使用了 i, j 两个变量,这是常数级别的空间消耗,所以空间复杂度是 

你也许会说,内层的 for 循环并没有遍历整个数组,且有可能提前 return,算法实际执行的次数应该是小于  的,时间复杂度还是  吗?

是的,还是 。具体到不同的输入,算法的实际执行次数确实会小于 ,但我们不需要关心这些细节,估算一个最坏情况的时间复杂度就可以了。

每层 for 循环在最坏情况下都是  的时间复杂度,套在一起,总的时间复杂度是 

**示例四,时间复杂度 ,空间复杂度 **:

1
2
def exampleFn(n: int):
nums = [0] * n

这个函数中创建了一个大小为 n 的数组,所以空间复杂度是 

上述代码申请数组空间并将 n 个元素初始化为 0。内存申请操作的时间复杂度可以认为是 ,但为所有元素赋值的操作相当于一个隐藏的 for 循环(由编程语言为我们自动完成),时间复杂度是 。所以总的时间复杂度是 

时间复杂度并不仅仅体现在你看得到的 for 循环,每一行代码都可能有隐藏的时间复杂度。所以说要了解编程语言提供的常用数据结构实现原理,这是准确分析时间复杂度的基础。

**示例五,时间复杂度 ,空间复杂度 **:

1
2
3
4
5
6
# 输入一个整数数组,返回一个新的数组,新数组的每个元素是原数组对应元素的平方
def squareArray(nums: List[int]) -> List[int]:
res = [0] * len(nums)
for i in range(len(nums)):
res[i] = nums[i] * nums[i]
return res

算法初始化 res 数组需要  的时间复杂度,包含一个 for 循环,时间复杂度也是 ,总的时间复杂度是还是 其中 n 代表 nums 数组的长度。

我们声明了一个新的数组 res,这个数组的长度和 nums 数组一样,所以空间复杂度是 

数组(顺序存储)基本原理


我们在说「数组」的时候有多种不同的语境,因为不同的编程语言提供的数组类型和 API 是不一样的,所以开头先统一一下说辞,方便后面的讲解。

我认为暂且可以把「数组」分为两大类,一类是「静态数组」,一类是「动态数组」。

「静态数组」就是一块连续的内存空间,我们可以通过索引来访问这块内存空间中的元素,这才是数组的原始形态

而「动态数组」是编程语言为了方便我们使用,在静态数组的基础上帮我们添加了一些常用的 API,比如 push, insert, remove 等等方法,这些 API 可以让我们更方便地操作数组元素,不用自己去写代码实现这些操作。

本章的内容就是带大家仅仅使用最原始的静态数组,自己实现一个动态数组,实现增删查改的常见 API。以后你在使用标准库提供的数据结构时,就知道它们的底层运行原理了。

有了动态数组,后面讲到的队列、栈、哈希表等复杂数据结构都会依赖它进行实现。

静态数组

静态数组在创建的时候就要确定数组的元素类型和元素数量。只有在 C++、Java、Golang 这类语言中才提供了创建静态数组的方式,类似 Python、JavaScript 这类语言并没有提供静态数组的定义方式。

静态数组的用法比较原始,实际软件开发中很少用到,写算法题也没必要用,我们一般直接用动态数组。但为了理解原理,在这里还是要讲解一下。

定义一个静态数组的方法如下:

1
2
3
4
5
6
7
8
9
10
11
12
# 严格来说,Python 没有静态数组的定义方式
# 我们暂且使用列表模拟静态数组

# 定义一个大小为 10 的静态数组
arr = [0] * 10

# 使用索引赋值
arr[0] = 1
arr[1] = 2

# 使用索引取值
a = arr[0]

就这,没有其他什么操作了。

拿 C++ 来举例吧,int arr[10] 这段代码到底做了什么事情呢?主要有这么几件事:

1、在内存中开辟了一段连续的内存空间,大小是 10 * sizeof(int) 字节。一个 int 在计算机内存中占 4 字节,也就是总共 40 字节。

2、定义了一个名为 arr 的数组指针,指向这段内存空间的首地址。

那么 arr[1] = 2 这段代码又做了什么事情呢?主要有这么几件事:

1、计算 arr 的首地址加上 1 * sizeof(int) 字节(4 字节)的偏移量,找到了内存空间中的第二个元素的首地址

2、从这个地址开始的 4 个字节的内存空间中写入了整数 2

[!写给初学者]

我记得以前刚上大学的时候要学 C 语言基础,有些同学就绕不清楚什么指针的数组,数组的指针,绕来绕去的。其实只要明白了上面这个简单的流程,一切就很清楚了。

1、为什么数组的索引从 0 开始?就是方便取地址。arr[0] 就是 arr 的首地址,从这个地址往后的 4 个字节存储着第一个元素的值;arr[1] 就是 arr 的首地址加上 1 * 4 字节,也就是第二个元素的首地址,这个地址往后的 4 个字节存储着第二个元素的值。arr[2], arr[3] 以此类推。

2、因为数组的名字 arr 就指向整块内存的首地址,所以数组名 arr 就是一个指针。你直接取这个地址的值,就是第一个元素的值。也就是说,*arr 的值就是 arr[0],即第一个元素的值。

3、如果不用 memset 这种函数初始化数组的值,那么数组内的值是不确定的。因为 int arr[10] 这个语句只是请操作系统在内存中开辟了一块连续的内存空间,你也不知道这块空间是谁使用过的二手内存,你也不知道里面存了什么奇奇怪怪的东西。所以一般我们会用 memset 函数把这块内存空间的值初始化一下再使用。

当然,上面讲的这些内容都是针对 C/C++,因为大家学习计算机基础的时候都接触过。其他比如 Java Golang 这种语言,静态数组创建出来后会自动帮你把元素值都初始化为 0,所以不需要再显式进行初始化。

我梳理一下上面的因果逻辑,静态数组本质上就是一块连续的内存空间,int arr[10] 这个语句我们可以得知:

1、我们知道这块内存空间的首地址(数组名 arr 就指向这块内存空间的首地址)。

2、我们知道了每个元素的类型(比如 int),也就是知道了每个元素占用的内存空间大小(比如一个 int 占 4 字节,32 bit)。

3、这块内存空间是连续的,其大小为 10 * sizeof(int) 即 40 字节。

所以,我们获得了数组的超能力「随机访问」:只要给定任何一个数组索引,我可以在  的时间内直接获取到对应元素的值

因为我可以通过首地址和索引直接计算出目标元素的内存地址。计算机的内存寻址时间可以认为是 ,所以数组的随机访问时间复杂度是 

但是,一个人最大的优势往往也是他的最大劣势。数组连续内存的特性给了他随机访问的超能力,但它也因此吃了不少苦,下面介绍。

增删查改

数据结构的职责就是增删查改,再无其他。

那么刚刚介绍数组这种数据结构的底层原理,我们其实只介绍了「查」和「改」的部分,也就是通过索引修改和访问对应元素的值。那么「增删」这两个操作又是如何实现的呢?

要想给静态数组增加元素,这就有些复杂了,需要分情况讨论。

[!情况一,数组末尾追加(append)元素]

比方说,我有一个大小为 10 的数组,里面装了 4 个元素,现在想在末尾追加一个元素,怎么办?

比较简单,直接在对应的索引赋值就行了,这是大概的代码逻辑:

1
2
3
4
5
6
7
8
9
10
11
12
13
# 大小为 10 的数组已经装了 4 个元素
arr = [0] * 10
for i in range(4):
arr[i] = i

# 现在想在数组末尾追加一个元素 4
arr[4] = 4

# 再在数组末尾追加一个元素 5
arr[5] = 5

# 依此类推
# ...

**可以看到,由于只是对索引赋值,所以在数组末尾追加元素的时间复杂度是 **。

[!情况二,数组中间插入(insert)元素]

比方说,我有一个大小为 10 的数组 arr,前 4 个位置装了元素,现在想在第 3 个位置(索引 2 arr[2])插入一个新元素,怎么办?

这就要涉及「数据搬移」,给新元素腾出空位,然后再才能插入新元素。大概的代码逻辑是这样的:

1
2
3
4
5
6
7
8
9
10
11
12
13
# 大小为 10 的数组已经装了 4 个元素
arr = [0] * 10
for i in range(4):
arr[i] = i

# 在索引 2 置插入元素 666
# 需要把索引 2 以及之后的元素都往后移动一位
# 注意要倒着遍历数组中已有元素避免覆盖,不懂的话请看下方可视化面板
for i in range(4, 2, -1):
arr[i] = arr[i - 1]

# 现在第 3 个位置空出来了,可以插入新元素
arr[2] = 666

综上,在数组中间插入元素的时间复杂度是 ,因为涉及到数据搬移,给新元素腾地方

[!情况三,数组空间已满]

静态数组在创建时就要确定大小,比方说现在我创建了一个数组 int arr[10](一块 40 字节的连续内存空间),然后往里面存了 10 个元素,这时候我想再插入一个元素,怎么办?无论是追加在尾部还是插入到中间,都没有位置留给新元素了。

有的读者可能说,这个简单呀,在这 40 字节后面再加上 4 个字节的连续内存空间,用来存储新的元素,不就行了吗?

不行的,连续内存必须一次性分配,分配完了之后就不能随意增减了。因为你这块连续内存后面的内存空间可能已经被其他程序占用了,不能说你想要就给你。

那怎么办呢?只能重新申请一块更大的内存空间,把原来的数据复制过去,再插入新的元素,这就是数组的「扩容」操作。

比方说,我重新创建一个更大的数组 int arr[20],然后把原来的 10 个元素复制过去,这样就有空余位置插入新的元素了。

大概的逻辑是这样的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 大小为 10 的数组已经装满了
arr = [i for i in range(10)]

# 现在想在数组末尾追加一个元素 10
# 需要先扩容数组
newArr = [0] * 20

# 把原来的 10 个元素复制过去
for i in range(10):
newArr[i] = arr[i]

# 释放旧数组的内存空间
# ...

# 在新的大数组中追加新元素
newArr[10] = 10

**综上,数组的扩容操作会涉及到新数组的开辟和数据的复制,时间复杂度是 **。

删除元素的操作和增加元素的操作类似,也需要分情况讨论。

[!情况一,删除末尾元素]

比方说,我有一个大小为 10 的数组,里面装了 5 个元素,现在想删除末尾的元素,怎么办?

很简单,直接把末尾元素标记为一个特殊值代表已删除就行了,我们这里简单举例,就用 -1 作为特殊值代表已删除好了。后面带大家具体实现动态数组的时候,会有更完善的方法删除数组元素,**这里只是为了说明删除数组尾部元素的本质就是进行一次随机访问,时间复杂度是 **。

大概的代码逻辑是这样的:

1
2
3
4
5
6
7
# 大小为 10 的数组已经装了 5 个元素
arr = [0] * 10
for i in range(5):
arr[i] = i

# 删除末尾元素,暂时用 -1 代表元素已删除
arr[4] = -1

[!情况二,删除中间元素]

比方说,我有一个大小为 10 的数组,里面装了 5 个元素,现在想删除第 2 个元素(arr[1]),怎么办?

这也要涉及「数据搬移」,把被删元素后面的元素都往前移动一位,保持数组元素的连续性。

大概的代码逻辑是这样的:

1
2
3
4
5
6
7
8
9
10
11
12
13
# 大小为 10 的数组已经装了 5 个元素
arr = [0] * 10
for i in range(5):
arr[i] = i

# 删除 arr[1]
# 需要把 arr[1] 之后的元素都往前移动一位
# 注意要正着遍历数组中已有元素避免覆盖,不懂的话请看下方可视化面板
for i in range(1, 4):
arr[i] = arr[i + 1]

# 最后一个元素置为 -1 代表已删除
arr[4] = -1

综上,在数组中间删除元素的时间复杂度是 ,因为涉及到数据搬移

总结

综上,静态数组的增删查改操作的时间复杂度是:

  • 增:
    • 在末尾追加元素:
    • 在中间(非末尾)插入元素:
  • 删:
    • 删除末尾元素:
    • 删除中间(非末尾)元素:
  • 查:给定指定索引,查询索引对应的元素的值,时间复杂度 
  • 改:给定指定索引,修改索引对应的元素的值,时间复杂度 

有读者可能问,刚才不是还探讨过数组的扩容操作吗,扩容涉及到新数组空间的开辟和数据的复制,时间复杂度是 ,这个复杂度为什么没有算到「增」的复杂度里面呢?

还有个问题初学者要注意,我们说数组的查、改复杂度是 ,这个仅仅适用于给定索引的情况。如果反过来,比方说给你一个值,让你去找这个值在数组中对应的索引,那你只能遍历整个数组去寻找对吧,这个复杂度就是  了。

所以说要搞清楚原理,而不要去背概念。原理懂了,概念你自己都能推导出来的。

动态数组

刚才讲了静态数组的超能力和种种局限性,现在讲动态数组,动态数组是静态数组的强化版,也是我们在实际软件开发或者写算法题时最常用的数据结构之一。

首先,你不要以为动态数组可以解决静态数组在中间增删元素效率差的问题,不可能解决的。数组随机访问的超能力源于数组连续的内存空间,而连续的内存空间就不可避免地面对数据搬移和扩缩容的问题。

动态数组底层还是静态数组,只是自动帮我们进行数组空间的扩缩容,并把增删查改操作进行了封装,让我们使用起来更方便而已

简单列举一下各个语言的动态数组使用方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# 创建动态数组
# 不用显式指定数组大小,它会根据实际存储的元素数量自动扩缩容
arr = []

for i in range(10):
# 在末尾追加元素,时间复杂度 O(1)
arr.append(i)

# 在中间插入元素,时间复杂度 O(N)
# 在索引 2 的位置插入元素 666
arr.insert(2, 666)

# 在头部插入元素,时间复杂度 O(N)
arr.insert(0, -1)

# 删除末尾元素,时间复杂度 O(1)
arr.pop()

# 删除中间元素,时间复杂度 O(N)
# 删除索引 2 的元素
arr.pop(2)

# 根据索引查询元素,时间复杂度 O(1)
a = arr[0]

# 根据索引修改元素,时间复杂度 O(1)
arr[0] = 100

# 根据元素值查找索引,时间复杂度 O(N)
index = arr.index(666)

在后面的章节,我会手把手带大家实现一个动态数组,让大家更加深入地理解动态数组的原理。

动态数组代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
class MyArrayList:
# 默认初始容量
INIT_CAP = 1

def __init__(self, init_capacity=None):
self.data = [None] * (init_capacity if init_capacity is not None else self.__class__.INIT_CAP)
self.size = 0

# 增
def add_last(self, e):
cap = len(self.data)
# 看 data 数组容量够不够
if self.size == cap:
self._resize(2 * cap)
# 在尾部插入元素
self.data[self.size] = e
self.size += 1

def add(self, index, e):
# 检查索引越界
self._check_position_index(index)

cap = len(self.data)
# 看 data 数组容量够不够
if self.size == cap:
self._resize(2 * cap)

# 搬移数据 data[index..] -> data[index+1..]
# 给新元素腾出位置
for i in range(self.size-1, index-1, -1):
self.data[i+1] = self.data[i]

# 插入新元素
self.data[index] = e

self.size += 1

def add_first(self, e):
self.add(0, e)

# 删
def remove_last(self):
if self.size == 0:
raise Exception("NoSuchElementException")
cap = len(self.data)
# 可以缩容,节约空间
if self.size == cap // 4:
self._resize(cap // 2)

deleted_val = self.data[self.size - 1]
# 删除最后一个元素
self.data[self.size - 1] = None
self.size -= 1

return deleted_val

def remove(self, index):
# 检查索引越界
self._check_element_index(index)

cap = len(self.data)
# 可以缩容,节约空间
if self.size == cap // 4:
self._resize(cap // 2)

deleted_val = self.data[index]

# 搬移数据 data[index+1..] -> data[index..]
for i in range(index + 1, self.size):
self.data[i - 1] = self.data[i]

self.data[self.size - 1] = None
self.size -= 1

return deleted_val

def remove_first(self):
return self.remove(0)

# 查
def get(self, index):
# 检查索引越界
self._check_element_index(index)

return self.data[index]

# 改
def set(self, index, element):
# 检查索引越界
self._check_element_index(index)
# 修改数据
old_val = self.data[index]
self.data[index] = element
return old_val

# 工具方法
def get_size(self):
return self.size

def is_empty(self):
return self.size == 0

# 将 data 的容量改为 newCap
def _resize(self, new_cap):
temp = [None] * new_cap
for i in range(self.size):
temp[i] = self.data[i]
self.data = temp

def _is_element_index(self, index):
return 0 <= index < self.size

def _is_position_index(self, index):
return 0 <= index <= self.size

def _check_element_index(self, index):
if not self._is_element_index(index):
raise IndexError(f"Index: {index}, Size: {self.size}")

def _check_position_index(self, index):
if not self._is_position_index(index):
raise IndexError(f"Index: {index}, Size: {self.size}")

def display(self):
print(f"size = {self.size}, cap = {len(self.data)}")
print(self.data)


# Usage example
if __name__ == "__main__":
arr = MyArrayList(init_capacity=3)

# 添加 5 个元素
for i in range(1, 6):
arr.add_last(i)

arr.remove(3)
arr.add(1, 9)
arr.add_first(100)
val = arr.remove_last()

# 100 1 9 2 3
for i in range(arr.get_size()):
print(arr.get(i))

链表(链式存储)基本原理


刷过力扣的读者肯定对单链表非常熟悉,力扣上的单链表节点定义如下:

1
2
3
4
class ListNode:
def __init__(self, x):
self.val = x
self.next = None

这仅仅是一个最简单的单链表节点,方便力扣出算法题来考你。在实际的编程语言中,我们使用的链表节点会稍微复杂一点,类似这样:

1
2
3
4
5
class Node:
def __init__(self, prev, element, next):
self.val = element
self.next = next
self.prev = prev

主要区别有两个:

1、编程语言标准库一般都会提供泛型,即你可以指定 val 字段为任意类型,而力扣的单链表节点的 val 字段只有 int 类型。

2、编程语言标准库一般使用的都是双链表而非单链表。单链表节点只有一个 next 指针,指向下一个节点;而双链表节点有两个指针,prev 指向前一个节点,next 指向下一个节点。

有了 prev 前驱指针,链表支持双向遍历,但由于要多维护一个指针,增删查改时会稍微复杂一些,后面带大家实现双链表时会具体介绍。

为什么需要链表

前面介绍了 数组(顺序存储)的底层原理,说白了就是一块连续的内存空间,有了这块内存空间的首地址,就能直接通过索引计算出任意位置的元素地址。

链表不一样,一条链表并不需要一整块连续的内存空间存储元素。链表的元素可以分散在内存空间的天涯海角,通过每个节点上的 next, prev 指针,将零散的内存块串联起来形成一个链式结构。

这样做的好处很明显,首先就是可以提高内存的利用效率,链表的节点不需要挨在一起,给点内存 new 出来一个节点就能用,操作系统会觉得这娃好养活。

另外一个好处,它的节点要用的时候就能接上,不用的时候拆掉就行了,从来不需要考虑扩缩容和数据搬移的问题,理论上讲,链表是没有容量限制的(除非把所有内存都占满,这不太可能)。

当然,不可能只有好处没有局限性。数组最大的优势是支持通过索引快速访问元素,而链表就不支持。

这个不难理解吧,因为元素并不是紧挨着的,所以如果你想要访问第 3 个链表元素,你就只能从头结点开始往顺着 next 指针往后找,直到找到第 3 个节点才行。

上面是对链表这种数据结构的基本介绍,接下来我们就结合代码实现单/双链表的几个基本操作。

单链表的基本操作

我先写一个工具函数,用于创建一条单链表,方便后面的讲解:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class ListNode:
def __init__(self, x):
self.val = x
self.next = None


# 输入一个数组,转换为一条单链表
def createLinkedList(arr: 'List[int]') -> 'ListNode':
if arr is None or len(arr) == 0:
return None

head = ListNode(arr[0])
cur = head
for i in range(1, len(arr)):
cur.next = ListNode(arr[i])
cur = cur.next

return head

查/改

[!单链表的遍历/查找/修改]

比方说,我想访问单链表的每一个节点,并打印其值,可以这样写:

1
2
3
4
5
6
7
8
# 创建一条单链表
head = createLinkedList([1, 2, 3, 4, 5])

# 遍历单链表
p = head
while p is not None:
print(p.val)
p = p.next

类似的,如果是要通过索引访问或修改链表中的某个节点,也只能用 for 循环从头结点开始往后找,直到找到索引对应的节点,然后进行访问或修改。

[!在单链表头部插入新元素]

我们会持有单链表的头结点,所以只需要将插入的节点接到头结点之前,并将新插入的节点作为头结点即可。

1
2
3
4
5
6
7
8
9
# 创建一条单链表
head = createLinkedList([1, 2, 3, 4, 5])

# 在单链表头部插入一个新节点 0
newNode = ListNode(0)
newNode.next = head
head = newNode

# 现在链表变成了 0 -> 1 -> 2 -> 3 -> 4 -> 5

[!在单链表尾部插入新元素]

直接看代码吧,很简单:

1
2
3
4
5
6
7
8
9
10
11
12
13
# 创建一条单链表
head = createLinkedList([1, 2, 3, 4, 5])

# 在单链表尾部插入一个新节点 6
p = head
# 先走到链表的最后一个节点
while p.next is not None:
p = p.next
# 现在 p 就是链表的最后一个节点
# 在 p 后面插入新节点
p.next = ListNode(6)

# 现在链表变成了 1 -> 2 -> 3 -> 4 -> 5 -> 6

当然,如果我们持有对链表尾节点的引用,那么在尾部插入新节点的操作就会变得非常简单,不用每次从头去遍历了。这个优化会在后面具体实现双链表时介绍。

[!在单链表中间插入新元素]

这个操作稍微有点复杂,我们还是要先找到要插入位置的前驱节点,然后操作前驱节点把新节点插入进去:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 创建一条单链表
head = createLinkedList([1, 2, 3, 4, 5])

# 在第 3 个节点后面插入一个新节点 66
# 先要找到前驱节点,即第 3 个节点
p = head
for _ in range(2):
p = p.next
# 此时 p 指向第 3 个节点
# 组装新节点的后驱指针
new_node = ListNode(66)
new_node.next = p.next

# 插入新节点
p.next = new_node

# 现在链表变成了 1 -> 2 -> 3 -> 66 -> 4 -> 5

[!在单链表中删除一个节点]

删除一个节点,首先要找到要被删除节点的前驱节点,然后把这个前驱节点的 next 指针指向被删除节点的下一个节点。这样就能把被删除节点从链表中摘除了。

1
2
3
4
5
6
7
8
9
10
11
12
13
# 创建一条单链表
head = createLinkedList([1, 2, 3, 4, 5])

# 删除第 4 个节点,要操作前驱节点
p = head
for i in range(2):
p = p.next

# 此时 p 指向第 3 个节点,即要删除节点的前驱节点
# 把第 4 个节点从链表中摘除
p.next = p.next.next

# 现在链表变成了 1 -> 2 -> 3 -> 5

[!在单链表尾部删除元素]

这个操作比较简单,找到倒数第二个节点,然后把它的 next 指针置为 null 就行了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 创建一条单链表
head = createLinkedList([1, 2, 3, 4, 5])

# 删除尾节点
p = head
# 找到倒数第二个节点
while p.next.next is not None:
p = p.next

# 此时 p 指向倒数第二个节点
# 把尾节点从链表中摘除
p.next = None

# 现在链表变成了 1 -> 2 -> 3 -> 4

[!在单链表头部删除元素]

这个操作比较简单,直接把 head 移动到下一个节点就行了,直接看代码吧:

1
2
3
4
5
6
7
# 创建一条单链表
head = createLinkedList([1, 2, 3, 4, 5])

# 删除头结点
head = head.next

# 现在链表变成了 2 -> 3 -> 4 -> 5

不过可能有读者疑惑,之前那个旧的头结点 1 的 next 指针依然指向着节点 2,这样会不会造成内存泄漏?

不会的,这个节点 1 指向其他的节点是没关系的,只要保证没有其他引用指向这个节点 1,它就能被垃圾回收器回收掉。

当然,如果你非要显式把节点 1 的 next 指针置为 null,这是个很好的习惯,在其他场景中可能可以避免指针错乱的潜在问题。

在下面这个可视化面板中,我显式地把待删除节点的 next 指针置为 null 了:

是不是觉得复杂?

链表的增删查改操作确实比数组复杂。这是因为链表的节点不是紧挨着的,所以要增删一个节点,必须先找到它的前驱和后驱节点进行协同,然后才能通过指针操作把它插入或删除。

上面给出的代码还仅仅是最简单的例子,你会发现在头部、尾部、中间增删元素的代码都不一样。如果要实现一个真正可用的链表,你还要考虑到很多边界情况,比如链表可能为空、前后驱节点可能为空等,这些情况都得保证不出错。

而且,上面只是介绍了「单链表」,而我们下一章要实现的是「双链表」,双链表要同时维护前驱和后驱指针,指针操作会更复杂一些。

双链表的基本操作

我先写一个工具函数,用于创建一条双链表,方便后面的讲解:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class DoublyListNode:
def __init__(self, x):
self.val = x
self.next = None
self.prev = None

def createDoublyLinkedList(arr: List[int]) -> Optional[DoublyListNode]:
if not arr:
return None

head = DoublyListNode(arr[0])
cur = head

# for 循环迭代创建双链表
for val in arr[1:]:
new_node = DoublyListNode(val)
cur.next = new_node
new_node.prev = cur
cur = cur.next

return head

查/改

[!双链表的遍历/查找/修改]

对于双链表的遍历和查找,我们可以从头节点或尾节点开始,根据需要向前或向后遍历:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 创建一条双链表
head = createDoublyLinkedList([1, 2, 3, 4, 5])
tail = None

# 从头节点向后遍历双链表
p = head
while p:
print(p.val)
tail = p
p = p.next

# 从尾节点向前遍历双链表
p = tail
while p:
print(p.val)
p = p.prev

访问或修改节点时,可以根据索引是靠近头部还是尾部,选择合适的方向遍历,这样可以一定程度上提高效率。

[!在双链表头部插入新元素]

在双链表头部插入元素,需要调整新节点和原头节点的指针:

1
2
3
4
5
6
7
8
9
# 创建一条双链表
head = create_doubly_linked_list([1, 2, 3, 4, 5])

# 在双链表头部插入新节点 0
new_head = DoublyListNode(0)
new_head.next = head
head.prev = new_head
head = new_head
# 现在链表变成了 0 -> 1 -> 2 -> 3 -> 4 -> 5

[!在双链表尾部插入新元素]

在双链表尾部插入元素时,如果我们持有尾节点的引用,这个操作会非常简单:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 创建一条双链表
head = createDoublyLinkedList([1, 2, 3, 4, 5])

tail = head
# 先走到链表的最后一个节点
while tail.next is not None:
tail = tail.next

# 在双链表尾部插入新节点 6
newNode = DoublyListNode(6)
tail.next = newNode
newNode.prev = tail
# 更新尾节点引用
tail = newNode

# 现在链表变成了 1 -> 2 -> 3 -> 4 -> 5 -> 6

[!在双链表中间插入新元素]

在双链表的指定位置插入新元素,需要调整前驱节点和后继节点的指针。

比如下面的例子,把元素 66 插入到索引 3(第 4 个节点)的位置:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 创建一条双链表
head = createDoublyLinkedList([1, 2, 3, 4, 5])

# 想要插入到索引 3(第 4 个节点)
# 需要操作索引 2(第 3 个节点)的指针
p = head
for _ in range(2):
p = p.next

# 组装新节点
newNode = DoublyListNode(66)
newNode.next = p.next
newNode.prev = p

# 插入新节点
p.next.prev = newNode
p.next = newNode

# 现在链表变成了 1 -> 2 -> 3 -> 66 -> 4 -> 5

[!在双链表中删除一个节点]

在双链表中删除节点时,需要调整前驱节点和后继节点的指针来摘除目标节点:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 创建一条双链表
head = createDoublyLinkedList([1, 2, 3, 4, 5])

# 删除第 4 个节点
# 先找到第 3 个节点
p = head
for i in range(2):
p = p.next

# 现在 p 指向第 3 个节点,我们将它后面的那个节点摘除出去
toDelete = p.next

# 把 toDelete 从链表中摘除
p.next = toDelete.next
toDelete.next.prev = p

# 把 toDelete 的前后指针都置为 null 是个好习惯(可选)
toDelete.next = None
toDelete.prev = None

# 现在链表变成了 1 -> 2 -> 3 -> 5

[!在双链表头部删除元素]

在双链表头部删除元素需要调整头节点的指针:

1
2
3
4
5
6
7
8
9
10
11
12
# 创建一条双链表
head = createDoublyLinkedList([1, 2, 3, 4, 5])

# 删除头结点
toDelete = head
head = head.next
head.prev = None

# 清理已删除节点的指针
toDelete.next = None

# 现在链表变成了 2 -> 3 -> 4 -> 5

[!在双链表尾部删除元素]

在单链表中,由于缺乏前驱指针,所以删除尾节点时需要遍历到倒数第二个节点,操作它的 next 指针,才能把尾节点摘除出去。

但在双链表中,由于每个节点都存储了前驱节点的指针,所以我们可以直接操作尾节点,把它自己从链表中摘除:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 创建一条双链表
head = createDoublyLinkedList([1, 2, 3, 4, 5])

# 删除尾节点
p = head
# 找到尾结点
while p.next is not None:
p = p.next

# 现在 p 指向尾节点
# 把尾节点从链表中摘除
p.prev.next = None

# 把被删结点的指针都断开是个好习惯(可选)
p.prev = None

# 现在链表变成了 1 -> 2 -> 3 -> 4

链表代码实现


读完本文,你不仅学会了算法套路,还可以顺便解决如下题目:

LeetCode 力扣 难度
707. Design Linked List 707. 设计链表

前置知识

阅读本文前,你需要先学习:

  • [链表(链式存储)基础]

几个关键点

下面我会分别用双链表和单链给出一个简单的 MyLinkedList 代码实现,包含了基本的增删查改功能。这里给出几个关键点,等会你看代码的时候可以着重注意一下。

关键点一、同时持有头尾节点的引用

在力扣做题时,一般题目给我们传入的就是单链表的头指针。但是在实际开发中,用的都是双链表,而双链表一般会同时持有头尾节点的引用。

因为在软件开发中,在容器尾部添加元素是个非常高频的操作,双链表持有尾部节点的引用,就可以在  的时间复杂度内完成尾部添加元素的操作。

对于单链表来说,持有尾部节点的引用也有优化效果。比如你要在单链表尾部添加元素,如果没有尾部节点的引用,你就需要遍历整个链表找到尾部节点,时间复杂度是 ;如果有尾部节点的引用,就可以在  的时间复杂度内完成尾部添加元素的操作。

细心的读者可能会说,即便如此,如果删除一次单链表的尾结点,那么之前尾结点的引用就失效了,还是需要遍历一遍链表找到尾结点。

是的,但你再仔细想想,删除单链表尾结点的时候,是不是也得遍历到倒数第二个节点(尾结点的前驱),才能通过指针操作把尾结点删掉?那么这个时候,你不就可以顺便把尾结点的引用给更新了吗?

关键点二、虚拟头尾节点

在上一篇文章 [链表基础]中我提到过「虚拟头尾节点」技巧,它的原理很简单,就是在创建双链表时就创建一个虚拟头节点和一个虚拟尾节点,无论双链表是否为空,这两个节点都存在。这样就不会出现空指针的问题,可以避免很多边界情况的处理。

举例来说,假设虚拟头尾节点分别是 dummyHead 和 dummyTail,那么一条空的双链表长这样:

1
dummyHead <-> dummyTail

如果你添加 1,2,3 几个元素,那么链表长这样:

1
dummyHead <-> 1 <-> 2 <-> 3 <-> dummyTail

你以前要把在头部插入元素、在尾部插入元素和在中间插入元素几种情况分开讨论,现在有了头尾虚拟节点,无论链表是否为空,都只需要考虑在中间插入元素的情况就可以了,这样代码会简洁很多。

当然,虚拟头结点会多占用一点内存空间,但是比起给你解决的麻烦,这点空间消耗是划算的。

对于单链表,虚拟头结点有一定的简化作用,但虚拟尾节点没有太大作用。

虚拟节点是内部实现,对外不可见

虚拟节点是你内部实现数据结构的技巧,对外是不可见的。比如按照索引获取元素的 get(index) 方法,都是从真实节点开始计算索引,而不是从虚拟节点开始计算。

关键点三、内存泄露?

在前文 [动态数组实现]中,我提到了删除元素时,要注意内存泄露的问题。那么在链表中,删除元素会不会也有内存泄露的问题呢?

尤其是这样的写法,你觉得有没有问题:

1
2
3
4
5
6
// 假设单链表头结点 head = 1 -> 2 -> 3 -> 4 -> 5

// 删除单链表头结点
head = head.next;

// 此时 head = 2 -> 3 -> 4 -> 5

细心的读者可能认为这样写会有内存泄露的问题,因为原来的那个头结点 1 的 next 指针没有断开,依然指向着节点 2

但实际上这样写是 OK 的,因为 Java 的垃圾回收的判断机制是看这个对象是否被别人引用,而并不会 care 这个对象是否还引用着别人。

那个节点 1 的 next 指针确实还指向着节点 2,但是并没有别的指针引用节点 1 了,所以节点 1 最终会被垃圾回收器回收释放。所以说这个场景和数组中删除元素的场景是不一样的,你可以再仔细思考一下。

不过呢,删除节点时,最好还是把被删除节点的指针都置为 null,这是个好习惯,不会有什么代价,还可能避免一些潜在的问题。所以在下面的实现中,无论是否有必要,我都会把被删除节点上的指针置为 null。

如何验证你的实现?

你可以借助力扣第 707 题「设计链表」来验证自己的实现是否正确。注意 707 题要求的增删查改 API 名字和本文给出的不一样,所以需要修改一下才能通过。

双链表代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
class Node:
def __init__(self, val):
self.val = val
self.next = None
self.prev = None

class MyLinkedList:
# 虚拟头尾节点
def __init__(self):
self.head = Node(None)
self.tail = Node(None)
self.head.next = self.tail
self.tail.prev = self.head
self.size = 0

# ***** 增 *****

def add_last(self, e):
x = Node(e)
temp = self.tail.prev

temp.next = x
x.prev = temp
# temp <-> x

x.next = self.tail
self.tail.prev = x
# temp <-> x <-> tail
self.size += 1

def add_first(self, e):
x = Node(e)
temp = self.head.next
# head <-> temp
temp.prev = x
x.next = temp

self.head.next = x
x.prev = self.head
# head <-> x <-> temp
self.size += 1

def add(self, index, element):
self.check_position_index(index)
if index == self.size:
self.add_last(element)
return

# 找到 index 对应的 Node
p = self.get_node(index)
temp = p.prev
# temp <-> p

# 新要插入的 Node
x = Node(element)

p.prev = x
temp.next = x

x.prev = temp
x.next = p

# temp <-> x <-> p

self.size += 1

# ***** 删 *****

def remove_first(self):
if self.size < 1:
raise IndexError("No elements to remove")
# 虚拟节点的存在是我们不用考虑空指针的问题
x = self.head.next
temp = x.next
# head <-> x <-> temp
self.head.next = temp
temp.prev = self.head

# head <-> temp

self.size -= 1
return x.val

def remove_last(self):
if self.size < 1:
raise IndexError("No elements to remove")
x = self.tail.prev
temp = x.prev
# temp <-> x <-> tail

self.tail.prev = temp
temp.next = self.tail

# temp <-> tail

self.size -= 1
return x.val

def remove(self, index):
self.check_element_index(index)
# 找到 index 对应的 Node
x = self.get_node(index)
prev = x.prev
next = x.next
# prev <-> x <-> next
prev.next = next
next.prev = prev

self.size -= 1

return x.val

# ***** 查 *****

def get(self, index):
self.check_element_index(index)
# 找到 index 对应的 Node
p = self.get_node(index)

return p.val

def get_first(self):
if self.size < 1:
raise IndexError("No elements in the list")

return self.head.next.val

def get_last(self):
if self.size < 1:
raise IndexError("No elements in the list")

return self.tail.prev.val

# ***** 改 *****

def set(self, index, val):
self.check_element_index(index)
# 找到 index 对应的 Node
p = self.get_node(index)

old_val = p.val
p.val = val

return old_val

# ***** 其他工具函数 *****

def size(self):
return self.size

def is_empty(self):
return self.size == 0

def get_node(self, index):
self.check_element_index(index)
p = self.head.next
# TODO: 可以优化,通过 index 判断从 head 还是 tail 开始遍历
for _ in range(index):
p = p.next
return p

def is_element_index(self, index):
return 0 <= index < self.size

def is_position_index(self, index):
return 0 <= index <= self.size

# 检查 index 索引位置是否可以存在元素
def check_element_index(self, index):
if not self.is_element_index(index):
raise IndexError(f"Index: {index}, Size: {self.size}")

# 检查 index 索引位置是否可以添加元素
def check_position_index(self, index):
if not self.is_position_index(index):
raise IndexError(f"Index: {index}, Size: {self.size}")

def display(self):
print(f"size = {self.size}")
p = self.head.next
while p != self.tail:
print(f"{p.val} <-> ", end="")
p = p.next
print("null\n")

if __name__ == "__main__":
list = MyLinkedList()
list.add_last(1)
list.add_last(2)
list.add_last(3)
list.add_first(0)
list.add(2, 100)

list.display()
# size = 5
# 0 <-> 1 <-> 100 <-> 2 <-> 3 <-> null

单链表代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class MyLinkedList2:

class Node:
def __init__(self, val):
self.val = val
self.next = None

def __init__(self):
self.head = self.Node(None)
self.tail = self.head
self.size = 0

def add_first(self, e):
new_node = self.Node(e)
new_node.next = self.head.next
self.head.next = new_node
if self.size == 0:
self.tail = new_node
self.size += 1

def add_last(self, e):
new_node = self.Node(e)
self.tail.next = new_node
self.tail = new_node
self.size += 1

def add(self, index, element):
self.check_position_index(index)

if index == self.size:
self.add_last(element)
return

prev = self.head
for i in range(index):
prev = prev.next
new_node = self.Node(element)
new_node.next = prev.next
prev.next = new_node
self.size += 1

def remove_first(self):
if self.is_empty():
raise Exception("NoSuchElementException")
first = self.head.next
self.head.next = first.next
if self.size == 1:
self.tail = self.head
self.size -= 1
return first.val

def remove_last(self):
if self.is_empty():
raise Exception("NoSuchElementException")

prev = self.head
while prev.next != self.tail:
prev = prev.next
val = self.tail.val
prev.next = None
self.tail = prev
self.size -= 1
return val

def remove(self, index):
self.check_element_index(index)

prev = self.head
for i in range(index):
prev = prev.next

node_to_remove = prev.next
prev.next = node_to_remove.next
# 删除的是最后一个元素
if index == self.size - 1:
self.tail = prev
self.size -= 1
return node_to_remove.val

# ***** 查 *****

def get_first(self):
if self.is_empty():
raise Exception("NoSuchElementException")
return self.head.next.val

def get_last(self):
if self.is_empty():
raise Exception("NoSuchElementException")
return self.tail.val

def get(self, index):
self.check_element_index(index)
p = self.get_node(index)
return p.val

# ***** 改 *****

def set(self, index, element):
self.check_element_index(index)
p = self.get_node(index)

old_val = p.val
p.val = element

return old_val

# ***** 其他工具函数 *****
def get_size(self):
return self.size

def is_empty(self):
return self.size == 0

def is_element_index(self, index):
return 0 <= index < self.size

def is_position_index(self, index):
return 0 <= index <= self.size

# 检查 index 索引位置是否可以存在元素
def check_element_index(self, index):
if not self.is_element_index(index):
raise IndexError(f"Index: {index}, Size: {self.size}")

# 检查 index 索引位置是否可以添加元素
def check_position_index(self, index):
if not self.is_position_index(index):
raise IndexError(f"Index: {index}, Size: {self.size}")

# 返回 index 对应的 Node
# 注意:请保证传入的 index 是合法的
def get_node(self, index):
p = self.head.next
for i in range(index):
p = p.next
return p

if __name__ == "__main__":
list = MyLinkedList2()
list.add_first(1)
list.add_first(2)
list.add_last(3)
list.add_last(4)
list.add(2, 5)

print(list.remove_first()) # 2
print(list.remove_last()) # 4
print(list.remove(1)) # 5

print(list.get_first()) # 1
print(list.get_last()) # 3
print(list.get(1)) # 3
LeetCode 力扣 难度
707. Design Linked List 707. 设计链表

环形数组技巧及实现


[!一句话总结]

环形数组技巧利用求模(余数)运算,将普通数组变成逻辑上的环形数组,可以让我们用  的时间在数组头部增删元素。

环形数组原理

数组可能是环形的么?不可能。数组就是一块线性连续的内存空间,怎么可能有环的概念?

但是,我们可以在「逻辑上」把数组变成环形的,比如下面这段代码:

1
2
3
4
5
6
7
# 长度为 5 的数组
arr = [1, 2, 3, 4, 5]
i = 0
# 模拟环形数组,这个循环永远不会结束
while i < len(arr):
print(arr[i])
i = (i + 1) % len(arr)

这段代码的关键在于求模运算 %,也就是求余数。当 i 到达数组末尾元素时,i + 1 和 arr.length 取余数又会变成 0,即会回到数组头部,这样就在逻辑上形成了一个环形数组,永远遍历不完。

这就是环形数组技巧。这个技巧如何帮助我们在  的时间在数组头部增删元素呢?

是这样,假设我们现在有一个长度为 6 的数组,现在其中只装了 3 个元素,如下(未装元素的位置用 _ 标识):

1
[1, 2, 3, _, _, _]

现在我们要在数组头部删除元素 1,那么我们可以把数组变成这样:

1
[_, 2, 3, _, _, _]

即,我们仅仅把元素 1 的位置标记为空,但并不做数据搬移。

此时,如果我们要在数组头部增加元素 4 和元素 5,我们可以把数组变成这样:

1
[4, 2, 3, _, _, 5]

你可以看到,当头部没有位置添加新元素时,它转了一圈,把新元素加到尾部了。

核心原理

上面只是让大家对环形数组有一个直观地印象,环形数组的关键在于,它维护了两个指针 start 和 endstart 指向第一个有效元素的索引,end 指向最后一个有效元素的下一个位置索引。

这样,当我们在数组头部添加或删除元素时,只需要移动 start 索引,而在数组尾部添加或删除元素时,只需要移动 end 索引。

当 start, end 移动超出数组边界(< 0 或 >= arr.length)时,我们可以通过求模运算 % 让它们转一圈到数组头部或尾部继续工作,这样就实现了环形数组的效果。

动手环节

纸上得来终觉浅,绝知此事要躬行。

我在可视化面板实现了一个简单的环形数组,你可以点击下面代码中的 arr.addLast 或 arr.addFirst,注意观察 start, end 指针以及 arr 数组中元素的变化:

代码实现

[!关键点、注意开闭区间]

在我的代码中,环形数组的区间被定义为左闭右开的,即 [start, end) 区间包含数组元素。所以其他的方法都是以左闭右开区间为基础实现的。

那么肯定就会有读者问,为啥要左闭右开,我就是想两端都开,或者两端都闭,不行么?

理论上,你可以随意设计区间的开闭,但一般设计为左闭右开区间是最方便处理的

因为这样初始化 start = end = 0 时,区间 [0, 0) 中没有元素,但只要让 end 向右移动(扩大)一位,区间 [0, 1) 就包含一个元素 0 了。

如果你设置为两端都开的区间,那么让 end 向右移动一位后开区间 (0, 1) 仍然没有元素;如果你设置为两端都闭的区间,那么初始区间 [0, 0] 就已经包含了一个元素。这两种情况都会给边界处理带来不必要的麻烦,如果你非要使用的话,需要在代码中做一些特殊处理。

最后,请看代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class CycleArray:
def __init__(self, size=1):
self.size = size
# 因为 Python 支持直接创建泛型数组,所以不需要类型转换
self.arr = [None] * size
# start 指向第一个有效元素的索引,闭区间
self.start = 0
# 切记 end 是一个开区间,
# 即 end 指向最后一个有效元素的下一个位置索引
self.end = 0
self.count = 0

# 自动扩缩容辅助函数
def resize(self, newSize):
# 创建新的数组
new_arr = [None] * newSize
# 将旧数组的元素复制到新数组中
for i in range(self.count):
new_arr[i] = self.arr[(self.start + i) % self.size]
self.arr = new_arr
# 重置 start 和 end 指针
self.start = 0
self.end = self.count
self.size = newSize

# 在数组头部添加元素,时间复杂度 O(1)
def add_first(self, val):
# 当数组满时,扩容为原来的两倍
if self.is_full():
self.resize(self.size * 2)
# 因为 start 是闭区间,所以先左移,再赋值
self.start = (self.start - 1 + self.size) % self.size
self.arr[self.start] = val
self.count += 1

# 删除数组头部元素,时间复杂度 O(1)
def remove_first(self):
if self.is_empty():
raise Exception("Array is empty")
# 因为 start 是闭区间,所以先赋值,再右移
self.arr[self.start] = None
self.start = (self.start + 1) % self.size
self.count -= 1
# 如果数组元素数量减少到原大小的四分之一,则减小数组大小为一半
if self.count > 0 and self.count == self.size // 4:
self.resize(self.size // 2)

# 在数组尾部添加元素,时间复杂度 O(1)
def add_last(self, val):
if self.is_full():
self.resize(self.size * 2)
# 因为 end 是开区间,所以是先赋值,再右移
self.arr[self.end] = val
self.end = (self.end + 1) % self.size
self.count += 1

# 删除数组尾部元素,时间复杂度 O(1)
def remove_last(self):
if self.is_empty():
raise Exception("Array is empty")
# 因为 end 是开区间,所以先左移,再赋值
self.end = (self.end - 1 + self.size) % self.size
self.arr[self.end] = None
self.count -= 1
# 缩容
if self.count > 0 and self.count == self.size // 4:
self.resize(self.size // 2)

# 获取数组头部元素,时间复杂度 O(1)
def get_first(self):
if self.is_empty():
raise Exception("Array is empty")
return self.arr[self.start]

# 获取数组尾部元素,时间复杂度 O(1)
def get_last(self):
if self.is_empty():
raise Exception("Array is empty")
# end 是开区间,指向的是下一个元素的位置,所以要减 1
return self.arr[(self.end - 1 + self.size) % self.size]

def is_full(self):
return self.count == self.size

def size(self):
return self.count

def is_empty(self):
return self.count == 0

思考题

数组增删头部元素的效率真的只能是  么?

我们都说,在数组增删头部元素的时间复杂度是 ,因为需要搬移元素。但是,如果我们使用环形数组,其实是可以实现在 的时间复杂度内增删头部元素的。

当然,上面实现的这个环形数组只提供了 addFirst, removeFirst, addLast, removeLast 这几个方法,并没有提供 我们之前实现的动态数组 的某些方法,比如删除指定索引的元素,获取指定索引的元素,在指定索引插入元素等等。

但是你可以思考一下,难道环形数组实现不了这些方法么?环形数组实现这些方法,时间复杂度相比普通数组,有退化吗?

好像没有吧。

环形数组也可以删除指定索引的元素,也要做数据搬移,和普通数组一样,复杂度是 

环形数组也可以获取指定索引的元素(随机访问),只不过不是直接访问对应索引,而是要通过 start 计算出真实索引,但计算和访问的时间复杂度依然是

环形数组也可以在指定索引插入元素,当然也要做数据搬移,和普通数组一样,复杂度是 
# 跳表核心原理

在实际的面试中,几乎不会让你手写跳表的实现代码,但可能会问你跳表的基本原理及复杂度分析,所以本站需要讲解这种数据结构。

本文处在基础章节,不会具体讲解跳表的实现细节,只介绍跳表的核心原理。初学者学习本文,知道有这么一种数据结构,了解它的基本原理和时间复杂度即可。具体的代码实现将放到数据结构设计章节。

在 链表基础 中我们说到,在单链表中增删查改指定索引的元素所需的时间复杂度是 

其实,如果拿到了待操作的链表节点,操作几次指针就能完成删除、修改、插入操作,时间复杂度是 

时间主要消耗在查询操作,因为通过索引查询对应的节点,只能从头结点开始,逐个遍历到目标节点,然后才做删除、修改、插入操作。

那么,我们是否可以通过一些优化方式,让链表支持快速的查找操作呢?

有一种方式是借助键值映射,用  的时间直接拿到目标节点,避免了遍历查找的时间消耗,这个思路在后面的 [哈希链表(LinkedHashMap)]中会详细介绍。

另一种方式,这就是本文介绍的跳表(Skip List),利用空间换时间的思想,用额外的空间记录额外的信息,增删查改的时间复杂度都能优化到 

跳表核心原理

我们就以查询指定索引的元素为例,来看看跳表是如何优化单链表的。

一条普通的单链表长这样:

1
2
index  0  1  2  3  4  5  6  7  8  9
node a->b->c->d->e->f->g->h->i->j

如果我们想查询索引为 7 的元素是什么,只能从索引 0 头结点开始往后遍历,直到遍历到索引 7,找到目标节点 h

而跳表则是这样的:

1
2
3
4
5
indexLevel   0-----------------------8-----10
indexLevel 0-----------4-----------8-----10
indexLevel 0-----2-----4-----6-----8-----10
indexLevel 0--1--2--3--4--5--6--7--8--9--10
nodeLevel a->b->c->d->e->f->g->h->i->j->k

跳表相当于在原链表的基础上,增加了多层索引,每向上一层,索引节点的数量减少一半,索引的间隔变为 2 倍,所以索引的高度是 ⁡, 代表链表中元素的个数。

此时,如果我们想查询索引为 7 的元素,可以从最高层索引开始一层一层地往下找:

首先最高层的第一个索引区间是 [0, 8],可以确定索引 7 在这个区间内,所以从下一层的节点 0 开始搜索;

第二层从节点 0 开始,索引区间 [0, 4] 不包含索引 7,继续往右移动到节点 4,索引区间 [4, 8] 包含索引 7,所以从下一层的节点 4 开始搜索;

第三层从节点 4 开始,索引区间 [4, 6] 不包含索引 7,继续往右移动到节点 6,索引区间 [6, 8] 包含索引 7,所以从下一层的节点 6 开始搜索;

第四层从节点 6 开始,索引区间 [6, 7] 包含索引 7,最终找到目标节点 h

这个搜索过程中,会经过  层索引,在每层索引中移动的次数不会超过 2 次(因为上层索引区间在下一层被分为两半),所以跳表的查询时间复杂度是 

总结

上面这个简化的例子应该能让你对跳表的核心原理有个直观的认识,跳表是典型的空间换时间设计思路,额外维护多层索引,增加空间复杂度,降低增删查改的时间复杂度。

跳表的具体实现还是有一些复杂,而且和上面的简化示例有一些不同,下面补充几点:

1、上面的例子只展示了查询操作,但跳表肯定得支持插入和删除操作,这就涉及到索引层中节点的动态调整,你需要保证每一层的索引区间尽可能二分,这样才能保证索引层的高度为 ⁡,否则时间复杂度就会退化。

2、不仅仅是查找索引对应的节点,跳表还可以运用到更通用的场景,比如说有序键值对的存储和查找。实际上,跳表的使用场景和后面我们会学习到的二叉搜索树非常类似,只不过跳表的代码实现相较于自平衡二叉搜索树要简单很多。

队列/栈基本原理


计算机的两种存储方式,顺序存储(数组)和链式存储(链表)都讲完了,之后的所有数据结构都是基于这两种存储方式之上玩花活。

本文讲解队列和栈的基本原理,后面的文章会讲解如何用代码具体实现。

先说概念吧,其实队列和栈都是「操作受限」的数据结构。说它操作受限,主要是和基本的数组和链表相比,它们提供的 API 是不完整的。

比方说我们前面实现的数组和链表,增删查改的 API 都实现过了,你可以对任意一个索引元素进行增删查改,只要索引不越界,就随便你。

但是对于队列和栈,它们的操作是受限的:队列只能在一端插入元素,另一端删除元素;栈只能在某一端插入和删除元素。说白了就是把数组链表提供的 API 删掉了一部分,只保留头尾操作元素的 API 给你用。

形象地理解,队列只允许在队尾插入元素,在队头删除元素,栈只允许在栈顶插入元素,从栈顶删除元素。这个图中把栈竖着画,队列横着画,只是为了更形象,但实际上它们底层都是数组和链表实现的,后面会讲到:

队列就像排队买票,先来的先离开,后来的后离开;栈就像一摞盘子,最先放的压在最下面,最后放的留在最上面,拿的时候也是最上面的先被拿走。所以我们常说,队列是一种「先进先出」的数据结构,栈是一种「先进后出」的数据结构,就是这个道理。

这两种数据结构的基本 API 如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# 队列的基本 API
class MyQueue:

# 向队尾插入元素,时间复杂度 O(1)
def push(self, e):
pass

# 从队头删除元素,时间复杂度 O(1)
def pop(self):
pass

# 查看队头元素,时间复杂度 O(1)
def peek(self):
pass

# 返回队列中的元素个数,时间复杂度 O(1)
def size(self):
pass

# 栈的基本 API
class MyStack:

# 向栈顶插入元素,时间复杂度 O(1)
def push(self, e):
pass

# 从栈顶删除元素,时间复杂度 O(1)
def pop(self):
pass

# 查看栈顶元素,时间复杂度 O(1)
def peek(self):
pass

# 返回栈中的元素个数,时间复杂度 O(1)
def size(self):
pass

不同编程语言中,队列和栈提供的方法名称可能不一样,但每个方法的效果肯定是一样的。

有些语言的标准库可能没有直接提供队列和栈,你可以自己用数组或者链表模拟出队列和栈的效果。下一章我就会先带你用链表实现队列和栈。

用链表实现队列/栈


用链表实现栈

一些读者应该已经知道该怎么用链表作为底层数据结构实现队列和栈了。因为实在是太简单了,直接调用双链表的 API 就可以了。

注意我这里是直接用标准库的链表容器,如果你用之前我们实现的 MyLinkedList,也是一样的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from collections import deque

# 用链表作为底层数据结构实现栈
# Python 的 deque 就是双链表
class MyLinkedStack:
def __init__(self):
self.list = deque()

# 向栈顶加入元素,时间复杂度 O(1)
def push(self, e):
self.list.append(e)

# 从栈顶弹出元素,时间复杂度 O(1)
def pop(self):
return self.list.pop()

# 查看栈顶元素,时间复杂度 O(1)
def peek(self):
return self.list[-1]

# 返回栈中的元素个数,时间复杂度 O(1)
def size(self):
return len(self.list)


if __name__ == "__main__":
stack = MyLinkedStack()
stack.push(1)
stack.push(2)
stack.push(3)
print(stack.pop())
print(stack.peek())
print(stack.size())

提示

上面这段代码相当于是把双链表的尾部作为栈顶,在双链表尾部增删元素的时间复杂度都是 ,符合要求。

当然,你也可以把双链表的头部作为栈顶,因为双链表头部增删元素的时间复杂度也是 ,所以这样实现也是一样的。只要做几个修改 addLast -> addFirstremoveLast -> removeFirstgetLast -> getFirst 就行了。

用链表实现队列

同理,用链表实现队列也是一样的,也直接调用双链表的 API 就可以了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# deque 是 Python 的双链表
from collections import deque

# 用链表作为底层数据结构实现队列
# Python 的 deque 就是双链表
class MyLinkedQueue:
def __init__(self):
self.list = deque()

# 向队尾插入元素,时间复杂度 O(1)
def push(self, e):
self.list.append(e)

# 从队头删除元素,时间复杂度 O(1)
def pop(self):
return self.list.popleft()

# 查看队头元素,时间复杂度 O(1)
def peek(self):
return self.list[0]

# 返回队列中的元素个数,时间复杂度 O(1)
def size(self):
return len(self.list)

if __name__ == "__main__":
queue = MyLinkedQueue()
queue.push(1)
queue.push(2)
queue.push(3)
print(queue.peek()) # 1
print(queue.pop()) # 1
print(queue.pop()) # 2
print(queue.peek()) # 3

提示

上面这段代码相当于是把双链表的尾部作为队尾,把双链表的头部作为队头,在双链表的头尾增删元素的复杂度都是 ,符合队列 API 的要求。

当然,你也可以反过来,把双链表的头部作为队尾,双链表的尾部作为队头。类似栈的实现,只要改一改 list 的调用方法就行了。

用数组实现队列/栈


用数组实现栈

先用数组实现栈,这个不难,你把动态数组的尾部作为栈顶,然后调用动态数组的 API 就行了。因为数组尾部增删元素的时间复杂度都是 ,符合栈的要求。

我这里直接用标准库提供的动态数组,如果你想用之前我们实现的 MyArrayList,也是一样的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 用数组作为底层数据结构实现栈
class MyArrayStack:
def __init__(self):
self.arr = []

# 向栈顶加入元素,时间复杂度 O(1)
def push(self, e):
self.arr.append(e)

# 从栈顶弹出元素,时间复杂度 O(1)
def pop(self):
return self.arr.pop()

# 查看栈顶元素,时间复杂度 O(1)
def peek(self):
return self.arr[-1]

# 返回栈中的元素个数,时间复杂度 O(1)
def size(self):
return len(self.arr)

能否让数组的头部作为栈顶?

按照我们之前实现 MyArrayList 的逻辑,是不行的。因为数组头部增删元素的时间复杂度都是 ,不符合要求。

但是我们可以改用前文 环形数组技巧中实现的 CycleArray 类,这个数据结构在头部增删元素的时间复杂度是 ,符合栈的要求。

你直接调用 CycleArray 的 addFirst 和 removeFirst 方法实现栈的 API 就行,我这里就不写了。

用数组实现队列

有了前文 [环形数组](## 环形数组原理)中实现的 CycleArray 类,用数组作为底层数据结构实现队列就不难了吧。直接复用我们实现的 CycleArray,就可以实现标准队列了。当然,一些编程语言也有内置的环形数组实现,你也可以自行搜索使用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class MyArrayQueue:
def __init__(self):
self.arr = CycleArray()

def push(self, t):
self.arr.add_last(t)

def pop(self):
return self.arr.remove_first()

def peek(self):
return self.arr.get_first()

def size(self):
return self.arr.size()

双端队列(Deque)原理及实现


基本原理

如果你理解了前面讲解的内容,这个双端队列其实没啥可讲的了。所谓双端队列,主要是对比标准队列(FIFO 先进先出队列)多了一些操作罢了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class MyDeque:
# 从队头插入元素,时间复杂度 O(1)
def add_first(self, e):
pass

# 从队尾插入元素,时间复杂度 O(1)
def add_last(self, e):
pass

# 从队头删除元素,时间复杂度 O(1)
def remove_first(self):
pass

# 从队尾删除元素,时间复杂度 O(1)
def remove_last(self):
pass

# 查看队头元素,时间复杂度 O(1)
def peek_first(self):
pass

# 查看队尾元素,时间复杂度 O(1)
def peek_last(self):
pass

[标准队列] 只能在队尾插入元素,队头删除元素,而双端队列的队头和队尾都可以插入或删除元素。

普通队列就好比排队买票,先来的先买,后来的后买;而双端队列就好比一个过街天桥,两端都可以随意进出。当然,双端队列的元素就不再满足「先进先出」了,因为它比较灵活嘛。

在做算法题的场景中,双端队列用的不算很多。感觉只有 Python 用到的多一些,因为 Python 标准库没有提供内置的栈和队列,一般会用双端队列来模拟标准队列。

用链表实现双端队列

很简单吧,直接复用我们之前实现的 [MyLinkedList]类,或者使用编程语言标准库提供的双链表结构就行了。因为双链表本就支持  时间复杂度在链表的头尾增删元素:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class MyListDeque:
def __init__(self):
# 使用我们之前实现的 `MyLinkedList` 类
self.list = MyLinkedList()

# 从队头插入元素,时间复杂度 O(1)
def add_first(self, e):
self.list.add_first(e)

# 从队尾插入元素,时间复杂度 O(1)
def add_last(self, e):
self.list.add_last(e)

# 从队头删除元素,时间复杂度 O(1)
def remove_first(self):
return self.list.remove_first()

# 从队尾删除元素,时间复杂度 O(1)
def remove_last(self):
return self.list.remove_last()

# 查看队头元素,时间复杂度 O(1)
def peek_first(self):
return self.list.get_first()

# 查看队尾元素,时间复杂度 O(1)
def peek_last(self):
return self.list.get_last()

# 使用示例
my_deque = MyListDeque()

my_deque.add_first(1)
my_deque.add_first(2)
my_deque.add_last(3)
my_deque.add_last(4)

print(my_deque.remove_first()) # 2
print(my_deque.remove_last()) # 4
print(my_deque.peek_first()) # 1
print(my_deque.peek_last()) # 3

用数组实现双端队列

也很简单吧,直接复用我们在 [环形数组技巧]中实现的 CycleArray 提供的方法就行了。环形数组头尾增删元素的复杂度都是 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class MyArrayDeque:
def __init__(self):
self.arr = CycleArray()

# 从队头插入元素,时间复杂度 O(1)
def add_first(self, e):
self.arr.add_first(e)

# 从队尾插入元素,时间复杂度 O(1)
def add_last(self, e):
self.arr.add_last(e)

# 从队头删除元素,时间复杂度 O(1)
def remove_first(self):
return self.arr.remove_first()

# 从队尾删除元素,时间复杂度 O(1)
def remove_last(self):
return self.arr.remove_last()

# 查看队头元素,时间复杂度 O(1)
def peek_first(self):
return self.arr.get_first()

# 查看队尾元素,时间复杂度 O(1)
def peek_last(self):
return self.arr.get_last()

哈希表核心原理


首先,我需要先阐明一个初学者很容易犯的概念错误。

请问,哈希表和我们常说的 Map(键值映射)是不是同一个东西?不是。

这一点用 Java 来讲解就很清楚,Map 是一个 Java 接口,仅仅声明了若干个方法,并没有给出方法的具体实现:

1
2
3
4
5
6
interface Map<K, V> {
V get(K key);
void put(K key, V value);
V remove(K key);
// ...
}

Map 接口本身只定义了键值映射的一系列操作,HashMap 这种数据结构根据自身特点实现了这些操作。还有其他数据结构也实现了这个接口,比如 TreeMapLinkedHashMap 等等。

换句话说,你可以说 HashMap 的 get, put, remove 方法的复杂度都是 的,但你不能说 Map 接口的复杂度都是 。因为如果换成其他的实现类,比如底层用二叉树结构实现的 TreeMap,这些方法的复杂度就变成  了。

我为什么要强调这一点呢?主要是针对使用非 Java 语言的读者。

其他编程语言可能没有 Java 这么清晰的接口定义,所以很容易让读者把哈希表和 Map 键值对混为一谈,听到键值对操作,就认为其增删查改的复杂度一定是 。这是不对的,具体要看这个底层的数据结构是如何实现键值操作的。

那么这一章节我会带大家动手实现一个哈希表,探讨哈希表为什么能做到增删查改  复杂度,以及解决哈希冲突的两种办法。

哈希表的基本原理

哈希表可以理解为一个加强版的数组

数组可以通过索引在  的时间复杂度内查找到对应元素,索引是一个非负整数。

哈希表是类似的,可以通过 key 在  的时间复杂度内查找到这个 key 对应的 valuekey 的类型可以是数字、字符串等多种类型。

怎么做的?特别简单,哈希表的底层实现就是一个数组(我们不妨称之为 table)。它先把这个 key 通过一个哈希函数(我们不妨称之为 hash)转化成数组里面的索引,然后增删查改操作和数组基本相同:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 哈希表伪码逻辑
class MyHashMap:

def __init__(self):
self.table = [None] * 1000

# 增/改,复杂度 O(1)
def put(self, key, value):
index = self.hash(key)
self.table[index] = value

# 查,复杂度 O(1)
def get(self, key):
index = self.hash(key)
return self.table[index]

# 删,复杂度 O(1)
def remove(self, key):
index = self.hash(key)
self.table[index] = None

# 哈希函数,把 key 转化成 table 中的合法索引
# 时间复杂度必须是 O(1),才能保证上述方法的复杂度都是 O(1)
def hash(self, key):
# ...
pass

具体实现上有不少细节需要处理,比如哈希函数的设计、哈希冲突的处理等等。但你只要明白了上面的核心原理,就已经成功了一半了,剩下的就是写代码了,这有何难呢?

下面我们来具体介绍一下上述增删查改过程中几个关键的概念和可能出现的问题。

几个关键概念及原理

key 是唯一的,value 可以重复

哈希表中,不可能出现两个相同的 key,而 value 是可以重复的。

明白了上面讲的原理应该很好理解,你直接类比数组就行了:

数组里面每个索引都是唯一的,不可能说你这个数组有两个索引 0。至于数组里面存什么元素,随便你,没人 care

所以哈希表是一样的,key 的值不可能出现重复,而 value 的值可以随意。

哈希函数

哈希函数的作用是把任意长度的输入(key)转化成固定长度的输出(索引)

你也看到了,增删查改的方法中都会用到哈希函数来计算索引,如果你设计的这个哈希函数复杂度是 ,那么哈希表的增删查改性能就会退化成 所以说这个函数的性能很关键

这个函数还要保证的一点是,输入相同的 key,输出也必须要相同,这样才能保证哈希表的正确性。不能说现在你计算 hash("123") = 5,待会儿计算 hash("123") = 6,这样的话哈希表就废了。

那么哈希函数是如何把非整数类型的 key 转化成整数索引的?又是如何保证这个索引是合法的呢?

如何把 key 转化成整数

这个问题可以有很多种答案,不同的哈希函数设计会有不同的方法,我这里就结合 Java 语言说一个简单的办法。其他编程语言也是类似的,可以参考这个思路,查询相关的标准库文档。

任意 Java 对象都会有一个 int hashCode() 方法,在实现自定义的类时,如果不重写这个方法,那么它的默认返回值可以认为是该对象的内存地址。一个对象的内存地址显然是全局唯一的一个整数。

所以我们只要调用 key 的 hashCode() 方法就相当于把 key 转化成了一个整数,且这个整数是全局唯一的。

当然,这个方法也有一些问题,下面会讲解,但现在至少找到了一种把任意对象转化为整数的方法。

如何保证索引合法

hashCode 方法返回的是 int 类型,首先一个问题就是,这个 int 值可能是负数,而数组的索引是非负整数。

那么你肯定想这样写代码,把这个值转化成非负数:

1
2
int h = key.hashCode();
if (h < 0) h = -h;

但这样有问题,int 类型可以表示的最小值是 -2^31,而最大值是 2^31 - 1。所以如果 h = -2^31,那么 -h = 2^31 就会超出 int 类型的最大值,这叫做整型溢出,编译器会报错,甚至产生不可预知的结果。

为什么 int 的最小值是 -2^31,而最大值是 2^31 - 1?这涉及计算机补码编码的原理,简单说,int 就是 32 个二进制位,其中最高位(最左边那位)是符号位,符号位是 0 时表示正数,是 1 时表示负数。

现在的问题是,我想保证 h 非负,但又不能用负号直接取反。那么一个简单直接的办法是利用这种补码编码的原理,直接把最高位的符号位变成 0,就可以保证 h 是非负数了:

1
2
3
4
5
6
7
8
int h = key.hashCode();
// 位运算,把最高位的符号位去掉
// 另外,位运算的运行速度也会比一般的算术运算快
// 所以你看标准库的源码,能用位运算的地方它都会优先使用位运算
h = h & 0x7fffffff;
// 这个 0x7fffffff 的二进制表示是 0111 1111 ... 1111
// 即除了最高位(符号位)是 0,其他位都是 1
// 把 0x7fffffff 和其他 int 进行 & 运算之后,最高位(符号位)就会被清零,即保证了 h 是非负数

关于补码编码的原理我这里就不详细展开了,有兴趣的话你可以自己搜索学习一下。

好的,上面解决了 hashCode 可能是负数的问题,但还有一个问题,就是这个 hashCode 一般都很大,我们需要把它映射成 table 数组的合法索引。

这个问题对你来说应该不难吧,我们之前在 [环形数组原理及实现]里面用 % 求模运算来保证索引永远落在数组的合法范围内。所以这里也可以用 % 运算来保证索引的合法性,完整的 hash 函数实现如下:

1
2
3
4
5
6
7
int hash(K key) {
int h = key.hashCode();
// 保证非负数
h = h & 0x7fffffff;
// 映射到 table 数组的合法索引
return h % table.length;
}

当然,直接使用 % 也有问题,因为 % 这个求余数的运算比较消耗性能,一般在追求运行效率的标准库源码中会尽量避免使用 % 运算,而是使用位运算提升性能。

不过本章主要目的是带你理解实现一个简单的哈希表,就忽略这些细节优化了。有兴趣的话你可以去看一下 Java HashMap 的源码,看看它是如何实现这个 hash 函数的。

哈希冲突

上面给出了 hash 函数的实现,那么你肯定也会想到,如果两个不同的 key 通过哈希函数得到了相同的索引,怎么办呢?这种情况就叫做「哈希冲突」。

哈希冲突是否可以避免?

哈希冲突不可能避免,只能在算法层面妥善处理出现哈希冲突的情况

哈希冲突是一定会出现的,因为这个 hash 函数相当于是把一个无穷大的空间映射到了一个有限的索引空间,所以必然会有不同的 key 映射到同一个索引上。

就好比三维物体映射到二维影子一样,这种有损压缩必然会出现信息丢失,有损信息本就无法和原信息一一对应。

出现哈希冲突的情况怎么解决?两种常见的解决方法,一种是拉链法,另一种是线性探查法(也经常被叫做开放寻址法)。

名字听起来高大上,说白了就是纵向延伸和横向延伸两种思路嘛:

拉链法相当于是哈希表的底层数组并不直接存储 value 类型,而是存储一个链表,当有多个不同的 key 映射到了同一个索引上,这些 key -> value 对儿就存储在这个链表中,这样就能解决哈希冲突的问题。

而线性探查法的思路是,一个 key 发现算出来的 index 值已经被别的 key 占了,那么它就去 index + 1 的位置看看,如果还是被占了,就继续往后找,直到找到一个空的位置为止。

比方说上图,key 的插入顺序是 k2, k4, k5, k3, k1,那么哈希表底层就会变成这样:

这里先讲一下原理,后面的章节我会手把手带大家分别实现这两种方法来解决哈希冲突。

扩容和负载因子

相信大家都听说过「负载因子」这个专业术语,现在你明白了哈希冲突的问题,就能理解负载因子的意义了。

拉链法和线性探查法虽然能解决哈希冲突的问题,但是它们会导致性能下降。

比如拉链法,你算出来 index = hash(key) 这个索引了,结果过去查出来的是个链表,你还得遍历一下这个链表,才能在里面找到你要的 value。这个过程的时间复杂度是 K 是这个链表的长度。

线性探查法也是类似的,你算出来 index = hash(key) 这个索引了,你去这个索引位置查看,发现存储的不是要找的 key,但由于线性探查法解决哈希冲突的方式,你并不能确定这个 key 真的不存在,你必须顺着这个索引往后找,直到找到一个空的位置或者找到这个 key 为止,这个过程的时间复杂度也是 K 为连续探查的次数。

所以说,如果频繁出现哈希冲突,那么 K 的值就会增大,这个哈希表的性能就会显著下降。这是我们需要避免的。

那么为什么会频繁出现哈希冲突呢?两个原因呗:

1、哈希函数设计的不好,导致 key 的哈希值分布不均匀,很多 key 映射到了同一个索引上。

2、哈希表里面已经装了太多的 key-value 对了,这种情况下即使哈希函数再完美,也没办法避免哈希冲突。

对于第一个问题没什么好说的,开发编程语言标准库的大佬们已经帮你设计好了哈希函数,你只要调用就行了。

对于第二个问题是我们可以控制的,即避免哈希表装太满,这就引出了「负载因子」的概念。

负载因子

负载因子是一个哈希表装满的程度的度量。一般来说,负载因子越大,说明哈希表里面存储的 key-value 对越多,哈希冲突的概率就越大,哈希表的操作性能就越差。

**负载因子的计算公式也很简单,就是 size / table.length**。其中 size 是哈希表里面的 key-value 对的数量,table.length 是哈希表底层数组的容量。

你不难发现,用拉链法实现的哈希表,负载因子可以无限大,因为链表可以无限延伸;用线性探查法实现的哈希表,负载因子不会超过 1。

像 Java 的 HashMap,允许我们创建哈希表时自定义负载因子,不设置的话默认是 0.75,这个值是经验值,一般保持默认就行了。

当哈希表内元素达到负载因子时,哈希表会扩容。和之前讲解 [动态数组的实现] 是类似的,就是把哈希表底层 table 数组的容量扩大,把数据搬移到新的大数组中。size 不变,table.length 增加,负载因子就减小了。

为什么不能依赖哈希表的遍历顺序

你大概也听过一个编程常识,即哈希表中键的遍历顺序是无序的,不能依赖哈希表的遍历顺序来编写程序。这是为什么呢?

哈希表的遍历本质上就是遍历那个底层 table 数组:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 遍历所有 key 的伪码逻辑

// 哈希表底层的 table 数组
KVNode[] table = new KVNode[1000];

// 获取哈希表中的所有键
// 我们不能依赖这个 keys 列表的顺序
List<KeyType> keys = new ArrayList<>();

for (int i = 0; i < table.length; i++) {
KVNode node = table[i];
if (node != null) {
keys.add(node.key);
}
}

你如果理解了前面讲的内容,应该已经能够理解这个问题了。

首先,由于 hash 函数要把你的 key 进行映射,所以 key 在底层 table 数组中的分布是随机的,不像数组/链表结构那样有个明确的元素顺序。

其次,刚才讲了哈希表达到负载因子时会怎样?会扩容对吧,也就是 table.length 会变化,且会搬移元素。

那么这个搬移数据的过程,是不是要用 hash 函数重新计算 key 的哈希值,然后放到新的 table 数组中?

而这个 hash 函数,它计算出的索引值依赖 table.length。也就是说,哈希表自动扩缩容后,同一个 key 存储在 table 的索引可能发生变化,所以遍历结果的顺序就和之前不一样了

你观察到的现象就是,这次遍历的第一个键是 key1,但是增删几个元素再遍历,可能发现 key1 跑到最后去了。

所以说,这些东西没必要背的,原理搞明白了,你稍微推理下自己都能想通。

为什么不建议在 for 循环中增/删哈希表的 key

注意我这里说的是不建议,并不是一定不可以。因为不同的编程语言标准库对哈希表的实现不同,有些语言针对这种情况做了优化,所以到底行不行,要查阅文档。

我们这里仅从哈希表的原理上分析,在 for 循环中增/删哈希表的 key,是很容易出现问题的,原因和上面相同,还是扩缩容导致的哈希值变化。

遍历哈希表的 key,本质就是遍历哈希表底层的 table 数组,如果一边遍历一边增删元素,如果遍历到一半,插入/删除操作触发了扩缩容,整个 table 数组都变了,那么请问,接下来应该是什么行为?还有,在遍历过程中新插入/删除的元素,是否应该被遍历到?

扩缩容导致 key 顺序变化是哈希表的特有行为,但即便排除这个因素,任何其他数据结构,也都不建议在遍历的过程中同时进行增删,否则很容易导致非预期的行为。

如果你非要这样做,请确保查阅了相关文档,明确这个操作的行为是什么,做到心里有数。

必须是不可变的

只有那些不可变类型,才能作为哈希表的 key,这一点很重要

所谓不可变类型,就是说这个对象一旦创建,它的值就不能再改变了。比如 Java 中的 String, Integer 等类型,一旦创建了这些对象,你就只能读取它的值,而不能再修改它的值了。

作为对比,Java 中的 ArrayListLinkedList 这些对象,它们创建出来之后,可以往里面随意增删元素,所以它们是可变类型。

因此,你可以把 String 对象作为哈希表的 key,但不能把 ArrayList 对象作为哈希表的 key

1
2
3
4
5
6
7
// 可以把不可变类型作为 key
Map<String, AnyOtherType> map1 = new HashMap<>();
Map<Integer, AnyOtherType> map2 = new HashMap<>();

// 不应该把可变类型作为 key
// 注意,这样写并不会产生语法错误,但是代码非常容易出 bug
Map<ArrayList<Integer>, AnyOtherType> map3 = new HashMap<>();

为啥不建议把可变类型作为 key 呢?就比如这个 ArrayList 吧,它的 hashCode 方法的实现逻辑如下:

1
2
3
4
5
public int hashCode() {
for (int i = 0; i < elementData.length; i++) {
h = 31 * h + elementData[i];
}
}

第一个就是效率问题,每次计算 hashCode 都要遍历整个数组,复杂度是 ,这样就会导致哈希表的增删查改操作的复杂度退化成 

更严重的问题是,ArrayList 的 hashCode 是根据它里面的元素计算出来的,如果你往这个 ArrayList 里面增删元素,或者其中某个元素的 hashCode 值发生改变,那么这个 ArrayList 的 hashCode 返回值也会发生改变。

比方说,你现在用一个 ArrayList 类型的 arr 变量作为哈希表的 key 在哈希表中保存了对应的 value。但如果 arr 中的某个元素在程序的其他位置被修改了,那么 arr 的 hashCode 就会变化。此时你再用这个 arr 变量去哈希表中查询,发现找不到任何值了。

也就是说,你存入哈希表的 key-value 意外丢失了,这是非常非常严重的 bug,还会带来潜在的内存泄漏问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
public class Test {
public static void main(String[] args) {
// 错误示例
// 把可变类型作为 HashMap 的 key
Map<ArrayList<Integer>, Integer> map = new HashMap<>();

ArrayList<Integer> arr = new ArrayList<>();
arr.add(1);
arr.add(2);

map.put(arr, 999);
System.out.println(map.containsKey(arr)); // true
System.out.println(map.get(arr)); // 999

arr.add(3);
// 出现严重 bug,键值对丢失
System.out.println(map.containsKey(arr)); // false
System.out.println(map.get(arr)); // null

// 此时 map 底层的 table 中,arr 的键值对数据依然存在
// 但是由于 arr 的 hashCode 改变了,此键值对无法被查找到
// 这也会导致内存泄漏,因为这个 arr 变量被 map 引用着,无法被垃圾回收
}
}

上面就是一个简单的错误示例。你也许会说,把元素 3 删掉,arr -> 999 这个键值对不就又出现了?或者,直接遍历哈希表底层的 table 数组,应该也可以看到这个键值对。

拜托🙏🏻,你这是在写代码还是在写盗墓笔记呢?一会儿出现一会儿消失,你这个哈希表是幽灵附体了吗?

开个玩笑。实际上可变类型本身就是一种不确定性,在代码构成的屎山里,你怎么知道这个 arr 传递到哪里被修改了呢?

所以正确的做法是,使用不可变类型作为哈希表的 key,比方说用 String 类型作为 key。因为 Java 中的 String 对象一旦创建出来,它的值就不允许被改变,你就不会遇到上面的问题。

String 类型的 hashCode 方法也需要遍历所有字符,但是由于它的不可变性,这个值只要算出来一次,就可以缓存下来,不用每次都重新计算,所以 [平均时间复杂度]依然是 

我这里是用 Java 举的例子,其他语言也是类似的,你需要查询相关文档,了解标准库提供的哈希表是如何计算对象哈希值的,避免产生类似的问题。

总结

上面的说明应该已经吧哈希表的底层原理全部串起来了,最后模拟几个面试问题来总结一下本文的内容:

**1、为什么我们常说,哈希表的增删查改效率都是 **?

因为哈希表底层就是操作一个数组,其主要的时间复杂度来自于哈希函数计算索引和哈希冲突。只要保证哈希函数的复杂度在 ,且合理解决哈希冲突的问题,那么增删查改的复杂度就都是 

2、哈希表的遍历顺序为什么会变化

因为哈希表在达到负载因子时会扩容,这个扩容过程会导致哈希表底层的数组容量变化,哈希函数计算出来的索引也会变化,所以哈希表的遍历顺序也会变化。

3、哈希表的增删查改效率一定是  吗

不一定,正如前面分析的,只有哈希函数的复杂度是 ,且合理解决哈希冲突的问题,才能保证增删查改的复杂度是 

哈希冲突好解决,都是有标准答案的。关键是哈希函数的计算复杂度。如果使用了错误的 key 类型,比如前面用 ArrayList 作为 key 的例子,那么哈希表的复杂度就会退化成 

**4、为啥一定要用不可变类型作为哈希表的 key**?

因为哈希表的主要操作都依赖于哈希函数计算出来的索引,如果 key 的哈希值会变化,会导致键值对意外丢失,产生严重的 bug。

要对自己使用的编程语言标准库中的源码有一定的了解,才能保证写出高效的代码。

二叉树基础及常见类型


我认为二叉树是最重要的基本数据结构,没有之一

如果你是初学者,现在这个阶段我很难给你彻底解释清楚得出这个结论的原因,你需要认真学习本站后面的内容才能逐渐理解。我暂且总结两个点:

1、二叉树本身是比较简单的基础数据结构,但是很多复杂的数据结构都是基于二叉树的,比如 [红黑树](二叉搜索树)、[多叉树]、[二叉堆]、[图]、[字典树]、[并查集]、[线段树] 等等。你把二叉树玩明白了,这些数据结构都不是问题;如果你不把二叉树搞明白,这些高级数据结构你也很难驾驭。

2、二叉树不单纯是一种数据结构,更是一种常用的算法思维。一切暴力穷举算法,比如 [回溯算法]、[BFS 算法]、[动态规划] 本质上也是把具体问题抽象成树结构,你只要抽象出来了,这些问题最终都回归二叉树的问题。同样看一段算法代码,在别人眼里是一串文本,每个字都认识,但连起来就不认识了;而在你眼里的代码就是一棵树,想咋改就咋改,咋改都能改对,实在是太简单了。

后面的数据结构章节包含大量关于二叉树的讲解和习题,你按照本站的目录顺序学习,我会带你把二叉树彻底搞懂,到时候你就明白我为什么这么重视二叉树了。

几种常见的二叉树

二叉树的主要难点在于做算法题,它本身其实没啥难的,就是这样一种树形结构嘛:

300

上面就是一棵普通的二叉树,几个术语你要了解一下:

1、每个节点下方直接相连的节点称为子节点,上方直接相连的节点称为父节点。比方说节点 3 的父节点是 1,左子节点是 5,右子节点是 6;节点 5 的父节点是 3,左子节点是 7,没有右子节点。

2、以子节点为根的树称为子树。比方说节点 3 的左子树是节点 5 和 7 组成的树,右子树是节点 6 和 8 组成的树。

3、我们称最上方那个没有父节点的节点 1 为根节点,称最下层没有子节点的节点 478 为叶子节点

4、我们称从根节点到最下方叶子节点经过的节点个数为二叉树的最大深度/高度,上面这棵树的最大深度是 4,即从根节点 1 到叶子节点 7 或 8 的路径上的节点个数。

没啥别的可说的了,就是这么简单。

有一些稍微特殊一些的二叉树,有他们自己的名字,你要了解一下,后面做题时见到这些专业术语,你就知道题目在说啥了。

满二叉树

直接看图比较直观,满二叉树就是每一层节点都是满的,整棵树像一个正三角形:

满二叉树有个优势,就是它的节点个数很好算。假设深度为 h,那么总节点数就是 2^h - 1,等比数列求和嘛,我们应该都学过的。

完全二叉树

完全二叉树是指,二叉树的每一层的节点都紧凑靠左排列,且除了最后一层,其他每层都必须是满的:

500

不难发现,满二叉树其实是一种特殊的完全二叉树。

完全二叉树的特点:由于它的节点紧凑排列,如果从左到右从上到下对它的每个节点编号,那么父子节点的索引存在明显的规律

这个特点在讲到 [二叉堆核心原理] 和 [线段树核心原理]时会用到:完全二叉树可以用数组来存储,不需要真的构建链式节点。

完全二叉树还有个比较难发觉的性质:完全二叉树的左右子树也是完全二叉树

或者更准确地说应该是:完全二叉树的左右子树中,至少有一棵是满二叉树

这个性质在做算法题的时候会用到,比如 [巧算完全二叉树的节点数],这里就先提一下。

中英文的定义有区别

关于完全二叉树和满二叉树的定义,中文语境和英文语境似乎有点区别。

我们说的完全二叉树对应英文 Complete Binary Tree,这个没问题,说的是同一种树。

我们说的满二叉树,按理说应该翻译成 Full Binary Tree 对吧,但其实不是,满二叉树的定义对应英文的 Perfect Binary Tree。

而英文中的 Full Binary Tree 是指一棵二叉树的所有节点要么没有孩子节点,要么有两个孩子节点。

以上定义出自 wikipedia,这里就是顺便一提。其实名词叫什么都无所谓,你知道有这个区别,在看英文资料时留意一下就行了。

二叉搜索树

二叉搜索树(Binary Search Tree,简称 BST)是一种很常见的二叉树,它的定义是:

对于树中的每个节点,其左子树的每个节点的值都要小于这个节点的值,右子树的每个节点的值都要大于这个节点的值。你可以简单记为「左小右大」。

我把「子树的每个节点」加粗了,这是初学者常犯的错误,不要只看子节点,而要看整棵子树的所有节点。

比方说,下面这棵树就是一棵 BST:

300

节点 7 的左子树所有节点的值都小于 7,右子树所有节点的值都大于 7;节点 4 的左子树所有节点的值都小于 4,右子树所有节点的值都大于 4,以此类推。

相反的,下面这棵树就不是 BST:

300

如果你只注意每个节点的左右子节点,似乎看不出问题。你应该看整棵子树,注意看节点 7 的左子树中有个节点 8,比 7 大,这就不符合 BST 的定义了。

BST 是非常常用的数据结构。因为左小右大的特性,可以让我们在 BST 中快速找到某个节点,或者找到某个范围内的所有节点,这是 BST 的优势所在

比方说,对于一棵普通的二叉树,其中的节点大小没有任何规律可言,那么你要找到某个值为 x 的节点,只能从根节点开始遍历整棵树。

而对于 BST,你可以先对比根节点和 x 的大小关系,如果 x 比根节点大,那么根节点的整棵左子树就可以直接排除了,直接从右子树开始找,这样就可以快速定位到值为 x 的那个节点。

高度平衡二叉树

高度平衡二叉树(Height-Balanced Binary Tree)是一种特殊的二叉树,它的「每个节点」的左右子树的高度差不超过 1

要注意是每个节点,而不仅仅是根节点。

比如下面这棵二叉树,根节点 1 的左子树高度是 2,右子树高度是 3;节点 2 的左子树高度是 1,右子树高度是 0;节点 3 的左子树高度是 2,右子树高度是 1,以此类推,每个节点的左右子树高度差都不超过 1,所以这是一棵高度平衡的二叉树:

300

下面这棵树就不是高度平衡的二叉树,因为节点 2 的左子树高度是 2,右子树高度是 0,高度差超过 1,不符合条件:

300

**假设高度平衡二叉树中共有  个节点,那么高度平衡二叉树的高度是 **。这是非常重要的性质,本站后面的章节会讲解几种基于二叉树的数据结构,如果能保证树的高度为 ,那么这些数据结构的增删查改效率就会很高。

反之,如果树很不平衡,比如这种极端情况:

300
那么这棵树其实就等同于单链表,在树中进行增删查改的效率就会大幅降低。

自平衡二叉树

上面介绍了高度平衡二叉树,说到它的高度为 ,增删查改的效率高。

如果我们可以在增删二叉树节点时对树的结构进行一些调整,那么就可以让树的高度始终是平衡的,这就是自平衡二叉树(Self-Balanced Binary Tree)。

自平衡的二叉树有很多种实现方式,最经典的就是 [红黑树],一种自平衡的二叉搜索树。

保持树的平衡性,最关键的就是「旋转」操作,下面这个可视化面板展示了红黑树的旋转操作,你可以点击左右旋和左旋的代码,查看旋转的效果:

二叉树的实现方式

最常见的二叉树就是类似链表那样的链式存储结构,每个二叉树节点有指向左右子节点的指针,这种方式比较简单直观。

力扣/LeetCode 上给你输入的二叉树一般都是用这种方式构建的,二叉树节点类 TreeNode 一般长这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class TreeNode:
def __init__(self, x: int):
self.val = x
self.left = None
self.right = None

# 你可以这样构建一棵二叉树:
root = TreeNode(1)
root.left = TreeNode(2)
root.right = TreeNode(3)
root.left.left = TreeNode(4)
root.right.left = TreeNode(5)
root.right.right = TreeNode(6)

# 构建出来的二叉树是这样的:
# 1
# / \
# 2 3
# / / \
# 4 5 6

既然说上面是比较常见的实现方式,那言下之意就是还有其他实现方式,对吧?

是的,在 [二叉堆原理及实现]和 [并查集算法详解]中,我们会根据具体的需求场景选择用数组来存储二叉树。

另外,在一般的算法题中,我们可能会把实际问题抽象成二叉树结构,但我们并不需要真的用 TreeNode 创建一棵二叉树出来,而是直接用类似 [哈希表] 的结构来表示二叉树/多叉树。

比方说这棵二叉树:

300

我可以用一个哈希表,其中的键是父节点 id,值是子节点 id 的列表(每个节点的 id 是唯一的),那么一个键值对就是一个多叉树节点了,这棵多叉树就可以表示成这样:

1
2
3
4
5
6
7
8
9
# 1 -> [2, 3]
# 2 -> [4]
# 3 -> [5, 6]

tree = {
1: [2, 3],
2: [4],
3: [5, 6]
}

这样就可以模拟和操作二叉树/多叉树结构,后文讲到图论的时候你就会知道,它有一个新的名字叫做 [邻接表]。

二叉树的递归/层序遍历

读完本文,你不仅学会了算法套路,还可以顺便解决如下题目:

LeetCode 力扣 难度
144. Binary Tree Preorder Traversal 144. 二叉树的前序遍历
94. Binary Tree Inorder Traversal 94. 二叉树的中序遍历
145. Binary Tree Postorder Traversal 145. 二叉树的后序遍历
102. Binary Tree Level Order Traversal 102. 二叉树的层序遍历

一句话总结

二叉树只有递归遍历层序遍历这两种,再无其他。递归遍历可以衍生出 DFS 算法,层序遍历可以衍生出 BFS 算法。

递归遍历二叉树节点的顺序是固定的,但是有三个关键位置,在不同位置插入代码,会产生不同的效果。

层序遍历二叉树节点的顺序也是固定的,但是有三种不同的写法,对应不同的场景。

二叉树的遍历算法主要分为递归遍历和层序遍历两种,都有代码模板。递归代码模板可以延伸出后面要讲的 DFS 算法、回溯算法,层序代码模板可以延伸出后面要讲的 BFS 算法,所以我经常强调二叉树结构的重要性。

大家熟知的前序遍历、中序遍历、后序遍历,都属于二叉树的递归遍历,只不过是把自定义代码插入到了代码模板的不同位置而已,下面我会结合可视化面板来讲解。

递归遍历(DFS)

递归遍历二叉树的代码模板如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
# 基本的二叉树节点
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right

# 二叉树的递归遍历框架
def traverse(root: TreeNode):
if root is None:
return
traverse(root.left)
traverse(root.right)

请问,这段短小精干的代码为什么能遍历二叉树?又是以什么顺序遍历二叉树的?

traverse 函数的遍历顺序就是一直往左子节点走,直到遇到空指针不能再走了,才尝试往右子节点走一步;然后再一直尝试往左子节点走,如此循环;如果左右子树都走完了,则返回上一层父节点。

看代码也能看出来,先递归调用的 root.left,然后才递归调用的 root.right,每次进入 traverse 函数,都会先往左子节点递归遍历,直到遇到空指针走不动了,才轮到往右子节点走一次。

那么我们简单拓展一下,如果修改前面的 traverse 函数,先递归遍历 root.right,再递归遍历 root.left,会是什么效果?

1
2
3
4
5
6
7
# 修改标准的二叉树遍历框架
def traverseFlip(root: TreeNode) -> None:
if root is None:
return
# 反过来,先递归遍历右子树,再递归遍历左子树
traverseFlip(root.right)
traverseFlip(root.left)

我举这个 traverseFlip 的例子,是想告诉你:

递归遍历节点的顺序 仅取决于左右子节点的递归调用顺序,与其他代码无关**。

我们说二叉树遍历时,一般不会像 traverseFlip 这样遍历二叉树,默认还是按照先左后右的顺序,所以当我们说二叉树遍历的代码模板时,指的是先左后右的遍历顺序:

1
2
3
4
5
6
7
8
9
10
11
12
13
# 基本的二叉树节点
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right

# 二叉树的递归遍历框架
def traverse(root: TreeNode):
if root is None:
return
traverse(root.left)
traverse(root.right)

只要这个先左后右的调用顺序不变,那么 traverse 函数访问节点的顺序就是固定的,你插入一万行代码进去,也不会变。

有一些数据结构基础的读者可能有点晕了:

不对呀,只要上过大学的数据结构课程,就知道二叉树有前/中/后序三种遍历,会得到三种不同顺序的结果。为啥你这里说递归遍历节点的顺序是固定的呢?

这个问题很好,下面来解答。

理解前/中/后序遍历

递归遍历的顺序,即 traverse 函数访问节点的顺序确实是固定的。正如可视化面板所示,root 指针在树上移动的顺序是固定的:

但是,你在 traverse 函数中不同位置写代码,效果是可以不一样的。前中后序遍历的结果不同,原因是因为你把代码写在了不同位置,所以产生了不同的效果

比方说,刚进入一个节点的时候,你还对它的子节点一无所知,而当你要离开一个节点的时候,它的所有子节点你都遍历过了。那么在这两种情况下写的代码,肯定是可以有不同的效果的。

所谓的前中后序遍历,其实就是在二叉树遍历框架的不同位置写代码:

1
2
3
4
5
6
7
8
9
# 二叉树的遍历框架
def traverse(root):
if root is None:
return
# 前序位置
traverse(root.left)
# 中序位置
traverse(root.right)
# 后序位置

前序位置的代码会在进入节点时立即执行;中序位置的代码会在左子树遍历完成后,遍历右子树之前执行;后序位置的代码会在左右子树遍历完成后执行
300

划重点

特别强调,三种位置的关键区别在于执行时机不同。

实际的算法题中不会简单的让你计算前中后序的遍历结果,而是需要你把正确的代码写到正确的位置,所以你必须准确理解三个位置的代码产生的不同效果,才能写出准确的代码。

最后一个知识点,[二叉搜索树(BST)]的中序遍历结果是有序的,这是 BST 的一个重要性质。

层序遍历(BFS)

上面讲的递归遍历是依赖函数堆栈递归遍历二叉树的,遍历顺序是从最左侧开始,一列一列地走到最右侧。

二叉树的层序遍历,顾名思义,就是一层一层地遍历二叉树:

层序遍历需要借助队列来实现,而且根据不同的需求,可以有三种不同的写法,下面一一列举。

写法一

这是最简单的写法,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from collections import deque

def levelOrderTraverse(root):
if root is None:
return
q = deque()
q.append(root)
while q:
cur = q.popleft()
# 访问 cur 节点
print(cur.val)

# 把 cur 的左右子节点加入队列
if cur.left is not None:
q.append(cur.left)
if cur.right is not None:
q.append(cur.right)

这种写法的优缺点

这种写法最大的优势就是简单。每次把队头元素拿出来,然后把它的左右子节点加入队列,就完事了。

但是这种写法的缺点是,无法知道当前节点在第几层。知道节点的层数是个常见的需求,比方说让你收集每一层的节点,或者计算二叉树的最小深度等等。

所以这种写法虽然简单,但用的不多,下面介绍的写法会更常见一些。

写法二

对上面的解法稍加改造,就得出了下面这种写法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from collections import deque

def levelOrderTraverse(root):
if root is None:
return
q = deque()
q.append(root)
# 记录当前遍历到的层数(根节点视为第 1 层)
depth = 1

while q:
sz = len(q)
for i in range(sz):
cur = q.popleft()
# 访问 cur 节点,同时知道它所在的层数
print(f"depth = {depth}, val = {cur.val}")

# 把 cur 的左右子节点加入队列
if cur.left is not None:
q.append(cur.left)
if cur.right is not None:
q.append(cur.right)
depth += 1

注意代码中的内层 for 循环:

1
2
3
4
int sz = q.size();
for (int i = 0; i < sz; i++) {
...
}

这个变量 i 记录的是节点 cur 是当前层的第几个,大部分算法题中都不会用到这个变量,所以你完全可以改用下面的写法:

1
2
3
4
int sz = q.size();
while (sz-- > 0) {
...
}

这个属于细节问题,按照自己的喜好来就行。

但是注意队列的长度 sz 一定要在循环开始前保存下来,因为在循环过程中队列的长度是会变化的,不能直接用 q.size() 作为循环条件。

这种写法就可以记录下来每个节点所在的层数,可以解决诸如二叉树最小深度这样的问题,是我们最常用的层序遍历写法。

写法三

既然写法二是最常见的,为啥还有个写法三呢?因为要给后面的进阶内容做铺垫。

现在我们只是在探讨二叉树的层序遍历,但是二叉树的层序遍历可以衍生出 [多叉树的层序遍历],[图的 BFS 遍历],以及经典的 [BFS 暴力穷举算法框架],所以这里要拓展延伸一下。

回顾写法二,我们每向下遍历一层,就给 depth 加 1,可以理解为每条树枝的权重是 1,二叉树中每个节点的深度,其实就是从根节点到这个节点的路径权重和,且同一层的所有节点,路径权重和都是相同的

那么假设,如果每条树枝的权重和可以是任意值,现在让你层序遍历整棵树,打印每个节点的路径权重和,你会怎么做?

这样的话,同一层节点的路径权重和就不一定相同了,写法二这样只维护一个 depth 变量就无法满足需求了。

写法三就是为了解决这个问题,在写法一的基础上添加一个 State 类,让每个节点自己负责维护自己的路径权重和,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class State:
def __init__(self, node, depth):
self.node = node
self.depth = depth

def levelOrderTraverse(root):
if root is None:
return
q = deque()
# 根节点的路径权重和是 1
q.append(State(root, 1))

while q:
cur = q.popleft()
# 访问 cur 节点,同时知道它的路径权重和
print(f"depth = {cur.depth}, val = {cur.node.val}")

# 把 cur 的左右子节点加入队列
if cur.node.left is not None:
q.append(State(cur.node.left, cur.depth + 1))
if cur.node.right is not None:
q.append(State(cur.node.right, cur.depth + 1))

你可以打开这个可视化面板,点击其中的  这一行代码,就可以看到还是一层一层,从左到右的遍历二叉树节点,还会输出节点所在的层数:

这样每个节点都有了自己的 depth 变量,是最灵活的,可以满足所有 BFS 算法的需求。但是由于要额外定义一个 State 类比较麻烦,所以非必要的话,用写法二就够了。

其他遍历?

二叉树的遍历方式只有上面两种,也许有其他的写法,但都是表现形式上的差异,本质上不可能跳出上面两种遍历方式。

比方说,你可能看到用栈来迭代遍历二叉树的代码。但这本质还是是递归遍历,只不过他手动维护栈模拟递归调用罢了。

再比如,你还可能看到递归地一层层遍历二叉树的代码。但这本质还是层序遍历,只不过他把层序遍历代码中的 for 循环用递归的形式展现了。

总之,不要被表象迷惑,二叉树的遍历方式就上面两种,结合后面的教程和习题,你把这两种遍历方式玩明白,一切暴力穷举算法都小菜一碟。

图论中的基本术语

一幅图结构由若干 节点 (Vertex) 和 边 (Edge) 构成,其中:

  • 每个节点有一个唯一 ID。
  • 边可以是有向的(有向图,Directional Graph),也可以是无向的(无向图,Undirected Graph)。
  • 边上可以有权重(加权图,Weighted Graph),也可以没有权重(无权图,Unweighted Graph)。

边的权重和方向

下图是一个有向无权图:
400
图中有一条从节点 1 指向节点 3 的有向边,这说明可以从节点 1 直接到达节点 3;但由于没有从节点 3 指向节点 1 的有向边,所以节点 3 不能直接到达节点 1

下图是一个无向无权图:
400
图中节点 1 和节点 3 之间有一条无向边,这说明可以从节点 1 到达节点 3,也可以从节点 3 到达节点 1

你可以把无向图理解成「双向图」,实际上我们在用代码实现图结构的时候就是这么做的。

下图是一个有向加权图:
400
下图是一个无向加权图:
400
加权图在实际场景中非常常见,比如在地图 App 中,边的权重可以是两个地点之间的距离;在物流网络中,边的权重可以是两个地点之间的运输成本等等。

围绕着加权图,又会有很多经典的图论算法,比如计算最短路径,最小生成树等等,这些都会在后面的章节逐步讲解。

对于图中的每个节点,有一个度 (degree) 的概念。

在无向图中,度就是每个节点相连的边的条数。

比方下面这幅无向图中,节点 1 的度为 2,节点 4 的度为 4。
400
由于有向图的边有方向,所以有向图中每个节点的度被细分为入度 (indegree)出度(outdegree)

比如下图中节点 3 的入度为 2(有两条边指向它),出度为 1(它有 1 条边指向别的节点):
400

边和节点的数量关系

我们一般讨论的图结构都是简单图(Simple Graph),即没有自环边(Self loop)和多4重边(Multiple edges)的图

在简单图中,假设包含  条边,个节点,我们想一下边的条数  的取值范围是多少?

 的最小值可以是 0,相当于图结构中只有若干互不相连的节点,这是可以的。

考虑  的最大值,图中的每个节点最多可以有  条边与其他  个节点相连,所以最多能有的边数为 

如果几乎每两个节点之间都有一条边,即  接近 ,我们说这幅图是 稠密图(Dense Graph);如果只有很少的边,即  远小于 ,我们说这幅图是 稀疏图(Sparse Graph)

子图

在图论中,子图是一个重要的基本概念。

**子图 (Subgraph)**:如果图  的所有节点和边都包含在图 GG 中,则称  是  的一个子图。简单来说,子图是从原图中删除一些节点和边后得到的图。

假设上面这幅图为 ,我们举例说明子图的概念。子图有两种特殊类型:

**生成子图 (Spanning Subgraph)**:包含原图中所有节点,但只包含部分边的子图。

下图是图  的一个生成子图,它包含了所有节点,但移除了节点 3 和节点 4 之间的边。

**导出子图 (Induced Subgraph)**:选择原图的一部分节点,以及这些节点之间在原图中的所有边所构成的子图。

下图是图  的一个导出子图,它包含节点 1,2,3,4 及它们之间在原图中的所有边。

子图的概念在很多图算法中都有应用,比如在寻找最小生成树时,我们实际上是在寻找一个包含所有节点的带权重最小的生成子图。

连通性

在图论中,连通性是一个非常重要的概念,它描述了图中节点之间是否存在路径。

无向图的连通性

连通图 (Connected Graph): 如果无向图中任意两个节点之间都存在一条路径,我们称这个图是连通的。

上图是一个连通图,从任意一个节点出发,都能到达其他所有节点。

**连通分量 (Connected Component)**:对于非连通的无向图,其中的多个连通子图被称为连通分量,一个图可以有多个连通分量。

比如下面这幅图有两个连通分量:节点 1~5 形成一个连通分量,节点 6,7 形成另一个连通分量。

有向图的连通性

有向图的连通性概念稍微复杂一些,因为考虑到边的方向,所以有向图的连通性分为强连通和弱连通。这块知识点有个印象就行了,实际的面试题中主要都是考察无向图的连通性。

**强连通图 (Strongly Connected Graph)**:如果有向图中任意两个节点之间都存在一条有向路径,我们称这个图是强连通的。

比如下面这幅图是一个强连通图,从任意节点出发都能到达其他所有节点。

**弱连通图 (Weakly Connected Graph)**:如果将有向图中的所有有向边都变成无向边后,该图变成连通的,那么原来的有向图就是弱连通的。

比如下面这幅图不是强连通的(无法从节点 4 到达节点 1),但它是弱连通的,因为忽略边的方向后,所有节点之间都是连通的。

**强连通分量 (Strongly Connected Component, SCC)**:有向图中的若干个最大的强连通子图称为强连通分量。

比如下面这幅图有两个强连通分量:节点 1~3 形成一个强连通分量,节点 4~6 形成另一个强连通分量。

**弱连通分量 (Weakly Connected Component, WCC)**:将有向图的所有有向边变为无向边后,形成的连通分量称为原有向图的弱连通分量。

图论中还有很多其他的复杂术语,不过对于数据结构和算法的学习,理解上面这些名词就绰绰够用了。后面我们讲到具体的图论算法时,会结合实际场景运用这些概念。

最小生成树算法概览

最小生成树是图论中的经典问题,在现实生活中有广泛的应用,比如设计最低成本的通信网络、电路布线、管道铺设等。

考虑到最小生成树的算法实现需要一些其他算法作为铺垫,且本文处在基础章节,所以不会详细讲解算法代码。

本文主要介绍最小生成树的定义及应用场景,并阐述两种经典的最小生成树算法的核心原理。具体的代码实现安排在数据结构设计章节。

什么是生成树

首先理解什么是生成树。给定一个无向连通图 ,其生成树是  的一个子图,它包含  中的所有顶点,并且是一棵树(即无环连通图)。

换句话说,生成树具有以下特性:

  • 包含原图中的所有顶点。
  • 边的数量为顶点数减一(V-1条边)。
  • 连通且无环。

一个图可以有多个不同的生成树,例如这幅加权图:

可以有以下生成树,其中属于生成树的边被标记为了红色:

下面是一个不同的生成树:

什么是最小生成树

如果图是加权图,那么最小生成树就是边权重总和最小的生成树。

比如上面展示的例子,第二种生成树是该图的最小生成树,总权重为:2 + 3 + 5 = 10,没有其他的生成树能够得到更小的权重和了。

最小生成树在现实生活中有很多应用场景,边的权重可能代表距离、成本、时间等。

比方说想在若干城市之间修建公路,图中的节点代表城市,边代表城市之间的公路,边的权重代表修建公路的成本,我们希望找到一种方案能够连接所有城市,且总成本最小,这就是典型的最小生成树问题。

最小生成树算法

有两种经典的算法用于求解最小生成树问题:Kruskal 算法和 Prim 算法。它们都基于贪心思想,但实现方式不同。

Kruskal 算法相对简单一些,只需要先对图中的所有边按照权重排序,然后借助 [Union-Find 并查集算法]即可找到最小生成树。

Prim 算法可以由 [Dijkstra 算法]拓展而来,借助 [优先级队列] 动态排序的特性,逐步构造最小生成树。

具体的代码实现在 [Kruskal 算法]和 [Prim 算法]中讲解。

随机地图构造问题

最小生成树算法经过一些巧妙的改造后,可以被用于生成游戏中的随机化迷宫、洞穴等场景。

其核心思想是利用最小生成树算法能够连接所有顶点且无环路的特性,来确保生成地图的连通性。通过引入随机性,可以创造出每次都不同、看起来自然且复杂的地图结构。

本站包含一个迷宫小游戏,要求你编写 mazeGenerate 函数生成迷宫地图,要求必须存在至少一条起点到终点的路径,且地图需要尽可能随机:

我们可以借助游戏面板直观体会一下最小生成树算法生成的地图的特点。

在游戏面板中可以选择「生成算法」和「求解算法」,你可以切换不同的生成算法,然后点击「生成」按钮,即可查看不同的算法生成地图的过程。

先来观察 Krusual 算法,地图被初始化为一个网格图结构,然后从图中的多个位置开始出现随机路径,最终连接成一个完整的迷宫地图。

再来观察 Prim 算法,地图的初始状态全部都是障碍物,然后从起点开始向周围扩展路径,最终连接成一个完整的迷宫地图。

不只是生成地图的过程不同,生成的地图特点也不同。你可以在游戏面板上切换不同的求解算法,点击「求解」按钮,即可对比查看不同的算法求解地图的过程。

我会建议观察 BFS/DFS 算法求解地图的过程,仔细体会一下不同算法生成地图的特点。在后文讲解完最小生成树算法实现之后,我们再具体讲解随机迷宫地图的生成算法。

排序算法的关键指标


[时空复杂度]

首先一个指标肯定是时间复杂度和空间复杂度。

正如 时空复杂度入门 中所说,对于任意一个算法,其时间复杂度和空间复杂度都是越小越好的。

排序稳定性

稳定性是排序算法的一个重要性质,我们可以简单总结为:

对于序列中的相同元素,如果排序之后它们的相对位置没有发生改变,则称该排序算法为「稳定排序」,反之则为「不稳定排序」

如果单单排序 int 数组,那么稳定性没有什么意义。但如果排序一些结构比较复杂的数据,那么稳定排序就会有一定的优势。

比如说现在你有若干订单数据,已经按照交易日期排好序了,现在你想对用户 ID 再进行排序,这样一来相同用户 ID 的订单就会聚集在一起,方便查看。稳定排序和不稳定排序的区别就体现在这里:

如果你用稳定排序算法,那么排序完成后,相同用户 ID 的订单依然会按照交易日期有序排列:

1
2
3
4
5
6
7
8
9
   Date    UserID
2020-02-01 1001
2020-02-02 1001
2020-02-03 1001

2020-01-01 1002
2020-01-02 1002
2020-01-03 1002
...

因为之前已经按照日期排好序了,对用户 ID 稳定排序之后,相同用户 ID 的订单的相对位置保持不变,所以在日期上依然是有序的。

如果你用不稳定排序算法,相同用户 ID 的订单相对位置可能变化,所以对于相同用户 ID 的订单,交易日期的有序性会丧失,相当于你之前对日期的排序白做了。

可以看到,稳定性是个很重要的性质,所以你在使用排序算法时要特别注意,避免出现预期之外的结果。

是否原地排序

原地排序就是指排序过程中不需要额外的辅助空间,只需要常数级别的额外空间,直接操作原数组进行排序

注意,关键是是否需要额外的空间,而不是是否返回一个新的数组。具体来说就是类似这样的区别:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 非原地排序
void sort(int[] nums) {
// 排序过程中需要额外的辅助数组,消耗 O(N) 的空间
int[] tmp = new int[nums.length];

// 对 nums 进行排序
for ...
}

// 原地排序
void sort(int[] nums) {
// 直接操作 nums,不需要额外的辅助数组,消耗 O(1) 的空间
for ...
}

不难想到,对于大数据量的排序,原地排序算法是比较有优势的。

排序算法的几个关键指标就是这些,后面我会介绍几种常见的排序算法,都会根据这些指标来分析它们的优劣。

选择排序所面临的问题

读完本文,你不仅学会了算法套路,还可以顺便解决如下题目:

LeetCode 力扣 难度
912. Sort an Array 912. 排序数组

[!一句话总结]

选择排序是最简单朴素的排序算法,但是时间复杂度较高,且不是稳定排序。其他基础排序算法都是基于选择排序的优化。

如果你是没接触过排序算法的初学者,那是最好的,不要急着看定义之类的东西;如果你之前了解过排序算法,现在请你忘记定义,忘记曾经背诵过的算法代码。

有了前面内容的铺垫,你已经有了一定的编程能力,能够解决一些基础的算法问题了。那么在这个前提下,我有一个学习方法分享,供你参考:

遇到一个新问题的时候,不要急着找人要一个标准答案,而应该启动自己的思考。被灌输一次标准答案,就错失一次机缘,少一分灵气。被灌得多了,人就傻了。

总有些读者,愁眉苦脸地找我诉苦,说算法题刷完了就忘怎么办啊。我还觉得这是好事呢,念念不忘的是执念,忘了才好,说明还没被塞满,这就是独立思考的机缘呀。

所以回到问题,让我们抓住这次机缘。现在就是给你输入一个数组,让你写个排序算法把所有元素从小到大排序,你来说,怎么写?如果你从来没有思考过这个问题,可以停下几分钟想一想。

1
2
3
void sort(int[] nums) {
// 你的代码,将 nums 中的元素从小到大排序
}

我第一次思考这个问题时,想到的最直接的方法是这样的:

先遍历一遍数组,找到数组中的最小值,然后把它和数组的第一个元素交换位置;接着再遍历一遍数组,找到第二小的元素,和数组的第二个元素交换位置;以此类推,直到整个数组有序。

这个算法有一个被大家熟知的名字,叫做「选择排序」,即每次都去遍历选择最小的元素。写成代码就是这样的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def sort(nums: List[int]) -> None:
n = len(nums)
# sortedIndex 是一个分割线
# 索引 < sortedIndex 的元素都是已排序的
# 索引 >= sortedIndex 的元素都是未排序的
# 初始化为 0,表示整个数组都是未排序的
sortedIndex = 0
while sortedIndex < n:
# 找到未排序部分 [sortedIndex, n) 中的最小值
minIndex = sortedIndex
# 每次都找出最小的
for i in range(sortedIndex + 1, n):
if nums[i] < nums[minIndex]:
minIndex = i
# 交换最小值和 sortedIndex 处的元素
nums[sortedIndex], nums[minIndex] = nums[minIndex], nums[sortedIndex]

# sortedIndex 后移一位
sortedIndex += 1

上述算法的可视化过程如下:

这个算法是正确的,稍加改动就可以作为力扣第 912 题「排序数组」的解法代码。

但这个算法无法通过 912 题的所有测试用例,最后会得到一个超时的错误,这说明算法的逻辑是正确的,只是时间复杂度较高,超出了题目的限制。

暂且不管如何通过 912 题,我们先来按照 [排序算法的几个关键指标]来分析一下这个排序算法。

是否是原地排序

是的。因为算法并没有使用额外的数组空间进行辅助,只是用了几个变量,空间复杂度是 

时空复杂度分析

这个 sort 函数中包含一个 while 循环嵌套一个 for 循环,相当于是这样:

1
2
3
4
5
for (int sortedIndex = 0; sortedIndex < n; sortedIndex++) {
for (int i = sortedIndex + 1; i < n; i++) {
// ...
}
}

你看到了,这就是嵌套 for 循环,总的循环次数是 (n - 1) + (n - 2) + (n - 3) +... + 1,这是等差数列求和,结果近似是 n^2 / 2,所以这个排序算法的时间复杂度用 Big O 表示法就是 ,其中 n 是待排序数组的元素个数。

而且你注意这个算法有个特点,即便整个数组已经是有序的,它还是会执行 n^2 / 2 次,即原始数据的有序度对算法的时间复杂度没有任何影响。

要关注排序算法的实际执行次数

对于一般的算法时空复杂度分析,我们只需要从 Big O 表示法的角度来分析即可,即仅关心量级(最高次项)的大小,而不关心系数和低次项。

但是在分析不同排序算法的场景下,实际的执行次数,以及一些特殊情况(比如数组本身就有序的情况),还是有必要关注的。

因为有多种排序算法从 Big O 的视角来看都是  复杂度,那么我们要根据他们的实际执行次数以及特殊情况下的表现,来分析它们的优劣。

时间都去哪了?优化思路?

现在,请你观察这个算法的逻辑,仔细思考几分钟,时间复杂度是否还有优化的可能?

不要小看这里是基础章节,我讲的都是思维方法,未来你做任何题目,优化时间复杂度的思路和这里一模一样

首先,如果代码没有写错,算法时间复杂度还是太高,那只有一种可能,就是存在冗余计算

上述算法中出现冗余计算的地方比较容易看出来:

它首先遍历 nums[0..] 寻找最小值,然后遍历 nums[1..] 寻找最小值,然后遍历 nums[2..] 寻找最小值,以此类推。

那么请问,在遍历 nums[0..] 的时候,其实已经遍历过 nums[1..] 和 nums[2..] 的所有元素了,你为什么要再次遍历呢?

理论上,你应该可以在遍历 nums[0..] 的时候,顺便找到 nums[1..] 和 nums[2..] 的最小元素,对吧?如果能做到这一点,是不是就可以消掉内层的 for 循环,从把时间复杂度降低一个数量级?

好,现在我们已经找到了冗余计算的症结所在,并且有了一个优化思路。那么这个思路是否可以实现呢?你是否能够在遍历 nums[0..] 的时候,顺便找到 nums[1..] 和 nums[2..] 的最小元素?

我将进行抽象,把这个优化场景转化成一个全新的问题

给你一个数组 nums,请你计算一个新数组 suffixMin 数组,其中 suffixMin[i] 表示 nums[i..] 中的最小值。

如果正着思考,假设现在我知道了 nums[0..] 中的最小元素,我是否能够推导出 nums[1..] 中的最小元素呢?

答案是不可能。信息不足,我实在不知道如何根据 min(nums[0..]) 推导出 min(nums[1..]),只能重新遍历一遍 nums[1..]

但是,我自己都不相信,就是算个最小值,咋可能这么难搞呢?我的脑子被智子锁死了吗???

如果反过来思考,假设现在我知道了 nums[1..] 中的最小元素,我是否能够推导出 nums[0..] 中的最小元素呢?

答案是可以的,min(nums[0..]) = min(nums[0], min(nums[1..]))

有了这个思路,这个 suffixMin 数组就能算出来了,关键是倒着计算:

1
2
3
4
5
6
7
8
9
10
11
12
int[] nums = new int[]{3, 1, 4, 2};
// suffixMin[i] 表示 nums[i..] 中的最小值
int[] suffixMin = new int[nums.length];

// 从后往前计算 suffixMin
suffixMin[nums.length - 1] = nums[nums.length - 1];
for (int i = nums.length - 2; i >= 0; i--) {
suffixMin[i] = Math.min(nums[i], suffixMin[i + 1]);
}

// [1, 1, 2, 2]
System.out.println(suffixMin);

好了,这个计算 suffixMin 数组的问题解决了,现在回到选择排序的优化,我现在只需要花  的时间遍历一遍 nums 数组算出 suffixMin 数组,就可以在  的时间内得到 nums[1..], nums[2..], ... 任意子数组的最小值。

按理说,现在我可以把选择排序的内层 for 循环消掉,时间复杂度优化成  了,对吗?答案是不行

请你思考几分钟,为什么不行,关键的问题在哪里?

综上,所有尝试都是错误的,选择排序无法进行任何优化。

那么我们花了那么多时间,尝试了种种方法,最后啥名堂也没弄出来,是不是很失败?

不,我认为这些才是有效的思考,是真正能够帮助读者掌握算法思维的。

拥有稳定性:冒泡排序

读完本文,你不仅学会了算法套路,还可以顺便解决如下题目:

LeetCode 力扣 难度
912. Sort an Array 912. 排序数组

一句话总结

冒泡算法是对 [选择排序]的一种优化,通过交换 nums[sortedIndex] 右侧的逆序对完成排序,是一种稳定排序算法。

前文讲解了 [选择排序]这种最简单直接的排序算法,其中分析了选择排序的几个待优化的问题:

1、选择排序算法是个不稳定排序算法,因为每次都要交换最小元素和当前元素的位置,这样可能会改变相同元素的相对位置。

2、选择排序的时间复杂度和初始数据的有序度完全没有关系,即便输入的是一个已经有序的数组,选择排序的时间复杂度依然是 

3、选择排序的时间复杂度是 ,具体的操作次数大概是 次,常规的优化思路无法降低时间复杂度。

那么本文就围绕着选择排序的种种缺陷,看看能不能想办法帮它解决一下。

重获排序稳定性

前文分析过选择排序失去稳定性的原因,即每次都要交换最小元素(nums[minIndex])和当前元素(nums[sortedIndex]),这样可能会改变相同元素的相对位置。

你仔细思考这个交换过程,其实它的目标是把 nums[minIndex] 放到到 nums[sortedIndex],至于 nums[sortedIndex] 这个位置的元素应该去哪里,它并不关心。之所以它用交换操作,只是因为交换操作最简单,不需要涉及数据搬移

在交换过程中,把 nums[minIndex] 放到 nums[sortedIndex] 的操作是不影响相同元素的相对顺序的:

1
2
3
4
5
[2, 2', 2'', 1, 1']
^ ^
[1, 2', 2'', _, 1']
^ ^
sortedIndex minIndex

真正破坏稳定性的,是让 nums[sortedIndex] 去 nums[minIndex] 的位置这一步:

1
2
[1, 2', 2'', 2, 1']
^ ^

可以看到 2, 2', 2'' 这三个元素的相对顺序被打乱了。

**所以优化的方向就在这里,你不要图省事儿直接把 nums[sortedIndex] 交换到 nums[minIndex],而是模仿 [在数组中部插入元素的操作]**,将 nums[sortedIndex..minIndex] 的元素整体向后移动一位,把 nums[sortedIndex + 1] 的位置空出来让 nums[sortedIndex] 这个元素去那里待着。

1
2
3
4
5
6
7
8
9
[2, 2', 2'', 1, 1']
^ ^
[1, 2', 2'', _, 1']
^ ^
[1, _, 2', 2'', 1']
^ ^
[1, 2, 2', 2'', 1']
^ ^
sortedIndex minIndex

可以看到,这次 2, 2', 2'' 和 1, 1' 的相对顺序都没有发生改变,选择排序就变成了稳定排序了。

具体代码如下,只需要把 [选择排序]代码中交换元素的部分换一下即可:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 对选择排序进行第一波优化,获得了稳定性
def sort(nums):
n = len(nums)
sortedIndex = 0
while sortedIndex < n:
# 在未排序部分中找到最小值 nums[minIndex]
minIndex = sortedIndex
for i in range(sortedIndex + 1, n):
if nums[i] < nums[minIndex]:
minIndex = i

# 优化:将 nums[minIndex] 插入到 nums[sortedIndex] 的位置
# 将 nums[sortedIndex..minIndex] 的元素整体向后移动一位
minVal = nums[minIndex]
# 数组搬移数据的操作
for i in range(minIndex, sortedIndex, -1):
nums[i] = nums[i - 1]
nums[sortedIndex] = minVal

sortedIndex += 1

你可以拿着这个算法去力扣第 912 题「排序数组」提交一下,虽然最后会超时无法通过,但是可以证明这个算法的正确性是没有问题的。

这个算法对比标准的选择排序,虽然拥有了稳定性,但是执行效率会下降,虽然从 Big O 表示法的角度来看,两层嵌套循环的时间复杂度还是 ,但毕竟又加了一个 for 循环,实际执行次数肯定会大于标准选择排序的 次。

下面我们再来看看,能不能进一步优化,避免这个额外的 for 循环。

优化时间复杂度

仔细观察上面的算法代码,while 循环内部主要做了两件事:

1、第一个 for 循环寻找 nums[sortedIndex..] 中的最小值。

2、第二个 for 循环将这个最小值插入到 nums[sortedIndex] 的位置。

那么我们能否将这两个步骤合在一起呢?具体来说,你在寻找 nums[sortedIndex..] 中的最小值的时候能不能做些力所能及的事情,能不能做到找到最小值后,它就已经被放在正确的位置上,不需要再进行数据搬移了?

答案是可以的,看我操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 对选择排序进行第二波优化,获得稳定性的同时避免额外的 for 循环
# 这个算法有另一个名字,叫做冒泡排序
def sort_list(nums):
n = len(nums)
sorted_index = 0
while sorted_index < n:
# 寻找 nums[sorted_index..] 中的最小值
# 同时将这个最小值逐步移动到 nums[sorted_index] 的位置
for i in range(n - 1, sorted_index, -1):
if nums[i] < nums[i - 1]:
# swap(nums[i], nums[i - 1])
tmp = nums[i]
nums[i] = nums[i - 1]
nums[i - 1] = tmp
sorted_index += 1

这个优化就比较巧妙了,倒序遍历 nums[sortedIndex..],如果发现逆序对儿,就交换顺序,这样最小值就会逐步移动到 nums[sortedIndex] 的位置。

而且由于我们只交换相邻的逆序对儿,不会去碰值相同的元素,所以这个算法是稳定排序。

这个算法的时间复杂度依然是 ,实际执行次数和选择排序类似,也是一个等差数列求和,大约是  次。

冒泡排序

这个算法的名字叫做冒泡排序,因为它的执行过程就像从数组尾部向头部冒出水泡,每次都会将最小值顶到正确的位置。

提前终止算法

上面说到选择排序的一个问题是,其时间复杂度和初始数据的有序度完全没有关系,即便输入的数组已经有序,选择排序依然会执行 ) 次操作。

在上面的一些列优化之后,就可以解决这个问题了,具体看代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 进一步优化,数组有序时提前终止算法
def sort(nums):
n = len(nums)
sorted_index = 0
while sorted_index < n:
# 加一个布尔变量,记录是否进行过交换操作
swapped = False
for i in range(n - 1, sorted_index, -1):
if nums[i] < nums[i - 1]:
# swap(nums[i], nums[i - 1])
tmp = nums[i]
nums[i] = nums[i - 1]
nums[i - 1] = tmp
swapped = True
# 如果一次交换操作都没有进行,说明数组已经有序,可以提前终止算法
if not swapped:
break
sorted_index += 1

好了,以上就是针对选择排序的一系列优化,最终使它拥有了排序稳定性,并支持在数组有序时提前终止算法。唯一的遗憾是,时间复杂度依然是 ,并没有降低。

下面我们继续探讨,看看还有什么方法能够改进选择排序。

运用逆向思维:插入排序

读完本文,你不仅学会了算法套路,还可以顺便解决如下题目:

LeetCode 力扣 难度
912. Sort an Array 912. 排序数组

一句话总结

插入排序是基于 [选择排序] 的一种优化,将 nums[sortedIndex] 插入到左侧的有序数组中。对于有序度较高的数组,插入排序的效率比较高。

前文 [选择排序所面临的问题]中分析了选择排序遇到的几个问题,然后逐步优化写出了 [冒泡排序],使得排序算法具有稳定性,且能够在输入数组的有序度较高时提前终止,提升效率。

回顾一下,冒泡排序的关键点在于对下面这段代码的优化:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 对选择排序进行第一波优化,获得了稳定性
def sort(nums):
n = len(nums)
sortedIndex = 0
while sortedIndex < n:
# 在未排序部分中找到最小值 nums[minIndex]
minIndex = sortedIndex
for i in range(sortedIndex + 1, n):
if nums[i] < nums[minIndex]:
minIndex = i

# 优化:将 nums[minIndex] 插入到 nums[sortedIndex] 的位置
# 将 nums[sortedIndex..minIndex] 的元素整体向后移动一位
minVal = nums[minIndex]
# 数组搬移数据的操作
for i in range(minIndex, sortedIndex, -1):
nums[i] = nums[i - 1]
nums[sortedIndex] = minVal

sortedIndex += 1

为了避免 while 内存在两个 for 循环,我们使用了一种类似冒泡的方式逐步交换 nums[sortedIndex..] 中的逆序对,将最小值换到 nums[sortedIndex] 的位置。

好的,先停在这一步,让我们忘记冒泡排序的优化方法,你来思考一下,是否还有其他方法能够优化上述代码,把 while 循环中的两个 for 循环优化成一个 for 循环?

反向思维

上面的算法思路是:在 nums[sortedIndex..] 中找到最小值,然后将其插入到 nums[sortedIndex] 的位置。

那么我们能不能反过来想,在 nums[0..sortedIndex-1] 这个部分有序的数组中,找到 nums[sortedIndex] 应该插入的位置,然后进行插入呢

当年我思考如何对插入排序进行优化时,是想到过这个思路的,因为我想利用数组的有序性呀:既然 nums[0..sortedIndex-1] 这部分是已经排好序的,那么我就可以用二分搜索来寻找 nums[sortedIndex] 应该插入的位置。

这样一来,上述代码中的内层第一个 for 循环,我可以给他优化成对数级别的复杂度。

但是仔细想想,用二分搜索好像是多此一举的。因为就算我用二分搜索找到了 nums[sortedIndex] 应该插入的位置,我还是需要搬移元素进行插入,那还不如一边遍历一遍交换元素的方法简单高效呢:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 对选择排序进一步优化,向左侧有序数组中插入元素
# 这个算法有另一个名字,叫做插入排序
def sort(nums):
n = len(nums)
# 维护 [0, sorted_index) 是有序数组
sorted_index = 0
while sorted_index < n:
# 将 nums[sorted_index] 插入到有序数组 [0, sorted_index) 中
for i in range(sorted_index, 0, -1):
if nums[i] < nums[i - 1]:
# swap(nums[i], nums[i - 1])
tmp = nums[i]
nums[i] = nums[i - 1]
nums[i - 1] = tmp
else:
break
sorted_index += 1

插入排序

这个算法的名字叫做插入排序,它的执行过程就像是打扑克牌时,将新抓到的牌插入到手中已经排好序的牌中。

插入排序的空间复杂度是 ,是原地排序算法。时间复杂度是 ,具体的操作次数和选择排序类似,是一个等差数列求和,大约是 次。

插入排序是一种稳定排序,因为只有在 nums[i] < nums[i - 1] 的情况下才会交换元素,所以相同元素的相对位置不会发生改变。

初始有序度越高,效率越高

显然,插入排序的效率和输入数组的有序度有很大关系,可以举极端例子来理解:

如果输入数组已经有序,或者仅有个别元素逆序,那么插入排序的内层 for 循环几乎不需要执行元素交换,所以时间复杂度接近 

如果输入的数组是完全逆序的,那么插入排序的效率就会很低,内层 for 循环每次都要对 nums[0..sortedIndex-1] 的所有元素进行交换,算法的总时间复杂度就接近 

如果对比插入排序和冒泡排序,插入排序的综合性能应该要高于冒泡排序

直观地说,插入排序的内层 for 循环,只需要对 sortedIndex 左侧 nums[0..sortedIndex-1] 这部分有序数组进行遍历和元素交换,大部分非极端情况下,可能不需要遍历完 nums[0..sortedIndex-1] 的所有元素;而冒泡排序的内层 for 循环,每次都需要遍历sortedIndex 右侧 nums[sortedIndex..] 的所有元素。

所以冒泡排序的操作数大约是 n2/2n2/2,而插入排序的操作数会小于 n2/2n2/2。

你可以把插入排序的代码拿去力扣第 912 题「排序数组」提交,它最终依然会超时,但可以说明算法代码的逻辑是正确的。之后的文章我们继续探讨如何对排序算法进行优化。

突破 O():希尔排序


妙用二叉树前序位置:快速排序


妙用二叉树后序位置:归并排序


二叉堆结构的运用:堆排序


全新的排序原理:计数排序


数据结构和算法的框架思维


总结一切数据结构和算法

种种数据结构,皆为数组(顺序存储)和链表(链式存储)的变换。

数据结构的关键点在于遍历和访问,即增删查改等基本操作。

种种算法,皆为穷举

穷举的关键点在于无遗漏和无冗余。熟练掌握算法框架,可以做到无遗漏;充分利用信息,可以做到无冗余。

数据结构的存储方式

**数据结构的存储方式只有两种:[数组(顺序存储)]和 [链表(链式存储)]**。

这句话怎么理解,不是还有哈希表、栈、队列、堆、树、图等等各种数据结构吗?

我们分析问题,一定要有递归的思想,自顶向下,从抽象到具体。你上来就列出这么多,那些都属于上层建筑,而数组和链表才是结构基础。因为那些多样化的数据结构,究其源头,都是在链表或者数组上的特殊操作,API 不同而已。

比如说 [队列、栈] 这两种数据结构既可以使用链表也可以使用数组实现。用数组实现,就要处理扩容缩容的问题;用链表实现,没有这个问题,但需要更多的内存空间存储节点指针。

[图结构] 的两种存储方式,邻接表就是链表,邻接矩阵就是二维数组。邻接矩阵判断连通性迅速,并可以进行矩阵运算解决一些问题,但是如果图比较稀疏的话很耗费空间。邻接表比较节省空间,但是很多操作的效率上肯定比不过邻接矩阵。

[哈希表] 就是通过散列函数把键映射到一个大数组里。而且对于解决散列冲突的方法,[拉链法] 需要链表特性,操作简单,但需要额外的空间存储指针;[线性探查法] 需要数组特性,以便连续寻址,不需要指针的存储空间,但操作稍微复杂些。

**[树结构]**,用数组实现就是「堆」,因为「堆」是一个完全二叉树,用数组存储不需要节点指针,操作也比较简单,经典应用有 [二叉堆];用链表实现就是很常见的那种「树」,因为不一定是完全二叉树,所以不适合用数组存储。为此,在这种链表「树」结构之上,又衍生出各种巧妙的设计,比如 [二叉搜索树]、AVL 树、[红黑树]、[区间树]、B 树等等,以应对不同的问题。

综上,数据结构种类很多,甚至你也可以发明自己的数据结构,但是底层存储无非数组或者链表,二者的优缺点如下:

[数组] 由于是紧凑连续存储,可以随机访问,通过索引快速找到对应元素,而且相对节约存储空间。但正因为连续存储,内存空间必须一次性分配够,所以说数组如果要扩容,需要重新分配一块更大的空间,再把数据全部复制过去,时间复杂度 ;而且你如果想在数组中间进行插入和删除,每次必须搬移后面的所有数据以保持连续,时间复杂度 

[链表] 因为元素不连续,而是靠指针指向下一个元素的位置,所以不存在数组的扩容问题;如果知道某一元素的前驱和后驱,操作指针即可删除该元素或者插入新元素,时间复杂度 。但是正因为存储空间不连续,你无法根据一个索引算出对应元素的地址,所以不能随机访问;而且由于每个元素必须存储指向前后元素位置的指针,会消耗相对更多的储存空间。

数据结构的基本操作

对于任何数据结构,其基本操作无非遍历 + 访问,再具体一点就是:增删查改

数据结构种类很多,但它们存在的目的都是在不同的应用场景,尽可能高效地增删查改,这就是数据结构的使命。

如何遍历 + 访问?我们仍然从最高层来看,各种数据结构的遍历 + 访问无非两种形式:线性的和非线性的。

线性就是 for/while 迭代为代表,非线性就是递归为代表。再具体一步,无非以下几种框架:

数组遍历框架,典型的线性迭代结构:

1
2
3
def traverse(arr: List[int]):
for i in range(len(arr)):
# 迭代访问 arr[i]

链表遍历框架,兼具迭代和递归结构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 基本的单链表节点
class ListNode:
def __init__(self, val):
self.val = val
self.next = None

def traverse(head: ListNode) -> None:
p = head
while p is not None:
# 迭代访问 p.val
p = p.next

def traverse(head: ListNode) -> None:
# 递归访问 head.val
traverse(head.next)

二叉树遍历框架,典型的非线性递归遍历结构:

1
2
3
4
5
6
7
8
9
10
# 基本的二叉树节点
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right

def traverse(root: TreeNode):
traverse(root.left)
traverse(root.right)

你看二叉树的递归遍历方式和链表的递归遍历方式,相似不?再看看二叉树结构和单链表结构,相似不?如果再多几条叉,N 叉树你会不会遍历?

二叉树框架可以扩展为 N 叉树的遍历框架:

1
2
3
4
5
6
7
8
# 基本的 N 叉树节点
class TreeNode:
val: int
children: List[TreeNode]

def traverse(root: TreeNode) -> None:
for child in root.children:
traverse(child)

N 叉树的遍历又可以扩展为图的遍历,因为图就是好几 N 叉棵树的结合体。你说图是可能出现环的?这个很好办,用个布尔数组 visited 做标记就行了,[图结构遍历]中有具体讲解。

所谓框架,就是套路。不管增删查改,这些代码都是永远无法脱离的结构,你可以把这个结构作为大纲,根据具体问题在框架上添加代码就行了。

算法的本质

如果要让我一句话总结,我想说算法的本质就是「穷举」

这么说肯定有人要反驳了,真的所有算法问题的本质都是穷举吗?没有例外吗?

例外肯定是有的,比如 [一行代码就能解决的算法题],这些题目类似脑筋急转弯,都是通过观察,发现规律,然后找到最优解法,不过这类算法问题较少,不必特别纠结。再比如,密码学算法、机器学习算法,它们的本质确实不是穷举,而是数学原理的编程实现,所以这类算法的本质是数学,不在我们所探讨的「数据结构和算法」的范畴之内。

顺便强调下,「算法工程师」做的这个「算法」,和「数据结构与算法」中的这个「算法」完全是两码事,免得一些初学读者误解。

对前者来说,重点在数学建模和调参经验,计算机真就只是拿来做计算的工具而已;而后者的重点是计算机思维,需要你能够站在计算机的视角,抽象、化简实际问题,然后用合理的数据结构去解决问题。

所以,你千万别以为学好了数据结构和算法就能去做算法工程师,也不要以为只要不做算法工程师就不需要学习数据结构和算法。

坦白说,大部分开发岗位工作中都是基于现成的开发框架做事,不怎么会碰到底层数据结构和算法相关的问题,但另一个事实是,只要你想找技术相关的岗位,数据结构和算法的考察是绕不开的,因为这块知识点是公认的程序员基本功。

为了区分,不妨称算法工程师研究的算法为「数学算法」,称刷题面试的算法为「计算机算法」,我写的内容主要聚焦的是「计算机算法」

这样解释应该很清楚了吧,我猜大部分人的目标是通过算法笔试,找一份开发岗位的工作,所以你真的不需要有多少数学基础,只要学会用计算机思维解决问题就够了。

其实计算机思维也没什么高端的,你想想计算机的特点是啥?不就是快嘛,你的脑回路一秒只能转一圈,人家 CPU 转几万圈无压力。所以计算机解决问题的方式大道至简,就是穷举

我记得自己刚入门的时候,也觉得计算机算法是一个很高大上的东西,每见到一道题,就想着能不能推导出一个什么数学公式,啪的一下就能把答案算出来。

比如你和一个没学过计算机算法的人说你写了个计算排列组合的算法,他大概以为你发明了一个公式,可以直接算出所有排列组合。但实际上呢?没什么高大上的公式,我会在 [回溯算法秒杀排列组合子集问题]讲解,其实就是把排列组合的所有可能抽象成一棵多叉树结构,然后你写代码去遍历这棵树,把所有的结果收集起来罢了。这有啥神奇的?

对计算机算法的误解也许是以前学数学留下的「后遗症」,数学题一般都是你仔细观察,找几何关系,列方程,然后算出答案。如果说你需要进行大规模穷举来寻找答案,那大概率是你的解题思路出问题了。

而计算机解决问题的思维恰恰相反:有没有什么数学公式就交给你们人类去推导吧,如果能找到一些巧妙的定理那最好,但如果找不到,那就穷举呗,反正只要复杂度允许,没有什么答案是穷举不出来的。理论上讲只要不断随机打乱一个数组,总有一天能得到有序的结果呢!当然,这绝不是一个好算法,因为鬼知道它要运行多久才有结果。

技术岗笔试面试考的那些算法题,求个最大值最小值什么的,你怎么求?把所有可行解穷举出来就能找到最值了呗,说白了不就这么点事儿么。

穷举的难点

[!穷举的两个关键]

你千万不要觉得穷举这个事儿很简单,穷举有两个关键难点:无遗漏、无冗余

遗漏,会直接导致答案出错,比如让你求最小值,你穷举时恰好把那个最小值漏掉了,这不就错了嘛。

冗余,会拖慢算法的运行速度,比如你的代码把完全相同的计算流程重复了十遍,那你的算法不就慢了十倍么,就有可能超过判题平台的时间限制。

为什么会遗漏?因为你对算法框架掌握不到位,不知道正确的穷举代码。

为什么会冗余?因为你没有充分利用信息。

所以,当你看到一道算法题,可以从这两个维度去思考:

1、如何穷举?即无遗漏地穷举所有可能解。

2、如何聪明地穷举?即避免穷举过程中的冗余计算,消耗尽可能少的资源求出答案。

如何穷举

什么算法的难点在「如何穷举」呢?一般是递归类问题,比方说回溯算法、动态规划系列算法

先说回溯算法,就拿我们高中学过的排列组合问题举例,我们当时都可以找到规律在草稿纸上推导排列组合:根据第一位可能的选择,先固定第一位,然后看第二位有哪些可能的选择,然后固定第二位… 以此类推,但如果未经训练,你很难用代码来穷举所有排列组合,因为你很难把这个手动穷举的过程抽象成程序化的规律。

首先,你要把排列组合问题抽象成一棵树,其次你要精确地使用代码遍历这棵树的所有节点,不能漏不能多,才能写出正确的代码。在后面的章节中,我会先介绍 [回溯算法核心框架]动态规划比回溯算法更难一点。它俩本质上都是穷举,但思考模式不同,回溯算法是「遍历」的思维,而动态规划是「分解问题」的思维。

[!啥叫分解问题的思维?]

我都不用举正儿八经的例子,就比方说,你看那棵树,回答我,树上有多少片叶子?

你如何穷举?顺着树枝去一片片数么?当然也可以的,但这是遍历的思维模式,胜似你手动推导排列组合的过程,属于回溯算法的范畴

如果你具备分解问题的思维模式,你应该告诉我:树上只有一片叶子,和剩下的叶子

听到这个回答,就知道是个算法高手。

还有不开窍的小同学追问,那剩下的叶子有多少呢?答曰,只有一片,和剩下的叶子。不要再往下问了,只能说,谜底就在谜面上,到了那个时候,你自然知道剩多少了。

所以你知道为啥我说动态规划这类问题的难点在于「如何穷举」了吧?一个脑瓜正常的人,本来就不会用这种奇怪的思维方式来思考问题,但这种思维结合计算机就是杀手锏,所以你要练,练好了,随心所欲写算法,咋写都是对的。

我在 动态规划核心框架 阐述了动态规划系列问题的解题过程,无非就是先写出暴力穷举解法(状态转移方程),加个备忘录就成自顶向下的递归解法了,再改一改就成自底向上的递推迭代解法了,动态规划的降维打击 里也讲过如何利用空间压缩技巧优化动态规划算法的空间复杂度。

其中加备忘录、空间压缩技巧都是固定的套路,不是难点。你亲自去做动态规划的题目就会发现,自己根本想不出状态转移方程,即第一步的暴力解法都写不出来,所以说找状态转移方程(如何穷举)才是难点。

我专门写了 动态规划设计方法:数学归纳法 这篇文章,告诉你穷举的核心是数学归纳法,明确函数的定义,分解问题,然后利用这个定义递归求解子问题。

如何聪明地穷举

什么算法的难点在「如何聪明地穷举」呢?一些耳熟能详的非递归算法技巧,都可以归在这一类

最简单的例子,比方说让你在有序数组中寻找一个元素,用一个 for 循环暴力穷举谁都会,但 二分搜索算法 就是更聪明的穷举方式,拥有更好的时间复杂度。

还有前文 Union Find 并查集算法详解 告诉你一种高效计算连通分量的技巧,理论上说,想判断图中的两个节点是否连通,我用 DFS/BFS 暴力搜索(穷举)肯定可以做到,但人家 Union Find 算法硬是用数组模拟树结构,给你把连通性相关的操作复杂度给干到  了。

这就属于聪明地穷举,大佬们把这些技巧发明出来,你学过就会用,没学过恐怕很难想出这种思路。

再比如贪心算法技巧,前文 当老司机学会贪心算法 就告诉你,所谓贪心算法就是在题目中发现一些规律(专业点叫贪心选择性质),使得你不用完整穷举所有解就可以得出答案。

人家动态规划好歹是无冗余地穷举所有解,然后找一个最值,你贪心算法可好,都不用穷举所有解就可以找到答案,所以前文 贪心算法解决跳跃游戏 中贪心算法的效率比动态规划还高。当然,并不是所有问题都存在贪心选择性质让你投机取巧,所以全量穷举虽然朴实无华且枯燥,但真的是任何情况下都可以用的。

下面我概括性地列举一些常见的算法技巧,供大家学习参考。

数组/单链表系列算法

单链表常考的技巧就是双指针,属于「如何聪明地穷举」这一类单链表双指针技巧汇总 全给你总结好了,会者不难,难者不会。

比如判断单链表是否成环,拍脑袋的暴力解是什么?就是用一个 HashSet 之类的数据结构来缓存走过的节点,遇到重复的就说明有环对吧。但我们用快慢指针可以避免使用额外的空间,这就是聪明地穷举嘛。

数组常用的技巧有也是双指针相关的技巧,也都属于「如何聪明地穷举」这一类数组双指针技巧汇总 全给你总结好了,会者不难,难者不会。

首先说二分搜索技巧,可以归为两端向中心的双指针。如果让你在数组中搜索元素,一个 for 循环花  时间穷举肯定能搞定对吧,但是二分搜索告诉你,如果数组是有序的,它只 的复杂度,这不就是一种更聪明的搜索方式么。

二分搜索框架详解 给你总结了二分搜索代码模板,保证不会出现搜索边界的问题。二分搜索算法运用 给你总结了二分搜索相关题目的共性以及如何将二分搜索思想运用到实际算法中。

**再说说 滑动窗口算法技巧**,典型的快慢双指针。你用嵌套 for 循环花  的时间肯定可以穷举出所有子数组,也就必然可以找到符合题目要求的子数组。但是滑动窗口算法表示,在某些场景下,它可以用一快一慢两个指针,只需  的时间就可以找到答案,这就是更聪明地穷举方式。

滑动窗口算法框架详解 介绍了滑动窗口算法的适用场景以及通用代码模板,保你写出正确的代码。滑动窗口习题 中手把手带你运用滑动窗口框架解决各种问题。

**最后说说 前缀和技巧 和 差分数组技巧**。

如果频繁地让你计算子数组的和,每次用 for 循环去遍历肯定没问题,但前缀和技巧预计算一个 preSum 数组,就可以避免循环。

类似的,如果频繁地让你对子数组进行增减操作,也可以每次用 for 循环去操作,但差分数组技巧维护一个 diff 数组,也可以避免循环。

数组链表的技巧差不多就这些了,都比较固定,只要你都见过,运用出来的难度不算大,下面来说一说稍微有些难度的算法。

二叉树系列算法

老读者都知道,二叉树的重要性我之前说了无数次,因为二叉树模型几乎是所有高级算法的基础,尤其是那么多人说对递归的理解不到位,更应该好好刷二叉树相关题目。

[!Tip]
在本站的二叉树章节,我会按照固定的公式和思维模式讲解 150 道二叉树题目,可以手把手带你刷完二叉树分类的题目,迅速掌握递归思维。

**二叉树心法(纲领篇) 说过,二叉树题目的递归解法可以分两类思路,第一类是遍历一遍二叉树得出答案,第二类是通过分解问题计算出答案,这两类思路分别对应着 回溯算法核心框架 和 动态规划核心框架**。

遍历的思维模式

什么叫通过遍历一遍二叉树得出答案

就比如说计算二叉树最大深度这个问题让你实现 maxDepth 这个函数,你这样写代码完全没问题:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution:
def __init__(self):
# 记录最大深度
self.res = 0
# 记录当前遍历节点的深度
self.depth = 0

def maxDepth(self, root: TreeNode) -> int:
self.traverse(root)
return self.res

def traverse(self, root: TreeNode) -> None:
if not root:
# 到达叶子节点
self.res = max(self.res, self.depth)
return
# 前序遍历位置
self.depth += 1
self.traverse(root.left)
self.traverse(root.right)
# 后序遍历位置
self.depth -= 1

这个逻辑就是用 traverse 函数遍历了一遍二叉树的所有节点,维护 depth 变量,在叶子节点的时候更新最大深度。

你看这段代码,有没有觉得很熟悉?能不能和回溯算法的代码模板对应上?

不信你照着 回溯算法核心框架 中全排列问题的代码对比下,backtrack 函数就是 traverse 函数,换汤不换药,整体逻辑非常类似:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Solution:
def permute(self, nums: List[int]) -> List[List[int]]:
# 记录所有全排列
res = []
# 记录当前正在穷举的排列
track = []

# track 中的元素会被标记为 true,避免重复使用
used = [False] * len(nums)

# 主函数,输入一组不重复的数字,返回它们的全排列
def backtrack(nums):
# 到达叶子节点,track 中的元素就是一个全排列
if len(track) == len(nums):
res.append(track[:])
return

for i in range(len(nums)):
# 排除不合法的选择
if used[i]:
# nums[i] 已经在 track 中,跳过
continue
# 做选择
track.append(nums[i])
used[i] = True

# 进入递归树的下一层
backtrack(nums)

# 取消选择
track.pop()
used[i] = False

backtrack(nums)
return res

你看这代码虽然多,但本质不就是多叉树的遍历吗?所以说回溯算法本质就是遍历多叉树,你只要能把问题抽象成树结构,就一定能用回溯算法解决。

分解问题的思维模式

那什么叫通过分解问题计算答案

同样是计算二叉树最大深度这个问题,你也可以写出下面这样的解法:

1
2
3
4
5
6
7
8
9
10
11
# 定义:输入根节点,返回这棵二叉树的最大深度
def maxDepth(root: TreeNode) -> int:
if root is None:
return 0
# 递归计算左右子树的最大深度
leftMax = maxDepth(root.left)
rightMax = maxDepth(root.right)
# 整棵树的最大深度就是左右子树的最大深度加一
res = max(leftMax, rightMax) + 1

return res

你看这段代码,有没有觉得很熟悉?有没有觉得有点动态规划解法代码的形式?

不信你看 动态规划核心框架 中凑零钱问题的暴力穷举解法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution:
def coinChange(self, coins: List[int], amount: int) -> int:
# 题目要求的最终结果是 dp(amount)
return self.dp(coins, amount)

# 定义:要凑出目标金额 amount,至少要 dp(coins, amount) 个硬币
def dp(self, coins, amount):
# base case
if amount == 0:
return 0
if amount < 0:
return -1

res = float('inf')
for coin in coins:
# 计算子问题的结果
subProblem = self.dp(coins, amount - coin)
# 子问题无解则跳过
if subProblem == -1:
continue
# 在子问题中选择最优解,然后加一
res = min(res, subProblem + 1)

return res if res != float('inf') else -1

这个暴力解法加个 memo 备忘录就是自顶向下的动态规划解法,你对照二叉树最大深度的解法代码,有没有发现很像?

思路拓展

如果你感受到最大深度这个问题两种解法的区别,那就趁热打铁,我问你,二叉树的前序遍历怎么写

我相信大家都会对这个问题嗤之以鼻,毫不犹豫就可以写出下面这段代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Solution:
def __init__(self):
# 创建一个链表作为结果容器
self.res = []

# 返回前序遍历结果
def preorder(self, root: TreeNode) -> List[int]:
self.traverse(root)
return self.res

# 二叉树遍历函数
def traverse(self, root: TreeNode) -> None:
if not root:
return
# 前序遍历位置
self.res.append(root.val)
self.traverse(root.left)
self.traverse(root.right)

但是,你结合上面说到的两种不同的思维模式,二叉树的遍历是否也可以通过分解问题的思路解决呢?

可以观察一下二叉树前序遍历结果的特点:

你注意前序遍历的结果,根节点的值在第一位,后面接着左子树的前序遍历结果,最后接着右子树的前序遍历结果

有没有体会出点什么来?其实完全可以重写前序遍历代码,用分解问题的形式写出来:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from typing import List

# 定义:输入一棵二叉树的根节点,返回这棵树的前序遍历结果
def preorder(root: TreeNode) -> List[int]:
res = []
if not root:
return res
# 前序遍历的结果,root.val 在第一个
res.append(root.val)
# 后面接着左子树的前序遍历结果
res.extend(preorder(root.left))
# 最后接着右子树的前序遍历结果
res.extend(preorder(root.right))
return res

你看,这就是用分解问题的思维模式写二叉树的前序遍历,如果写中序和后序遍历也是类似的。

层序遍历

除了动归、回溯(DFS):深度优先搜索、分治,还有一个常用算法就是 BFS(广度优先搜索) 了,BFS 算法核心框架 就是根据下面这段二叉树的层序遍历代码改装出来的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 输入一棵二叉树的根节点,层序遍历这棵二叉树
class Solution:
def levelOrder(self, root: TreeNode) -> List[List[int]]:
if not root:
return
q = collections.deque()
q.append(root)
depth = 0
# 从上到下遍历二叉树的每一层
while q:
sz = len(q)
# 从左到右遍历每一层的每个节点
for i in range(sz):
cur = q.popleft()

# 将下一层节点放入队列
if cur.left:
q.append(cur.left)
if cur.right:
q.append(cur.right)
depth += 1

更进一步,图论相关的算法也是二叉树算法的延续

比如 图论基础环判断和拓扑排序 和 二分图判定算法 就用到了 DFS 算法;再比如 Dijkstra 算法模板,就是改良版的 BFS 算法。

好了,说的差不多了,上述这些算法的本质都是穷举二(多)叉树,有机会的话通过剪枝或者备忘录的方式减少冗余计算,提高效率,就这么点事儿。

最后总结

很多读者问我什么刷题方式是正确的,我认为正确的刷题方式应该是刷一道题能获得刷十道题的效果,不然力扣现在 2000 道题目,你都打算刷完么?

那么怎么做到呢?要有框架思维,学会提炼重点,寻找那个不变的东西。一个算法技巧可以包装出一万道题,如果你能一眼看穿它们的本质,那么一万道题等于一道,何必浪费时间去做呢?

这就是框架的力量,能够保证你在快睡着的时候,依然能写出正确的程序;就算你啥都没学过,就这种思维方法,都能比别人高一个维度

授人以鱼不如授人以渔,算法真的没啥难的,只要有心,谁都可以学好。我希望你能在我这里培养出成体系的思维方法,享受支配算法的乐趣,而不是被算法支配。

双指针技巧秒杀七道链表题目


读完本文,你不仅学会了算法套路,还可以顺便解决如下题目:

LeetCode 力扣 难度
21. Merge Two Sorted Lists 21. 合并两个有序链表 green
86. Partition List 86. 分隔链表 orange
23. Merge k Sorted Lists 23. 合并 K 个升序链表 red
141. Linked List Cycle 141. 环形链表 green
142. Linked List Cycle II 142. 环形链表 II orange
876. Middle of the Linked List 876. 链表的中间结点 green
19. Remove Nth Node From End of List 19. 删除链表的倒数第 N 个结点 orange
160. Intersection of Two Linked Lists 160. 相交链表 green
LCR 140. 训练计划 II LCR 140. 训练计划 II green

[!info]

阅读本文前,你需要先学习:

本文总结一下单链表的基本技巧,每个技巧都对应着至少一道算法题:

1、合并两个有序链表
2、链表的分解
3、合并 k 个有序链表
4、寻找单链表的倒数第 k 个节点
5、寻找单链表的中点
6、判断单链表是否包含环并找出环起点
7、判断两个单链表是否相交并找出交点

这些解法都用到了双指针技巧,所以说对于单链表相关的题目,双指针的运用是非常广泛的,下面我们就来一个一个看。

合并两个有序链表

这是最基本的链表技巧,力扣第 21 题「合并两个有序链表」就是这个问题,给你输入两个有序链表,请你把他俩合并成一个新的有序链表:

将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 

示例 1:

输入: l1 = [1,2,4], l2 = [1,3,4]
输出:[1,1,2,3,4,4]

示例 2:

输入: l1 = [], l2 = []
输出:[]

示例 3:

输入: l1 = [], l2 = [0]
输出:[0]

提示:

  • 两个链表的节点数目范围是 [0, 50]
  • -100 <= Node.val <= 100
  • l1 和 l2 均按 非递减顺序 排列

题目来源:力扣 21. 合并两个有序链表

1
2
# 函数签名如下
def mergeTwoLists(l1: ListNode, l2: ListNode) -> ListNode:

这题比较简单,我们直接看解法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Solution:
def mergeTwoLists(self, l1: ListNode, l2: ListNode) -> ListNode:
# 虚拟头结点
dummy = ListNode(-1)
p = dummy
p1 = l1
p2 = l2

while p1 is not None and p2 is not None:
# 比较 p1 和 p2 两个指针
# 将值较小的的节点接到 p 指针
if p1.val > p2.val:
p.next = p2
p2 = p2.next
else:
p.next = p1
p1 = p1.next
# p 指针不断前进
p = p.next

if p1 is not None:
p.next = p1

if p2 is not None:
p.next = p2

return dummy.next

我们的 while 循环每次比较 p1 和 p2 的大小,把较小的节点接到结果链表上,看如下 GIF:

形象地理解,这个算法的逻辑类似于拉拉链,l1, l2 类似于拉链两侧的锯齿,指针 p 就好像拉链的拉索,将两个有序链表合并。

下面是算法的可视化,你可以多次点击  这一行代码,即可看到  合并两个有序链表的过程:

代码中还用到一个链表的算法题中是很常见的「虚拟头结点」技巧,也就是 dummy 节点。你可以试试,如果不使用 dummy 虚拟节点,代码会复杂一些,需要额外处理指针 p 为空的情况。而有了 dummy 节点这个占位符,可以避免处理空指针的情况,降低代码的复杂性。

何时使用虚拟头结点

经常有读者问我,什么时候需要用虚拟头结点?我这里总结下:当你需要创造一条新链表的时候,可以使用虚拟头结点简化边界情况的处理

比如说,让你把两条有序链表合并成一条新的有序链表,是不是要创造一条新链表?再比你想把一条链表分解成两条链表,是不是也在创造新链表?这些情况都可以使用虚拟头结点简化边界情况的处理。

单链表的分解

直接看下力扣第 86 题「分隔链表」:

给你一个链表的头节点 head 和一个特定值 x ,请你对链表进行分隔,使得所有 小于 x 的节点都出现在 大于或等于 x 的节点之前。

你应当 保留 两个分区中每个节点的初始相对位置。

示例 1:

输入: head = [1,4,3,2,5,2], x = 3
输出:[1,2,2,4,3,5]

示例 2:

输入: head = [2,1], x = 2
输出:[1,2]

提示:

  • 链表中节点的数目在范围 [0, 200] 内
  • -100 <= Node.val <= 100
  • -200 <= x <= 200

题目来源:力扣 86. 分隔链表

在合并两个有序链表时让你合二为一,而这里需要分解让你把原链表一分为二。具体来说,我们可以把原链表分成两个小链表,一个链表中的元素大小都小于 x,另一个链表中的元素都大于等于 x,最后再把这两条链表接到一起,就得到了题目想要的结果。

整体逻辑和合并有序链表非常相似,细节直接看代码吧,注意虚拟头结点的运用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Solution:
def partition(self, head: ListNode, x: int) -> ListNode:
# 存放小于 x 的链表的虚拟头结点
dummy1 = ListNode(-1)
# 存放大于等于 x 的链表的虚拟头结点
dummy2 = ListNode(-1)
# p1, p2 指针负责生成结果链表
p1, p2 = dummy1, dummy2
# p 负责遍历原链表,类似合并两个有序链表的逻辑
# 这里是将一个链表分解成两个链表
p = head
while p:
if p.val >= x:
p2.next = p
p2 = p2.next
else:
p1.next = p
p1 = p1.next
# 不能直接让 p 指针前进,
# p = p.next
# 断开原链表中的每个节点的 next 指针
temp = p.next
p.next = None
p = temp
# 连接两个链表
p1.next = dummy2.next

return dummy1.next

我知道有很多读者会对这段代码有疑问:

1
2
3
4
5
6
// 不能直接让 p 指针前进,
// p = p.next
// 断开原链表中的每个节点的 next 指针
ListNode temp = p.next;
p.next = null;
p = temp;

借助我们的可视化面板看一下就明白了。首先看下正确的写法,你可以多次点击  这一行代码即可看到链表分解的过程:

如果你不断开原链表中的每个节点的 next 指针,那么就会出错,因为结果链表中会包含一个环,你可以多次点击  这一行代码查看:

总的来说,如果我们需要把原链表的节点接到新链表上,而不是 new 新节点来组成新链表的话,那么断开节点和原链表之间的链接可能是必要的。那其实我们可以养成一个好习惯,但凡遇到这种情况,就把原链表的节点断开,这样就不会出错了。

合并 k 个有序链表

看下力扣第 23 题「合并K个升序链表」:

给你一个链表数组,每个链表都已经按升序排列。

请你将所有链表合并到一个升序链表中,返回合并后的链表。

示例 1:

输入: lists = [[1,4,5],[1,3,4],[2,6]]
输出:[1,1,2,3,4,4,5,6]
解释: 链表数组如下:
[
1->4->5,
1->3->4,
2->6
]
将它们合并到一个有序链表中得到。
1->1->2->3->4->4->5->6

示例 2:

输入: lists = []
输出:[]

示例 3:

输入: lists = [[]]
输出:[]

提示:

  • k == lists.length
  • 0 <= k <= 10^4
  • 0 <= lists[i].length <= 500
  • -10^4 <= lists[i][j] <= 10^4
  • lists[i] 按 升序 排列
  • lists[i].length 的总和不超过 10^4

题目来源:力扣 23. 合并 K 个升序链表

1
2
# 函数签名如下
def mergeKLists(lists: List[ListNode]) -> ListNode:

合并 k 个有序链表的逻辑类似合并两个有序链表,难点在于,如何快速得到 k 个节点中的最小节点,接到结果链表上?

这里我们就要用到优先级队列这种数据结构,把链表节点放入一个最小堆,就可以每次获得 k 个节点中的最小节点。关于优先级队列可以参考 优先级队列(二叉堆)原理及实现,本文不展开。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import heapq

class ListNode:
def __init__(self, val=0, next=None):
self.val = val
self.next = next

# 重载比较运算符,方便将 ListNode 加入最小堆
def __lt__(self, other):
return self.val < other.val

class Solution:
def mergeKLists(self, lists):
if not lists:
return None
# 虚拟头结点
dummy = ListNode(-1)
p = dummy
# 优先级队列,最小堆
pq = []
# 将 k 个链表的头结点加入最小堆
for i, head in enumerate(lists):
if head is not None:
heapq.heappush(pq, (head.val, i, head))

while pq:
# 获取最小节点,接到结果链表中
val, i, node = heapq.heappop(pq)
p.next = node
if node.next is not None:
heapq.heappush(pq, (node.next.val, i, node.next))
# p 指针不断前进
p = p.next

return dummy.next

这个算法是面试常考题,它的时间复杂度是多少呢?

优先队列 pq 中的元素个数最多是 k,所以一次 poll 或者 add 方法的时间复杂度是 ;所有的链表节点都会被加入和弹出 pq所以算法整体的时间复杂度是 ,其中  是链表的条数, 是这些链表的节点总数

[!tip]
这道题还有一个经典的解法,在 分治算法核心框架 中详细讲解,这里不展开。

单链表的倒数第 k 个节点

从前往后寻找单链表的第 k 个节点很简单,一个 for 循环遍历过去就找到了,但是如何寻找从后往前数的第 k 个节点呢?

那你可能说,假设链表有 n 个节点,倒数第 k 个节点就是正数第 n - k + 1 个节点,不也是一个 for 循环的事儿吗?

是的,但是算法题一般只给你一个 ListNode 头结点代表一条单链表,你不能直接得出这条链表的长度 n,而需要先遍历一遍链表算出 n 的值,然后再遍历链表计算第 n - k + 1 个节点。

也就是说,这个解法需要遍历两次链表才能得到出倒数第 k 个节点。

那么,我们能不能只遍历一次链表,就算出倒数第 k 个节点?可以做到的,如果是面试问到这道题,面试官肯定也是希望你给出只需遍历一次链表的解法。

这个解法就比较巧妙了,假设 k = 2,思路如下:

首先,我们先让一个指针 p1 指向链表的头节点 head,然后走 k 步:

现在的 p1,只要再走 n - k 步,就能走到链表末尾的空指针了对吧?

趁这个时候,再用一个指针 p2 指向链表头节点 head

接下来就很显然了,让 p1 和 p2 同时向前走,p1 走到链表末尾的空指针时前进了 n - k 步,p2 也从 head 开始前进了 n - k 步,停留在第 n - k + 1 个节点上,即恰好停链表的倒数第 k 个节点上:

这样,只遍历了一次链表,就获得了倒数第 k 个节点 p2

上述逻辑的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
# 返回链表的倒数第 k 个节点
def findFromEnd(head: ListNode, k: int) -> ListNode:
p1 = head
# p1 先走 k 步
for i in range(k):
p1 = p1.next
p2 = head
# p1 和 p2 同时走 n - k 步
while p1 != None:
p2 = p2.next
p1 = p1.next
# p2 现在指向第 n - k + 1 个节点,即倒数第 k 个节点
return p2

当然,如果用 big O 表示法来计算时间复杂度,无论遍历一次链表和遍历两次链表的时间复杂度都是 ,但上述这个算法更有技巧性。

很多链表相关的算法题都会用到这个技巧,比如说力扣第 19 题「删除链表的倒数第 N 个结点」:

给你一个链表,删除链表的倒数第 n 个结点,并且返回链表的头结点。

示例 1:

输入: head = [1,2,3,4,5], n = 2
输出: [1,2,3,5]

示例 2:

输入: head = [1], n = 1
输出: []

示例 3:

输入: head = [1,2], n = 1
输出: [1]

提示:

  • 链表中结点的数目为 sz
  • 1 <= sz <= 30
  • 0 <= Node.val <= 100
  • 1 <= n <= sz

进阶: 你能尝试使用一趟扫描实现吗?

题目来源:力扣 19. 删除链表的倒数第 N 个结点

我们直接看解法代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 主函数
class Solution:
def removeNthFromEnd(self, head: ListNode, n: int) -> ListNode:
# 虚拟头结点
dummy = ListNode(-1)
dummy.next = head
# 删除倒数第 n 个,要先找倒数第 n + 1 个节点
x = self.findFromEnd(dummy, n + 1)
# 删掉倒数第 n 个节点
x.next = x.next.next
return dummy.next

def findFromEnd(self, head: ListNode, k: int) -> ListNode:
# 代码见上文
pass

这个逻辑就很简单了,要删除倒数第 n 个节点,就得获得倒数第 n + 1 个节点的引用,可以用我们实现的 findFromEnd 来操作。

不过注意我们又使用了虚拟头结点的技巧,也是为了防止出现空指针的情况,比如说链表总共有 5 个节点,题目就让你删除倒数第 5 个节点,也就是第一个节点,那按照算法逻辑,应该首先找到倒数第 6 个节点。但第一个节点前面已经没有节点了,这就会出错。

但有了我们虚拟节点 dummy 的存在,就避免了这个问题,能够对这种情况进行正确的删除。

单链表的中点

力扣第 876 题「链表的中间结点」就是这个题目,问题的关键也在于我们无法直接得到单链表的长度 n,常规方法也是先遍历链表计算 n,再遍历一次得到第 n / 2 个节点,也就是中间节点。

如果想一次遍历就得到中间节点,也需要耍点小聪明,使用「快慢指针」的技巧:

我们让两个指针 slow 和 fast 分别指向链表头结点 head

每当慢指针 slow 前进一步,快指针 fast 就前进两步,这样,当 fast 走到链表末尾时,slow 就指向了链表中点

上述思路的代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
class Solution:
# 快慢指针初始化指向 head
def middleNode(self, head: ListNode) -> ListNode:
slow = head
fast = head
# 快指针走到末尾时停止
while fast is not None and fast.next is not None:
# 慢指针走一步,快指针走两步
slow = slow.next
fast = fast.next.next
# 慢指针指向中点
return slow

需要注意的是,如果链表长度为偶数,也就是说中点有两个的时候,我们这个解法返回的节点是靠后的那个节点。

另外,这段代码稍加修改就可以直接用到判断链表成环的算法题上。

判断链表是否包含环

判断链表是否包含环属于经典问题了,解决方案也是用快慢指针:

每当慢指针 slow 前进一步,快指针 fast 就前进两步。

如果 fast 最终能正常走到链表末尾,说明链表中没有环;如果 fast 走着走着竟然和 slow 相遇了,那肯定是 fast 在链表中转圈了,说明链表中含有环。

只需要把寻找链表中点的代码稍加修改就行了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Solution:
# 快慢指针初始化指向 head
def hasCycle(self, head: ListNode) -> bool:
slow = head
fast = head
# 快指针走到末尾时停止
while fast is not None and fast.next is not None:
# 慢指针走一步,快指针走两步
slow = slow.next
fast = fast.next.next
# 快慢指针相遇,说明含有环
if slow == fast:
return True
# 不包含环
return False

当然,这个问题还有进阶版,也是力扣第 142 题「环形链表 II」:如果链表中含有环,如何计算这个环的起点?

举个例子,环的起点是指下面这幅图中的节点 2:

这里先直接看一下寻找环起点的解法代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution:
def detectCycle(self, head: ListNode):
fast, slow = head, head
while fast and fast.next:
fast = fast.next.next
slow = slow.next
if fast == slow:
break

# 上面的代码类似 hasCycle 函数
if not fast or not fast.next:
# fast 遇到空指针说明没有环
return None

# 重新指向头结点
slow = head
# 快慢指针同步前进,相交点就是环起点
while slow != fast:
fast = fast.next
slow = slow.next
return slow

当快慢指针相遇时,让其中任一个指针指向头节点,然后让它俩以相同速度前进,再次相遇时所在的节点位置就是环开始的位置。

为什么要这样呢?这里简单说一下其中的原理。

我们假设快慢指针相遇时,慢指针 slow 走了 k 步,那么快指针 fast 一定走了 2k 步:

fast 一定比 slow 多走了 k 步,这多走的 k 步其实就是 fast 指针在环里转圈圈,所以 k 的值就是环长度的「整数倍」。

假设相遇点距环的起点的距离为 m,那么结合上图的 slow 指针,环的起点距头结点 head 的距离为 k - m,也就是说如果从 head 前进 k - m 步就能到达环起点。

巧的是,如果从相遇点继续前进 k - m 步,也恰好到达环起点。因为结合上图的 fast 指针,从相遇点开始走k步可以转回到相遇点,那走 k - m 步肯定就走到环起点了:

所以,只要我们把快慢指针中的任一个重新指向 head,然后两个指针同速前进,k - m 步后一定会相遇,相遇之处就是环的起点了。

两个链表是否相交

这个问题有意思,也是力扣第 160 题「相交链表」函数签名如下:

1
def getIntersectionNode(headA: ListNode, headB: ListNode) -> ListNode:

给你输入两个链表的头结点 headA 和 headB,这两个链表可能存在相交。

如果相交,你的算法应该返回相交的那个节点;如果没相交,则返回 null。

比如题目给我们举的例子,如果输入的两个链表如下图:

那么我们的算法应该返回 c1 这个节点。

这个题直接的想法可能是用 HashSet 记录一个链表的所有节点,然后和另一条链表对比,但这就需要额外的空间。

如果不用额外的空间,只使用两个指针,你如何做呢?

难点在于,由于两条链表的长度可能不同,两条链表之间的节点无法对应:

如果用两个指针 p1 和 p2 分别在两条链表上前进,并不能同时走到公共节点,也就无法得到相交节点 c1

**解决这个问题的关键是,通过某些方式,让 p1 和 p2 能够同时到达相交节点 c1**。

所以,我们可以让 p1 遍历完链表 A 之后开始遍历链表 B,让 p2 遍历完链表 B 之后开始遍历链表 A,这样相当于「逻辑上」两条链表接在了一起。

如果这样进行拼接,就可以让 p1 和 p2 同时进入公共部分,也就是同时到达相交节点 c1

那你可能会问,如果说两个链表没有相交点,是否能够正确的返回 null 呢?

这个逻辑可以覆盖这种情况的,相当于 c1 节点是 null 空指针嘛,可以正确返回 null。

按照这个思路,可以写出如下代码:

1
2
3
4
5
6
7
8
9
10
class Solution:
def getIntersectionNode(self, headA: ListNode, headB: ListNode) -> ListNode:
# p1 指向 A 链表头结点,p2 指向 B 链表头结点
p1, p2 = headA, headB
while p1 != p2:
# p1 走一步,如果走到 A 链表末尾,转到 B 链表
p1 = headB if p1 is None else p1.next
# p2 走一步,如果走到 B 链表末尾,转到 A 链表
p2 = headA if p2 is None else p2.next
return p1

这样,这道题就解决了,空间复杂度为 ,时间复杂度为 O(N)O(N)。

以上就是单链表的所有技巧,希望对你有启发。

2022/1/24 更新

评论区有不少优秀读者对最后一题「寻找两条链表的交点」提出了一些其他思路,也补充到这里。

首先有读者提到,如果把两条链表首尾相连,那么「寻找两条链表的交点」的问题转换成了前面讲的「寻找环起点」的问题:

说实话我没有想到这种思路,不得不说这是一个很巧妙的转换!不过需要注意的是,这道题说不让你改变原始链表的结构,所以你把题目输入的链表转化成环形链表求解之后记得还要改回来,否则无法通过。

另外,还有读者提到,既然「寻找两条链表的交点」的核心在于让 p1 和 p2 两个指针能够同时到达相交节点 c1,那么可以通过预先计算两条链表的长度来做到这一点,具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Solution:
def getIntersectionNode(self, headA: ListNode, headB: ListNode) -> ListNode:
lenA, lenB = 0, 0
# 计算两条链表的长度
p1, p2 = headA, headB
while p1:
lenA += 1
p1 = p1.next
while p2:
lenB += 1
p2 = p2.next

# 让 p1 和 p2 到达尾部的距离相同
p1, p2 = headA, headB
if lenA > lenB:
for _ in range(lenA - lenB):
p1 = p1.next
else:
for _ in range(lenB - lenA):
p2 = p2.next

# 看两个指针是否会相同,p1 == p2 时有两种情况:
# 1、要么是两条链表不相交,他俩同时走到尾部空指针
# 2、要么是两条链表相交,他俩走到两条链表的相交点
while p1 != p2:
p1 = p1.next
p2 = p2.next

return p1

虽然代码多一些,但是时间复杂度是还是 ,而且会更容易理解一些。

双指针技巧秒杀七道数组题目


读完本文,你不仅学会了算法套路,还可以顺便解决如下题目:

LeetCode 力扣 难度
26. Remove Duplicates from Sorted Array 26. 删除有序数组中的重复项
83. Remove Duplicates from Sorted List 83. 删除排序链表中的重复元素
27. Remove Element 27. 移除元素
283. Move Zeroes 283. 移动零
167. Two Sum II - Input Array Is Sorted 167. 两数之和 II - 输入有序数组
344. Reverse String 344. 反转字符串
5. Longest Palindromic Substring 5. 最长回文子串

[!前置知识]

阅读本文前,你需要先学习:

在处理数组和链表相关问题时,双指针技巧是经常用到的,双指针技巧主要分为两类:左右指针快慢指针

所谓左右指针,就是两个指针相向而行或者相背而行;而所谓快慢指针,就是两个指针同向而行,一快一慢。

对于单链表来说,大部分技巧都属于快慢指针,单链表的六大解题套路 都涵盖了,比如链表环判断,倒数第 K 个链表节点等问题,它们都是通过一个 fast 快指针和一个 slow 慢指针配合完成任务。

在数组中并没有真正意义上的指针,但我们可以把索引当做数组中的指针,这样也可以在数组中施展双指针技巧,本文主要讲数组相关的双指针算法

一、快慢指针技巧

原地修改

数组问题中比较常见的快慢指针技巧,是让你原地修改数组

比如说看下力扣第 26 题「删除有序数组中的重复项」,让你在有序数组去重:

给你一个 非严格递增排列 的数组 nums ,请你 原地 删除重复出现的元素,使每个元素 只出现一次 ,返回删除后数组的新长度。元素的 相对顺序 应该保持 一致 。然后返回 nums 中唯一元素的个数。

考虑 nums 的唯一元素的数量为 k ,你需要做以下事情确保你的题解可以被通过:

  • 更改数组 nums ,使 nums 的前 k 个元素包含唯一元素,并按照它们最初在 nums 中出现的顺序排列。nums 的其余元素与 nums 的大小不重要。
  • 返回 k 。

判题标准:

系统会用下面的代码来测试你的题解:

int[] nums = […]; // 输入数组
int[] expectedNums = […]; // 长度正确的期望答案

int k = removeDuplicates(nums); // 调用

assert k == expectedNums.length;
for (int i = 0; i < k; i++) {
assert nums[i] == expectedNums[i];
}

如果所有断言都通过,那么您的题解将被 通过

示例 1:

输入: nums = [1,1,2]
输出: 2, nums = [1,2,_]
解释: 函数应该返回新的长度 2 ,并且原数组 nums 的前两个元素被修改为 1, 2 不需要考虑数组中超出新长度后面的元素。

示例 2:

输入: nums = [0,0,1,1,1,2,2,3,3,4]
输出: 5, nums = [0,1,2,3,4]
解释: 函数应该返回新的长度 5 , 并且原数组 nums 的前五个元素被修改为 0, 1, 2, 3, 4 。不需要考虑数组中超出新长度后面的元素。

提示:

  • 1 <= nums.length <= 3 * 104
  • -104 <= nums[i] <= 104
  • nums 已按 非严格递增 排列

题目来源:力扣 26. 删除有序数组中的重复项

函数签名如下:

1
def removeDuplicates(nums: List[int]) -> int:

简单解释一下什么是原地修改:

如果不是原地修改的话,我们直接 new 一个 int[] 数组,把去重之后的元素放进这个新数组中,然后返回这个新数组即可。

但是现在题目让你原地删除,不允许 new 新数组,只能在原数组上操作,然后返回一个长度,这样就可以通过返回的长度和原始数组得到我们去重后的元素有哪些了。

由于数组已经排序,所以重复的元素一定连在一起,找出它们并不难。但如果毎找到一个重复元素就立即原地删除它,由于数组中删除元素涉及数据搬移,整个时间复杂度是会达到 

高效解决这道题就要用到快慢指针技巧:

我们让慢指针 slow 走在后面,快指针 fast 走在前面探路,找到一个不重复的元素就赋值给 slow 并让 slow 前进一步。

这样,就保证了 nums[0..slow] 都是无重复的元素,当 fast 指针遍历完整个数组 nums 后,nums[0..slow] 就是整个数组去重之后的结果。

看代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Solution:
def removeDuplicates(self, nums: List[int]) -> int:
if len(nums) == 0:
return 0
slow = 0
fast = 0
while fast < len(nums):
if nums[fast] != nums[slow]:
slow += 1
# 维护 nums[0..slow] 无重复
nums[slow] = nums[fast]
fast += 1
# 数组长度为索引 + 1
return slow + 1

再简单扩展一下,看看力扣第 83 题「删除排序链表中的重复元素」,如果给你一个有序的单链表,如何去重呢?

其实和数组去重是一模一样的,唯一的区别是把数组赋值操作变成操作指针而已,你对照着之前的代码来看:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Solution:
def deleteDuplicates(self, head: ListNode) -> ListNode:
if head is None:
return None
slow = head
fast = head
while fast is not None:
if fast.val != slow.val:
# nums[slow] = nums[fast];
slow.next = fast
# slow++;
slow = slow.next
# fast++
fast = fast.next
# 断开与后面重复元素的连接
slow.next = None
return head

这里可能有读者会问,链表中那些重复的元素并没有被删掉,就让这些节点在链表上挂着,合适吗?

这就要探讨不同语言的特性了,像 Java/Python 这类带有垃圾回收的语言,可以帮我们自动找到并回收这些「悬空」的链表节点的内存,而像 C++ 这类语言没有自动垃圾回收的机制,确实需要我们编写代码时手动释放掉这些节点的内存。

不过话说回来,就算法思维的培养来说,我们只需要知道这种快慢指针技巧即可。

除了让你在有序数组/链表中去重,题目还可能让你对数组中的某些元素进行「原地删除」

比如力扣第 27 题「移除元素」,看下题目:

给你一个数组 nums 和一个值 val,你需要 原地 移除所有数值等于 val 的元素。元素的顺序可能发生改变。然后返回 nums 中与 val 不同的元素的数量。

假设 nums 中不等于 val 的元素数量为 k,要通过此题,您需要执行以下操作:

  • 更改 nums 数组,使 nums 的前 k 个元素包含不等于 val 的元素。nums 的其余元素和 nums 的大小并不重要。
  • 返回 k

用户评测:

评测机将使用以下代码测试您的解决方案:

int[] nums = […]; // 输入数组
int val = …; // 要移除的值
int[] expectedNums = […]; // 长度正确的预期答案。
// 它以不等于 val 的值排序。

int k = removeElement(nums, val); // 调用你的实现

assert k == expectedNums.length;
sort(nums, 0, k); // 排序 nums 的前 k 个元素
for (int i = 0; i < actualLength; i++) {
assert nums[i] == expectedNums[i];
}

如果所有的断言都通过,你的解决方案将会 通过

示例 1:

输入: nums = [3,2,2,3], val = 3
输出: 2, nums = [2,2,_,_]
解释: 你的函数应该返回 k = 2, 并且 nums 中的前两个元素均为 2。
你在返回的 k 个元素之外留下了什么并不重要(因此它们并不计入评测)。

示例 2:

输入: nums = [0,1,2,2,3,0,4,2], val = 2
输出: 5, nums = [0,1,4,0,3,_,_,_]
解释: 你的函数应该返回 k = 5,并且 nums 中的前五个元素为 0,0,1,3,4。
注意这五个元素可以任意顺序返回。
你在返回的 k 个元素之外留下了什么并不重要(因此它们并不计入评测)。

提示:

  • 0 <= nums.length <= 100
  • 0 <= nums[i] <= 50
  • 0 <= val <= 100

题目来源:力扣 27. 移除元素

1
2
# 函数签名如下
def removeElement(nums: List[int], val: int) -> int:

题目要求我们把 nums 中所有值为 val 的元素原地删除,依然需要使用快慢指针技巧:

如果 fast 遇到值为 val 的元素,则直接跳过,否则就赋值给 slow 指针,并让 slow 前进一步。

这和前面说到的数组去重问题解法思路是完全一样的,直接看代码:

1
2
3
4
5
6
7
8
9
class Solution:
def removeElement(self, nums: List[int], val: int) -> int:
fast, slow = 0, 0
while fast < len(nums):
if nums[fast] != val:
nums[slow] = nums[fast]
slow += 1
fast += 1
return slow

注意这里和有序数组去重的解法有一个细节差异,我们这里是先给 nums[slow] 赋值然后再给 slow++,这样可以保证 nums[0..slow-1] 是不包含值为 val 的元素的,最后的结果数组长度就是 slow

实现了这个 removeElement 函数,接下来看看力扣第 283 题「移动零」:

给你输入一个数组 nums,请你原地修改,将数组中的所有值为 0 的元素移到数组末尾,函数签名如下:

1
def moveZeroes(nums: List[int]) -> None:

比如说给你输入 nums = [0,1,4,0,2],你的算法没有返回值,但是会把 nums 数组原地修改成 [1,4,2,0,0]

结合之前说到的几个题目,你是否有已经有了答案呢?

稍微修改上一题中的 removeElement 函数就可以完成这道题,或者直接复用 removeElement 函数也可以。

题目让我们将所有 0 移到最后,其实就相当于移除 nums 中的所有 0,然后再把后面的元素都赋值为 0:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Solution:
def moveZeroes(self, nums):
# 去除 nums 中的所有 0
# 返回去除 0 之后的数组长度
p = self.removeElement(nums, 0)
# 将 p 之后的所有元素赋值为 0
for i in range(p, len(nums)):
nums[i] = 0

# 双指针技巧,复用 [27. 移除元素] 的解法。
def removeElement(self, nums, val):
fast, slow = 0, 0
while fast < len(nums):
if nums[fast] != val:
nums[slow] = nums[fast]
slow += 1
fast += 1
return slow

你可以点开下面的可视化面板,多次点击  这行代码查看快慢指针的运动,然后多次点击  这行代码将后面的元素都改为 0:

到这里,原地修改数组的这些题目就已经差不多了。

滑动窗口

数组中另一大类快慢指针的题目就是「滑动窗口算法」。我在另一篇文章 滑动窗口算法核心框架详解 给出了滑动窗口的代码框架:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 滑动窗口算法框架伪码
int left = 0, right = 0;

while (right < nums.size()) {
// 增大窗口
window.addLast(nums[right]);
right++;

while (window needs shrink) {
// 缩小窗口
window.removeFirst(nums[left]);
left++;
}
}

具体的题目本文就不重复了,这里只强调滑动窗口算法的快慢指针特性:

left 指针在后,right 指针在前,两个指针中间的部分就是「窗口」,算法通过扩大和缩小「窗口」来解决某些问题。

二、左右指针的常用算法

二分查找

我在另一篇文章 二分查找框架详解 中有详细探讨二分搜索代码的细节问题,这里只写最简单的二分算法,旨在突出它的双指针特性:

1
2
3
4
5
6
7
8
9
10
11
12
def binarySearch(nums: List[int], target: int) -> int:
# 一左一右两个指针相向而行
left, right = 0, len(nums) - 1
while left <= right:
mid = (right + left) // 2
if nums[mid] == target:
return mid
elif nums[mid] < target:
left = mid + 1
elif nums[mid] > target:
right = mid - 1
return -1

n 数之和

看下力扣第 167 题「两数之和 II」:

给你一个下标从 1 开始的整数数组 numbers ,该数组已按 非递减顺序排列  ,请你从数组中找出满足相加之和等于目标数 target 的两个数。如果设这两个数分别是 numbers[index1] 和 numbers[index2] ,则 1 <= index1 < index2 <= numbers.length 。

以长度为 2 的整数数组 [index1, index2] 的形式返回这两个整数的下标 index1 和 index2

你可以假设每个输入 只对应唯一的答案 ,而且你 不可以 重复使用相同的元素。

你所设计的解决方案必须只使用常量级的额外空间。

 

示例 1:

输入: numbers = [_2_,_7_,11,15], target = 9
输出:\1,2]
解释: 2 与 7 之和等于目标数 9 。因此 index1 = 1, index2 = 2 。返回 \1, 2] 。

示例 2:

输入: numbers = [_2_,3,_4_], target = 6
输出:[1,3]
解释: 2 与 4 之和等于目标数 6 。因此 index1 = 1, index2 = 3 。返回 [1, 3] 。

示例 3:

输入: numbers = [_-1_,0], target = -1
输出:[1,2]
解释:-1 与 0 之和等于目标数 -1 。因此 index1 = 1, index2 = 2 。返回 [1, 2] 。

提示:

  • 2 <= numbers.length <= 3 * 104
  • -1000 <= numbers[i] <= 1000
  • numbers 按 非递减顺序 排列
  • -1000 <= target <= 1000
  • 仅存在一个有效答案

题目来源:力扣 167. 两数之和 II - 输入有序数组

只要数组有序,就应该想到双指针技巧。这道题的解法有点类似二分查找,通过调节 left 和 right 就可以调整 sum 的大小:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Solution:
def twoSum(self, numbers: List[int], target: int) -> List[int]:
# 一左一右两个指针相向而行
left, right = 0, len(numbers) - 1
while left < right:
sum = numbers[left] + numbers[right]
if sum == target:
# 题目要求的索引是从 1 开始的
return [left + 1, right + 1]
elif sum < target:
# 让 sum 大一点
left += 1
elif sum > target:
# 让 sum 小一点
right -= 1
return [-1, -1]

我在另一篇文章 一个函数秒杀所有 nSum 问题 中也运用类似的左右指针技巧给出了 nSum 问题的一种通用思路,本质上利用的也是双指针技巧。

反转数组

一般编程语言都会提供 reverse 函数,其实这个函数的原理非常简单,力扣第 344 题「反转字符串」就是类似的需求,让你反转一个 char[] 类型的字符数组,我们直接看代码吧:

1
2
3
4
5
6
7
8
9
10
def reverseString(s: List[str]) -> None:
# 一左一右两个指针相向而行
left, right = 0, len(s) - 1
while left < right:
# 交换 s[left] 和 s[right]
temp = s[left]
s[left] = s[right]
s[right] = temp
left += 1
right -= 1

关于数组翻转的更多进阶问题,可以参见 二维数组的花式遍历

回文串判断

回文串就是正着读和反着读都一样的字符串。比如说字符串 aba 和 abba 都是回文串,因为它们对称,反过来还是和本身一样;反之,字符串 abac 就不是回文串。

现在你应该能感觉到回文串问题和左右指针肯定有密切的联系,比如让你判断一个字符串是不是回文串,你可以写出下面这段代码:

1
2
3
4
5
6
7
8
9
def isPalindrome(s: str) -> bool:
# 一左一右两个指针相向而行
left, right = 0, len(s) - 1
while left < right:
if s[left] != s[right]:
return False
left += 1
right -= 1
return True

那接下来我提升一点难度,给你一个字符串,让你用双指针技巧从中找出最长的回文串,你会做吗?

这就是力扣第 5 题「最长回文子串」:

给你一个字符串 s,找到 s 中最长的 回文 子串。

示例 1:

输入: s = “babad”
输出:“bab”
解释:“aba” 同样是符合题意的答案。

示例 2:

输入: s = “cbbd”
输出:“bb”

提示:

  • 1 <= s.length <= 1000
  • s 仅由数字和英文字母组成

题目来源:力扣 5. 最长回文子串

函数签名如下:

1
def longestPalindrome(s: str):

找回文串的难点在于,回文串的的长度可能是奇数也可能是偶数,解决该问题的核心是从中心向两端扩散的双指针技巧

如果回文串的长度为奇数,则它有一个中心字符;如果回文串的长度为偶数,则可以认为它有两个中心字符。所以我们可以先实现这样一个函数:

1
2
3
4
5
6
7
8
9
# 在 s 中寻找以 s[l] 和 s[r] 为中心的最长回文串
def palindrome(s: str, l: int, r: int) -> str:
# 防止索引越界
while l >= 0 and r < len(s) and s[l] == s[r]:
# 双指针,向两边展开
l -= 1
r += 1
# 此时 s[l+1..r-1] 就是最长回文串
return s[l + 1: r]

这样,如果输入相同的 l 和 r,就相当于寻找长度为奇数的回文串,如果输入相邻的 l 和 r,则相当于寻找长度为偶数的回文串。

那么回到最长回文串的问题,解法的大致思路就是:

1
2
3
4
for 0 <= i < len(s):
找到以 s[i] 为中心的回文串
找到以 s[i] 和 s[i+1] 为中心的回文串
更新答案

翻译成代码,就可以解决最长回文子串这个问题:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution:
def longestPalindrome(self, s: str) -> str:
res = ""
for i in range(len(s)):
# 以 s[i] 为中心的最长回文子串
s1 = self.palindrome(s, i, i)
# 以 s[i] 和 s[i+1] 为中心的最长回文子串
s2 = self.palindrome(s, i, i + 1)
# res = longest(res, s1, s2)
res = res if len(res) > len(s1) else s1
res = res if len(res) > len(s2) else s2
return res

def palindrome(self, s: str, l: int, r: int) -> str:
# 防止索引越界
while l >= 0 and r < len(s) and s[l] == s[r]:
# 向两边展开
l -= 1
r += 1
# 此时 s[l+1..r-1] 就是最长回文串
return s[l + 1:r]

你应该能发现最长回文子串使用的左右指针和之前题目的左右指针有一些不同:之前的左右指针都是从两端向中间相向而行,而回文子串问题则是让左右指针从中心向两端扩展。不过这种情况也就回文串这类问题会遇到,所以我也把它归为左右指针了。

滑动窗口算法核心代码模板


读完本文,你不仅学会了算法套路,还可以顺便解决如下题目:

LeetCode 力扣 难度
76. Minimum Window Substring 76. 最小覆盖子串
567. Permutation in String 567. 字符串的排列
438. Find All Anagrams in a String 438. 找到字符串中所有字母异位词
3. Longest Substring Without Repeating Characters 3. 无重复字符的最长子串

前文 双指针技巧汇总 讲解了一些较为简单的数组双指针技巧,本文就讲解一个稍微复杂的技巧:滑动窗口技巧。

滑动窗口可以归为快慢双指针,一快一慢两个指针前后相随,中间的部分就是窗口。滑动窗口算法技巧主要用来解决子数组问题,比如让你寻找符合某个条件的最长/最短子数组

滑动窗口框架概览

如果用暴力解的话,你需要嵌套 for 循环这样穷举所有子数组,时间复杂度是 

1
2
3
4
5
for (int i = 0; i < nums.length; i++) {
for (int j = i; j < nums.length; j++) {
// nums[i, j] 是一个子数组
}
}

滑动窗口算法技巧的思路也不难,就是维护一个窗口,不断滑动,然后更新答案,该算法的大致逻辑如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 索引区间 [left, right) 是窗口
int left = 0, right = 0;

while (right < nums.size()) {
// 增大窗口
window.addLast(nums[right]);
right++;

while (window needs shrink) {
// 缩小窗口
window.removeFirst(nums[left]);
left++;
}
}

基于滑动窗口算法框架写出的代码,时间复杂度是 ,比嵌套 for 循环的暴力解法效率高。

为啥是

肯定有读者要问了,你这个滑动窗口框架不也用了一个嵌套 while 循环?为啥复杂度是  呢?

简单说,指针 left, right 不会回退(它们的值只增不减),所以字符串/数组中的每个元素都只会进入窗口一次,然后被移出窗口一次,不会说有某些元素多次进入和离开窗口,所以算法的时间复杂度就和字符串/数组的长度成正比。

反观嵌套 for 循环的暴力解法,那个 j 会回退,所以某些元素会进入和离开窗口多次,所以时间复杂度就是  了。

我在 算法时空复杂度分析实用指南 有具体教大家如何从理论上估算时间空间复杂度,这里就不展开了。

为啥滑动窗口能在 的时间穷举子数组?

这个问题本身就是错误的,滑动窗口并不能穷举出所有子串。要想穷举出所有子串,必须用那个嵌套 for 循环。

然而对于某些题目,并不需要穷举所有子串,就能找到题目想要的答案。滑动窗口就是这种场景下的一套算法模板,帮你对穷举过程进行剪枝优化,避免冗余计算。

所以在 算法的本质 中我把滑动窗口算法归为「如何聪明地穷举」一类。

其实困扰大家的,不是算法的思路,而是各种细节问题。比如说如何向窗口中添加新元素,如何缩小窗口,在窗口滑动的哪个阶段更新结果。即便你明白了这些细节,代码也容易出 bug,找 bug 还不知道怎么找,真的挺让人心烦的。

所以今天我就写一套滑动窗口算法的代码框架,我连再哪里做输出 debug 都给你写好了,以后遇到相关的问题,你就默写出来如下框架然后改三个地方就行,保证不会出 bug

因为本文的例题大多是子串相关的题目,字符串实际上就是数组,所以我就把输入设置成字符串了。你做题的时候根据具体题目自行变通即可:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# 滑动窗口算法伪码框架
def slidingWindow(s: str):
# 用合适的数据结构记录窗口中的数据,根据具体场景变通
# 比如说,我想记录窗口中元素出现的次数,就用 map
# 如果我想记录窗口中的元素和,就可以只用一个 int
window = ...

left, right = 0, 0
while right < len(s):
# c 是将移入窗口的字符
c = s[right]
window.add(c)
# 增大窗口
right += 1
# 进行窗口内数据的一系列更新
...

# *** debug 输出的位置 ***
# 注意在最终的解法代码中不要 print
# 因为 IO 操作很耗时,可能导致超时
# print(f"window: [{left}, {right})")
# ***********************

# 判断左侧窗口是否要收缩
while left < right and window needs shrink:
# d 是将移出窗口的字符
d = s[left]
window.remove(d)
# 缩小窗口
left += 1
# 进行窗口内数据的一系列更新
...

框架中两处 ... 表示的更新窗口数据的地方,在具体的题目中,你需要做的就是往这里面填代码逻辑。而且,这两个 ... 处的操作分别是扩大和缩小窗口的更新操作,等会你会发现它们操作是完全对称的。

基于这个框架,遇到子串/子数组相关的题目,你只需要回答以下三个问题:

1、什么时候应该移动 right 扩大窗口?窗口加入字符时,应该更新哪些数据?

2、什么时候窗口应该暂停扩大,开始移动 left 缩小窗口?从窗口移出字符时,应该更新哪些数据?

3、什么时候应该更新结果?

只要能回答这三个问题,就说明可以使用滑动窗口技巧解题。

下面就直接上四道力扣原题来套这个框架,其中第一道题会详细说明其原理,其他题目就直接闭眼睛秒杀了。

一、最小覆盖子串

先来看看力扣第 76 题「最小覆盖子串」难度 Hard:

给你一个字符串 s 、一个字符串 t 。返回 s 中涵盖 t 所有字符的最小子串。如果 s 中不存在涵盖 t 所有字符的子串,则返回空字符串 "" 。

注意:

  • 对于 t 中重复字符,我们寻找的子字符串中该字符数量必须不少于 t 中该字符数量。
  • 如果 s 中存在这样的子串,我们保证它是唯一的答案。

示例 1:

输入: s = “ADOBECODEBANC”, t = “ABC”
输出:“BANC”
解释: 最小覆盖子串 “BANC” 包含来自字符串 t 的 ‘A’、’B’ 和 ‘C’。

示例 2:

输入: s = “a”, t = “a”
输出:“a”
解释: 整个字符串 s 是最小覆盖子串。

示例 3:

输入: s = “a”, t = “aa”
输出: “”
解释: t 中两个字符 ‘a’ 均应包含在 s 的子串中,
因此没有符合条件的子字符串,返回空字符串。

提示:

  • m == s.length
  • n == t.length
  • 1 <= m, n <= 105
  • s 和 t 由英文字母组成

进阶: 你能设计一个在 o(m+n) 时间内解决此问题的算法吗?

题目来源:力扣 76. 最小覆盖子串

就是说要在 S(source) 中找到包含 T(target) 中全部字母的一个子串,且这个子串一定是所有可能子串中最短的。

如果我们使用暴力解法,代码大概是这样的:

1
2
3
4
for (int i = 0; i < s.length(); i++)
for (int j = i + 1; j < s.length(); j++)
if s[i:j] 包含 t 的所有字母:
更新答案

思路很直接,但是显然,这个算法的复杂度肯定大于  了,不好。

滑动窗口算法的思路是这样

1、我们在字符串 S 中使用双指针中的左右指针技巧,初始化 left = right = 0,把索引左闭右开区间 [left, right) 称为一个「窗口」。

为什么要「左闭右开」区间

理论上你可以设计两端都开或者两端都闭的区间,但设计为左闭右开区间是最方便处理的

因为这样初始化 left = right = 0 时区间 [0, 0) 中没有元素,但只要让 right 向右移动(扩大)一位,区间 [0, 1) 就包含一个元素 0 了。

如果你设置为两端都开的区间,那么让 right 向右移动一位后开区间 (0, 1) 仍然没有元素;如果你设置为两端都闭的区间,那么初始区间 [0, 0] 就包含了一个元素。这两种情况都会给边界处理带来不必要的麻烦。

2、我们先不断地增加 right 指针扩大窗口 [left, right),直到窗口中的字符串符合要求(包含了 T 中的所有字符)。

3、此时,我们停止增加 right,转而不断增加 left 指针缩小窗口 [left, right),直到窗口中的字符串不再符合要求(不包含 T 中的所有字符了)。同时,每次增加 left,我们都要更新一轮结果。

4、重复第 2 和第 3 步,直到 right 到达字符串 S 的尽头。

这个思路其实也不难,第 2 步相当于在寻找一个「可行解」,然后第 3 步在优化这个「可行解」,最终找到最优解,也就是最短的覆盖子串。左右指针轮流前进,窗口大小增增减减,就好像一条毛毛虫,一伸一缩,不断向右滑动,这就是「滑动窗口」这个名字的来历。

下面画图理解一下,needs 和 window 相当于计数器,分别记录 T 中字符出现次数和「窗口」中的相应字符的出现次数。

初始状态:

增加 right,直到窗口 [left, right) 包含了 T 中所有字符:

现在开始增加 left,缩小窗口 [left, right)

直到窗口中的字符串不再符合要求,left 不再继续移动:

之后重复上述过程,先移动 right,再移动 left… 直到 right 指针到达字符串 S 的末端,算法结束。

如果你能够理解上述过程,恭喜,你已经完全掌握了滑动窗口算法思想。现在我们来看看这个滑动窗口代码框架怎么用

首先,初始化 window 和 need 两个哈希表,记录窗口中的字符和需要凑齐的字符:

1
2
3
4
5
6
7
8
// 记录 window 中的字符出现次数
HashMap<Character, Integer> window = new HashMap<>();
// 记录所需的字符出现次数
HashMap<Character, Integer> need = new HashMap<>();
for (int i = 0; i < t.length(); i++) {
char c = t.charAt(i);
need.put(c, need.getOrDefault(c, 0) + 1);
}

然后,使用 left 和 right 变量初始化窗口的两端,不要忘了,区间 [left, right) 是左闭右开的,所以初始情况下窗口没有包含任何元素:

1
2
3
4
5
6
7
8
9
10
int left = 0, right = 0;
int valid = 0;
while (right < s.length()) {
// c 是将移入窗口的字符
char c = s.charAt(right);
// 右移窗口
right++;
// 进行窗口内数据的一系列更新
...
}

其中 valid 变量表示窗口中满足 need 条件的字符个数,如果 valid 和 need.size 的大小相同,则说明窗口已满足条件,已经完全覆盖了串 T

现在开始套模板,只需要思考以下几个问题

1、什么时候应该移动 right 扩大窗口?窗口加入字符时,应该更新哪些数据?

2、什么时候窗口应该暂停扩大,开始移动 left 缩小窗口?从窗口移出字符时,应该更新哪些数据?

3、我们要的结果应该在扩大窗口时还是缩小窗口时进行更新?

如果一个字符进入窗口,应该增加 window 计数器;如果一个字符将移出窗口的时候,应该减少 window 计数器;当 valid 满足 need 时应该收缩窗口;应该在收缩窗口的时候更新最终结果。

下面是完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Solution:
def minWindow(self, s: str, t: str) -> str:
need, window = {}, {}
for c in t:
need[c] = need.get(c, 0) + 1

left = 0
right = 0
valid = 0
# 记录最小覆盖子串的起始索引及长度
start = 0
length = float('inf')
while right < len(s):
# c 是将移入窗口的字符
c = s[right]
# 扩大窗口
right += 1
# 进行窗口内数据的一系列更新
if c in need:
window[c] = window.get(c, 0) + 1
if window[c] == need[c]:
valid += 1

# 判断左侧窗口是否要收缩
while valid == len(need):
# 在这里更新最小覆盖子串
if right - left < length:
start = left
length = right - left
# d 是将移出窗口的字符
d = s[left]
# 缩小窗口
left += 1
# 进行窗口内数据的一系列更新
if d in need:
if window[d] == need[d]:
valid -= 1
window[d] -= 1

# 返回最小覆盖子串
return "" if length == float('inf') else s[start: start + length]

你可以点开下面的可视化面板,多次点击  这一行代码,即可看到滑动窗口 [left, right) 的滑动过程:

使用 Java 的读者请注意

对 Java 包装类进行比较时要尤为小心,IntegerString 等类型应该用 equals 方法判定相等,而不能直接用等号 ==,否则会出错。所以在缩小窗口更新数据的时候,不能直接写为 window.get(d) == need.get(d),而要用 window.get(d).equals(need.get(d)),之后的题目代码同理。

上面的代码中,当我们发现某个字符在 window 的数量满足了 need 的需要,就要更新 valid,表示有一个字符已经满足要求。而且,你能发现,两次对窗口内数据的更新操作是完全对称的。

当 valid == need.size() 时,说明 T 中所有字符已经被覆盖,已经得到一个可行的覆盖子串,现在应该开始收缩窗口了,以便得到「最小覆盖子串」。

移动 left 收缩窗口时,窗口内的字符都是可行解,所以应该在收缩窗口的阶段进行最小覆盖子串的更新,以便从可行解中找到长度最短的最终结果。

至此,应该可以完全理解这套框架了,滑动窗口算法又不难,就是细节问题让人烦得很。以后遇到滑动窗口算法,你就按照这框架写代码,保准没有 bug,还省事儿

下面就直接利用这套框架秒杀几道题吧,你基本上一眼就能看出思路了。

二、字符串排列

这是力扣第 567 题「字符串的排列」,难度中等:

给你两个字符串 s1 和 s2 ,写一个函数来判断 s2 是否包含 s1 的排列。如果是,返回 true ;否则,返回 false 。

换句话说,s1 的排列之一是 s2 的 子串 。

示例 1:

输入: s1 = “ab” s2 = “eidbaooo”
输出: true
解释: s2 包含 s1 的排列之一 (“ba”).

示例 2:

输入: s1= “ab” s2 = “eidboaoo”
输出: false

提示:

  • 1 <= s1.length, s2.length <= 104
  • s1 和 s2 仅包含小写字母

题目来源:力扣 567. 字符串的排列

注意哦,输入的 s1 是可以包含重复字符的,所以这个题难度不小。

这种题目,是明显的滑动窗口算法,相当给你一个 S 和一个 T,请问你 S 中是否存在一个和 T 长度相同的子串,且包含 T 中所有字符

首先,先复制粘贴之前的算法框架代码,然后明确刚才提出的几个问题,即可写出这道题的答案:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Solution:
# 判断 s 中是否存在 t 的排列
def checkInclusion(self, t: str, s: str) -> bool:
need = {}
window = {}
for c in t:
need[c] = need.get(c, 0) + 1

left = 0
right = 0
valid = 0
while right < len(s):
c = s[right]
right += 1
# 进行窗口内数据的一系列更新
if c in need:
window[c] = window.get(c, 0) + 1
if window[c] == need[c]:
valid += 1

# 判断左侧窗口是否要收缩
while right - left >= len(t):
# 在这里判断是否找到了合法的子串
if valid == len(need):
return True
d = s[left]
left += 1
# 进行窗口内数据的一系列更新
if d in need:
if window[d] == need[d]:
valid -= 1
window[d] -= 1

# 未找到符合条件的子串
return False

对于这道题的解法代码,基本上和最小覆盖子串一模一样,只需要改变几个地方:

1、本题移动 left 缩小窗口的时机是窗口大小大于 t.length() 时,因为排列嘛,显然长度应该是一样的。

2、当发现 valid == need.size() 时,就说明窗口中就是一个合法的排列,所以立即返回 true

至于如何处理窗口的扩大和缩小,和最小覆盖子串完全相同。

小优化

由于这道题中 [left, right) 其实维护的是一个定长的窗口,窗口长度为 t.length()。因为定长窗口每次向前滑动时只会移出一个字符,所以完全可以把内层的 while 改成 if,效果是一样的。

三、找所有字母异位词

这是力扣第 438 题「找到字符串中所有字母异位词」,难度中等:

给定两个字符串 s 和 p,找到 s 中所有 p 的 异位词 的子串,返回这些子串的起始索引。不考虑答案输出的顺序。

异位词 指由相同字母重排列形成的字符串(包括相同的字符串)。

示例 1:

输入: s = “cbaebabacd”, p = “abc”
输出: [0,6]
解释:
起始索引等于 0 的子串是 “cba”, 它是 “abc” 的异位词。
起始索引等于 6 的子串是 “bac”, 它是 “abc” 的异位词。

 示例 2:

输入: s = “abab”, p = “ab”
输出: [0,1,2]
解释:
起始索引等于 0 的子串是 “ab”, 它是 “ab” 的异位词。
起始索引等于 1 的子串是 “ba”, 它是 “ab” 的异位词。
起始索引等于 2 的子串是 “ab”, 它是 “ab” 的异位词。

提示:

  • 1 <= s.length, p.length <= 3 * 104
  • s 和 p 仅包含小写字母

题目来源:力扣 438. 找到字符串中所有字母异位词

呵呵,这个所谓的字母异位词,不就是排列吗,搞个高端的说法就能糊弄人了吗?相当于,输入一个串 S,一个串 T,找到 S 中所有 T 的排列,返回它们的起始索引

直接默写一下框架,明确刚才讲的三个问题,即可秒杀这道题:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution:
def findAnagrams(self, s: str, t: str) -> list[int]:
need = {}
window = {}
for c in t:
need[c] = need.get(c, 0) + 1

left = 0
right = 0
valid = 0
# 记录结果
res = []
while right < len(s):
c = s[right]
right += 1
# 进行窗口内数据的一系列更新
if c in need:
window[c] = window.get(c, 0) + 1
if window[c] == need[c]:
valid += 1
# 判断左侧窗口是否要收缩
while right - left >= len(t):
# 当窗口符合条件时,把起始索引加入 res
if valid == len(need):
res.append(left)
d = s[left]
left += 1
# 进行窗口内数据的一系列更新
if d in need:
if window[d] == need[d]:
valid -= 1
window[d] -= 1
return res

跟寻找字符串的排列一样,只是找到一个合法异位词(排列)之后将起始索引加入 res 即可。

你可以点开下面的可视化面板,多次点击  这一行代码,即可看到定长窗口滑动的过程:

四、最长无重复子串

这是力扣第 3 题「无重复字符的最长子串」,难度中等:

给定一个字符串 s ,请你找出其中不含有重复字符的 最长 子串 的长度。

示例 1:

输入: s = “abcabcbb”
输出: 3
解释: 因为无重复字符的最长子串是 "abc",所以其长度为 3。

示例 2:

输入: s = “bbbbb”
输出: 1
解释: 因为无重复字符的最长子串是 "b",所以其长度为 1。

示例 3:

输入: s = “pwwkew”
输出: 3
解释: 因为无重复字符的最长子串是 "wke",所以其长度为 3。
  请注意,你的答案必须是 子串 的长度,"pwke" 是一个_子序列,_不是子串。

提示:

  • 0 <= s.length <= 5 * 104
  • s 由英文字母、数字、符号和空格组成

题目来源:力扣 3. 无重复字符的最长子串

这个题终于有了点新意,不是一套框架就出答案,不过反而更简单了,稍微改一改框架就行了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution:
def lengthOfLongestSubstring(self, s: str) -> int:
window = {}

left = 0
right = 0
# 记录结果
res = 0
while right < len(s):
c = s[right]
right += 1
# 进行窗口内数据的一系列更新
window[c] = window.get(c, 0) + 1
# 判断左侧窗口是否要收缩
while window[c] > 1:
d = s[left]
left += 1
# 进行窗口内数据的一系列更新
window[d] = window.get(d, 0) - 1
# 在这里更新答案
res = max(res, right - left)
return res

这就是变简单了,连 need 和 valid 都不需要,而且更新窗口内数据也只需要简单的更新计数器 window 即可。

当 window[c] 值大于 1 时,说明窗口中存在重复字符,不符合条件,就该移动 left 缩小窗口了嘛。

唯一需要注意的是,在哪里更新结果 res 呢?我们要的是最长无重复子串,哪一个阶段可以保证窗口中的字符串是没有重复的呢?

这里和之前不一样,要在收缩窗口完成后更新 res,因为窗口收缩的 while 条件是存在重复元素,换句话说收缩完成后一定保证窗口中没有重复嘛。

好了,滑动窗口算法模板就讲到这里,希望大家能理解其中的思想,记住算法模板并融会贯通。回顾一下,遇到子数组/子串相关的问题,你只要能回答出来以下几个问题,就能运用滑动窗口算法:

1、什么时候应该扩大窗口?

2、什么时候应该缩小窗口?

3、什么时候应该更新答案?

二叉树系列算法核心纲领


读完本文,你不仅学会了算法套路,还可以顺便解决如下题目:

LeetCode 力扣 难度
104. Maximum Depth of Binary Tree 104. 二叉树的最大深度
543. Diameter of Binary Tree 543. 二叉树的直径
144. Binary Tree Preorder Traversal 144. 二叉树的前序遍历

[!前置知识]

阅读本文前,你需要先学习:

本文阅读方法

本文会把很多算法进行抽象和归纳,所以会包含大量其他文章链接。

第一次阅读本文的读者不要 DFS 学习本文,遇到没学过的算法或不理解的地方请跳过,只要对本文所总结的理论有些印象即可。在学习本站后面的算法技巧时,你自然可以逐渐理解本文的精髓所在,日后回来重读本文,会有更深的体会。

本站所有文章的脉络都是按照 学习数据结构和算法的框架思维 提出的框架来构建的,其中着重强调了二叉树题目的重要性,所以把本文放在第一章的必读系列中。

先在开头总结一下,二叉树解题的思维模式分两类:

1、是否可以通过遍历一遍二叉树得到答案?如果可以,用一个 traverse 函数配合外部变量来实现,这叫「遍历」的思维模式。

2、是否可以定义一个递归函数,通过子问题(子树)的答案推导出原问题的答案?如果可以,写出这个递归函数的定义,并充分利用这个函数的返回值,这叫「分解问题」的思维模式。

无论使用哪种思维模式,你都需要思考:

如果单独抽出一个二叉树节点,它需要做什么事情?需要在什么时候(前/中/后序位置)做?其他的节点不用你操心,递归函数会帮你在所有节点上执行相同的操作。

本文中会用题目来举例,但都是最最简单的题目,所以不用担心自己看不懂,我可以帮你从最简单的问题中提炼出所有二叉树题目的共性,并将二叉树中蕴含的思维进行升华,反手用到 动态规划回溯算法分治算法图论算法 中去,这也是我一直强调框架思维的原因。希望你在学习了上述高级算法后,也能回头再来看看本文,会对它们有更深刻的认识。

首先,我还是要不厌其烦地强调一下二叉树这种数据结构及相关算法的重要性。

二叉树的重要性

举个例子,比如两个经典排序算法 快速排序 和 归并排序,对于它俩,你有什么理解?

如果你告诉我,快速排序就是个二叉树的前序遍历,归并排序就是个二叉树的后序遍历,那么我就知道你是个算法高手了

为什么快速排序和归并排序能和二叉树扯上关系?我们来简单分析一下他们的算法思想和代码框架:

快速排序的逻辑是,若要对 nums[lo..hi] 进行排序,我们先找一个分界点 p,通过交换元素使得 nums[lo..p-1] 都小于等于 nums[p],且 nums[p+1..hi] 都大于 nums[p],然后递归地去 nums[lo..p-1] 和 nums[p+1..hi] 中寻找新的分界点,最后整个数组就被排序了。

快速排序的代码框架如下:

1
2
3
4
5
6
7
8
9
10
11
def sort(nums: List[int], lo: int, hi: int):
if lo >= hi:
return
# ****** 前序位置 ******
# 对 nums[lo..hi] 进行切分,将 nums[p] 排好序
# 使得 nums[lo..p-1] <= nums[p] < nums[p+1..hi]
p = partition(nums, lo, hi)

# 去左右子数组进行切分
sort(nums, lo, p - 1)
sort(nums, p + 1, hi)

先构造分界点,然后去左右子数组构造分界点,你看这不就是一个二叉树的前序遍历吗?

再说说归并排序的逻辑,若要对 nums[lo..hi] 进行排序,我们先对 nums[lo..mid] 排序,再对 nums[mid+1..hi] 排序,最后把这两个有序的子数组合并,整个数组就排好序了。

归并排序的代码框架如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 定义:排序 nums[lo..hi]
def sort(nums: List[int], lo: int, hi: int) -> None:
if lo == hi:
return
mid = (lo + hi) // 2
# 利用定义,排序 nums[lo..mid]
sort(nums, lo, mid)
# 利用定义,排序 nums[mid+1..hi]
sort(nums, mid + 1, hi)

# ****** 后序位置 ******
# 此时两部分子数组已经被排好序
# 合并两个有序数组,使 nums[lo..hi] 有序
merge(nums, lo, mid, hi)

先对左右子数组排序,然后合并(类似合并有序链表的逻辑),你看这是不是二叉树的后序遍历框架?另外,这不就是传说中的分治算法嘛,不过如此呀。

如果你一眼就识破这些排序算法的底细,还需要背这些经典算法吗?不需要。你可以手到擒来,从二叉树遍历框架就能扩展出算法了。

说了这么多,旨在说明,二叉树的算法思想的运用广泛,甚至可以说,只要涉及递归,都可以抽象成二叉树的问题。

接下来我们从二叉树的前中后序开始讲起,让你深刻理解这种数据结构的魅力。

深入理解前中后序

我先甩给你几个问题,请默默思考 30 秒:

1、你理解的二叉树的前中后序遍历是什么,仅仅是三个顺序不同的 List 吗?

2、请分析,后序遍历有什么特殊之处?

3、请分析,为什么多叉树没有中序遍历?

答不上来,说明你对前中后序的理解仅仅局限于教科书,不过没关系,我用类比的方式解释一下我眼中的前中后序遍历。

首先,回顾一下 二叉树的 DFS/BFS 遍历 中说到的二叉树递归遍历框架:

1
2
3
4
5
6
7
8
9
# 二叉树的遍历框架
def traverse(root):
if root is None:
return
# 前序位置
traverse(root.left)
# 中序位置
traverse(root.right)
# 后序位置

先不管所谓前中后序,单看 traverse 函数,你说它在做什么事情?

其实它就是一个能够遍历二叉树所有节点的一个函数,和你遍历数组或者链表本质上没有区别:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 迭代遍历数组
def traverse(arr: List[int]) -> None:
for i in range(len(arr)):
pass

# 递归遍历数组
def traverse_recursive(arr: List[int], i: int) -> None:
if i == len(arr):
return
# 前序位置
traverse_recursive(arr, i + 1)
# 后序位置


# 迭代遍历单链表
def traverse_linked_list(head: ListNode) -> None:
p = head
while p:
p = p.next

# 递归遍历单链表
def traverse_linked_list_recursive(head: ListNode) -> None:
if not head:
return
# 前序位置
traverse_linked_list_recursive(head.next)
# 后序位置

单链表和数组的遍历可以是迭代的,也可以是递归的,二叉树这种结构无非就是二叉链表,它没办法简单改写成 for 循环的迭代形式,所以我们遍历二叉树一般都使用递归形式。

你也注意到了,只要是递归形式的遍历,都可以有前序位置和后序位置,分别在递归之前和递归之后。

所谓前序位置,就是刚进入一个节点(元素)的时候,后序位置就是即将离开一个节点(元素)的时候,那么进一步,你把代码写在不同位置,代码执行的时机也不同:

比如说,如果让你倒序打印一条单链表上所有节点的值,你怎么搞?

实现方式当然有很多,但如果你对递归的理解足够透彻,可以利用后序位置来操作:

1
2
3
4
5
6
7
# 递归遍历单链表,倒序打印链表元素
def traverse(head):
if head is None:
return
traverse(head.next)
# 后序位置
print(head.val)

结合上面那张图,你应该知道为什么这段代码能够倒序打印单链表了吧,本质上是利用递归的堆栈帮你实现了倒序遍历的效果。

那么说回二叉树也是一样的,只不过多了一个中序位置罢了。

教科书里只会问你前中后序遍历结果分别是什么,所以对于一个只上过大学数据结构课程的人来说,他大概以为二叉树的前中后序只不过对应三种顺序不同的 List<Integer> 列表。

但是我想说,前中后序是遍历二叉树过程中处理每一个节点的三个特殊时间点,绝不仅仅是三个顺序不同的 List:

前序位置的代码在刚刚进入一个二叉树节点的时候执行;

后序位置的代码在将要离开一个二叉树节点的时候执行;

中序位置的代码在一个二叉树节点左子树都遍历完,即将开始遍历右子树的时候执行。

你注意本文的用词,我一直说前中后序「位置」,就是要和大家常说的前中后序「遍历」有所区别:你可以在前序位置写代码往一个 List 里面塞元素,那最后得到的就是前序遍历结果;但并不是说你就不可以写更复杂的代码做更复杂的事。

画成图,前中后序三个位置在二叉树上是这样:

300

你可以发现每个节点都有「唯一」属于自己的前中后序位置,所以我说前中后序遍历是遍历二叉树过程中处理每一个节点的三个特殊时间点。

这里你也可以理解为什么多叉树没有中序位置,因为二叉树的每个节点只会进行唯一一次左子树切换右子树,而多叉树节点可能有很多子节点,会多次切换子树去遍历,所以多叉树节点没有「唯一」的中序遍历位置。

说了这么多基础的,就是要帮你对二叉树建立正确的认识,然后你会发现:

二叉树的所有问题,就是让你在前中后序位置注入巧妙的代码逻辑,去达到自己的目的,你只需要单独思考每一个节点应该做什么,其他的不用你管,抛给二叉树遍历框架,递归会在所有节点上做相同的操作

你也可以看到,图论算法基础 把二叉树的遍历框架扩展到了图,并以遍历为基础实现了图论的各种经典算法,不过这是后话,本文就不多说了。

两种解题思路

**二叉树题目的递归解法可以分两类思路,第一类是遍历一遍二叉树得出答案,第二类是通过分解问题计算出答案,这两类思路分别对应着 回溯算法核心框架 和 动态规划核心框架**。

[!Tip]

这里说一下我的函数命名习惯:二叉树中用遍历思路解题时函数签名一般是 void traverse(...),没有返回值,靠更新外部变量来计算结果,而用分解问题思路解题时函数名根据该函数具体功能而定,而且一般会有返回值,返回值是子问题的计算结果。

与此对应的,你会发现我在 回溯算法核心框架 中给出的函数签名一般也是没有返回值的 void backtrack(...),而在 动态规划核心框架 中给出的函数签名是带有返回值的 dp 函数。这也说明它俩和二叉树之间千丝万缕的联系。

虽然函数命名没有什么硬性的要求,但我还是建议你也遵循我的这种风格,这样更能突出函数的作用和解题的思维模式,便于你自己理解和运用。

当时我是用二叉树的最大深度这个问题来举例,重点在于把这两种思路和动态规划和回溯算法进行对比,而本文的重点在于分析这两种思路如何解决二叉树的题目。

力扣第 104 题「二叉树的最大深度」就是最大深度的题目,所谓最大深度就是根节点到「最远」叶子节点的最长路径上的节点数,比如输入这棵二叉树,算法应该返回 3:

300

你做这题的思路是什么?显然遍历一遍二叉树,用一个外部变量记录每个节点所在的深度,取最大值就可以得到最大深度,这就是遍历二叉树计算答案的思路

解法代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# 遍历的思路
class Solution:

def __init__(self):
# 记录遍历到的节点的深度
self.depth = 0
# 记录最大深度
self.res = 0

def maxDepth(self, root: TreeNode) -> int:
self.traverse(root)
return self.res

# 遍历二叉树
def traverse(self, root: TreeNode):
if root is None:
return

# 前序遍历位置(进入节点)增加深度
self.depth += 1
# 遍历到叶子节点时记录最大深度
if root.left is None and root.right is None:
self.res = max(self.res, self.depth)
self.traverse(root.left)
self.traverse(root.right)

# 后序遍历位置(离开节点)减少深度
self.depth -= 1

这个解法应该很好理解,但为什么需要在前序位置增加 depth,在后序位置减小 depth

因为前面说了,前序位置是进入一个节点的时候,后序位置是离开一个节点的时候,depth 记录当前递归到的节点深度,你把 traverse 理解成在二叉树上游走的一个指针,所以当然要这样维护。

至于对 res 的更新,你放到前中后序位置都可以,只要保证在进入节点之后,离开节点之前(即 depth 自增之后,自减之前)就行了。

当然,你也很容易发现一棵二叉树的最大深度可以通过子树的最大深度推导出来,这就是分解问题计算答案的思路

解法代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 分解问题的思路
class Solution:
# 定义:输入一个节点,返回以该节点为根的二叉树的最大深度
def maxDepth(self, root: TreeNode) -> int:
if root is None:
return 0
# 利用定义,计算左右子树的最大深度
leftMax = self.maxDepth(root.left)
rightMax = self.maxDepth(root.right)

# 根据左右子树的最大深度推出原二叉树的最大深度
# 整棵树的最大深度等于左右子树的最大深度取最大值,
# 然后再加上根节点自己
return 1 + max(leftMax, rightMax)

只要明确递归函数的定义,这个解法也不难理解,但为什么主要的代码逻辑集中在后序位置?

因为这个思路正确的核心在于,你确实可以通过子树的最大深度推导出原树的深度,所以当然要首先利用递归函数的定义算出左右子树的最大深度,然后推出原树的最大深度,主要逻辑自然放在后序位置。

如果你理解了最大深度这个问题的两种思路,那么我们再回头看看最基本的二叉树前中后序遍历,就比如力扣第 144 题「二叉树的前序遍历」,让你计算前序遍历结果。

我们熟悉的解法就是用「遍历」的思路,我想应该没什么好说的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 用遍历的思路计算前序遍历结果
class Solution:
def __init__(self):
self.res = []

def preorderTraversal(self, root: TreeNode) -> List[int]:
self.traverse(root)
return self.res

# 二叉树遍历函数
def traverse(self, root: TreeNode):
if root is None:
return
# 前序位置
self.res.append(root.val)
self.traverse(root.left)
self.traverse(root.right)

但你是否能够用「分解问题」的思路,来计算前序遍历的结果?

换句话说,不要用像 traverse 这样的辅助函数和任何外部变量,单纯用题目给的 preorderTraverse 函数递归解题,你会不会?

我们知道前序遍历的特点是,根节点的值排在首位,接着是左子树的前序遍历结果,最后是右子树的前序遍历结果:

那这不就可以分解问题了么,一棵二叉树的前序遍历结果 = 根节点 + 左子树的前序遍历结果 + 右子树的前序遍历结果

所以,你可以这样实现前序遍历算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
class Solution:
# 定义:输入一棵二叉树的根节点,返回这棵树的前序遍历结果
def preorderTraversal(self, root):
res = []
if root == None:
return res
# 前序遍历的结果,root.val 在第一个
res.append(root.val)
# 利用函数定义,后面接着左子树的前序遍历结果
res.extend(self.preorderTraversal(root.left))
# 利用函数定义,最后接着右子树的前序遍历结果
res.extend(self.preorderTraversal(root.right))
return res

中序和后序遍历也是类似的,只要把 add(root.val) 放到中序和后序对应的位置就行了。

这个解法短小精干,但为什么不常见呢?

一个原因是这个算法的复杂度不好把控,比较依赖语言特性。

Java 的话无论 ArrayList 还是 LinkedList,addAll 方法的复杂度都是 ,所以总体的最坏时间复杂度会达到 ,除非你自己实现一个复杂度为 ) 的 addAll 方法,底层用链表的话是可以做到的,因为多条链表只要简单的指针操作就能连接起来。

当然,最主要的原因还是因为教科书上从来没有这么教过……

上文举了两个简单的例子,但还有不少二叉树的题目是可以同时使用两种思路来思考和求解的,这就要靠你自己多去练习和思考,不要仅仅满足于一种熟悉的解法思路。

综上,遇到一道二叉树的题目时的通用思考过程是:

1、是否可以通过遍历一遍二叉树得到答案?如果可以,用一个 traverse 函数配合外部变量来实现。

2、是否可以定义一个递归函数,通过子问题(子树)的答案推导出原问题的答案?如果可以,写出这个递归函数的定义,并充分利用这个函数的返回值。

3、无论使用哪一种思维模式,你都要明白二叉树的每一个节点需要做什么,需要在什么时候(前中后序)做

后序位置的特殊之处

说后序位置之前,先简单说下前序和中序。

前序位置本身其实没有什么特别的性质,之所以你发现好像很多题都是在前序位置写代码,实际上是因为我们习惯把那些对前中后序位置不敏感的代码写在前序位置罢了。

中序位置主要用在 BST 场景中,你完全可以把 BST 的中序遍历认为是遍历有序数组。

划重点

仔细观察,前中后序位置的代码,能力依次增强

前序位置的代码只能从函数参数中获取父节点传递来的数据。

中序位置的代码不仅可以获取参数数据,还可以获取到左子树通过函数返回值传递回来的数据。

后序位置的代码最强,不仅可以获取参数数据,还可以同时获取到左右子树通过函数返回值传递回来的数据。

所以,某些情况下把代码移到后序位置效率最高;有些事情,只有后序位置的代码能做。

举些具体的例子来感受下它们的能力区别。现在给你一棵二叉树,我问你两个简单的问题:

1、如果把根节点看做第 1 层,如何打印出每一个节点所在的层数?

2、如何打印出每个节点的左右子树各有多少节点?

第一个问题可以这样写代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
// 二叉树遍历函数
void traverse(TreeNode root, int level) {
if (root == null) {
return;
}
// 前序位置
printf("节点 %s 在第 %d 层", root.val, level);
traverse(root.left, level + 1);
traverse(root.right, level + 1);
}

// 这样调用
traverse(root, 1);

第二个问题可以这样写代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
// 定义:输入一棵二叉树,返回这棵二叉树的节点总数
int count(TreeNode root) {
if (root == null) {
return 0;
}
int leftCount = count(root.left);
int rightCount = count(root.right);
// 后序位置
printf("节点 %s 的左子树有 %d 个节点,右子树有 %d 个节点",
root, leftCount, rightCount);

return leftCount + rightCount + 1;
}

这两个问题的根本区别在于

一个节点在第几层,你从根节点遍历过来的过程就能顺带记录,用递归函数的参数就能传递下去;而以一个节点为根的整棵子树有多少个节点,你必须遍历完子树之后才能数清楚,然后通过递归函数的返回值拿到答案。

结合这两个简单的问题,你品味一下后序位置的特点,只有后序位置才能通过返回值获取子树的信息。

那么换句话说,一旦你发现题目和子树有关,那大概率要给函数设置合理的定义和返回值,在后序位置写代码了

接下来看下后序位置是如何在实际的题目中发挥作用的,简单聊下力扣第 543 题「二叉树的直径」,让你计算一棵二叉树的最长直径长度。

所谓二叉树的「直径」长度,就是任意两个结点之间的路径长度。最长「直径」并不一定要穿过根结点,比如下面这棵二叉树:
300

它的最长直径是 3,即 [4,2,1,3][4,2,1,9] 或者 [5,2,1,3] 这几条「直径」的长度。

解决这题的关键在于,每一条二叉树的「直径」长度,就是一个节点的左右子树的最大深度之和

现在让我求整棵树中的最长「直径」,那直截了当的思路就是遍历整棵树中的每个节点,然后通过每个节点的左右子树的最大深度算出每个节点的「直径」,最后把所有「直径」求个最大值即可。

最大深度的算法我们刚才实现过了,上述思路就可以写出以下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class Solution:

def __init__(self):
# 记录最大直径的长度
self.maxDiameter = 0

def diameterOfBinaryTree(self, root):
# 对每个节点计算直径,求最大直径
self.traverse(root)
return self.maxDiameter

# 遍历二叉树
def traverse(self, root):
if root is None:
return
# 对每个节点计算直径
leftMax = self.maxDepth(root.left)
rightMax = self.maxDepth(root.right)
myDiameter = leftMax + rightMax
# 更新全局最大直径
self.maxDiameter = max(self.maxDiameter, myDiameter)

self.traverse(root.left)
self.traverse(root.right)

# 计算二叉树的最大深度
def maxDepth(self, root):
if root is None:
return 0
leftMax = self.maxDepth(root.left)
rightMax = self.maxDepth(root.right)
return 1 + max(leftMax, rightMax)

这个解法是正确的,但是运行时间很长,原因也很明显,traverse 遍历每个节点的时候还会调用递归函数 maxDepth,而 maxDepth 是要遍历子树的所有节点的,所以最坏时间复杂度是

这就出现了刚才探讨的情况,前序位置无法获取子树信息,所以只能让每个节点调用 maxDepth 函数去算子树的深度

那如何优化?我们应该把计算「直径」的逻辑放在后序位置,准确说应该是放在 maxDepth 的后序位置,因为 maxDepth 的后序位置是知道左右子树的最大深度的。

所以,稍微改一下代码逻辑即可得到更好的解法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class Solution:
def __init__(self):
# 记录最大直径的长度
self.maxDiameter = 0

def diameterOfBinaryTree(self, root):
self.maxDepth(root)
return self.maxDiameter

def maxDepth(self, root):
if root is None:
return 0
leftMax = self.maxDepth(root.left)
rightMax = self.maxDepth(root.right)
# 后序位置,顺便计算最大直径
myDiameter = leftMax + rightMax
self.maxDiameter = max(self.maxDiameter, myDiameter)

return 1 + max(leftMax, rightMax)

这下时间复杂度只有 maxDepth 函数的 了。

讲到这里,照应一下前文:遇到子树问题,首先想到的是给函数设置返回值,然后在后序位置做文章。

[!Info]

思考题:请你思考一下,运用后序位置的题目使用的是「遍历」的思路还是「分解问题」的思路?

反过来,如果你写出了类似一开始的那种递归套递归的解法,大概率也需要反思是不是可以通过后序遍历优化了。

更多利用后序位置的习题参见 手把手带你刷二叉树(后序篇)手把手带你刷二叉搜索树(后序篇) 和 【练习】利用后序位置解题

以树的视角看动归/回溯/DFS算法的区别和联系

前文我说动态规划/回溯算法就是二叉树算法两种不同思路的表现形式,相信能看到这里的读者应该也认可了我这个观点。但有细心的读者经常提问:你的思考方法让我豁然开朗,但你好像一直没讲过 DFS 算法?

其实我在 一文秒杀所有岛屿题目 中就是用的 DFS 算法,但我确实没有单独用一篇文章讲 DFS 算法,因为 DFS 算法和回溯算法非常类似,只是在细节上有所区别

这个细节上的差别是什么呢?其实就是「做选择」和「撤销选择」到底在 for 循环外面还是里面的区别,DFS 算法在外面,回溯算法在里面。

为什么有这个区别?还是要结合着二叉树理解。这一部分我就把回溯算法、DFS 算法、动态规划三种经典的算法思想,以及它们和二叉树算法的联系和区别,用一句话来说明:

[!important]

动归/DFS/回溯算法都可以看做二叉树问题的扩展,只是它们的关注点不同:

  • 动态规划算法属于分解问题(分治)的思路,它的关注点在整棵「子树」。
  • 回溯算法属于遍历的思路,它的关注点在节点间的「树枝」。
  • DFS 算法属于遍历的思路,它的关注点在单个「节点」。

怎么理解?我分别举三个例子你就懂了。

例子一:分解问题的思想体现

第一个例子,给你一棵二叉树,请你用分解问题的思路写一个 count 函数,计算这棵二叉树共有多少个节点。代码很简单,上文都写过了:

1
2
3
4
5
6
7
8
9
10
# 定义:输入一棵二叉树,返回这棵二叉树的节点总数
def count(root):
if root is None:
return 0
# 当前节点关心的是两个子树的节点总数分别是多少
# 因为用子问题的结果可以推导出原问题的结果
leftCount = count(root.left)
rightCount = count(root.right)
# 后序位置,左右子树节点数加上自己就是整棵树的节点数
return leftCount + rightCount + 1

你看,这就是动态规划分解问题的思路,它的着眼点永远是结构相同的整个子问题,类比到二叉树上就是「子树」

你再看看具体的动态规划问题,比如 动态规划框架套路详解 中举的斐波那契的例子,我们的关注点在一棵棵子树的返回值上:

1
2
3
4
5
6
# f(n) 计算第 n 个斐波那契数
def fib(n: int) -> int:
# base case
if n == 0 or n == 1:
return n
return fib(n - 1) + fib(n - 2)

例子二:回溯算法的思想体现

第二个例子,给你一棵二叉树,请你用遍历的思路写一个 traverse 函数,打印出遍历这棵二叉树的过程,你看下代码:

1
2
3
4
5
6
7
8
9
10
void traverse(TreeNode root) {
if (root == null) return;
printf("从节点 %s 进入节点 %s", root, root.left);
traverse(root.left);
printf("从节点 %s 回到节点 %s", root.left, root);

printf("从节点 %s 进入节点 %s", root, root.right);
traverse(root.right);
printf("从节点 %s 回到节点 %s", root.right, root);
}

不难理解吧,好的,我们现在从二叉树进阶成多叉树,代码也是类似的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 多叉树节点
class Node {
int val;
Node[] children;
}

void traverse(Node root) {
if (root == null) return;
for (Node child : root.children) {
printf("从节点 %s 进入节点 %s", root, cd);
traverse(child);
printf("从节点 %s 回到节点 %s", child, root);
}
}

这个多叉树的遍历框架就可以延伸出 回溯算法框架套路详解 中的回溯算法框架:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// 回溯算法框架
void backtrack(...) {
// base case
if (...) return;

for (int i = 0; i < ...; i++) {
// 做选择
...

// 进入下一层决策树
backtrack(...);

// 撤销刚才做的选择
...
}
}

你看,这就是回溯算法遍历的思路,它的着眼点永远是在节点之间移动的过程,类比到二叉树上就是「树枝」

你再看看具体的回溯算法问题,比如 回溯算法秒杀排列组合子集的九种题型 中讲到的全排列,我们的关注点在一条条树枝上:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// 回溯算法核心部分代码
void backtrack(int[] nums) {
// 回溯算法框架
for (int i = 0; i < nums.length; i++) {
// 做选择
used[i] = true;
track.addLast(nums[i]);

// 进入下一层回溯树
backtrack(nums);

// 取消选择
track.removeLast();
used[i] = false;
}
}

例子三:DFS 的思想体现

第三个例子,我给你一棵二叉树,请你写一个 traverse 函数,把这棵二叉树上的每个节点的值都加一。很简单吧,代码如下:

1
2
3
4
5
6
7
def traverse(root):
if root is None:
return
# 遍历过的每个节点的值加一
root.val += 1
traverse(root.left)
traverse(root.right)

你看,这就是 DFS 算法遍历的思路,它的着眼点永远是在单一的节点上,类比到二叉树上就是处理每个「节点」

你再看看具体的 DFS 算法问题,比如 一文秒杀所有岛屿题目 中讲的前几道题,我们的关注点是 grid 数组的每个格子(节点),我们要对遍历过的格子进行一些处理,所以我说是用 DFS 算法解决这几道题的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// DFS 算法核心逻辑
void dfs(int[][] grid, int i, int j) {
int m = grid.length, n = grid[0].length;
if (i < 0 || j < 0 || i >= m || j >= n) {
return;
}
if (grid[i][j] == 0) {
return;
}
// 遍历过的每个格子标记为 0
grid[i][j] = 0;
dfs(grid, i + 1, j);
dfs(grid, i, j + 1);
dfs(grid, i - 1, j);
dfs(grid, i, j - 1);
}

好,请你仔细品一下上面三个简单的例子,是不是像我说的:动态规划关注整棵「子树」,回溯算法关注节点间的「树枝」,DFS 算法关注单个「节点」。

有了这些铺垫,你就很容易理解为什么回溯算法和 DFS 算法代码中「做选择」和「撤销选择」的位置不同了,看下面两段代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# DFS 算法把「做选择」「撤销选择」的逻辑放在 for 循环外面
def dfs(root):
if root is None:
return
# 做选择
print("enter node %s" % root)
for child in root.children:
dfs(child)
# 撤销选择
print("leave node %s" % root)

# 回溯算法把「做选择」「撤销选择」的逻辑放在 for 循环里面
def backtrack(root):
if root is None:
return
for child in root.children:
# 做选择
print("I'm on the branch from %s to %s" % (root, child))
backtrack(child)
# 撤销选择
print("I'll leave the branch from %s to %s" % (child, root))

看到了吧,你回溯算法必须把「做选择」和「撤销选择」的逻辑放在 for 循环里面,否则怎么拿到「树枝」的两个端点?

层序遍历

二叉树题型主要是用来培养递归思维的,而层序遍历属于迭代遍历,也比较简单,这里就过一下代码框架吧:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 输入一棵二叉树的根节点,层序遍历这棵二叉树
class Solution:
def levelOrder(self, root: TreeNode) -> List[List[int]]:
if not root:
return
q = collections.deque()
q.append(root)
depth = 0
# 从上到下遍历二叉树的每一层
while q:
sz = len(q)
# 从左到右遍历每一层的每个节点
for i in range(sz):
cur = q.popleft()

# 将下一层节点放入队列
if cur.left:
q.append(cur.left)
if cur.right:
q.append(cur.right)
depth += 1

这里面 while 循环和 for 循环分管从上到下和从左到右的遍历:

前文 [BFS 算法框架](#BFS 算法解题套路框架) 就是从二叉树的层序遍历扩展出来的,常用于求无权图的最短路径问题。

当然这个框架还可以灵活修改,题目不需要记录层数(步数)时可以去掉上述框架中的 for 循环。

值得一提的是,有些很明显需要用层序遍历技巧的二叉树的题目,也可以用递归遍历的方式去解决,而且技巧性会更强,非常考察你对前中后序的把控。

好了,本文已经够长了,围绕前中后序位置算是把二叉树题目里的各种套路给讲透了,真正能运用出来多少,就需要你亲自刷题实践和思考了。

回答评论区的问题

关于层序遍历(以及其扩展出的 [BFS 算法框架](#BFS 算法解题套路框架)),我在最后多说几句。

如果你对二叉树足够熟悉,可以想到很多方式通过递归函数得到层序遍历结果,比如下面这种写法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution:
def __init__(self):
self.res = []

def levelTraverse(self, root):
if root is None:
return self.res
# root 视为第 0 层
self.traverse(root, 0)
return self.res

def traverse(self, root, depth):
if root is None:
return
# 前序位置,看看是否已经存储 depth 层的节点了
if len(self.res) <= depth:
# 第一次进入 depth 层
self.res.append([])
# 前序位置,在 depth 层添加 root 节点的值
self.res[depth].append(root.val)
self.traverse(root.left, depth + 1)
self.traverse(root.right, depth + 1)

这种思路从结果上说确实可以得到层序遍历结果,但其本质还是二叉树的前序遍历,或者说 DFS 的思路,而不是层序遍历,或者说 BFS 的思路。因为这个解法是依赖前序遍历自顶向下、自左向右的顺序特点得到了正确的结果。

抽象点说,这个解法更像是从左到右的「列序遍历」,而不是自顶向下的「层序遍历」。所以对于计算最小距离的场景,这个解法完全等同于 DFS 算法,没有 BFS 算法的性能的优势。

还有优秀读者评论了这样一种递归进行层序遍历的思路:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Solution:
def __init__(self):
self.res = []

def levelTraverse(self, root):
if not root:
return self.res
nodes = [root]
self.traverse(nodes)
return self.res

def traverse(self, curLevelNodes):
# base case
if not curLevelNodes:
return
# 前序位置,计算当前层的值和下一层的节点列表
nodeValues = []
nextLevelNodes = []
for node in curLevelNodes:
nodeValues.append(node.val)
if node.left:
nextLevelNodes.append(node.left)
if node.right:
nextLevelNodes.append(node.right)
# 前序位置添加结果,可以得到自顶向下的层序遍历
self.res.append(nodeValues)
self.traverse(nextLevelNodes)
# 后序位置添加结果,可以得到自底向上的层序遍历结果
# res.append(nodeValues)

这个 traverse 函数很像递归遍历单链表的函数,其实就是把二叉树的每一层抽象理解成单链表的一个节点进行遍历。

相较上一个递归解法,这个递归解法是自顶向下的「层序遍历」,更接近 BFS 的奥义,可以作为 BFS 算法的递归实现扩展一下思维。

一个视角 + 两种思维模式搞定递归


读完本文,你不仅学会了算法套路,还可以顺便解决如下题目:

LeetCode 力扣 难度
104. Maximum Depth of Binary Tree 104. 二叉树的最大深度

[!前置知识]

阅读本文前,你需要先学习:

一句话总结

一个视角是指「树」的视角,两种思维模式是指「遍历」和「分解问题」两种思维模式。

本文给你讲清楚:

1、算法的本质是穷举,递归是一种重要的穷举手段,递归的正确理解方法是从「树」的角度理解。

2、编写递归算法,有两种思维模式:一种是通过「遍历」一遍树得到答案,另一种是通过「分解问题」得到答案。

从树的角度理解递归

对于初学者,递归算法确实不容易理解,我刚学算法时也不例外。

我曾设想过一些有趣的视角来理解递归,比如把两面镜子相对放置,镜子中的影像就会无限嵌套下去,这似乎也算是递归的一种体现?

再比如,从程序运行原理上来看,递归函数的调用本质上就是入栈和出栈的过程,所以应该可以从栈的角度理解递归?

随着我对算法的理解不断深入,现在我可以负责任地告诉你,理解和编写递归算法最有效的方法是从「树」的视角去理解,其他的都属于花拳绣腿,中看不中用。

下面我将用斐波那契树和全排列这两个简单的经典算法问题来论证这一点。

再次强调,本文的重点是思维方法而不是代码,所以不必太纠结代码细节。

斐波那契数列

首先来看一个既简单又经典的问题:斐波那契数列。

斐波那契数列的数学定义如下:

比方说:

那么现在请你写一个函数,输入一个整数 ,返回斐波那契数列  的值:

1
int fib(int n);

其实直接把斐波那契数列的数学定义翻译成代码,就可以得到一个递归解法:

1
2
3
4
5
6
int fib(int n) {
if (n < 2) {
return n;
}
return fib(n - 1) + fib(n - 2);
}

当然,这个解法的效率并不高,在 动态规划核心框架 中,我们会继续优化,这里暂且不优化,就看这个递归解法。

接下来,我描述一下这个算法的计算过程:

首先,我们想计算 fib(5),根据算法,我们需要计算 fib(4) 和 fib(3),然后求和。

那就先算 fib(4) 的值吧:根据定义,fib(4) 需要计算 fib(3) 和 fib(2),然后求和。

那就先算 fib(3) 的值吧:根据定义,fib(3) 需要计算 fib(2) 和 fib(1),然后求和。

那就先算 fib(2) 的值吧:根据定义,fib(2) 需要对 fib(1) = 1 和 fib(0) = 0 求和,结果是 1。

从树结构的角度,是不是很容易理解递归的计算过程?你看这个 fib 函数和二叉树的遍历函数像不像?所以这个函数抽象出来的递归树就是一棵二叉树:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// 斐波那契数列
int fib(int n) {
if (n < 2) {
return n;
}
return fib(n - 1)
+ fib(n - 2);
}

// 二叉树遍历函数
void traverse(TreeNode root) {
if (root == null) {
return;
}
traverse(root.left);
traverse(root.right);
}

接下来,我们再看一个稍微复杂一些的递归算法:全排列问题。

全排列问题

现在给你输入一个 nums 数组,其中有不重复的 n 个元素,请你返回这些元素的所有排列方式。

比方说输入 nums = [1,2,3],那么算法返回如下 6 种排列:

1
2
3
[1,2,3], [1,3,2],
[2,1,3], [2,3,1],
[3,1,2], [3,2,1]

我们中学时学过排列组合,就应该做过类似的题目吧,如果让你手动计算 [1,2,3] 的全排列,你会怎么做?本质上就是穷举,只不过这个穷举过程需要一些条理:

先穷举第一位,可以放 1, 2, 3 中的任意一个,我们都要尝试一遍。

如果把 1 放在第一位,接下来第二位只能放 2 或 3 了。

如果第二位放 2,那么第三位只能放 3 了,得到了第一个全排列 [1,2,3]

如果第二位放 3,那么第三位只能放 2 了,得到了第二个全排列 [1,3,2]

回头来看,把 2 放在第一位,那么第二位只能放 1 或 3 了。

如果第二位放 1,那么第三位只能放 3 了,得到了第三个全排列 [2,1,3]

如果第二位放 3,那么第三位只能放 1 了,得到了第四个全排列 [2,3,1]

回头来看,把 3 放在第一位,那么第二位只能放 1 或 2 了。

如果第二位放 1,那么第三位只能放 2 了,得到了第五个全排列 [3,1,2]

如果第二位放 2,那么第三位只能放 1 了,得到了第六个全排列 [3,2,1]

这样,就得到了 [1,2,3] 的所有全排列。

上面的的穷举过程,其实就可以抽象成一棵递归树。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution:
def __init__(self):
self.res = []

# 主函数,输入一组不重复的数字,返回它们的全排列
def permute(self, nums):
# 记录「路径」
track = []
# 「路径」中的元素会被标记为 true,避免重复使用
used = [False] * len(nums)

self.backtrack(nums, track, used)
return self.res

# 路径:记录在 track 中
# 选择列表:nums 中不存在于 track 的那些元素(used[i] 为 false)
# 结束条件:nums 中的元素全都在 track 中出现
def backtrack(self, nums, track, used):
# 触发结束条件
if len(track) == len(nums):
self.res.append(track.copy())
return

for i in range(len(nums)):
# 排除不合法的选择
if used[i]:
# nums[i] 已经在 track 中,跳过
continue
# 做选择
track.append(nums[i])
used[i] = True
# 进入下一层决策树
self.backtrack(nums, track, used)
# 取消选择
track.pop()
used[i] = False

抽出递归部分,应该能看出这个算法可以抽象成一棵多叉树:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// 全排列算法主要结构
void backtrack(int[] nums, List<Integer> track) {
if (track.size() == nums.length) {
return;
}
for (int i = 0; i < nums.length; i++) {
backtrack(nums, track);
}
}

// 多叉树遍历函数
void traverse(TreeNode root) {
if (root == null) {
return;
}
for (TreeNode child : root.children) {
traverse(child);
}
}

你应该已经感觉到了,「树」结构是一个非常有效的数据结构。把问题抽象成树结构,然后用代码去遍历这棵树,就是递归的本质

编写递归的两种思维模式

现在你已经知道了一切递归算法都要抽象成树结构来理解,接下来要更进一步:如果你想用递归算法来求解一个问题,应该怎么写代码呢?

其实很简单,编写递归算法只可能有两种思维模式,都尝试套用一下,必然有一种能写出来:

一种是「遍历」的思维模式,另一种是「分解问题」的思维模式。

上面讲的两道例题中,它们虽然都抽象成了一棵递归树,但斐波那契数列使用的是「分解问题」的思维模式求解,全排列使用的是「遍历」的思维模式求解。

分解问题的思维模式

你看斐波那契数列问题,递归树上的每个节点,其实就是一个子问题的解。fib(5) 是怎么算出来的?是根节点 fib(5) 去问左右子节点 fib(4) 和 fib(3) 的值,然后相加得到的:

这里面就存在一个分解问题的过程:把规模较大的问题 fib(5) 分解成规模较小的问题 fib(4) 和 fib(3),然后通过子问题的解得到原问题的解,我们可以称这种思维模式为「分解问题」。

划重点

如果你想用「分解问题」的思维模式来写递归算法,那么这个递归函数一定要有一个清晰的定义,说明这个函数参数的含义是什么,返回什么结果

这样你才能利用这个定义来计算子问题,反推原问题的解。

比如斐波那契数列的递归函数 fib 就有一个清晰的定义,且算法就在利用这个定义:

1
2
3
4
5
6
7
8
9
10
11
12
// 定义:输入一个非负整数 n,返回斐波那契数列中的第 n 个数
int fib(int n) {
if (n < 2) {
return n;
}
// 利用定义,计算前两个斐波那契数(子问题)
int fib_n_1 = fib(n - 1);
int fib_n_2 = fib(n - 2);

// 通过子问题的解,计算原问题的解
return fib_n_1 + fib_n_2;
}

再来一个简单的例题吧,比如计算二叉树的最大深度,力扣第 104 题「二叉树的最大深度」:

给定一个二叉树 root ,返回其最大深度。

二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。

示例 1:

300

输入: root = [3,9,20,null,null,15,7]
输出: 3

示例 2:

输入: root = [1,null,2]
输出: 2

提示:

  • 树中节点的数量在 [0, 104] 区间内。
  • -100 <= Node.val <= 100

题目来源:力扣 104. 二叉树的最大深度

这道题可以用分解问题的思路求解:想计算整棵树的最大深度,可以先计算左右子树的最大深度,取两者的最大值加一,就是整棵树的最大深度。

那么我们可以给 maxDepth 函数一个明确的定义:输入一棵二叉树的节点,函数返回以这个节点为根的二叉树的最大深度。

然后,就可以得到一个类似斐波那契的递归公式:

maxDepth(root)={0if root=nullmax(maxDepth(root.left),maxDepth(root.right))+1otherwisemaxDepth(root)={0max(maxDepth(root.left),maxDepth(root.right))+1​if root=nullotherwise​

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 分解问题的思路
class Solution:
# 定义:输入一个节点,返回以该节点为根的二叉树的最大深度
def maxDepth(self, root: TreeNode) -> int:
if root is None:
return 0
# 利用定义,计算左右子树的最大深度
leftMax = self.maxDepth(root.left)
rightMax = self.maxDepth(root.right)

# 根据左右子树的最大深度推出原二叉树的最大深度
# 整棵树的最大深度等于左右子树的最大深度取最大值,
# 然后再加上根节点自己
return 1 + max(leftMax, rightMax)

遍历的思维模式

递归树上的节点并没有一个明确的含义,只是记录了之前所做的一些选择。所有全排列,就是所有叶子节点上的结果。这种思维模式称为「遍历」。

[!improtant]

如果你想用「遍历」的思维模式来写递归算法,那么你需要一个无返回值的遍历函数,在遍历的过程中收集结果

比如全排列问题,目前你不需要完全理解全排列的代码,只需注意 backtrack 函数没有返回值,也没有一个明确的定义,它就类似 for 循环一样,单纯起到遍历递归树,收集叶子节点上的结果的作用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// 全排列算法主要结构

// 全局变量,存储 backtrack 函数的遍历状态
List<List<Integer>> res = new LinkedList<>();
List<Integer> track = new LinkedList<>();

// 递归树遍历函数
void backtrack(int[] nums, List<Integer> track) {
if (track.size() == nums.length) {
// 到达叶子节点,收集结果
res.add(new LinkedList<>(track));
return;
}
for (int i = 0; i < nums.length; i++) {
// 做选择
track.add(nums[i]);

backtrack(nums, track);

// 撤销选择
track.removeLast();
}
}

有没有感觉出「遍历」和「分解问题」两种思维模式的区别?

再来看力扣第 104 题「二叉树的最大深度」,我们也可以用「遍历」的思维模式来写解法,用标准的二叉树遍历函数 traverse 来遍历整棵树,在遍历的过程更新最大深度,这样当遍历完所有节点时,必然可以求出整棵树的最大深度:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# 遍历的思路
class Solution:

def __init__(self):
# 记录遍历到的节点的深度
self.depth = 0
# 记录最大深度
self.res = 0

def maxDepth(self, root: TreeNode) -> int:
self.traverse(root)
return self.res

# 遍历二叉树
def traverse(self, root: TreeNode):
if root is None:
return

# 前序遍历位置(进入节点)增加深度
self.depth += 1
# 遍历到叶子节点时记录最大深度
if root.left is None and root.right is None:
self.res = max(self.res, self.depth)
self.traverse(root.left)
self.traverse(root.right)

# 后序遍历位置(离开节点)减少深度
self.depth -= 1

总结

本文先用斐波那契数列和全排列问题的递归可视化,论证一定要从「树」的角度理解递归算法。

然后总结编写递归算法的两种思维模式:「分解问题」的思路和「遍历」的思路。文中给出的斐波那契数列的解法是「分解问题」的思路,全排列的解法是「遍历」的思路。

有些题目可以同时运用这两种思维模式,比如力扣第 104 题「二叉树的最大深度」,既可以用「分解问题」的思路,也可以用「遍历」的思路来求解,两种解法效率相同,但是代码看起来差异很大。

讲上面这些,最终目的还是希望你能够参考以下步骤,运用自如地写出递归算法

1、首先,这个问题是否可以抽象成一棵树结构?如果可以,那么就要用递归算法了。

2、如果要用递归算法,那么就思考「遍历」和「分解问题」这两种思维模式,看看哪种更适合这个问题。

3、如果用「分解问题」的思维模式,那么一定要写清楚这个递归函数的定义是什么,然后利用这个定义来分解问题,利用子问题的答案推导原问题的答案;如果用「遍历」的思维模式,那么要用一个无返回值的递归函数,单纯起到遍历递归树,收集目标结果的作用。

其实,「分解问题」的思维模式就对应着后面要讲解的 动态规划算法 和 分治算法,「遍历」的思维模式就对应着后面要讲解的 DFS/回溯算法

动态规划解题套路框架

读完本文,你不仅学会了算法套路,还可以顺便解决如下题目:

LeetCode 力扣 难度
509. Fibonacci Number 509. 斐波那契数
322. Coin Change 322. 零钱兑换

[!前置知识]

阅读本文前,你需要先学习:

动态规划问题(Dynamic Programming)应该是很多读者头疼的,不过这类问题也是最具有技巧性,最有意思的。本站使用了整整一个章节专门来写这个算法,动态规划的重要性也可见一斑。

本文解决几个问题:

动态规划是什么?解决动态规划问题有什么技巧?如何学习动态规划?

刷题刷多了就会发现,算法技巧就那几个套路,我们后续的动态规划系列章节,都在使用本文的解题框架思维,如果你心里有数,就会轻松很多。所以本文放在第一章,希望能够成为解决动态规划问题的一部指导方针,下面上干货。

首先,动态规划问题的一般形式就是求最值。动态规划其实是运筹学的一种最优化方法,只不过在计算机问题上应用比较多,比如说让你求最长递增子序列呀,最小编辑距离呀等等。

既然是要求最值,核心问题是什么呢?求解动态规划的核心问题是穷举。因为要求最值,肯定要把所有可行的答案穷举出来,然后在其中找最值呗。

动态规划这么简单,就是穷举就完事了?我看到的动态规划问题都很难啊!

首先,虽然动态规划的核心思想就是穷举求最值,但是问题可以千变万化,穷举所有可行解其实并不是一件容易的事,需要你熟练掌握递归思维,只有列出正确的「状态转移方程」,才能正确地穷举。

而且,你需要判断算法问题是否具备「最优子结构」,是否能够通过子问题的最值得到原问题的最值。

另外,动态规划问题存在「重叠子问题」,如果暴力穷举的话效率会很低,所以需要你使用「备忘录」或者「DP table」来优化穷举过程,避免不必要的计算。

以上提到的重叠子问题、最优子结构、状态转移方程就是动态规划三要素。具体什么意思等会会举例详解,但是在实际的算法问题中,写出状态转移方程是最困难的,这也就是为什么很多朋友觉得动态规划问题困难的原因,我来提供我总结的一个思维框架,辅助你思考状态转移方程:

明确「状态」-> 明确「选择」 -> 定义 dp 数组/函数的含义

按上面的套路走,最后的解法代码就会是如下的框架:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 自顶向下递归的动态规划
def dp(状态1, 状态2, ...):
for 选择 in 所有可能的选择:
# 此时的状态已经因为做了选择而改变
result = 求最值(result, dp(状态1, 状态2, ...))
return result

# 自底向上迭代的动态规划
# 初始化 base case
dp[0][0][...] = base case
# 进行状态转移
for 状态1 in 状态1的所有取值:
for 状态2 in 状态2的所有取值:
for ...
dp[状态1][状态2][...] = 求最值(选择1,选择2...)

下面通过斐波那契数列问题和凑零钱问题来详解动态规划的基本原理。前者主要是让你明白什么是重叠子问题(斐波那契数列没有求最值,所以严格来说不是动态规划问题),后者主要举集中于如何列出状态转移方程。

一、斐波那契数列

力扣第 509 题「斐波那契数」就是这个问题,请读者不要嫌弃这个例子简单,只有简单的例子才能让你把精力充分集中在算法背后的通用思想和技巧上,而不会被那些隐晦的细节问题搞的莫名其妙。想要困难的例子,接下来的动态规划系列里有的是。

暴力递归

斐波那契数列的数学形式就是递归的,写成代码就是这样:

1
2
3
4
5
6
# f(n) 计算第 n 个斐波那契数
def fib(n: int) -> int:
# base case
if n == 0 or n == 1:
return n
return fib(n - 1) + fib(n - 2)

信息

这里我们按照力扣的题目描述,认为 base case 是 f(0) = 0 和 f(1) = 1,但在有些斐波那契数列的描述中说 base case 是 f(1) = 1 和 f(2) = 1,其实它们都是一样的。

学校老师讲递归的时候似乎都是拿这个举例。我们也知道这样写代码虽然简洁易懂,但是十分低效,低效在哪里?假设 n = 20,请画出递归树:

这个递归树怎么理解?就是说想要计算原问题 f(20),我就得先计算出子问题 f(19) 和 f(18),然后要计算 f(19),我就要先算出子问题 f(18) 和 f(17),以此类推。最后遇到 f(1) 或者 f(2) 的时候,结果已知,就能直接返回结果,递归树不再向下生长了。

借助算法可视化面板可以更好地帮你理解这个过程,f(20) 的递归树太大,我们展示一下计算 f(5) 的递归过程吧。

递归算法的时间复杂度怎么计算?就是用子问题个数乘以解决一个子问题需要的时间

首先计算子问题个数,即递归树中节点的总数。这棵递归树的高度为 ,所以二叉树的节点总数为 

然后计算解决一个子问题的时间,在本算法中,没有循环,只有 f(n - 1) + f(n - 2) 一个加法操作,时间为 

所以,这个算法的时间复杂度为二者相乘,即 ,指数级别,爆炸。

观察递归树,很明显发现了算法低效的原因:存在大量重复计算。

比如 f(18) 被计算了两次,而且你可以看到,以 f(18) 为根的这个递归树体量巨大,多算一遍,会耗费大量的时间。更何况还不止 f(18) 这一个节点被重复计算,所以这个算法效率很差。

这就是动态规划问题的第一个性质:重叠子问题。下面,我们想办法解决这个问题。

带备忘录的递归解法

即然耗时的原因是重复计算,那么我们可以造一个「备忘录」,每次算出某个子问题的答案后顺便记到「备忘录」里;每次遇到一个子问题别急着计算,先去「备忘录」里查一查,如果发现之前已经解决过这个问题了,直接把答案拿出来用,不要再耗时去计算了。

对于斐波那契数列问题,我们需要一个备忘录记录子问题 f(x) 的值,其中 x 是一个非负整数,所以一般用一个一维数组 memo 充当备忘录就可以了,让 memo[x] 存储子问题 f(x) 的返回值。

当然,你也可以用一个哈希表来存储,思想都是一样的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def fib(n: int) -> int:
# 备忘录全初始化为 -1
# 因为斐波那契数肯定是非负整数,所以初始化为特殊值 -1 表示未计算

# 因为数组的索引从 0 开始,所以需要 n + 1 个空间
# 这样才能把 `f(0) ~ f(n)` 都记录到 memo 中
memo = [-1] * (n + 1)

return self.dp(memo, n)

# 带着备忘录进行递归
def dp(memo: list, n: int) -> int:
# base case
if n == 0 or n == 1:
return n
# 已经计算过,不用再计算了
if memo[n] != -1:
return memo[n]
# 在返回结果之前,存入备忘录
memo[n] = self.dp(memo, n - 1) + self.dp(memo, n - 2)
return memo[n]

现在,画出递归树,你就知道「备忘录」到底做了什么。

实际上,带「备忘录」的递归算法,把一棵存在巨量冗余的递归树通过「剪枝」,改造成了一幅不存在冗余的递归图,极大减少了子问题(即递归图中节点)的个数,每个子问题都只会被计算一次:

递归算法的时间复杂度怎么计算?就是用子问题个数乘以解决一个子问题需要的时间

子问题个数,即图中节点的总数,由于本算法不存在冗余计算,子问题就是 f(0)f(1)f(2) … f(20),数量和输入规模 n = 20 成正比,所以子问题个数为 

解决一个子问题的时间,同上,没有什么循环,时间为 

所以,本算法的时间复杂度是 ,比起指数级复杂度的暴力算法,已经非常高效了。

自顶向下 vs 自底向上

其实如果你只掌握上面的内容,就已经掌握动态规划的解题方法了:无非就是先写出暴力解法,然后用「备忘录」剪枝消除重叠子问题嘛,动态规划就是这么简单。

不过肯定有读者会提问,为什么我见过的很多动态规划解法就是几个 for 循环,好像并不包含递归,也没见到什么备忘录之类的东西,这是怎么回事呢?

实际上,动态规划解法确实有两种表现形式:

第一种是带备忘录的递归解法,或称为「自顶向下」的解法,也就是我们上面展示的,一个递归函数带一个 memo 备忘录。

第二种是 DP table 的迭代解法,或称为「自底向上」的解法,也就是你说的,用 for 循环去迭代 dp 数组进行求解。

这两者的本质是一样的,可以互相转化。迭代解法中的那个 dp 数组,就是递归解法中的 memo 数组

为啥叫「自顶向下」?比如刚才的递归解法,多次点击  可以看到递归树从上向下生长,从一个规模较大的原问题 f(5),向下逐渐分解规模,直到 f(0) 和 f(1) 这两个 base case,然后逐层返回答案,这就叫「自顶向下」。

啥叫「自底向上」?就是反过来嘛。我们直接从最底下、最简单、问题规模最小、已知结果的 f(0) 和 f(1)(base case)开始往上推出 f(2), f(3)... 最后推出我们想要的 f(5),这就是「自底向上」。

其实「自底向上」和「自顶向下」本质是一样的,只是视角不同而已

比如我把上面写的带备忘录的递归解法稍微改一改,把对 base case n == 0 || n == 1 的处理从递归函数 dp 中移到 memo 数组中,这应该没问题吧?我们再来看 fib(5) 的计算过程。

你可以多次点击  这一行代码,请注意递归树和 memo 数组的变化

可以看到,递归树从下向上传递结果的过程,就是 memo 数组从 base case 向右推算的过程,这就叫自底向上,是不是很直观?

到这里你应该也观察出来了,其实整个计算过程就是在从左到右计算 memo 的值,那又何苦用递归了,搞这么复杂。一个 for 循环是不是就够用了?

dp 数组的迭代(递推)解法

有了上一步的启发,我们不再使用递归函数,直接创建一个数组(DP table),用一个 for 循环从 base case 开始从左到右进行计算即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
def fib(n: int) -> int:
if n == 0 or n == 1:
return n
# dp table
dp = [0] * (n + 1)
# base case
dp[0] = 0
dp[1] = 1
# 状态转移
for i in range(2, n + 1):
dp[i] = dp[i - 1] + dp[i - 2]

return dp[n]

画个图就很好理解了,而且你发现这个 DP table 特别像之前那个「剪枝」后的结果,只是反过来算而已:

实际上,带备忘录的递归解法中的那个「备忘录」memo 数组,最终完成后就是这个解法中的 dp 数组,你对比一下可视化面板中两个算法执行的过程可以更直观地看出它俩的联系。

所以说自顶向下、自底向上两种解法本质其实是差不多的,大部分情况下,效率也基本相同。

拓展延伸

这里,引出「状态转移方程」这个名词,实际上就是描述问题结构的数学形式:

为啥叫「状态转移方程」?其实就是为了听起来高端。

f(n) 的函数参数会不断变化,所以你把参数 n 想做一个状态,这个状态 n 是由状态 n - 1 和状态 n - 2 转移(相加)而来,这就叫状态转移,仅此而已。

你会发现,上面的几种解法中的所有操作,例如 return f(n - 1) + f(n - 2)dp[i] = dp[i - 1] + dp[i - 2],以及对备忘录或 DP table 的初始化操作,都是围绕这个方程式的不同表现形式。

可见列出「状态转移方程」的重要性,它是解决问题的核心,而且很容易发现,其实状态转移方程直接代表着暴力解法。

千万不要看不起暴力解,动态规划问题最困难的就是写出这个暴力解,即状态转移方程

只要写出暴力解,优化方法无非是用备忘录或者 DP table,再无奥妙可言。

这个例子的最后,讲一个细节优化。

细心的读者会发现,根据斐波那契数列的状态转移方程,当前状态 n 只和之前的 n-1, n-2 两个状态有关,其实并不需要那么长的一个 DP table 来存储所有的状态,只要想办法存储之前的两个状态就行了。

所以,可以进一步优化,把空间复杂度降为 。这也就是我们最常见的计算斐波那契数的算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
def fib(n: int) -> int:
if n == 0 or n == 1:
# base case
return n
# 分别代表 dp[i - 1] 和 dp[i - 2]
dp_i_1, dp_i_2 = 1, 0
for i in range(2, n + 1):
# dp[i] = dp[i - 1] + dp[i - 2];
dp_i = dp_i_1 + dp_i_2
# 滚动更新
dp_i_2 = dp_i_1
dp_i_1 = dp_i
return dp_i_1

这一般是动态规划问题的最后一步优化,如果我们发现每次状态转移只需要 DP table 中的一部分,那么可以尝试缩小 DP table 的大小,只记录必要的数据,从而降低空间复杂度。

上述例子就相当于把 DP table 的大小从 n 缩小到 2,即把空间复杂度下降了一个量级。一般来说用来把一个二维的 DP table 压缩成一维,即把空间复杂度从  压缩到 

有人会问,动态规划的另一个重要特性「最优子结构」,怎么没有涉及?下面会涉及。斐波那契数列的例子严格来说不算动态规划,因为没有涉及求最值,以上旨在说明重叠子问题的消除方法,演示得到最优解法逐步求精的过程。下面,看第二个例子,凑零钱问题。

二、凑零钱问题

这是力扣第 322 题「零钱兑换」:

给你 k 种面值的硬币,面值分别为 c1, c2 ... ck,每种硬币的数量无限,再给一个总金额 amount,问你最少需要几枚硬币凑出这个金额,如果不可能凑出,算法返回 -1 。算法的函数签名如下:

1
2
# coins 中是可选硬币面值,amount 是目标金额
def coinChange(coins: List[int], amount: int) -> int:

比如说 k = 3,面值分别为 1,2,5,总金额 amount = 11。那么最少需要 3 枚硬币凑出,即 11 = 5 + 5 + 1。

你认为计算机应该如何解决这个问题?显然,就是把所有可能的凑硬币方法都穷举出来,然后找找看最少需要多少枚硬币。

暴力递归

首先,这个问题是动态规划问题,因为它具有「最优子结构」的。要符合「最优子结构」,子问题间必须互相独立。啥叫相互独立?你肯定不想看数学证明,我用一个直观的例子来讲解。

比如说,假设你考试,每门科目的成绩都是互相独立的。你的原问题是考出最高的总成绩,那么你的子问题就是要把语文考到最高,数学考到最高…… 为了每门课考到最高,你要把每门课相应的选择题分数拿到最高,填空题分数拿到最高…… 当然,最终就是你每门课都是满分,这就是最高的总成绩。

得到了正确的结果:最高的总成绩就是总分。因为这个过程符合最优子结构,「每门科目考到最高」这些子问题是互相独立,互不干扰的。

但是,如果加一个条件:你的语文成绩和数学成绩会互相制约,不能同时达到满分,数学分数高,语文分数就会降低,反之亦然。

这样的话,显然你能考到的最高总成绩就达不到总分了,按刚才那个思路就会得到错误的结果。因为「每门科目考到最高」的子问题并不独立,语文数学成绩户互相影响,无法同时最优,所以最优子结构被破坏。

回到凑零钱问题,为什么说它符合最优子结构呢?假设你有面值为 1, 2, 5 的硬币,你想求 amount = 11 时的最少硬币数(原问题),如果你知道凑出 amount = 10, 9, 6 的最少硬币数(子问题),你只需要把子问题的答案加一(再选一枚面值为 1, 2, 5 的硬币),求个最小值,就是原问题的答案。因为硬币的数量是没有限制的,所以子问题之间没有相互制,是互相独立的。

那么,既然知道了这是个动态规划问题,就要思考如何列出正确的状态转移方程?

1、确定「状态」,也就是原问题和子问题中会变化的变量。由于硬币数量无限,硬币的面额也是题目给定的,只有目标金额会不断地向 base case 靠近,所以唯一的「状态」就是目标金额 amount

2、确定「选择」,也就是导致「状态」产生变化的行为。目标金额为什么变化呢,因为你在选择硬币,你每选择一枚硬币,就相当于减少了目标金额。所以说所有硬币的面值,就是你的「选择」。

3、明确 dp 函数/数组的定义。我们这里讲的是自顶向下的解法,所以会有一个递归的 dp 函数,一般来说函数的参数就是状态转移中会变化的量,也就是上面说到的「状态」;函数的返回值就是题目要求我们计算的量。就本题来说,状态只有一个,即「目标金额」,题目要求我们计算凑出目标金额所需的最少硬币数量。

所以我们可以这样定义 dp 函数:dp(n) 表示,输入一个目标金额 n,返回凑出目标金额 n 所需的最少硬币数量

那么根据这个定义,我们的最终答案就是 dp(amount) 的返回值。

搞清楚上面这几个关键点,解法的伪码就可以写出来了:

1
2
3
4
5
6
7
8
9
10
11
12
13
# 伪码框架
def coinChange(coins: List[int], amount: int) -> int:
# 题目要求的最终结果是 dp(amount)
return dp(coins, amount)

# 定义:要凑出金额 n,至少要 dp(coins, n) 个硬币
def dp(coins: List[int], n: int) -> int:
# 做选择,选择需要硬币最少的那个结果
# 初始化res为正无穷
res = float('inf')
for coin in coins:
res = min(res, sub_problem + 1)
return res

根据伪码,我们加上 base case 即可得到最终的答案。显然目标金额为 0 时,所需硬币数量为 0;当目标金额小于 0 时,无解,返回 -1:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution:
def coinChange(self, coins: List[int], amount: int) -> int:
# 题目要求的最终结果是 dp(amount)
return self.dp(coins, amount)

# 定义:要凑出目标金额 amount,至少要 dp(coins, amount) 个硬币
def dp(self, coins, amount):
# base case
if amount == 0:
return 0
if amount < 0:
return -1

res = float('inf')
for coin in coins:
# 计算子问题的结果
subProblem = self.dp(coins, amount - coin)
# 子问题无解则跳过
if subProblem == -1:
continue
# 在子问题中选择最优解,然后加一
res = min(res, subProblem + 1)

return res if res != float('inf') else -1

[!Info]

这里 coinChange 和 dp 函数的签名完全一样,所以理论上不需要额外写一个 dp 函数。但为了后文讲解方便,这里还是另写一个 dp 函数来实现主要逻辑。

另外,我经常看到有读者留言问,子问题的结果为什么要加 1(subProblem + 1),而不是加硬币金额之类的。我这里统一提示一下,动态规划问题的关键是 dp 函数/数组的定义,你这个函数的返回值代表什么?你回过头去搞清楚这一点,然后就知道为什么要给子问题的返回值加 1 了。

至此,状态转移方程其实已经完成了,以上算法已经是暴力解法了,以上代码的数学形式就是状态转移方程:

至此,这个问题其实就解决了,只不过需要消除一下重叠子问题,比如 amount = 11, coins = {1,2,5} 时画出递归树看看:

递归算法的时间复杂度分析:子问题总数 x 解决每个子问题所需的时间

子问题总数为递归树的节点个数,但算法会进行剪枝,剪枝的时机和题目给定的具体硬币面额有关,所以可以想象,这棵树生长的并不规则,确切算出树上有多少节点是比较困难的。对于这种情况,我们一般的做法是按照最坏的情况估算一个时间复杂度的上界。

假设目标金额为 n,给定的硬币个数为 k,那么递归树最坏情况下高度为 n(全用面额为 1 的硬币),然后再假设这是一棵满 k 叉树,则节点的总数在 k^n 这个数量级。

接下来看每个子问题的复杂度,由于每次递归包含一个 for 循环,复杂度为 ,相乘得到总时间复杂度为 ,指数级别。

带备忘录的递归

类似之前斐波那契数列的例子,只需要稍加修改,就可以通过备忘录消除子问题:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Solution:
def __init__(self):
self.memo = []

def coinChange(self, coins: List[int], amount: int) -> int:
self.memo = [-666] * (amount + 1)
# 备忘录初始化为一个不会被取到的特殊值,代表还未被计算
return self.dp(coins, amount)

def dp(self, coins, amount):
if amount == 0: return 0
if amount < 0: return -1
# 查备忘录,防止重复计算
if self.memo[amount] != -666:
return self.memo[amount]

res = float('inf')
for coin in coins:
# 计算子问题的结果
subProblem = self.dp(coins, amount - coin)
# 子问题无解则跳过
if subProblem == -1: continue
# 在子问题中选择最优解,然后加一
res = min(res, subProblem + 1)
# 把计算结果存入备忘录
self.memo[amount] = res if res != float('inf') else -1
return self.memo[amount]

不画图了,很显然「备忘录」大大减小了子问题数目,完全消除了子问题的冗余,所以子问题总数不会超过金额数 n,即子问题数目为。处理一个子问题的时间不变,仍是 ,所以总的时间复杂度是 

dp 数组的迭代解法

当然,我们也可以自底向上使用 dp table 来消除重叠子问题,关于「状态」「选择」和 base case 与之前没有区别,dp 数组的定义和刚才 dp 函数类似,也是把「状态」,也就是目标金额作为变量。不过 dp 函数体现在函数参数,而 dp 数组体现在数组索引:

dp 数组的定义:当目标金额为 i 时,至少需要 dp[i] 枚硬币凑出

根据我们文章开头给出的动态规划代码框架可以写出如下解法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Solution:
def coinChange(self, coins: List[int], amount: int) -> int:
# 数组大小为 amount + 1,初始值也为 amount + 1
dp = [amount + 1] * (amount + 1)

dp[0] = 0
# base case
# 外层 for 循环在遍历所有状态的所有取值
for i in range(len(dp)):
# 内层 for 循环在求所有选择的最小值
for coin in coins:
# 子问题无解,跳过
if i - coin < 0:
continue
dp[i] = min(dp[i], 1 + dp[i - coin])
return -1 if dp[amount] == amount + 1 else dp[amount]

[!Info]

为啥 dp 数组中的值都初始化为 amount + 1 呢,因为凑成 amount 金额的硬币数最多只可能等于 amount(全用 1 元面值的硬币),所以初始化为 amount + 1 就相当于初始化为正无穷,便于后续取最小值。为啥不直接初始化为 int 型的最大值 Integer.MAX_VALUE 呢?因为后面有 dp[i - coin] + 1,这就会导致整型溢出。

三、最后总结

第一个斐波那契数列的问题,解释了如何通过「备忘录」或者「dp table」的方法来优化递归树,并且明确了这两种方法本质上是一样的,只是自顶向下和自底向上的不同而已。

第二个凑零钱的问题,展示了如何流程化确定「状态转移方程」,只要通过状态转移方程写出暴力递归解,剩下的也就是优化递归树,消除重叠子问题而已。

如果你不太了解动态规划,还能看到这里,真得给你鼓掌,相信你已经掌握了这个算法的设计技巧。

计算机解决问题其实没有任何特殊的技巧,它唯一的解决办法就是穷举,穷举所有可能性。算法设计无非就是先思考「如何穷举」,然后再追求「如何聪明地穷举」。

列出状态转移方程,就是在解决「如何穷举」的问题。之所以说它难,一是因为很多穷举需要递归实现,二是因为有的问题本身的解空间复杂,不那么容易穷举完整。

备忘录、DP table 就是在追求「如何聪明地穷举」。用空间换时间的思路,是降低时间复杂度的不二法门,除此之外,试问,还能玩出啥花活?

回溯算法解题套路框架


读完本文,你不仅学会了算法套路,还可以顺便解决如下题目:

LeetCode 力扣 难度
46. Permutations 46. 全排列

[!前置知识]

阅读本文前,你需要先学习:

本文解决几个问题:

回溯算法是什么?解决回溯算法相关的问题有什么技巧?如何学习回溯算法?回溯算法代码是否有规律可循?

其实回溯算法和我们常说的 DFS 算法基本可以认为是同一种算法.

抽象地说,解决一个回溯问题,实际上就是遍历一棵决策树的过程,树的每个叶子节点存放着一个合法答案。你把整棵树遍历一遍,把叶子节点上的答案都收集起来,就能得到所有的合法答案

站在回溯树的一个节点上,你只需要思考 3 个问题:

1、路径:也就是已经做出的选择。

2、选择列表:也就是你当前可以做的选择。

3、结束条件:也就是到达决策树底层,无法再做选择的条件。

如果你不理解这三个词语的解释,没关系,我们后面会用「全排列」这个经典的回溯算法问题来帮你理解这些词语是什么意思,现在你先留着印象。

代码方面,回溯算法的框架:

1
2
3
4
5
6
7
8
9
10
result = []
def backtrack(路径, 选择列表):
if 满足结束条件:
result.add(路径)
return

for 选择 in 选择列表:
做选择
backtrack(路径, 选择列表)
撤销选择

其核心就是 for 循环里面的递归,在递归调用之前「做选择」,在递归调用之后「撤销选择」,特别简单。

什么叫做选择和撤销选择呢,这个框架的底层原理是什么呢?下面我们就通过「全排列」这个问题来解开之前的疑惑,详细探究一下其中的奥妙!

全排列问题解析

力扣第 46 题「全排列」就是给你输入一个数组 nums,让你返回这些数字的全排列。

我们在高中的时候就做过排列组合的数学题,我们也知道 n 个不重复的数,全排列共有 n! 个。那么我们当时是怎么穷举全排列的呢?

比方说给三个数 [1,2,3],你肯定不会无规律地乱穷举,一般是这样:

先固定第一位为 1,然后第二位可以是 2,那么第三位只能是 3;然后可以把第二位变成 3,第三位就只能是 2 了;然后就只能变化第一位,变成 2,然后再穷举后两位……

其实这就是回溯算法,我们高中无师自通就会用,或者有的同学直接画出如下这棵回溯树:

只要从根遍历这棵树,记录路径上的数字,其实就是所有的全排列。我们不妨把这棵树称为回溯算法的「决策树」

为啥说这是决策树呢,因为你在每个节点上其实都在做决策。比如说你站在下图的红色节点上:

你现在就在做决策,可以选择 1 那条树枝,也可以选择 3 那条树枝。为啥只能在 1 和 3 之中选择呢?因为 2 这个树枝在你身后,这个选择你之前做过了,而全排列是不允许重复使用数字的。

现在可以解答开头的几个名词:[2] 就是「路径」,记录你已经做过的选择;[1,3] 就是「选择列表」,表示你当前可以做出的选择;「结束条件」就是遍历到树的底层叶子节点,这里也就是选择列表为空的时候

如果明白了这几个名词,可以把「路径」和「选择」列表作为决策树上每个节点的属性,比如下图列出了几个蓝色节点的属性:

我们定义的 backtrack 函数其实就像一个指针,在这棵树上游走,同时要正确维护每个节点的属性,每当走到树的底层叶子节点,其「路径」就是一个全排列

再进一步,如何遍历一棵树?各种搜索问题其实都是树的遍历问题,而多叉树的遍历框架就是这样:

1
2
3
4
5
def traverse(root: TreeNode):
for child in root.children:
# 前序位置需要的操作
traverse(child)
# 后序位置需要的操作

[!Info]

细心的读者肯定会疑问:多叉树 DFS 遍历框架的前序位置和后序位置应该在 for 循环外面,并不应该是在 for 循环里面呀?为什么在回溯算法中跑到 for 循环里面了?

是的,DFS 算法的前序和后序位置应该在 for 循环外面,不过回溯算法和 DFS 算法略有不同

而所谓的前序遍历和后序遍历,他们只是两个很有用的时间点,我给你画张图你就明白了:

前序遍历的代码在进入某一个节点之前的那个时间点执行,后序遍历代码在离开某个节点之后的那个时间点执行

回想我们刚才说的,「路径」和「选择」是每个节点的属性,函数在树上游走要正确处理节点的属性,那么就要在这两个特殊时间点搞点动作:

现在,你是否理解了回溯算法的这段核心框架?

1
2
3
4
5
6
7
8
for 选择 in 选择列表:
# 做选择
将该选择从选择列表移除
路径.add(选择)
backtrack(路径, 选择列表)
# 撤销选择
路径.remove(选择)
将该选择再加入选择列表

我们只要在递归之前做出选择,在递归之后撤销刚才的选择,就能正确得到每个节点的选择列表和路径。

下面,直接看全排列代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution:
def __init__(self):
self.res = []

# 主函数,输入一组不重复的数字,返回它们的全排列
def permute(self, nums):
# 记录「路径」
track = []
# 「路径」中的元素会被标记为 true,避免重复使用
used = [False] * len(nums)

self.backtrack(nums, track, used)
return self.res

# 路径:记录在 track 中
# 选择列表:nums 中不存在于 track 的那些元素(used[i] 为 false)
# 结束条件:nums 中的元素全都在 track 中出现
def backtrack(self, nums, track, used):
# 触发结束条件
if len(track) == len(nums):
self.res.append(track.copy())
return

for i in range(len(nums)):
# 排除不合法的选择
if used[i]:
# nums[i] 已经在 track 中,跳过
continue
# 做选择
track.append(nums[i])
used[i] = True
# 进入下一层决策树
self.backtrack(nums, track, used)
# 取消选择
track.pop()
used[i] = False

我们这里稍微做了些变通,没有显式记录「选择列表」,而是通过 used 数组排除已经存在 track 中的元素,从而推导出当前的选择列表:

至此,我们就通过全排列问题详解了回溯算法的底层原理。当然,这个算法解决全排列不是最高效的,你可能看到有的解法连 used 数组都不使用,通过交换元素达到目的。

但是必须说明的是,不管怎么优化,都符合回溯框架,而且时间复杂度都不可能低于 ,因为穷举整棵决策树是无法避免的,你最后肯定要穷举出 N! 种全排列结果。

这也是回溯算法的一个特点,不像动态规划存在重叠子问题可以优化,回溯算法就是纯暴力穷举,复杂度一般都很高

最后总结

回溯算法就是个多叉树的遍历问题,关键就是在前序遍历和后序遍历的位置做一些操作,算法框架如下:

1
2
3
4
5
def backtrack(...):
for 选择 in 选择列表:
做选择
backtrack(...)
撤销选择

写 backtrack 函数时,需要维护走过的「路径」和当前可以做的「选择列表」,当触发「结束条件」时,将「路径」记入结果集

其实想想看,回溯算法和动态规划是不是有点像呢?我们在动态规划系列文章中多次强调,动态规划的三个需要明确的点就是「状态」「选择」和「base case」,是不是就对应着走过的「路径」,当前的「选择列表」和「结束条件」?

动态规划和回溯算法底层都把问题抽象成了树的结构,但这两种算法在思路上是完全不同的。在 二叉树心法(纲领篇) 你将看到动态规划和回溯算法更深层次的区别和联系。

BFS 算法解题套路框架

读完本文,你不仅学会了算法套路,还可以顺便解决如下题目:

LeetCode 力扣 难度
752. Open the Lock 752. 打开转盘锁
773. Sliding Puzzle 773. 滑动谜题

我多次强调,DFS/回溯/BFS 这类算法,本质上就是把具体的问题抽象成树结构,然后遍历这棵树进行暴力穷举,所以这些穷举算法的代码本质上就是树的遍历代码。

梳理一下这里面的因果关系:

DFS/回溯算法的本质就是递归遍历一棵穷举树(多叉树),而多叉树的递归遍历又是从二叉树的递归遍历衍生出来的。所以我说 DFS/回溯算法的本质是二叉树的递归遍历。

BFS 算法的本质就是遍历一幅图

而图的遍历算法其实就是多叉树的遍历算法加了个 visited 数组防止死循环;多叉树的遍历算法又是从二叉树遍历算法衍生出来的。所以我说 BFS 算法的本质就是二叉树的层序遍历。

其实所谓的最短路径,都可以类比成二叉树最小深度这类问题(寻找距离根节点最近的叶子节点),递归遍历必须要遍历整棵树的所有节点才能找到目标节点,而层序遍历不需要遍历所有节点就能搞定,所以层序遍历适合解决这类最短路径问题。

这么梳理应该够清楚了吧?

本文的重点在于,教会你如何对具体的算法问题进行抽象和转化,然后套用 BFS 算法框架进行求解。

在真实的面试笔试题目中,一般不是直接让你遍历树/图这种标准数据结构,而是给你一个具体的场景题,你需要把具体的场景抽象成一个标准的图/树结构,然后利用 BFS 算法穷举得出答案。

比方说给你一个迷宫游戏,请你计算走到出口的最小步数?如果这个迷宫还包含传送门,可以瞬间传送到另一个位置,那么最小步数又是多少?

再比如说两个单词,要求你通过某些替换,把其中一个变成另一个,每次可以替换/删除/插入一个字符,最少要操作几次?

再比如说连连看游戏,两个方块消除的条件不仅仅是图案相同,还得保证两个方块之间的最短连线不能多于两个拐点。你玩连连看,点击两个坐标,游戏是如何判断它俩的最短连线有几个拐点的?

你看上面这些例子,是不是感觉和我们前面学习的树/图结构完全扯不上关系?但实际上只要稍加抽象,它们就是树/图结构的遍历,实在是太简单枯燥了。

下面用几道例题来讲解 BFS 的套路框架,以后再也不要觉得这类问题难解决了。

算法框架

BFS 的算法框架其实就是 [图结构的 DFS/BFS 遍历]

对于实际的 BFS 算法问题,第一种写法最简单,但局限性太大,不常用;第二种写法最常用,中等难度的 BFS 算法题基本都可以用这种写法解决;第三种写法稍微复杂一点,但灵活性最高,可能会在一些难度较大的的 BFS 问题中用到。

本文的例题都是中等难度,所以本文给出的解法都以第二种写法为准:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 从 s 开始 BFS 遍历图的所有节点,且记录遍历的步数
def bfs(graph, s, target):
visited = [False] * len(graph)
q = deque([s])
visited[s] = True
# 记录从 s 开始走到当前节点的步数
step = 0

while q:
sz = len(q)
for i in range(sz):
cur = q.popleft()
print(f"visit {cur} at step {step}")
# 判断是否到达终点
if cur == target:
return step

# 将邻居节点加入队列,向四周扩散搜索
for to in neighborsOf(cur):
if not visited[to]:
q.append(to)
visited[to] = True
step += 1
# 如果走到这里,说明在图中没有找到目标节点
return -1

上面这个代码框架几乎就是从 图结构的 DFS/BFS 遍历 中复制过来的,只不过添加了一个 target 参数,当第一次走到 target 时,直接结束算法并返回走过的步数。

下面我们用几个具体的例题来看看如何运用这个框架。

滑动谜题

力扣第 773 题「滑动谜题」就是一个可以运用 BFS 框架解决的题目,题目的要求如下:

给你一个 2x3 的滑动拼图,用一个 2x3 的数组 board 表示。拼图中有数字 0~5 六个数,其中数字 0 就表示那个空着的格子,你可以移动其中的数字,当 board 变为 [[1,2,3],[4,5,0]] 时,赢得游戏。

请你写一个算法,计算赢得游戏需要的最少移动次数,如果不能赢得游戏,返回 -1。

比如说输入的二维数组 board = [[4,1,2],[5,0,3]],算法应该返回 5:

如果输入的是 board = [[1,2,3],[5,4,0]],则算法返回 -1,因为这种局面下无论如何都不能赢得游戏。

我感觉这题还挺有意思的,小时候玩过类似的拼图游戏,比如华容道:

你需要移动这些方块,想办法让曹操从初始位置移动到最下方的出口位置。

华容道应该比这道题更难一些,因为力扣的这道题中每个方块的大小可以看作是相同的,而华容道中每个方块的大小还不一样。

回到这道题,我们如何把这道题抽象成树/图的结构,从而用 BFS 算法框架来解决呢?

其实棋盘的初始状态就可以认为是起点:

1
2
[[2,4,1],
[5,0,3]]

我们最终的目标状态是把棋盘变成这样:

1
2
[[1,2,3],
[4,5,0]]

那么这就可以认为是终点。

现在这个问题不就成为了一个图的问题了吗?题目问的其实就是从起点到终点所需的最短路径是多少嘛。

起点的邻居节点是谁?把数字 0 和上下左右的数字进行交换,其实就是起点的四个邻居节点嘛(由于本题中棋盘的大小是 2x3,所以索引边界内的实际邻居节点会小于四个):

以此类推,这四个邻居节点还有各自的四个邻居节点,那这不就是一幅图结构吗?

那么我从起点开始使用 BFS 算法遍历这幅图,第一次到达终点时,走过的步数就是答案。

伪码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
int bfs(int[][] board, int[][] target) {
Queue<int[][]> q = new LinkedList<>();
HashSet visited = new HashSet<>();

// 将起点加入队列
q.offer(board);
visited.add(board);

int step = 0;
while (!q.isEmpty()) {
int sz = q.size();
for (int i = 0; i < sz; i++) {
int[][] cur = q.poll();
// 判断是否到达终点
if (cur == target) {
return step;
}
// 将当前节点的邻居节点加入队列
for (int[][] neighbor : getNeighbors(cur)) {
if (!visited.contains(neighbor)) {
q.offer(neighbor);
visited.add(neighbor);
}
}
}
step++;
}
return -1;
}

List<int[][]> getNeighbors(int[][] board) {
// 将 board 中的数字 0 和上下左右的数字进行交换,得到 4 个邻居节点
}

对于这道题,我们抽象出来的图结构也是会包含环的,所以需要一个 visited 数组记录已经走过的节点,避免成环导致死循环。

比如说我从 [[2,4,1],[5,0,3]] 节点开始,数字 0 向右移动得到新节点 [[2,4,1],[5,3,0]],但是这个新节点中的 0 也可以向左移动的,又会回到 [[2,4,1],[5,0,3]],这其实就是成环。我们也需要一个 visited 哈希集合来记录已经走过的节点,防止成环导致的死循环。

还有一个问题,这道题中 board 是一个二维数组,我们在 哈希表/哈希集合原理 中介绍过,二维数组这种可变数据结构是无法直接加入哈希集合的。

所以我们还要再用一点技巧,想办法把二维数组转化成一个不可变类型才能存到哈希集合中。常见的解决方案是把二维数组序列化成一个字符串,这样就可以直接存入哈希集合了。

其中比较有技巧性的点在于,二维数组有「上下左右」的概念,压缩成一维的字符串后后,还怎么把数字 0 和上下左右的数字进行交换

对于这道题,题目说输入的数组大小都是 2 x 3,所以我们可以直接手动写出来这个映射:

1
2
3
4
5
6
7
8
9
# 记录一维字符串的相邻索引
neighbor = [
[1, 3],
[0, 4, 2],
[1, 5],
[0, 4],
[3, 1, 5],
[4, 2]
]

**这个映射的含义就是,在一维字符串中,索引 i 在二维数组中的的相邻索引为 neighbor[i]**。

例如,我们可以知道 neighbor[4] 的周围元素为 neighbor[3], neighbor[1], neighbor[5]

这样,无论数字 0 在哪里,都可以通过这个索引映射得到它的相邻索引进行交换了。下面是完整的代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from collections import deque

class Solution:
def slidingPuzzle(self, board):
target = "123450"
# 将 2x3 的数组转化成字符串作为 BFS 的起点
start = ""
for i in range(len(board)):
for j in range(len(board[0])):
start += str(board[i][j])

# ****** BFS 算法框架开始 ******
q = deque()
visited = set()
# 从起点开始 BFS 搜索
q.append(start)
visited.add(start)

step = 0
while q:
# 当前层的节点数量
sz = len(q)
for _ in range(sz):
cur = q.popleft()
# 判断是否达到目标局面
if cur == target:
return step
# 将数字 0 和相邻的数字交换位置
for neighbor_board in self.getNeighbors(cur):
# 防止走回头路
if neighbor_board not in visited:
q.append(neighbor_board)
visited.add(neighbor_board)
step += 1
# ****** BFS 算法框架结束 ******
return -1

def getNeighbors(self, board):
# 记录一维字符串的相邻索引
mapping = [
[1, 3],
[0, 4, 2],
[1, 5],
[0, 4],
[3, 1, 5],
[4, 2]
]

idx = board.index('0')
neighbors = []
for adj in mapping[idx]:
new_board = self.swap(board, idx, adj)
neighbors.append(new_board)
return neighbors

def swap(self, board, i, j):
chars = list(board)
chars[i], chars[j] = chars[j], chars[i]
return ''.join(chars)

这道题就解决了。你会发现 BFS 算法本身的写法都是固定的套路,这道题的难点其实在于将题目转化为 BFS 穷举的模型,然后用合理的方法将多维数组转化成字符串,以便哈希集合记录访问过的节点。

下面再看一道实际场景题。

解开密码锁的最少次数

来看力扣第 752 题「打开转盘锁」,比较有意思:

你有一个带有四个圆形拨轮的转盘锁。每个拨轮都有10个数字: '0', '1', '2', '3', '4', '5', '6', '7', '8', '9' 。每个拨轮可以自由旋转:例如把 '9' 变为 '0''0' 变为 '9' 。每次旋转都只能旋转一个拨轮的一位数字。

锁的初始数字为 '0000' ,一个代表四个拨轮的数字的字符串。

列表 deadends 包含了一组死亡数字,一旦拨轮的数字和列表里的任何一个元素相同,这个锁将会被永久锁定,无法再被旋转。

字符串 target 代表可以解锁的数字,你需要给出解锁需要的最小旋转次数,如果无论如何不能解锁,返回 -1 。

示例 1:

输入: deadends = [“0201”,”0101”,”0102”,”1212”,”2002”], target = “0202”
输出: 6
解释:
可能的移动序列为 “0000” -> “1000” -> “1100” -> “1200” -> “1201” -> “1202” -> “0202”。
注意 “0000” -> “0001” -> “0002” -> “0102” -> “0202” 这样的序列是不能解锁的,
因为当拨动到 “0102” 时这个锁就会被锁定。

示例 2:

输入: deadends = [“8888”], target = “0009”
输出: 1
解释: 把最后一位反向旋转一次即可 “0000” -> “0009”。

示例 3:

输入: deadends = [“8887”,”8889”,”8878”,”8898”,”8788”,”8988”,”7888”,”9888”], target = “8888”
输出:-1
解释: 无法旋转到目标数字且不被锁定。

提示:

  • 1 <= deadends.length <= 500
  • deadends[i].length == 4
  • target.length == 4
  • target 不在 deadends 之中
  • target 和 deadends[i] 仅由若干位数字组成

题目来源:力扣 752. 打开转盘锁

函数签名如下:

1
2
3
class Solution:
def openLock(self, deadends: List[str], target: str) -> int:
# ...

题目中描述的就是我们生活中常见的那种密码锁,如果没有任何约束,最少的拨动次数很好算。比方说想拨到 "1234",那一个个数字拨动就可以了,最少的拨动次数就是 1 + 2 + 3 + 4 = 10 次。

但现在的难点就在于,在拨动密码锁的过程中不能出现 deadends,这样就有一些难度了。如果遇到了 deadends,你该怎么处理,才能使得总的拨动次数最少呢?

千万不要陷入细节,尝试去想各种具体的情况。要知道算法的本质就是穷举,我们直接从 "0000" 开始暴力穷举,把所有可能的拨动情况都穷举出来,难道还怕找不到最少的拨动次数么?

第一步,我们不管所有的限制条件,不管 deadends 和 target 的限制,就思考一个问题:如果让你设计一个算法,穷举所有可能的密码组合,你怎么做

就从 "0000" 开始,如果你只转一下锁,有几种可能?总共有 4 个位置,每个位置可以向上转,也可以向下转,也就是可以穷举出 "1000", "9000", "0100", "0900"... 共 8 种密码。

然后,再以这 8 种密码作为基础,其中每个密码又可以转动一下衍生出 8 种密码,以此类推…

心里那棵递归树出来没有?应该是一棵八叉树,每个节点都有 8 个子节点,向下衍生。

下面这段伪码就描述了上述思路,用层序遍历一棵八叉树:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import List

# 将 s[j] 向上拨动一次
def plusOne(s: str, j: int) -> str:
ch = list(s)
if ch[j] == '9':
ch[j] = '0'
else:
ch[j] = chr(ord(ch[j]) + 1)
return ''.join(ch)

# 将 s[i] 向下拨动一次
def minusOne(s: str, j: int) -> str:
ch = list(s)
if ch[j] == '0':
ch[j] = '9'
else:
ch[j] = chr(ord(ch[j]) - 1)
return ''.join(ch)

# BFS 框架,寻找最少的拨动次数
def BFS(target: str) -> int:
q = ['0000']

while q:
sz = len(q)
# 将当前队列中的所有节点向周围扩散
for _ in range(sz):
cur = q.pop(0)
# 判断是否到达终点
if cur == target:
return step
# 将一个节点的相邻节点加入队列
for neighbor in getNeighbors(cur):
q.append(neighbor)
# 在这里增加步数
step += 1
return -1

# 将 s 的每一位向上拨动一次或向下拨动一次,8 种相邻密码
def getNeighbors(s: str) -> List[str]:
neighbors = []
for i in range(4):
neighbors.append(plusOne(s, i))
neighbors.append(minusOne(s, i))
return neighbors

这个代码已经可以穷举所有可能的密码组合了,但是还有些问题需要解决。

1、会走回头路,我们可以从 "0000" 拨到 "1000",但是等从队列拿出 "1000" 时,还会拨出一个 "0000",这样的话会产生死循环。

这个问题很好解决,其实就是成环了嘛,我们用一个 visited 集合记录已经穷举过的密码,再次遇到时,不要再加到队列里就行了。

2、没有对 deadends 进行处理,按道理这些「死亡密码」是不能出现的。

这个问题也好处理,额外用一个 deadends 集合记录这些死亡密码,凡是遇到这些密码,不要加到队列里就行了。

或者还可以更简单一些,直接把 deadends 中的死亡密码作为 visited 集合的初始元素,这样也可以达到目的。

下面是完整的代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class Solution:
def openLock(self, deadends: List[str], target: str) -> int:
# 记录需要跳过的死亡密码
deads = set(deadends)
if "0000" in deads:
return -1

# 记录已经穷举过的密码,防止走回头路
visited = set()
q = collections.deque()
# 从起点开始启动广度优先搜索
step = 0
q.append("0000")
visited.add("0000")

while q:
sz = len(q)
# 将当前队列中的所有节点向周围扩散
for _ in range(sz):
cur = q.popleft()

# 判断是否到达终点
if cur == target:
return step

# 将一个节点的合法相邻节点加入队列
for neighbor in self.getNeighbors(cur):
if neighbor not in visited and neighbor not in deads:
q.append(neighbor)
visited.add(neighbor)

# 在这里增加步数
step += 1

# 如果穷举完都没找到目标密码,那就是找不到了
return -1

# 将 s[j] 向上拨动一次
def plusOne(self, s: str, j: int) -> str:
ch = list(s)
if ch[j] == '9':
ch[j] = '0'
else:
ch[j] = chr(ord(ch[j]) + 1)
return ''.join(ch)

# 将 s[i] 向下拨动一次
def minusOne(self, s: str, j: int) -> str:
ch = list(s)
if ch[j] == '0':
ch[j] = '9'
else:
ch[j] = chr(ord(ch[j]) - 1)
return ''.join(ch)

# 将 s 的每一位向上拨动一次或向下拨动一次,8 种相邻密码
def getNeighbors(self, s: str) -> List[str]:
neighbors = []
for i in range(4):
neighbors.append(self.plusOne(s, i))
neighbors.append(self.minusOne(s, i))
return neighbors

双向 BFS 优化

下面再介绍一种 BFS 算法的优化思路:双向 BFS,可以提高 BFS 搜索的效率。

你把这种技巧当做扩展阅读就行,在一般的面试笔试题中,普通的 BFS 算法已经够用了,如果遇到超时无法通过,或者面试官的追问,可以考虑解法是否需要双向 BFS 优化。

双向 BFS 就是从标准的 BFS 算法衍生出来的:

传统的 BFS 框架是从起点开始向四周扩散,遇到终点时停止;而双向 BFS 则是从起点和终点同时开始扩散,当两边有交集的时候停止

为什么这样能够能够提升效率呢?

就好比有 A 和 B 两个人,传统 BFS 就相当于 A 出发去找 B,而 B 待在原地不动;双向 BFS 则是 A 和 B 一起出发,双向奔赴。那当然第二种情况下 A 和 B 可以更快相遇。

图示中的树形结构,如果终点在最底部,按照传统 BFS 算法的策略,会把整棵树的节点都搜索一遍,最后找到 target;而双向 BFS 其实只遍历了半棵树就出现了交集,也就是找到了最短距离。

当然从 Big O 表示法分析算法复杂度的话,这两种 BFS 在最坏情况下都可能遍历完所有节点,所以理论时间复杂度都是 ,但实际运行中双向 BFS 确实会更快一些。

双向 BFS 的局限性

你必须知道终点在哪里,才能使用双向 BFS 进行优化

对于 BFS 算法,我们肯定是知道起点的,但是终点具体是什么,我们在一开始可能并不知道。

比如上面的密码锁问题和滑动拼图问题,题目都明确给出了终点,都可以用双向 BFS 进行优化。

但比如我们在 二叉树的 DFS/BFS 遍历 中讨论的二叉树最小高度的问题,起点是根节点,终点是距离根节点最近的叶子节点,你在算法开始时并不知道终点具体在哪里,所以就没办法使用双向 BFS 进行优化。

下面我们就以密码锁问题为例,看看如何将普通 BFS 算法优化为双向 BFS 算法,直接看代码吧:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class Solution:
def openLock(self, deadends: List[str], target: str) -> int:
deads = set(deadends)
# base case
if "0000" in deads: return -1
if target == "0000": return 0

# 用集合不用队列,可以快速判断元素是否存在
q1 = set()
q2 = set()
visited = set()

step = 0
q1.add("0000")
visited.add("0000")
q2.add(target)
visited.add(target)

while q1 and q2:
# 在这里增加步数
step += 1

# 哈希集合在遍历的过程中不能修改,所以用 newQ1 存储邻居节点
newQ1 = set()

# 获取 q1 中的所有节点的邻居
for cur in q1:
# 将一个节点的未遍历相邻节点加入集合
for neighbor in self.getNeighbors(cur):
# 判断是否到达终点
if neighbor in q2:
return step
if neighbor not in visited and neighbor not in deads:
newQ1.add(neighbor)
visited.add(neighbor)
# newQ1 存储着 q1 的邻居节点
q1 = newQ1
# 因为每次 BFS 都是扩散 q1,所以把元素数量少的集合作为 q1
if len(q1) > len(q2):
q1, q2 = q2, q1
return -1

# 将 s[j] 向上拨动一次
def plusOne(self, s: str, j: int) -> str:
ch = list(s)
if ch[j] == '9':
ch[j] = '0'
else:
ch[j] = str(int(ch[j]) + 1)
return ''.join(ch)

# 将 s[i] 向下拨动一次
def minusOne(self, s: str, j: int) -> str:
ch = list(s)
if ch[j] == '0':
ch[j] = '9'
else:
ch[j] = str(int(ch[j]) - 1)
return ''.join(ch)

def getNeighbors(self, s: str) -> List[str]:
neighbors = []
for i in range(4):
neighbors.append(self.plusOne(s, i))
neighbors.append(self.minusOne(s, i))
return neighbors

双向 BFS 还是遵循 BFS 算法框架的,但是有几个细节区别:

1、不再使用队列存储元素,而是改用 [哈希集合],方便快速判两个集合是否有交集。

2、调整了 return step 的位置。因为双向 BFS 中不再是简单地判断是否到达终点,而是判断两个集合是否有交集,所以要在计算出邻居节点时就进行判断。

3、还有一个优化点,每次都保持 q1 是元素数量较小的集合,这样可以一定程度减少搜索次数。

因为按照 BFS 的逻辑,队列(集合)中的元素越多,扩散邻居节点之后新的队列(集合)中的元素就越多;在双向 BFS 算法中,如果我们每次都选择一个较小的集合进行扩散,那么占用的空间增长速度就会慢一些,效率就会高一些。

不过话说回来,无论传统 BFS 还是双向 BFS,无论做不做优化,从 Big O 衡量标准来看,时间复杂度都是一样的,只能说双向 BFS 是一种进阶技巧,算法运行的速度会相对快一点,掌握不掌握其实都无所谓。

最关键的还是要把 BFS 通用框架记下来,并且做到熟练运用.