go并发之WaitGroup使用

需求

有时候我们会开启很多线程(go中是协程)去做一件事件,然后希望主线程等待这些线程都完成后才结束,一个简单的想法是,我在主线程sleep一段时间,譬如3s钟,但是明显这样的做法不科学,因为这些任务很有可能在200ms内就都完成了。如果你用过Java的话,那你很快就会想到CountDownLatch类,在Go中,也有类似的结构,就是本文要讨论的WaitGroup

使用

示例代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
package main

import (
"fmt"
"sync"
"time"
)

func main() {
learnWaitGroup()
}

func learnWaitGroup() {
num := 10
wg := sync.WaitGroup{}
wg.Add(num)

for i := 0; i < num; i++ {
go func(idx int) {
fmt.Printf("%d Doing something...\n", idx)
time.Sleep(time.Second)
wg.Done()
}(i)
}

wg.Wait()
fmt.Println("All is done...")
}
WaitGroup 对象内部有一个计数器,最初从0开始,它有三个方法:Add(), Done(), Wait() 用来控制计数器的数量。Add(n) 把计数器设置为n ,Done() 每次把计数器-1 ,Wait() 会阻塞代码的运行,直到计数器地值减为0。

注意问题

WaitGroup对象不是一个引用类型,所以在作为参数的时候,你应该要使用指针。在上面的示例提取一个任务函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
package main

import (
"fmt"
"sync"
"time"
)

func main() {
learnWaitGroup()
}

func learnWaitGroup() {
num := 10
wg := sync.WaitGroup{}
wg.Add(num)

for i := 0; i < num; i++ {
go runTask(i, &wg)
}

wg.Wait()
fmt.Println("All is done...")
}

func runTask(idx int, wg *sync.WaitGroup) {
fmt.Printf("%d Doing something...\n", idx)
time.Sleep(time.Second)
wg.Done()
}

Java类比

Java中可以使用CountDownLatch类实现这个功能,它暴露出三个方法:

1
2
3
4
5
// 调用此方法的线程会被阻塞,直到 CountDownLatch 的 count 为 0
public void await() throws InterruptedException

// 会将 count 减 1,直至为 0
public void countDown()
countDown()WaitGroupDone()函数类似,我们还是很容易实现的
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
public class Main {
public static void main(String[] args) {
try {
testCountDownLatch();
} catch (Exception ex) {
ex.printStackTrace();
}
}

static class TaskThread extends Thread {

CountDownLatch latch;

public TaskThread(CountDownLatch latch) {
this.latch = latch;
}

@Override
public void run() {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
System.out.println(getName() + " Task is Done");
latch.countDown();
}
}
}

public static void testCountDownLatch() throws InterruptedException {
final int threadNum = 10;
CountDownLatch latch = new CountDownLatch(threadNum);
for(int i = 0; i < threadNum; i++) {
TaskThread task = new TaskThread(latch);
task.start();
}

System.out.println("Task Start!");
latch.await();
System.out.println("All Task is Done!");
}
}