Scala For 推导式介绍及最佳实践

摘要:本文详细介绍 ScalaFor 推导式的内部原理

基础知识

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
// src/main/scala/progscala2/forcomps/RemoveBlanks.scala
package progscala2.forcomps

object RemoveBlanks {

/**
* Remove blank lines from the specified input file.
*/
def apply(path: String, compressWhiteSpace: Boolean = false): Seq[String] =
for {
// 这是一个生成器
line <- scala.io.Source.fromFile(path).getLines.toSeq // <1>
if line.matches("""^\s*$""") == false // <2>
line2 = if (compressWhiteSpace) line replaceAll ("\\s+", " ") // <3>
else line
} yield line2 // <4>

/**
* Remove blank lines from the specified input files and echo the remaining
* lines to standard output, one after the other.
* @param args list of file paths. Prefix each with an optional "-" to
* compress remaining whitespace in the file.
*/
def main(args: Array[String]) = for {
path2 <- args // <5>
(compress, path) = if (path2 startsWith "-") (true, path2.substring(1))
else (false, path2) // <6>
line <- apply(path, compress)
} println(line) // <7>
}

<1> 由于 for 推导式无法返回 Iterator 对象, for 推导式的返回类型由初始的生成器所决定,因此我们必须将其转化成一个序列。

<2> 过滤空行,也就是说这句话是使用上一句的输出结果

<3> 是一个 guard, 通过某个条件对进过进行处理,并赋值给另外的临时变量

<4> 输出最后的结果

For 推导式原理

For 推导式实际上是语法糖,是组合了 map withFilter flatMap 的基础操作,下面给出一个等价的表示方法:

案例1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
val lara = Person("Lara", isMale = false)
val bob = Person("Bob", isMale = true)
val julie = Person("Julie", false, lara, bob)
val persons = List(lara, bob, julie)

var temp1 = persons filter (p => !p.isMale)
var temp11 = temp1 flatMap (k => k.name)
var temp12 = temp1 flatMap (k => k.children)
var temp2 = temp1 flatMap (p => p.children map (c => (p.name, c.name)))

println(temp1)
println(temp11)
println(temp12)
println(temp2)

val temp3 = persons withFilter (p => !p.isMale) flatMap (p => p.children map (c => (p.name, c.name)))
println(temp3)

val temp4 = for {
p <- persons
if !p.isMale
c <- p.children
} yield (p.name, c.name)
println(temp4)


val temp5 = for {
p <- persons // a generator
n = p.name // a definition
if n startsWith "To" // a filter
} yield n
println(temp5)

val temp6 = for (
x <- List(1, 2);
y <- List("one", "two");
) yield (x, y)
println(temp6)

val temp6Equivalence = List(1, 2) flatMap (p => { List("one", "two") map ( k => (p,k))})

println(temp6Equivalence)

val temp7 = for (
x <- List(1, 2);
y <- List("one", "two");
z <- List("a", "b")
) yield (x, y, z)
println(temp7)

val temp7Equivalence = List(1, 2) flatMap (p => { List("one", "two") flatMap ( k => { List("a", "b") map (l => (p,k,l))})})

println(temp7Equivalence)

案例2

没有 yield

1
2
3
4
5
6
7
8
9
10
val states = List("Alabama", "Alaska", "Virginia", "Wyoming")
for {
s <- states
} println(s)
// 结果值:
// Alabama
// Alaska
// Virginia
// Wyoming
states foreach println

含有 yield

1
2
3
4
5
6
7
val states = List("Alabama", "Alaska", "Virginia", "Wyoming")
for {
s <- states
} yield s.toUpperCase
// 结果值: List(ALABAMA, ALASKA, VIRGINIA, WYOMING)
states map (_.toUpperCase) // 有点类似于 Java Lambda 的 ::Method 操作,如果在 `map` 操作中只执行了一个方法。
// 结果值: List(ALABAMA, ALASKA, VIRGINIA, WYOMING)

多个生成器

1
2
3
4
5
6
7
8
9
10
11
12
val states = List("Alabama", "Alaska", "Virginia", "Wyoming")
val temp8 = for {
s <- states
c <- s
} yield s"$c-${c.toUpper}"
// 结果值: List("A-A", "l-L", "a-A", "b-B", ...)
val temp8Equivalence = states flatMap (_.toSeq map (c => s"$c-${c.toUpper}"))
val temp8Equivalence2 = states flatMap (p => p map (c => s"$c-${c.toUpper}"))
// 结果值: List("A-A", "l-L", "a-A", "b-B", ...)
println(temp8)
println(temp8Equivalence)
println(temp8Equivalence2)

加入保护式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
val temp9 = for {
s <- states
c <- s
if c.isLower
} yield s"$c-${c.toUpper}"
// 结果值: List("l-L", "a-A", "b-B", ...)
val temp9Equivalence = states flatMap (_.toSeq withFilter (_.isLower) map (c => s"$c-${c.toUpper}"))
val temp9Equivalence2 = states flatMap (p => p withFilter (k => k.isLower) map (c => s"$c-${c.toUpper}"))
val temp9Equivalence3 = states flatMap (p => p filter (k => k.isLower) map (c => s"$c-${c.toUpper}"))
println(temp9)
println(temp9Equivalence)
println(temp9Equivalence2)
println(temp9Equivalence3)
// 结果值: List("l-L", "a-A", "b-B", ...)

加入定义

1
2
3
4
5
6
7
8
9
10
11
12
13
val states = List("Alabama", "Alaska", "Virginia", "Wyoming")
for {
s <- states
c <- s
if c.isLower
c2 = s"$c-${c.toUpper} "
} yield c2
// 结果值: List("l-L", "a-A", "b-B", ...)
states flatMap (_.toSeq withFilter (_.isLower) map { c =>
val c2 = s"$c-${c.toUpper} "
c2
})
// 结果值: List("l-L", "a-A", "b-B", ...)

总结内部机制

结合 《Scala 程序设计第2版》 7.2章节和上面的例子,我们总结 for 推导式的内部机制如下:

  1. withFilter 效率比 filter 高, 因为它不会生成一次中间的临时的容器,而是和后续操作结合使用。而且 withFilter 会限制传递给后续组合器的元素类型域。
  2. yield 实际是一次 map 操作,具体 map 的操作方法是根据 yield 的表达式,比如案例1中是yield 元组,而案例2中是 yield 一个 toUpperCase 操作。
  3. 多个生成器,除了最后一个,其他的生成器都会转化为 flatMap,最后一个生成器对应这 map 调用。
  4. toSeq 可以用来简化针对 Seq 的 map 后面的操作,实际上就是 p => p do something
  5. 定义变量将会在 map 表达式中定义 (flatmap 也可以吗?)

深层理解

在像 pat <- expr 这样的生成器表达式中,pat 实际上是一个模式表达式( pattern expression),例如: (x,y) <- List((1,2),(3,4))。 Scala 会以类似的方式对值定义语句 pat2 = expr 进行处理,该语句也会被视为某一模式。

1
2
// pat <- expr
pat <- expr.withFilter { case pat => true; case _ => false }

Scala 在转化 for 推导式时,要做的第一件事便是将 pat <- expr 语句转化为上述语句,然后, Scala 将重复执行下列转化规则,直到所有的推导表达式都被替换掉。值得一提的是,某些转化会生成新的 for 推导式,而后续的迭代则会负责对这些推导式进行转化。

如果 for 推导式中包含了一个生成器和一个 yield 表达式,那么该表达式将被转化为下列
语句:

1
2
// for ( pat <- expr1 ) yield expr2
expr map { case pat => expr2 }

如果 for 循环中未使用 yield 语句,但执行的代码具有副作用,那么该语句将被转化为:

1
2
// for ( pat <- expr1 ) expr2
expr foreach { case pat => expr2 }

包含多个生成器(同时包含 yield 表达式)的 for 推导式将被转化成下列语句:

1
2
// for ( pat1 <- expr1; pat2 <- expr2; ... ) yield exprN
expr1 flatMap { case pat1 => for (pat2 <- expr2 ...) yield exprN }

请留意,嵌套的生成器会被转化成嵌套的 for 推导式。这些嵌套的 for 推导式会在下一次执行转化规则时被转化成方法调用。上面示例中 (…) 代表了省略的表达式,这些表达式可能是其他的生成器,也可能是值定义或保护式( guard)。

包含多个生成器的 for 循环将被翻译成下列语句:

1
2
// for ( pat1 <- expr1; pat2 <- expr2; ... ) exprN
expr1 foreach { case pat1 => for (pat2 <- expr2 ...) yield exprN }

我们之前所见的示例中包含保护式( guard)表达式,该表达式被编写在单独的一行中。事实上, guard 以及上一行中的代码可以编写在一行中,例如: pat1 <- expr1 if guard。

后面跟着保护式的生成器会被翻译成下列语句:

1
2
// pat1 <- expr1 if guard
pat1 <- expr1 withFilter ((arg1, arg2, ...) => guard)

此处,变量 argN 代表了传递给 withFilter 方法的参数。对于大多数的容器而言,传入的方法中只含有一个参数

生成器后尾随一个值定义

如果生成器后面尾随着一个值定义,那么转化这个生成器的复杂度会令人惊奇。如下所示:

1
2
3
4
5
6
7
// pat1 <- expr1; pat2 = expr2
(pat1, pat2) <- for { // ➊
x1 @ pat1 <- expr1 // ➋
} yield {
val x2 @ pat2 = expr2 // ➌
(x1, x2) // ➍
}

➊ for 推导式将返回包含两个模式的 pair 对象。
➋ x1 @ pat1 语句会将整个表达式中 pat1 所匹配的值赋给变量 x1,该值可能包含另一个
变量的某一部分。假如 pat1 是一个不可变变量名, x1 和 pat1 的赋值将会是冗余的。
➌ 将 pat2 值赋给 x2。
➍ 返回元组。

下面的 REPL 会话中包含了 x @ pat = expr 语句的对应示例:

1
2
3
4
scala> val z @ (x, y) = (1 -> 2)
z: (Int, Int) = (1,2)
x: Int = 1
y: Int = 2

变量 z 的值为元组 (1,2),而变量 x 和变量 y 则对应了元组中各个组成部分的值。

一个具体案例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
val map = Map("one" -> 1, "two" -> 2)
val list1 = for {
(key, value) <- map // 本行和下一行将会被翻译成什么语句呢?
i10 = value + 10
} yield (i10)
// 执行结果值: list1: scala.collection.immutable.Iterable[Int] = List(11, 12)

// 翻译后的语句:
val list2 = for {
(i, i10) <- for {
x1 @ (key, value) <- map
} yield {
val x2 @ i10 = value + 10
(x1, x2)
}
} yield (i10)
// 执行结果: list2: scala.collection.immutable.Iterable[Int] = List(11, 12)

书中给出的解释就到这一步,这实际上还是一个嵌套的 for 表达式,下面我们进一步把它转化为 map。

首先我们需要知道如何把 map 中的元素映射为元组,很简答,只需要使用 val (x,y) = p,其中 p 是 map 的元素,x和y就是元组值,后面就可以使用 x 和 y 进一步计算了。

我们这里分两步转化,首先转化内部的 for 推导式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
val list2Inter = for {
x1@(key, value) <- map
} yield {
val x2@i10 = value + 10
(x1, x2)
}

val z@(x, y) = 1 -> 2
println("z: " + z)


val list2InterEqual = map map { p => {
val (x, y) = p
(x, y) -> (y + 10)
}
}

val list2InterEqual2 = map map { p => {
val (x, y) = p
((x, y), y + 10)
}
}
println(list2Inter)
println(list2InterEqual)
println(list2InterEqual2)

我们可以看到,内部的推导式 yield 了一个元组,所以还是返回了一个 map,所以我们的 map 的返回值也应是一个map 或元组 (代码中有两个等价表达式)。

在内部的 for 推导式的基础上,我们进一步映射

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
val list2Equal = map map { p => {
val (x, y) = p
(x, y) -> (y + 10)
}
} map { k => {
val (m, n) = k
n
}
}

val list2Equal2 = map map { p => {
val (_, y) = p
y + 10
}
}

第一个是纯粹使用上一步的的内部推导式的转换结果再次进行 map 得到的。外层的 for 推导式只是将元组的第二个值返回。所以我们第二个表达式简化了过程,直接在内部的 map 就只返回元组的第二个值。当然如果第二层 for 推导式更加复杂,就仍然需要使用第一个表达式来等价翻译,而且对于 scala 本身应该也是使用第一个表达式进行翻译。

For 的应用容器

一般我们都是使用明显的容器,比如 List, Array 以及 Map,实际上只要容器支持 foreachmapflatwithFilter 的操作都可以使用 for 推导式。比如 OptionEitherTry 等等。