Spark ML库简析

讨论Spark上分布式机器学习库的实现。
Spark的机器学习库主要分为ml和mllib,其中ml较新,本文主要围绕ml来讲。
ml和mllib底层会用到Breeze库(类似于numpy的线性代数库)和BLAS(Basic Linear Algebra Subroutines,更基础的线性代数库),在这里也进行介绍。

LR回归的实现

总体实现

ml库的整个训练和预测过程可以用下面三行代码来概括

1
2
3
val classifier = new LogisticRegression()
val model:LogisticRegressionModel = classifier.fit(train)
val predictions = model.transform(test)

训练过程

LogisticRegression类的训练入口是fitfit首先定义在基类Predictor.scala中,它会首先做一些事情,比如转换labelCol的类型到Double,然后调用traintrain会返回一个Model,但在Predictor中并没有实现,而是交付给子类来实现。
LogisticRegression类的train是自己实现的。主要包括下面几个部分。

  1. 校验打印参数
    参数包含下面的一些列名

    1
    labelCol, weightCol, featuresCol, predictionCol, rawPredictionCol, probabilityCol

    这里的weightCol可以给样本指定权重,用来解决数据不平衡问题。详见balanceDataset的实现。
    还包括下面的一些训练用参数

    1
    regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept
  2. 计算summarizerlabelSummarizer
    可以得到一些基础信息,比如直方图、特征数量、特征的均值方差和数量,class数量。其中直方图(histogram),大概内容类似[0: 34.0, 1: 66.0],意味着分类1拥有34样本,分类2有66个样本。均值(mean)大概内容类似[0: -0.03327765069036007 ]
    根据这些信息还会判断任务类型,是不是多类的(设置isMultinomial

  3. 计算coefficientMatrixinterceptVectorobjectiveHistory
    这是主要过程,分别是返回系数矩阵、偏置值向量和loss之类的东西。
    返回值是(denseCoefficientMatrix.compressed,interceptVec.compressed,arrayBuilder.result())

  4. 计算summaryModelprobabilityColNamepredictionColName

模型表示

模型存储

首先Model会继承一个org.apache.spark.ml.util.MLWritable。这样就支持model.save()方法。另外,可以通过MLWriter来保存,对应model.write.overwrite().save()

预测过程

LogisticRegressionModeltransform方法实际上实现在ProbabilisticClassificationModel里面,最终会调用predictRaw计算rawPrediction列。predictRaw和下面提到的一些方法都定义在ClassificationModel里面,不过是虚的。

1
2
3
4
5
6
7
if ($(rawPredictionCol).nonEmpty) {
val predictRawUDF = udf { (features: Any) =>
predictRaw(features.asInstanceOf[FeaturesType])
}
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
numColsOutput += 1
}

然后再调用predictProbability或者raw2probability计算probability列,这个列是一个Vector,表示每一个label的probability,加起来是等于1的(见ProbabilisticClassificationModel里面的注释)。所有继承了ProbabilisticClassifier的分类器都会有在这一列。

1
2
3
4
5
6
7
8
9
10
11
12
if ($(probabilityCol).nonEmpty) {
val probUDF = if ($(rawPredictionCol).nonEmpty) {
udf(raw2probability _).apply(col($(rawPredictionCol)))
} else {
val probabilityUDF = udf { (features: Any) =>
predictProbability(features.asInstanceOf[FeaturesType])
}
probabilityUDF(col($(featuresCol)))
}
outputData = outputData.withColumn($(probabilityCol), probUDF)
numColsOutput += 1
}

然后再调用raw2prediction或者probability2prediction或者predict计算prediction列。

训练过程详解

Summary过程

首先使用MetadataUtils.getNumClasses来获得Class数量。

1
2
3
4
5
6
7
val numClasses = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
case Some(n: Int) =>
require(n >= histogram.length, s"Specified number of classes $n was " +
s"less than the number of unique labels ${histogram.length}.")
n
case None => histogram.length
}

MetadataUtils是一个私有对象,里面会计算得到labelSchema对应的Attribute,然后来判断有多少类。Attribute对象的实现比较复杂,单独提到后面讲。

1
2
3
4
5
6
7
def getNumClasses(labelSchema: StructField): Option[Int] = {
Attribute.fromStructField(labelSchema) match {
case binAttr: BinaryAttribute => Some(2)
case nomAttr: NominalAttribute => nomAttr.getNumValues
case _: NumericAttribute | UnresolvedAttribute => None
}
}

接着判断是否是多类分类任务,主要是根据family的值来判断,再根据numClasses的值来校验。family可以取下面的值:

  1. auto(默认值)
    自动根据class数量选择,如果numClasses == 1 || numClasses == 2,则是binomial,否则是multinomial
  2. binomial
    基于pivoting的LR
  3. multinomial
    softmax的LR,不基于pivoting
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    @Since("2.1.0")
    final val family: Param[String] = new Param(this, "family",
    "The name of family which is a description of the label distribution to be used in the " +
    s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
    (value: String) => supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT)))

    /** @group getParam */
    @Since("2.1.0")
    def getFamily: String = $(family)

    val isMultinomial = getFamily.toLowerCase(Locale.ROOT) match {
    case "binomial" =>
    require(numClasses == 1 || numClasses == 2, s"...")
    false
    case "multinomial" => true
    case "auto" => numClasses > 2
    case other => throw new IllegalArgumentException(s"Unsupported family: $other")
    }
    val numCoefficientSets = if (isMultinomial) numClasses else 1

优化过程

主要是创建了损失函数costFun,并对于这个损失函数使用优化器进行优化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
val regParamL1 = $(elasticNetParam) * $(regParam)
val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
val getAggregatorFunc = new LogisticAggregator(bcFeaturesStd, numClasses, $(fitIntercept),
multinomial = isMultinomial)(_)
val regularization = if (regParamL2 != 0.0) {
val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures * numCoefficientSets
Some(new L2Regularization(regParamL2, shouldApply,
if ($(standardization)) None else Some(getFeaturesStd)))
} else {
None
}
val costFun = new RDDLossFunction(instances, getAggregatorFunc, regularization,
$(aggregationDepth))
val states = optimizer.iterations(new CachedDiffFunction(costFun),
new BDV[Double](initialCoefWithInterceptMatrix.toArray))

首先来看损失函数RDDLossFunction的类定义。
首先ClassTag用来实现保障类型擦除后类型安全的功能

1
2
3
4
5
6
7
8
private[ml] class RDDLossFunction[
T: ClassTag,
Agg <: DifferentiableLossAggregator[T, Agg]: ClassTag](
instances: RDD[T],
getAggregator: (Broadcast[Vector] => Agg),
regularization: Option[DifferentiableRegularization[Vector]],
aggregationDepth: Int = 2)
extends DiffFunction[BDV[Double]] {
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
  override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val bcCoefficients = instances.context.broadcast(Vectors.fromBreeze(coefficients))
val thisAgg = getAggregator(bcCoefficients)
val seqOp = (agg: Agg, x: T) => agg.add(x)
val combOp = (agg1: Agg, agg2: Agg) => agg1.merge(agg2)
val newAgg = instances.treeAggregate(thisAgg)(seqOp, combOp, aggregationDepth)
val gradient = newAgg.gradient
val regLoss = regularization.map { regFun =>
val (regLoss, regGradient) = regFun.calculate(Vectors.fromBreeze(coefficients))
BLAS.axpy(1.0, regGradient, gradient)
regLoss
}.getOrElse(0.0)
bcCoefficients.destroy()
(newAgg.loss + regLoss, gradient.asBreeze.toDenseVector)
}
}

优化器

LBGFS

牛顿法的特点是收敛很快,但是运用牛顿法需要计算二阶偏导数,而且目标函数的Hesse矩阵可能非正定。对此,可以使用拟牛顿法,也就是用不包含二阶导数的矩阵近似牛顿法中的Hesse矩阵的逆矩阵。 是一种一阶方法。由于构造近似矩阵的方法不同,因而出现不同的拟牛顿法,BFGS是其中一种。因为BFGS仍然存在较大的内存占用。因此有了LBFGS算法,其中L指的是Limited的意思。
下面,简单查看一下LBFGS的代码,可以发现,它继承了FirstOrderMinimizer优化器,印证了LBFGS是一个一阶方法。

1
2
3
class LBFGS[T](convergenceCheck: ConvergenceCheck[T], m: Int)(implicit space: MutableInnerProductModule[T, Double])
extends FirstOrderMinimizer[T, DiffFunction[T]](convergenceCheck)
with SerializableLogging {

OWLQN

OWL-QN(Orthant-Wise Limited-Memory Quasi-Newton)算法,该算法是基于L-BFGS算法的可用于求解L1正则的算法。

Loss函数

treeAggregate

seqOp操作会聚合各分区中的元素,然后combOp操作把所有分区的聚合结果再次聚合,两个操作的初始值都是zeroValuetreeAggregateaggregate的区别是,aggregate在每个分区处理完之后就直接交给Driver合并了,但treeAggregate会在Executor直接树形地Aggregate。

1
2
3
4
5
6
7
8
9
10
val sc = spark.sparkContext
val rdd = sc.parallelize(1 to 100000).repartition(6)
val start = System.currentTimeMillis();
val reduceResult = rdd.reduce{(x, y) => x + y}
val endReduce = System.currentTimeMillis();
val aggregateResult = rdd.aggregate(0)((x, y) => x + y, (x, y) => x + y)
val endAgg = System.currentTimeMillis();
val treeAggregateResult = rdd.treeAggregate(0)((x, y) => x + y, (x, y) => x + y)
val endTreeAgg = System.currentTimeMillis();
println(s"reduceResult ${reduceResult} ${(endReduce-start).toString} aggregateResult ${aggregateResult} ${(endAgg-endReduce).toString} treeAggregateResult ${treeAggregateResult} ${(endTreeAgg-endAgg).toString}")

在六台机器上跑了,耗时如下

1
reduceResult 705082704 7073 aggregateResult 705082704 831 treeAggregateResult 705082704 755

BLAS库

BLAS库的选择

Spark的BLAS库实现在mllib/linalg下面。其底层借助于com.github.fommil.netlib,而netelib有两个BLAS库的实现,其中f2jBLAS是从Fortran代码翻译到Java里面的,nativeBLAS是用系统原生的。两个库的使用规则如下所示,至于为什么,参考爆栈网的解释的意思是f2j的实现在Level较低的情况的性能要更好一点。

1
2
3
4
5
6
7
private[mllib] def getBLAS(vectorSize: Int): NetlibBLAS = {
if (vectorSize < nativeL1Threshold) {
f2jBLAS
} else {
nativeBLAS
}
}

这里的Level,指的是BLAS函数分为三个Level,Level越高,计算越复杂。Level1是向量之间的操作,Level2是 向量和矩阵之间的操作,Level3是矩阵和矩阵之间的操作。

axpy和Vector的实现

axpy指的是线代里面的$y = a * x$这样的操作。
axpy里面根据是否稀疏,进行讨论。对于DenseVector会直接调用daxpy。这里daxpy的d不是dense,而是指double,对应的saxpy指的是单精度,还有caxpy代表复数单精度和zaxpy代表复数双精度。
对于SparseVector则要首先进行解码为DenseVectorSparseVector 用两个数组,一个记录原始向量中非零元素的值,另一个记录非零元素对应到原始向量的位置。这有点类似于一维的COO表示形式。

但不是所有的计算都需要转换为DenseVector的,例如向量的dot就不需要。观摩下面的代码发现,稀疏向量x中为0的列就不需要乘了。事实上有人进行了比较,Sparse算得快,但是构造慢。其中nnz表示number of nonzero entries。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
private def dot(x: SparseVector, y: DenseVector): Double = {
val xValues = x.values
val xIndices = x.indices
val yValues = y.values
val nnz = xIndices.length

var sum = 0.0
var k = 0
while (k < nnz) {
sum += xValues(k) * yValues(xIndices(k))
k += 1
}
sum
}

gemm和Matrix的实现

gemm,即通用矩阵乘(GEMM,General Matrix Multiplication),其公式为$ C = alpha * A * B + beta * C $,其定义如下。其中A大小$m*k$,B大小$k*n$,C大小$m*n$,且C.isTransposed必须是false。根据A是否为Dense,可以分为2个版本。

1
2
3
def gemm(alpha: Double,
A: Matrix, B: DenseMatrix,
beta: Double, C: DenseMatrix): Unit = {

这里的Matrix同样分为Dense和Sparse。isTransposed表示有没有进行过转置。我理解用这个的原因是为了实现lazy的转置。
DenseMatrix的四要素是numRowsnumColsvaluesisTransposedvalues是一维数组,因为我们可以通过numRowsnumCols找到定位。
SparseMatrix是六要素,多了colPtrsrowIndicesvalues是一维数组,按列存储,实际上就是Compressed Sparse Columns(CSC)格式存储,如果isTransposed是true,那么就按照Compressed Sparse Row(CSR)存储。以isTransposed为false的情况为例,colPtrs表示每一列的起始在values中的位置,rowIndices表示在values中的每一个元素在是这一列中的第几行。

下面展示一下CSC的表示方法,例如矩阵

1
2
3
4
5
6
1.0 0.0 4.0
0.0 3.0 5.0
2.0 0.0 6.0
values=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
rowIndices=[0, 2, 1, 0, 1, 2]
colPointers=[0, 2, 3, 6]

下面看gemm的实现,首先要对A是不是Dense进行讨论,对Dense而言是平凡的,调用dgemm就行了。对于Spark而言,需要根据A和B有没有transpose过的四种情况进行讨论。下面我们来看A是Sparse,B是Dense对应的gemm重载下的四种情况。

首先来看A和B都没有转置的情况,我们现在是需要用A的对应行乘以B的对应列,可以看做左边的矩阵的每一列,分别去乘上右边矩阵对应的行,得到n个矩阵,然后我们将它们对应位置加起来就能得到C。这种方式可以参考我的MIT线性代数的学习报告

首先先要放缩一下C,这样我们的列乘行的操作可以直接在C.values上面累加,节约空间。

1
2
3
if (beta != 1.0) {
getBLAS(C.values.length).dscal(C.values.length, beta, C.values, 1)
}

下面是主体,包含两个循环。外层循环变量colCounterForB每轮递增,用来枚举到B的所有列nB,内层循环变量colCounterForA用来枚举A的所有列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
while (colCounterForB < nB) {
var colCounterForA = 0 // The column of A to multiply with the row of B
// kB=B行数=A列数kA
// 因为按列存储,BStart表示第colCounterForB列开始的位置。
// 因为B是Dense的,所以可以直接乘
val Bstart = colCounterForB * kB
// mA = A行数
// Cstart表示第colCounterForB列开始的位置,是对C而言的,且C有mA行
val Cstart = colCounterForB * mA

while (colCounterForA < kA) {
// 计算C的第colCounterForB列上所有行的部分值
var i = AcolPtrs(colCounterForA)
val indEnd = AcolPtrs(colCounterForA + 1)
// indEnd-i表示第colCounterForA列有多少个元素
// Avals(i)于对CSC格式的A的行进行遍历
// BVal是B的第colCounterForB列第colCounterForA行
// 遍历A的colCounterForA列上的每个行元素
// 可以看到,Cstart在内层循环是不变的,所以对Cstart对应的第colCounterForB列的操作会重复kA次
val Bval = Bvals(Bstart + colCounterForA) * alpha
while (i < indEnd) {
// Cvals是c.values
// 对于A的colCounterForA上的每个行元素,加到C的对应位置上面,注意C是Dense的。
// 这Avals(i)相当于对CSC格式的行进行遍历
Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
}
colCounterForB += 1
}

现在来看A转置B没有转置的情况,。我们现在是需要用A的对应行乘以B的对应列。因为A被转置了,所以问题实际上变简单了,因为A转置后的行,实际上就是A转置前的列,在物理上是连续存储的。因为我们恰恰需要的是A转置后的行,所以实际可以当做转置前的列来处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
// colCounterForB在每个while后递增,枚举到B的所有列nB
while (colCounterForB < nB) {
var rowCounterForA = 0
// mA=A行数,因为按列存储,表示第colCounterForB列开始的位置
val Cstart = colCounterForB * mA
// kA=A列数=B行数kB,实际上是B的第colCounterForB列开始的位置
val Bstart = colCounterForB * kA
// rowCounterForA枚举A的所有行
while (rowCounterForA < mA) {
// 目标:设置C的第colCounterForB列第rowCounterForA行的元素
var i = AcolPtrs(rowCounterForA)
val indEnd = AcolPtrs(rowCounterForA + 1)
// i和indEnd表示A的第rowCounterForA行有多少元素,
// 稀疏矩阵A的其他行是0,所以就可以不管了
var sum = 0.0
while (i < indEnd) {
// 因为B是Dense的,这里要算一下实际的坐标,
// 我们要知道B的是第colCounterForB列,对应到就是A的第colCounterForB行,
// 所以查询ArowIndices表就能知道对应的行便宜
sum += Avals(i) * Bvals(Bstart + ArowIndices(i))
i += 1
}

val Cindex = Cstart + rowCounterForA
Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha
rowCounterForA += 1
}
colCounterForB += 1
}

gemm就有gemv,也就是C变成Vector了。除此之外,还有对于各种不同m的特化实现。例如SYMV是对称矩阵乘法,TRMV是三角矩阵乘法。com.github.fommil.netlib.BLAS里面有dgbmv带状矩阵,dsbmvdtrmv等。此外,这些在Breeze库里面也有提供。

启发

  1. transpose可以做到lazy,但得是对二维特化的
  2. 矩阵可以Sparse存储
  3. 矩阵和多维数组最好单独实现

Breeze库

breeze库是一个类似numpy的线性代数库
其中BDM是一个Matrix,BDV是一个Vector,其实都是别名import breeze.linalg.{DenseMatrix => BDM}

其他代码

Attribute对象

Attribute继承自AttributeFactory这个object(原来伴生对象也能继承),主要是从StructField生成Attribute对象。StructField是构成StructType的成员。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def fromStructField(field: StructField): Attribute = decodeStructField(field, false)

// 下面这个函数主要就是从filed.metadata中解析得到Attribute对象
private[ml] def decodeStructField(field: StructField, preserveName: Boolean): Attribute = {
require(field.dataType.isInstanceOf[NumericType])
val metadata = field.metadata
val mlAttr = AttributeKeys.ML_ATTR // 值为ML_ATTR
if (metadata.contains(mlAttr)) {
// 是使用metadata里面的名字,还是field.name的名字
val attr = fromMetadata(metadata.getMetadata(mlAttr))
if (preserveName) {
attr
} else {
// Copy with a new name
attr.withName(field.name)
}
} else {
UnresolvedAttribute
}
}

private[attribute] override def fromMetadata(metadata: Metadata): Attribute = {
import org.apache.spark.ml.attribute.AttributeKeys._
val attrType = if (metadata.contains(TYPE)) {
metadata.getString(TYPE)
} else {
AttributeType.Numeric.name
}
// 首先得到一个AttributeFactory,然后结合metadata得到一个Attribute
getFactory(attrType).fromMetadata(metadata)
}

private def getFactory(attrType: String): AttributeFactory = {
if (attrType == AttributeType.Numeric.name) {
NumericAttribute
} else if (attrType == AttributeType.Nominal.name) {
NominalAttribute
} else if (attrType == AttributeType.Binary.name) {
BinaryAttribute
} else {
throw new IllegalArgumentException(s"Cannot recognize type $attrType.")
}
}

Reference