讨论Spark上分布式机器学习库的实现。
Spark的机器学习库主要分为ml和mllib,其中ml较新,本文主要围绕ml来讲。
ml和mllib底层会用到Breeze库(类似于numpy的线性代数库)和BLAS(Basic Linear Algebra Subroutines,更基础的线性代数库),在这里也进行介绍。
LR回归的实现
总体实现
ml库的整个训练和预测过程可以用下面三行代码来概括
1 | val classifier = new LogisticRegression() |
训练过程
LogisticRegression类的训练入口是fit。fit首先定义在基类Predictor.scala中,它会首先做一些事情,比如转换labelCol的类型到Double,然后调用train。train会返回一个Model,但在Predictor中并没有实现,而是交付给子类来实现。LogisticRegression类的train是自己实现的。主要包括下面几个部分。
校验打印参数
参数包含下面的一些列名1
labelCol, weightCol, featuresCol, predictionCol, rawPredictionCol, probabilityCol
这里的
weightCol可以给样本指定权重,用来解决数据不平衡问题。详见balanceDataset的实现。
还包括下面的一些训练用参数1
regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept
计算
summarizer和labelSummarizer
可以得到一些基础信息,比如直方图、特征数量、特征的均值方差和数量,class数量。其中直方图(histogram),大概内容类似[0: 34.0, 1: 66.0],意味着分类1拥有34样本,分类2有66个样本。均值(mean)大概内容类似[0: -0.03327765069036007 ]。
根据这些信息还会判断任务类型,是不是多类的(设置isMultinomial)计算
coefficientMatrix、interceptVector和objectiveHistory
这是主要过程,分别是返回系数矩阵、偏置值向量和loss之类的东西。
返回值是(denseCoefficientMatrix.compressed,interceptVec.compressed,arrayBuilder.result())计算
summaryModel、probabilityColName、predictionColName
模型表示
模型存储
首先Model会继承一个org.apache.spark.ml.util.MLWritable。这样就支持model.save()方法。另外,可以通过MLWriter来保存,对应model.write.overwrite().save()
预测过程
LogisticRegressionModel的transform方法实际上实现在ProbabilisticClassificationModel里面,最终会调用predictRaw计算rawPrediction列。predictRaw和下面提到的一些方法都定义在ClassificationModel里面,不过是虚的。
1 | if ($(rawPredictionCol).nonEmpty) { |
然后再调用predictProbability或者raw2probability计算probability列,这个列是一个Vector,表示每一个label的probability,加起来是等于1的(见ProbabilisticClassificationModel里面的注释)。所有继承了ProbabilisticClassifier的分类器都会有在这一列。
1 | if ($(probabilityCol).nonEmpty) { |
然后再调用raw2prediction或者probability2prediction或者predict计算prediction列。
训练过程详解
Summary过程
首先使用MetadataUtils.getNumClasses来获得Class数量。
1 | val numClasses = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { |
MetadataUtils是一个私有对象,里面会计算得到labelSchema对应的Attribute,然后来判断有多少类。Attribute对象的实现比较复杂,单独提到后面讲。
1 | def getNumClasses(labelSchema: StructField): Option[Int] = { |
接着判断是否是多类分类任务,主要是根据family的值来判断,再根据numClasses的值来校验。family可以取下面的值:
auto(默认值)
自动根据class数量选择,如果numClasses == 1 || numClasses == 2,则是binomial,否则是multinomialbinomial
基于pivoting的LRmultinomial
softmax的LR,不基于pivoting1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19("2.1.0")
final val family: Param[String] = new Param(this, "family",
"The name of family which is a description of the label distribution to be used in the " +
s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
(value: String) => supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT)))
/** @group getParam */
("2.1.0")
def getFamily: String = $(family)
val isMultinomial = getFamily.toLowerCase(Locale.ROOT) match {
case "binomial" =>
require(numClasses == 1 || numClasses == 2, s"...")
false
case "multinomial" => true
case "auto" => numClasses > 2
case other => throw new IllegalArgumentException(s"Unsupported family: $other")
}
val numCoefficientSets = if (isMultinomial) numClasses else 1
优化过程
主要是创建了损失函数costFun,并对于这个损失函数使用优化器进行优化。
1 | val regParamL1 = $(elasticNetParam) * $(regParam) |
首先来看损失函数RDDLossFunction的类定义。
首先ClassTag用来实现保障类型擦除后类型安全的功能。
1 | private[ml] class RDDLossFunction[ |
1 | override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { |
优化器
LBGFS
牛顿法的特点是收敛很快,但是运用牛顿法需要计算二阶偏导数,而且目标函数的Hesse矩阵可能非正定。对此,可以使用拟牛顿法,也就是用不包含二阶导数的矩阵近似牛顿法中的Hesse矩阵的逆矩阵。 是一种一阶方法。由于构造近似矩阵的方法不同,因而出现不同的拟牛顿法,BFGS是其中一种。因为BFGS仍然存在较大的内存占用。因此有了LBFGS算法,其中L指的是Limited的意思。
下面,简单查看一下LBFGS的代码,可以发现,它继承了FirstOrderMinimizer优化器,印证了LBFGS是一个一阶方法。
1 | class LBFGS[T](convergenceCheck: ConvergenceCheck[T], m: Int)(implicit space: MutableInnerProductModule[T, Double]) |
OWLQN
OWL-QN(Orthant-Wise Limited-Memory Quasi-Newton)算法,该算法是基于L-BFGS算法的可用于求解L1正则的算法。
Loss函数
treeAggregate
seqOp操作会聚合各分区中的元素,然后combOp操作把所有分区的聚合结果再次聚合,两个操作的初始值都是zeroValue。treeAggregate和aggregate的区别是,aggregate在每个分区处理完之后就直接交给Driver合并了,但treeAggregate会在Executor直接树形地Aggregate。
1 | val sc = spark.sparkContext |
在六台机器上跑了,耗时如下
1 | reduceResult 705082704 7073 aggregateResult 705082704 831 treeAggregateResult 705082704 755 |
BLAS库
BLAS库的选择
Spark的BLAS库实现在mllib/linalg下面。其底层借助于com.github.fommil.netlib,而netelib有两个BLAS库的实现,其中f2jBLAS是从Fortran代码翻译到Java里面的,nativeBLAS是用系统原生的。两个库的使用规则如下所示,至于为什么,参考爆栈网的解释的意思是f2j的实现在Level较低的情况的性能要更好一点。
1 | private[mllib] def getBLAS(vectorSize: Int): NetlibBLAS = { |
这里的Level,指的是BLAS函数分为三个Level,Level越高,计算越复杂。Level1是向量之间的操作,Level2是 向量和矩阵之间的操作,Level3是矩阵和矩阵之间的操作。
axpy和Vector的实现
axpy指的是线代里面的$y = a * x$这样的操作。
axpy里面根据是否稀疏,进行讨论。对于DenseVector会直接调用daxpy。这里daxpy的d不是dense,而是指double,对应的saxpy指的是单精度,还有caxpy代表复数单精度和zaxpy代表复数双精度。
对于SparseVector则要首先进行解码为DenseVector。SparseVector 用两个数组,一个记录原始向量中非零元素的值,另一个记录非零元素对应到原始向量的位置。这有点类似于一维的COO表示形式。
但不是所有的计算都需要转换为DenseVector的,例如向量的dot就不需要。观摩下面的代码发现,稀疏向量x中为0的列就不需要乘了。事实上有人进行了比较,Sparse算得快,但是构造慢。其中nnz表示number of nonzero entries。
1 | private def dot(x: SparseVector, y: DenseVector): Double = { |
gemm和Matrix的实现
gemm,即通用矩阵乘(GEMM,General Matrix Multiplication),其公式为$ C = alpha * A * B + beta * C $,其定义如下。其中A大小$m*k$,B大小$k*n$,C大小$m*n$,且C.isTransposed必须是false。根据A是否为Dense,可以分为2个版本。
1 | def gemm(alpha: Double, |
这里的Matrix同样分为Dense和Sparse。isTransposed表示有没有进行过转置。我理解用这个的原因是为了实现lazy的转置。DenseMatrix的四要素是numRows、numCols、 values和isTransposed,values是一维数组,因为我们可以通过numRows和numCols找到定位。SparseMatrix是六要素,多了colPtrs和rowIndices,values是一维数组,按列存储,实际上就是Compressed Sparse Columns(CSC)格式存储,如果isTransposed是true,那么就按照Compressed Sparse Row(CSR)存储。以isTransposed为false的情况为例,colPtrs表示每一列的起始在values中的位置,rowIndices表示在values中的每一个元素在是这一列中的第几行。
下面展示一下CSC的表示方法,例如矩阵
1 | 1.0 0.0 4.0 |
下面看gemm的实现,首先要对A是不是Dense进行讨论,对Dense而言是平凡的,调用dgemm就行了。对于Spark而言,需要根据A和B有没有transpose过的四种情况进行讨论。下面我们来看A是Sparse,B是Dense对应的gemm重载下的四种情况。
首先来看A和B都没有转置的情况,我们现在是需要用A的对应行乘以B的对应列,可以看做左边的矩阵的每一列,分别去乘上右边矩阵对应的行,得到n个矩阵,然后我们将它们对应位置加起来就能得到C。这种方式可以参考我的MIT线性代数的学习报告
首先先要放缩一下C,这样我们的列乘行的操作可以直接在C.values上面累加,节约空间。
1 | if (beta != 1.0) { |
下面是主体,包含两个循环。外层循环变量colCounterForB每轮递增,用来枚举到B的所有列nB,内层循环变量colCounterForA用来枚举A的所有列
1 | while (colCounterForB < nB) { |
现在来看A转置B没有转置的情况,。我们现在是需要用A的对应行乘以B的对应列。因为A被转置了,所以问题实际上变简单了,因为A转置后的行,实际上就是A转置前的列,在物理上是连续存储的。因为我们恰恰需要的是A转置后的行,所以实际可以当做转置前的列来处理。
1 | // colCounterForB在每个while后递增,枚举到B的所有列nB |
有gemm就有gemv,也就是C变成Vector了。除此之外,还有对于各种不同m的特化实现。例如SYMV是对称矩阵乘法,TRMV是三角矩阵乘法。com.github.fommil.netlib.BLAS里面有dgbmv带状矩阵,dsbmv,dtrmv等。此外,这些在Breeze库里面也有提供。
启发
- transpose可以做到lazy,但得是对二维特化的
- 矩阵可以Sparse存储
- 矩阵和多维数组最好单独实现
Breeze库
breeze库是一个类似numpy的线性代数库
其中BDM是一个Matrix,BDV是一个Vector,其实都是别名import breeze.linalg.{DenseMatrix => BDM}
其他代码
Attribute对象
Attribute继承自AttributeFactory这个object(原来伴生对象也能继承),主要是从StructField生成Attribute对象。StructField是构成StructType的成员。
1 | def fromStructField(field: StructField): Attribute = decodeStructField(field, false) |