讨论Spark上分布式机器学习库的实现。
Spark的机器学习库主要分为ml和mllib,其中ml较新,本文主要围绕ml来讲。
ml和mllib底层会用到Breeze库(类似于numpy的线性代数库)和BLAS(Basic Linear Algebra Subroutines,更基础的线性代数库),在这里也进行介绍。
LR回归的实现
总体实现
ml库的整个训练和预测过程可以用下面三行代码来概括
1 | val classifier = new LogisticRegression() |
训练过程
LogisticRegression
类的训练入口是fit
。fit
首先定义在基类Predictor.scala中,它会首先做一些事情,比如转换labelCol
的类型到Double,然后调用train
。train
会返回一个Model,但在Predictor
中并没有实现,而是交付给子类来实现。LogisticRegression
类的train
是自己实现的。主要包括下面几个部分。
校验打印参数
参数包含下面的一些列名1
labelCol, weightCol, featuresCol, predictionCol, rawPredictionCol, probabilityCol
这里的
weightCol
可以给样本指定权重,用来解决数据不平衡问题。详见balanceDataset
的实现。
还包括下面的一些训练用参数1
regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept
计算
summarizer
和labelSummarizer
可以得到一些基础信息,比如直方图、特征数量、特征的均值方差和数量,class数量。其中直方图(histogram),大概内容类似[0: 34.0, 1: 66.0]
,意味着分类1拥有34样本,分类2有66个样本。均值(mean)大概内容类似[0: -0.03327765069036007 ]
。
根据这些信息还会判断任务类型,是不是多类的(设置isMultinomial
)计算
coefficientMatrix
、interceptVector
和objectiveHistory
这是主要过程,分别是返回系数矩阵、偏置值向量和loss之类的东西。
返回值是(denseCoefficientMatrix.compressed,interceptVec.compressed,arrayBuilder.result())
计算
summaryModel
、probabilityColName
、predictionColName
模型表示
模型存储
首先Model会继承一个org.apache.spark.ml.util.MLWritable
。这样就支持model.save()
方法。另外,可以通过MLWriter
来保存,对应model.write.overwrite().save()
预测过程
LogisticRegressionModel
的transform
方法实际上实现在ProbabilisticClassificationModel
里面,最终会调用predictRaw
计算rawPrediction列。predictRaw
和下面提到的一些方法都定义在ClassificationModel
里面,不过是虚的。
1 | if ($(rawPredictionCol).nonEmpty) { |
然后再调用predictProbability
或者raw2probability
计算probability列,这个列是一个Vector,表示每一个label的probability,加起来是等于1的(见ProbabilisticClassificationModel
里面的注释)。所有继承了ProbabilisticClassifier
的分类器都会有在这一列。
1 | if ($(probabilityCol).nonEmpty) { |
然后再调用raw2prediction
或者probability2prediction
或者predict
计算prediction列。
训练过程详解
Summary过程
首先使用MetadataUtils.getNumClasses
来获得Class数量。
1 | val numClasses = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { |
MetadataUtils
是一个私有对象,里面会计算得到labelSchema
对应的Attribute,然后来判断有多少类。Attribute对象的实现比较复杂,单独提到后面讲。
1 | def getNumClasses(labelSchema: StructField): Option[Int] = { |
接着判断是否是多类分类任务,主要是根据family
的值来判断,再根据numClasses
的值来校验。family
可以取下面的值:
auto
(默认值)
自动根据class数量选择,如果numClasses == 1 || numClasses == 2
,则是binomial
,否则是multinomial
binomial
基于pivoting的LRmultinomial
softmax的LR,不基于pivoting1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19"2.1.0") (
final val family: Param[String] = new Param(this, "family",
"The name of family which is a description of the label distribution to be used in the " +
s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
(value: String) => supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT)))
/** @group getParam */
"2.1.0") (
def getFamily: String = $(family)
val isMultinomial = getFamily.toLowerCase(Locale.ROOT) match {
case "binomial" =>
require(numClasses == 1 || numClasses == 2, s"...")
false
case "multinomial" => true
case "auto" => numClasses > 2
case other => throw new IllegalArgumentException(s"Unsupported family: $other")
}
val numCoefficientSets = if (isMultinomial) numClasses else 1
优化过程
主要是创建了损失函数costFun
,并对于这个损失函数使用优化器进行优化。
1 | val regParamL1 = $(elasticNetParam) * $(regParam) |
首先来看损失函数RDDLossFunction
的类定义。
首先ClassTag
用来实现保障类型擦除后类型安全的功能。
1 | private[ml] class RDDLossFunction[ |
1 | override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { |
优化器
LBGFS
牛顿法的特点是收敛很快,但是运用牛顿法需要计算二阶偏导数,而且目标函数的Hesse矩阵可能非正定。对此,可以使用拟牛顿法,也就是用不包含二阶导数的矩阵近似牛顿法中的Hesse矩阵的逆矩阵。 是一种一阶方法。由于构造近似矩阵的方法不同,因而出现不同的拟牛顿法,BFGS是其中一种。因为BFGS仍然存在较大的内存占用。因此有了LBFGS算法,其中L指的是Limited的意思。
下面,简单查看一下LBFGS的代码,可以发现,它继承了FirstOrderMinimizer
优化器,印证了LBFGS是一个一阶方法。
1 | class LBFGS[T](convergenceCheck: ConvergenceCheck[T], m: Int)(implicit space: MutableInnerProductModule[T, Double]) |
OWLQN
OWL-QN(Orthant-Wise Limited-Memory Quasi-Newton)算法,该算法是基于L-BFGS算法的可用于求解L1正则的算法。
Loss函数
treeAggregate
seqOp
操作会聚合各分区中的元素,然后combOp
操作把所有分区的聚合结果再次聚合,两个操作的初始值都是zeroValue
。treeAggregate
和aggregate
的区别是,aggregate
在每个分区处理完之后就直接交给Driver合并了,但treeAggregate
会在Executor直接树形地Aggregate。
1 | val sc = spark.sparkContext |
在六台机器上跑了,耗时如下
1 | reduceResult 705082704 7073 aggregateResult 705082704 831 treeAggregateResult 705082704 755 |
BLAS库
BLAS库的选择
Spark的BLAS库实现在mllib/linalg下面。其底层借助于com.github.fommil.netlib
,而netelib
有两个BLAS库的实现,其中f2jBLAS是从Fortran代码翻译到Java里面的,nativeBLAS是用系统原生的。两个库的使用规则如下所示,至于为什么,参考爆栈网的解释的意思是f2j的实现在Level较低的情况的性能要更好一点。
1 | private[mllib] def getBLAS(vectorSize: Int): NetlibBLAS = { |
这里的Level,指的是BLAS函数分为三个Level,Level越高,计算越复杂。Level1是向量之间的操作,Level2是 向量和矩阵之间的操作,Level3是矩阵和矩阵之间的操作。
axpy和Vector的实现
axpy指的是线代里面的$y = a * x$这样的操作。
axpy里面根据是否稀疏,进行讨论。对于DenseVector
会直接调用daxpy
。这里daxpy
的d不是dense,而是指double,对应的saxpy
指的是单精度,还有caxpy
代表复数单精度和zaxpy
代表复数双精度。
对于SparseVector
则要首先进行解码为DenseVector
。SparseVector
用两个数组,一个记录原始向量中非零元素的值,另一个记录非零元素对应到原始向量的位置。这有点类似于一维的COO表示形式。
但不是所有的计算都需要转换为DenseVector
的,例如向量的dot就不需要。观摩下面的代码发现,稀疏向量x中为0的列就不需要乘了。事实上有人进行了比较,Sparse算得快,但是构造慢。其中nnz表示number of nonzero entries。
1 | private def dot(x: SparseVector, y: DenseVector): Double = { |
gemm和Matrix的实现
gemm,即通用矩阵乘(GEMM,General Matrix Multiplication),其公式为$ C = alpha * A * B + beta * C $,其定义如下。其中A大小$m*k$,B大小$k*n$,C大小$m*n$,且C.isTransposed
必须是false。根据A是否为Dense,可以分为2个版本。
1 | def gemm(alpha: Double, |
这里的Matrix同样分为Dense和Sparse。isTransposed
表示有没有进行过转置。我理解用这个的原因是为了实现lazy的转置。DenseMatrix
的四要素是numRows
、numCols
、 values
和isTransposed
,values
是一维数组,因为我们可以通过numRows
和numCols
找到定位。SparseMatrix
是六要素,多了colPtrs
和rowIndices
,values
是一维数组,按列存储,实际上就是Compressed Sparse Columns(CSC)格式存储,如果isTransposed
是true,那么就按照Compressed Sparse Row(CSR)存储。以isTransposed
为false的情况为例,colPtrs
表示每一列的起始在values
中的位置,rowIndices
表示在values
中的每一个元素在是这一列中的第几行。
下面展示一下CSC的表示方法,例如矩阵
1 | 1.0 0.0 4.0 |
下面看gemm
的实现,首先要对A是不是Dense进行讨论,对Dense而言是平凡的,调用dgemm
就行了。对于Spark而言,需要根据A和B有没有transpose过的四种情况进行讨论。下面我们来看A是Sparse,B是Dense对应的gemm
重载下的四种情况。
首先来看A和B都没有转置的情况,我们现在是需要用A的对应行乘以B的对应列,可以看做左边的矩阵的每一列,分别去乘上右边矩阵对应的行,得到n个矩阵,然后我们将它们对应位置加起来就能得到C。这种方式可以参考我的MIT线性代数的学习报告
首先先要放缩一下C,这样我们的列乘行的操作可以直接在C.values
上面累加,节约空间。
1 | if (beta != 1.0) { |
下面是主体,包含两个循环。外层循环变量colCounterForB
每轮递增,用来枚举到B的所有列nB
,内层循环变量colCounterForA
用来枚举A的所有列
1 | while (colCounterForB < nB) { |
现在来看A转置B没有转置的情况,。我们现在是需要用A的对应行乘以B的对应列。因为A被转置了,所以问题实际上变简单了,因为A转置后的行,实际上就是A转置前的列,在物理上是连续存储的。因为我们恰恰需要的是A转置后的行,所以实际可以当做转置前的列来处理。
1 | // colCounterForB在每个while后递增,枚举到B的所有列nB |
有gemm
就有gemv
,也就是C变成Vector了。除此之外,还有对于各种不同m的特化实现。例如SYMV
是对称矩阵乘法,TRMV
是三角矩阵乘法。com.github.fommil.netlib.BLAS
里面有dgbmv
带状矩阵,dsbmv
,dtrmv
等。此外,这些在Breeze库里面也有提供。
启发
- transpose可以做到lazy,但得是对二维特化的
- 矩阵可以Sparse存储
- 矩阵和多维数组最好单独实现
Breeze库
breeze库是一个类似numpy的线性代数库
其中BDM是一个Matrix,BDV是一个Vector,其实都是别名import breeze.linalg.{DenseMatrix => BDM}
其他代码
Attribute对象
Attribute
继承自AttributeFactory
这个object(原来伴生对象也能继承),主要是从StructField
生成Attribute
对象。StructField
是构成StructType
的成员。
1 | def fromStructField(field: StructField): Attribute = decodeStructField(field, false) |