需求
有时候我们会开启很多线程(go中是协程)去做一件事件,然后希望主线程等待这些线程都完成后才结束,一个简单的想法是,我在主线程sleep一段时间,譬如3s钟,但是明显这样的做法不科学,因为这些任务很有可能在200ms内就都完成了。如果你用过Java的话,那你很快就会想到CountDownLatch
类,在Go中,也有类似的结构,就是本文要讨论的WaitGroup
。
使用
示例代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
| package main
import (
"fmt"
"sync"
"time"
)
func main() {
learnWaitGroup()
}
func learnWaitGroup() {
num := 10
wg := sync.WaitGroup{}
wg.Add(num)
for i := 0; i < num; i++ {
go func(idx int) {
fmt.Printf("%d Doing something...\n", idx)
time.Sleep(time.Second)
wg.Done()
}(i)
}
wg.Wait()
fmt.Println("All is done...")
}
|
WaitGroup
对象内部有一个计数器,最初从0开始,它有三个方法:Add()
, Done()
, Wait()
用来控制计数器的数量。Add(n)
把计数器设置为n ,Done()
每次把计数器-1 ,Wait()
会阻塞代码的运行,直到计数器地值减为0。
注意问题
WaitGroup
对象不是一个引用类型,所以在作为参数的时候,你应该要使用指针。在上面的示例提取一个任务函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
| package main
import (
"fmt"
"sync"
"time"
)
func main() {
learnWaitGroup()
}
func learnWaitGroup() {
num := 10
wg := sync.WaitGroup{}
wg.Add(num)
for i := 0; i < num; i++ {
go runTask(i, &wg)
}
wg.Wait()
fmt.Println("All is done...")
}
func runTask(idx int, wg *sync.WaitGroup) {
fmt.Printf("%d Doing something...\n", idx)
time.Sleep(time.Second)
wg.Done()
}
|
Java类比
Java中可以使用CountDownLatch
类实现这个功能,它暴露出三个方法:
1
2
3
4
5
| // 调用此方法的线程会被阻塞,直到 CountDownLatch 的 count 为 0
public void await() throws InterruptedException
// 会将 count 减 1,直至为 0
public void countDown()
|
countDown()
跟WaitGroup
的Done()
函数类似,我们还是很容易实现的
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
| public class Main {
public static void main(String[] args) {
try {
testCountDownLatch();
} catch (Exception ex) {
ex.printStackTrace();
}
}
static class TaskThread extends Thread {
CountDownLatch latch;
public TaskThread(CountDownLatch latch) {
this.latch = latch;
}
@Override
public void run() {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
System.out.println(getName() + " Task is Done");
latch.countDown();
}
}
}
public static void testCountDownLatch() throws InterruptedException {
final int threadNum = 10;
CountDownLatch latch = new CountDownLatch(threadNum);
for(int i = 0; i < threadNum; i++) {
TaskThread task = new TaskThread(latch);
task.start();
}
System.out.println("Task Start!");
latch.await();
System.out.println("All Task is Done!");
}
}
|