Alink漫谈(十九) :源码解析 之 分位点离散化Quantile

Alink漫谈(十九) :源码解析 之 分位点离散化Quantile

0x00 摘要

Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文将带领大家来分析Alink中 Quantile 的实现。

因为Alink的公开资料太少,所以以下均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。

本文缘由是因为想分析GBDT,发现GBDT涉及到Quantile的使用,所以只能先分析Quantile 。

0x01 背景概念

1.1 离散化

离散化:就是把无限空间中有限的个体映射到有限的空间中(分箱处理)。数据离散化操作大多是针对连续数据进行的,处理之后的数据值域分布将从连续属性变为离散属性。

离散化方式会影响后续数据建模和应用效果:

  • 使用决策树往往倾向于少量的离散化区间,过多的离散化将使得规则过多受到碎片区间的影响。
  • 关联规则需要对所有特征一起离散化,关联规则关注的是所有特征的关联关系,如果对每个列单独离散化将失去整体规则性。

连续数据的离散化结果可以分为两类:

  • 一类是将连续数据划分为特定区间的集合,例如{(0,10],(10,20],(20,50],(50,100]};
  • 一类是将连续数据划分为特定类,例如类1、类2、类3;

1.2 分位数

分位数(Quantile),亦称分位点,是指将一个随机变量的概率分布范围分为几个等份的数值点,常用的有中位数(即二分位数)、四分位数、百分位数等。

假如有1000个数字(正数),这些数字的5%,30%,50%,70%,99%分位数分别是 [3.0,5.0,6.0,9.0,12.0],这表明

  • 有5%的数字分布在0-3.0之间
  • 有25%的数字分布在3.0-5.0之间
  • 有20%的数字分布在5.0-6.0之间
  • 有20%的数字分布在6.0-9.0之间
  • 有29%的数字分布在9.0-12.0之间
  • 有1%的数字大于12.0

这就是分位数的统计学理解。

因此求解某一组数字中某个数的分位数,只需要将该组数字进行排序,然后再统计小于等于该数的个数,除以总的数字个数即可。

确定p分位数位置的两种方法

  • position = (n+1)p
  • position = 1 + (n-1)p

1.3 四分位数

这里我们用四分位数做进一步说明。

四分位数 概念:把给定的乱序数值由小到大排列并分成四等份,处于三个分割点位置的数值就是四分位数。

第1四分位数 (Q1),又称“较小四分位数”,等于该样本中所有数值由小到大排列后第25%的数字。

第2四分位数 (Q2),又称“中位数”,等于该样本中所有数值由小到大排列后第50%的数字。

第3四分位数 (Q3),又称“较大四分位数”,等于该样本中所有数值由小到大排列后第75%的数字。

四分位距(InterQuartile Range,IQR)= 第3四分位数与第1四分位数的差距。

0x02 示例代码

Alink中完成分位数功能的是QuantileDiscretizerQuantileDiscretizer输入连续的特征列,输出分箱的类别特征。

  • 分位点离散可以计算选定列的分位点,然后使用这些分位点进行离散化。生成选中列对应的q-quantile,其中可以所有列指定一个,也可以每一列对应一个。
  • 分箱数(所需离散的数目,即分为几段)是通过参数numBuckets(桶数目)来指定的。 箱的范围是通过使用近似算法来得到的。

本文示例代码如下。

public class QuantileDiscretizerExample {
    public static void main(String[] args) throws Exception {
        NumSeqSourceBatchOp numSeqSourceBatchOp = new NumSeqSourceBatchOp(1001,2000,"col0"); // 就是把1001 ~ 2000 这个连续数值分段

        Pipeline pipeline = new Pipeline()
                .add(new QuantileDiscretizer()
                        .setNumBuckets(6) // 指定分箱数数目
                        .setSelectedCols(new String[]{"col0"}));

        List<Row> result = pipeline.fit(numSeqSourceBatchOp).transform(numSeqSourceBatchOp).collect();
        System.out.println(result);
    }
}

输出

[0,1,2,3,4,5,.....
0,1
.....
5,5]

0x03 总体逻辑

我们首先给出总体逻辑图例

-------------------------------- 准备阶段 --------------------------------
       │
       │
       │  
┌───────────────────┐ 
│  getSelectedCols  │ 获取需要分位的列名字
└───────────────────┘ 
       │
       │
       │
┌─────────────────────┐ 
│     quantileNum     │ 获取分箱数
└─────────────────────┘ 
       │
       │
       │
┌──────────────────────┐ 
│ Preprocessing.select │ 从输入中根据列名字select出数据
└──────────────────────┘ 
       │
       │
       │
-------------------------------- 预处理阶段 --------------------------------
       │ 
       │
       │
┌──────────────────────┐ 
│       quantile       │ 后续步骤 就是 计算分位数
└──────────────────────┘ 
       │
       │
       │ 
┌────────────────────────────────┐ 
│   countElementsPerPartition    │ 在每一个partition中获取该分区的所有元素个数
└────────────────────────────────┘ 
       │ <task id,count in this task>
       │
       │
┌──────────────────────┐ 
│       sum(1)         │ 这里对第二个参数,即"count in this task"进行累积,得出所有元素的个数
└──────────────────────┘ 
       │  
       │
       │
┌──────────────────────┐ 
│        map           │ 取出所有元素个数,cnt在后续会使用
└──────────────────────┘ 
       │    
       │    
       │
       │    
┌──────────────────────┐ 
│     missingCount     │ 分区查找应选的列中,有哪些数据没有被查到,比如zeroAsMissing,null,isNaN
└──────────────────────┘ 
       │
       │
       │
┌────────────────┐ 
│  mapPartition  │ 把输入数据Row打散,对于Row中的子元素按照Row内顺序一一发送出来
└────────────────┘ 
       │ <idx in row,item in row>,即<row中第几个元素,元素>
       │
       │  
┌──────────────┐ 
│    pSort     │ 将flatten数据进行排序
└──────────────┘ 
       │ 返回的是二元组
       │ f0: dataset which is indexed by partition id
       │ f1: dataset which has partition id and count
       │ 
       │  
-------------------------------- 计算阶段 --------------------------------
       │ 
       │
       │ 
┌─────────────────┐ 
│  MultiQuantile  │ 后续都是具体计算步骤
└─────────────────┘ 
       │
       │ 
       │
┌─────────────────┐ 
│      open       │ 从广播中获取变量,初步处理counts(排序),totalCnt,missingCounts(排序)
└─────────────────┘ 
       │
       │ 
       │
┌─────────────────┐ 
│  mapPartition   │ 具体计算
└─────────────────┘         
       │
       │ 
       │
┌─────────────────┐ 
│    groupBy(0)   │ 依据 列idx 分组
└─────────────────┘   
       │
       │ 
       │
┌─────────────────┐ 
│   reduceGroup   │ 归并排序
└─────────────────┘    
       │set(Tuple2<column idx,真实数据值>)
       │ 
       │ 
-------------------------------- 序列化模型 --------------------------------
       │ 
       │
       │    
┌──────────────┐ 
│  reduceGroup │ 分组归并
└──────────────┘ 
       │ 
       │
       │   
┌─────────────────┐ 
│  SerializeModel │ 序列化模型
└─────────────────┘ 
  

下面图片是为了在手机上缩放适配展示。

QuantileDiscretizerTrainBatchOp.linkFrom如下:

public QuantileDiscretizerTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
   BatchOperator<?> in = checkAndGetFirst(inputs);

   // 示例中设置了 .setSelectedCols(new String[]{"col0"}));, 所以这里 quantileColNames 的数值是"col0 
   String[] quantileColNames = getSelectedCols();

   int[] quantileNum = null;

   // 示例中设置了 .setNumBuckets(6),所以这里 quantileNum 是 quantileNum = {int[1]@2705} 0 = 6
   if (getParams().contains(QuantileDiscretizerTrainParams.NUM_BUCKETS)) {
      quantileNum = new int[quantileColNames.length];
      Arrays.fill(quantileNum,getNumBuckets());
   } else {
      quantileNum = Arrays.stream(getNumBucketsArray()).mapToInt(Integer::intValue).toArray();
   }

   /* filter the selected column from input */
   // 获取了 选择的列 "col0"
   DataSet<Row> input = Preprocessing.select(in,quantileColNames).getDataSet();

   // 计算分位数
   DataSet<Row> quantile = quantile(
      input,quantileNum,getParams().get(HasRoundMode.ROUND_MODE),getParams().get(Preprocessing.ZERO_AS_MISSING)
   );

   // 序列化模型
   quantile = quantile.reduceGroup(
      new SerializeModel(
         getParams(),quantileColNames,TableUtil.findColTypesWithAssertAndHint(in.getSchema(),quantileColNames),BinTypes.BinDivideType.QUANTILE
      )
   );

   /* set output */
   setOutput(quantile,new QuantileDiscretizerModelDataConverter().getModelSchema());

   return this;
}

其总体逻辑如下:

  • 获取需要分位的列名字
  • 获取分箱数
  • 从输入中根据列名字select出数据
  • 调用 quantile 计算分位数
    • 调用 countElementsPerPartition 在每一个partition中获取该分区的所有元素个数,返回<task id,count in this task>,然后 对于元素个数进行累积 sum(1) ,即"count in this task"进行累积,得出所有元素的个数 cnt;
    • 分区查找应选的列中,有哪些数据没有被查到,从代码看,是zeroAsMissing,isNaN这几种情况,然后依据 partition id 进行分组 groupBy(0) 累积求和,得到 missingCount;
    • 把输入数据Row打散,对于Row中的子元素按照Row内顺序一一发送出来,这就做到了把Row类型给flatten了, 返回flatten = <idx in row,即<row中第几个元素,元素>;
    • 将flatten数据进行排序,pSort是大规模分区排序,此时还没有分类。pSort返回的是二元组sortedData,f0: dataset which is indexed by partition id,f1: dataset which has partition id and count;
    • 调用 MultiQuantile ,对 sortedData.f0(f0: dataset which is indexed by partition id) 进行计算分位数;具体是分区计算 mapPartition:
      • 累积,得到当前 task 的起始位置,即 n 个输入数据中从哪个数据开始计算;
      • 根据 taskId 从 counts 中得到了本 task 应该处理哪些数据,即数据的start,end位置;
      • 把数据插入 allRows.add(value); value 可认为是 <partition id,真实数据>;
      • 调用 QIndex 计算分位数元数据;quantileNum是分成几段,q1就是每一段的大小。如果分成6段,则每一段的大小是1/6;
      • 遍历一直到分箱数,每次循环 调用 qIndex.genIndex(j) 获取每个分箱的index。然后依据这个分箱的index从输入数据中获取真实数据值,这个 真实数据值 就是 真实数据的index。比如连续区域是 1001 ~ 2000,分成 6 份,则第一份调用 qIndex.genIndex(j) 得到 167,则根据167,获取真实数据是 1001 + 167 = 1168,即在 1001 ~ 2000 中,第一个分位index 是 1168.
    • 依据 列idx 分组,得到 set(Tuple2<column idx,真实数据值>);
  • 序列化模型

0x04 训练

4.1 quantile

训练是通过 quantile 完成的,大致包含以下步骤。

  • 调用 countElementsPerPartition 在每一个partition中获取该分区的所有元素个数,返回<task id,isNaN这几种情况,然后依据 partition id 进行分组 groupBy(0) 累积求和,得到 missingCount;
  • 把输入数据Row打散,对于Row中的子元素按照Row内顺序一一发送出来,这就做到了把Row类型给flatten了,返回flatten = <idx in row,f1: dataset which has partition id and count;
  • 调用 MultiQuantile ,对 sortedData.f0(f0: dataset which is indexed by partition id) 进行计算分位数。

具体如下

public static DataSet<Row> quantile(
   DataSet<Row> input,final int[] quantileNum,final HasRoundMode.RoundMode roundMode,final boolean zeroAsMissing) {
  
   /* instance count of dataset */
   // countElementsPerPartition 的作用是:在每一个partition中获取该分区的所有元素个数,返回<task id,count in this task>。
   DataSet<Long> cnt = DataSetUtils
      .countElementsPerPartition(input)
      .sum(1) // 这里对第二个参数,即"count in this task"进行累积,得出所有元素的个数。
      .map(new MapFunction<Tuple2<Integer,Long>,Long>() {
         @Override
         public Long map(Tuple2<Integer,Long> value) throws Exception {
            return value.f1; // 取出所有元素个数
         }
      }); // cnt在后续会使用

   /* missing count of columns */
   // 会查找应选的列中,有哪些数据没有被查到,从代码看,是zeroAsMissing,isNaN这几种情况
   DataSet<Tuple2<Integer,Long>> missingCount = input
      .mapPartition(new RichMapPartitionFunction<Row,Tuple2<Integer,Long>>() {
         public void mapPartition(Iterable<Row> values,Collector<Tuple2<Integer,Long>> out) {
            StreamSupport.stream(values.spliterator(),false)
               .flatMap(x -> {
                  long[] counts = new long[x.getArity()];

                  Arrays.fill(counts,0L);
   
                  // 如果发现有数据没有查到,就增加counts
                  for (int i = 0; i < x.getArity(); ++i) {
                     if (x.getField(i) == null
                     || (zeroAsMissing && ((Number) x.getField(i)).doubleValue() == 0.0)
                     || Double.isNaN(((Number)x.getField(i)).doubleValue())) {
                        counts[i]++;
                     }
                  }

                  return IntStream.range(0,x.getArity())
                     .mapToObj(y -> Tuple2.of(y,counts[y]));
               })
               .collect(Collectors.groupingBy(
                  x -> x.f0,Collectors.mapping(x -> x.f1,Collectors.reducing((a,b) -> a + b))
                  )
               )
               .entrySet()
               .stream()
               .map(x -> Tuple2.of(x.getKey(),x.getValue().get()))
               .forEach(out::collect);
         }
      })
      .groupBy(0) //按第一个元素分组
      .reduce(new RichReduceFunction<Tuple2<Integer,Long>>() {
         @Override
         public Tuple2<Integer,Long> reduce(Tuple2<Integer,Long> value1,Long> value2) {
            return Tuple2.of(value1.f0,value1.f1 + value2.f1); //累积求和
         }
      });

   /* flatten dataset to 1d */
   // 把输入数据打散。
   DataSet<PairComparable> flatten = input
      .mapPartition(new RichMapPartitionFunction<Row,PairComparable>() {
         PairComparable pairBuff;
         public void mapPartition(Iterable<Row> values,Collector<PairComparable> out) {
            for (Row value : values) { // 遍历分区内所有输入元素
               for (int i = 0; i < value.getArity(); ++i) { // 如果输入元素Row本身包含多个子元素
                  pairBuff.first = i; // 则对于这些子元素按照Row内顺序一一发送出来,这就做到了把Row类型给flatten了
                  if (value.getField(i) == null
                     || (zeroAsMissing && ((Number) value.getField(i)).doubleValue() == 0.0)
                     || Double.isNaN(((Number)value.getField(i)).doubleValue())) {
                     pairBuff.second = null;
                  } else {
                     pairBuff.second = (Number) value.getField(i);
                  }
                  out.collect(pairBuff); // 返回<idx in row,即<row中第几个元素,元素>
               }
            }
         }
      });

   /* sort data */
   // 将flatten数据进行排序,pSort是大规模分区排序,此时还没有分类
   // pSort返回的是二元组,f0: dataset which is indexed by partition id,f1: dataset which has partition id and count.
   Tuple2<DataSet<PairComparable>,DataSet<Tuple2<Integer,Long>>> sortedData
      = SortUtilsNext.pSort(flatten);

   /* calculate quantile */
   return sortedData.f0 //f0: dataset which is indexed by partition id
      .mapPartition(new MultiQuantile(quantileNum,roundMode))
      .withBroadcastSet(sortedData.f1,"counts") //f1: dataset which has partition id and count
      .withBroadcastSet(cnt,"totalCnt")
      .withBroadcastSet(missingCount,"missingCounts")
      .groupBy(0) // 依据 列idx 分组
      .reduceGroup(new RichGroupReduceFunction<Tuple2<Integer,Number>,Row>() {
         @Override
         public void reduce(Iterable<Tuple2<Integer,Number>> values,Collector<Row> out) {
            TreeSet<Number> set = new TreeSet<>(new Comparator<Number>() {
               @Override
               public int compare(Number o1,Number o2) {
                  return SortUtils.OBJECT_COMPARATOR.compare(o1,o2);
               }
            });

            int id = -1;
            for (Tuple2<Integer,Number> val : values) {
               // Tuple2<column idx,数据>
               id = val.f0;
               set.add(val.f1); 
            }

// runtime变量           
set = {TreeSet@9379}  size = 5
 0 = {Long@9389} 167 // 就是第 0 列的第一段 idx
 1 = {Long@9392} 333 // 就是第 0 列的第二段 idx
 2 = {Long@9393} 500 
 3 = {Long@9394} 667
 4 = {Long@9382} 833
  
            out.collect(Row.of(id,set.toArray(new Number[0])));
         }
      });
}

下面会对几个重点函数做说明。

4.2 countElementsPerPartition

countElementsPerPartition 的作用是:在每一个partition中获取该分区的所有元素个数。

public static <T> DataSet<Tuple2<Integer,Long>> countElementsPerPartition(DataSet<T> input) {
   return input.mapPartition(new RichMapPartitionFunction<T,Long>>() {
      @Override
      public void mapPartition(Iterable<T> values,Long>> out) throws Exception {
         long counter = 0;
         for (T value : values) {
            counter++; // 在每一个partition中获取该分区的所有元素个数
         }
         out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(),counter));
      }
   });
}

4.3 MultiQuantile

MultiQuantile用来计算具体的分位点。

open函数中会从广播中获取变量,初步处理counts(排序),totalCnt,missingCounts(排序)等等。

mapPartition函数则做具体计算,大致步骤如下:

  • 累积,得到当前 task 的起始位置,即 n 个输入数据中从哪个数据开始计算;
  • 根据 taskId 从 counts 中得到了本 task 应该处理哪些数据,即数据的start,end位置;
  • 把数据插入 allRows.add(value); value 可认为是 <partition id,真实数据>;
  • 调用 QIndex 计算分位数元数据;quantileNum是分成几段,q1就是每一段的大小。如果分成6段,则每一段的大小是1/6;
  • 遍历一直到分箱数,每次循环 调用 qIndex.genIndex(j) 获取每个分箱的index。然后依据这个分箱的index从输入数据中获取真实数据值,这个 真实数据值 就是 真实数据的index。比如连续区域是 1001 ~ 2000,分成 6 份,则第一份调用 qIndex.genIndex(j) 得到 167,则根据167,获取真实数据是 1001 + 167 = 1168,即在 1001 ~ 2000 中,第一个分位index 是 1168;

具体代码是:

public static class MultiQuantile
   extends RichMapPartitionFunction<PairComparable,Number>> {
		private List<Tuple2<Integer,Long>> counts;
		private List<Tuple2<Integer,Long>> missingCounts;
		private long totalCnt = 0;
		private int[] quantileNum;
		private HasRoundMode.RoundMode roundType;
		private int taskId;

		@Override
		public void open(Configuration parameters) throws Exception {
      // 从广播中获取变量,初步处理counts(排序),totalCnt,missingCounts(排序)。
      // 之前设置广播变量.withBroadcastSet(sortedData.f1,"counts"),其中 f1 的格式是: dataset which has partition id and count,所以就是用 partition id来排序
			this.counts = getRuntimeContext().getBroadcastVariableWithInitializer(
				"counts",new BroadcastVariableInitializer<Tuple2<Integer,List<Tuple2<Integer,Long>>>() {
					@Override
					public List<Tuple2<Integer,Long>> initializeBroadcastVariable(
						Iterable<Tuple2<Integer,Long>> data) {
						ArrayList<Tuple2<Integer,Long>> sortedData = new ArrayList<>();
						for (Tuple2<Integer,Long> datum : data) {
							sortedData.add(datum);
						}
            //排序
						sortedData.sort(Comparator.comparing(o -> o.f0));
            
// runtime的数据如下,本机有4核,所以数据分为4个 partition,每个partition的数据分别为251,250,250,250        
sortedData = {ArrayList@9347}  size = 4
 0 = {Tuple2@9350} "(0,251)" // partition 0,数据个数是251
 1 = {Tuple2@9351} "(1,250)"
 2 = {Tuple2@9352} "(2,250)"
 3 = {Tuple2@9353} "(3,250)"         
            
						return sortedData;
					}
				});

			this.totalCnt = getRuntimeContext().getBroadcastVariableWithInitializer("totalCnt",new BroadcastVariableInitializer<Long,Long>() {
					@Override
					public Long initializeBroadcastVariable(Iterable<Long> data) {
						return data.iterator().next();
					}
				});

			this.missingCounts = getRuntimeContext().getBroadcastVariableWithInitializer(
				"missingCounts",Long>> data) {
						return StreamSupport.stream(data.spliterator(),false)
							.sorted(Comparator.comparing(o -> o.f0))
							.collect(Collectors.toList());
					}
				}
			);

			taskId = getRuntimeContext().getIndexOfThisSubtask();
      
// runtime的数据如下        
this = {QuantileDiscretizerTrainBatchOp$MultiQuantile@9348} 
 counts = {ArrayList@9347}  size = 4
  0 = {Tuple2@9350} "(0,251)"
  1 = {Tuple2@9351} "(1,250)"
  2 = {Tuple2@9352} "(2,250)"
  3 = {Tuple2@9353} "(3,250)"
 missingCounts = {ArrayList@9375}  size = 1
  0 = {Tuple2@9381} "(0,0)"
 totalCnt = 1001
 quantileNum = {int[1]@9376} 
  0 = 6
 roundType = {HasRoundMode$RoundMode@9377} "ROUND"
 taskId = 2
		}

		@Override
		public void mapPartition(Iterable<PairComparable> values,Number>> out) throws Exception {

			long start = 0;
			long end;

			int curListIndex = -1;
			int size = counts.size(); // 分成4份,所以这里是4

			for (int i = 0; i < size; ++i) {
				int curId = counts.get(i).f0; // 取出输入元素中的 partition id

				if (curId == taskId) {
					curListIndex = i; // 当前 task 对应哪个 partition id
					break; // 到了当前task,就可以跳出了
				}

				start += counts.get(i).f1; // 累积,得到当前 task 的起始位置,即1000个数据中从哪个数据开始计算
			}

      // 根据 taskId 从counts中得到了本 task 应该处理哪些数据,即数据的start,end位置
      // 本 partition 是 0,其中有251个数据
			end = start + counts.get(curListIndex).f1; // end = 起始位置 + 此partition的数据个数 

			ArrayList<PairComparable> allRows = new ArrayList<>((int) (end - start));

			for (PairComparable value : values) {
				allRows.add(value); // value 可认为是 <partition id,真实数据>
			}

			allRows.sort(Comparator.naturalOrder());

// runtime变量
start = 0
curListIndex = 0
size = 4
end = 251
allRows = {ArrayList@9406}  size = 251
 0 = {PairComparable@9408} 
  first = {Integer@9397} 0
  second = {Long@9434} 0
 1 = {PairComparable@9409} 
  first = {Integer@9397} 0
  second = {Long@9435} 1
 2 = {PairComparable@9410} 
  first = {Integer@9397} 0
  second = {Long@9439} 2
 ......
      
      // size = ((251 - 1) / 1001 - 0 / 1001) + 1 = 1
			size = (int) ((end - 1) / totalCnt - start / totalCnt) + 1;

			int localStart = 0;
			for (int i = 0; i < size; ++i) {
				int fIdx = (int) (start / totalCnt + i);
				int subStart = 0;
				int subEnd = (int) totalCnt;

				if (i == 0) {
					subStart = (int) (start % totalCnt); // 0
				}

				if (i == size - 1) {
					subEnd = (int) (end % totalCnt == 0 ? totalCnt : end % totalCnt); // 251
				}

				if (totalCnt - missingCounts.get(fIdx).f1 == 0) {
					localStart += subEnd - subStart;
					continue;
				}

				QIndex qIndex = new QIndex(
					totalCnt - missingCounts.get(fIdx).f1,quantileNum[fIdx],roundType);

// runtime变量
qIndex = {QuantileDiscretizerTrainBatchOp$QIndex@9548} 
 totalCount = 1001.0
 q1 = 0.16666666666666666
 roundMode = {HasRoundMode$RoundMode@9377} "ROUND"      
        
        // 遍历,一直到分箱数。
				for (int j = 1; j < quantileNum[fIdx]; ++j) {
          // 获取每个分箱的index 
					long index = qIndex.genIndex(j); // j = 1 ---> index = 167,就是把 1001 个分为6段,第一段终点是167
          //对应本 task = 0,subStart = 0,subEnd = 251。则index = 167,直接从allRows获取第167个,数值是 1168。因为连续区域是 1001 ~ 2000,所以第167个对应数值就是1168
          //如果本 task = 1,subStart = 251,subEnd = 501。则index = 333,直接从allRows获取第 (333 + 0 - 251)= 第 82 个,获取其中的数值。这里因为数值区域是 1001 ~ 2000,所以数值是1334。
					if (index >= subStart && index < subEnd) { // idx刚刚好在本分区的数据中
						PairComparable pairComparable = allRows.get(
							(int) (index + localStart - subStart)); // 
            
              
// runtime变量            
pairComparable = {PairComparable@9581} 
 first = {Integer@9507} 0 // first是column idx
 second = {Long@9584} 167 // 真实数据     
   
						out.collect(Tuple2.of(pairComparable.first,pairComparable.second));
					}
				}

				localStart += subEnd - subStart;
			}
		}
	}

4.4 QIndex

其中 QIndex 是本文关键所在,就是具体计算分位数。

  • 构造函数中会得倒所有元素个数,每段大小;
  • genIndex函数中会具体计算,比如假设还是6段,则如果取第一段,则k=1,其index为 (1/6 * (1001 - 1) * 1) = 167
public static class QIndex {
   private double totalCount;
   private double q1;
   private HasRoundMode.RoundMode roundMode;

   public QIndex(double totalCount,int quantileNum,HasRoundMode.RoundMode type) {
      this.totalCount = totalCount; // 1001,所有元素的个数
      this.q1 = 1.0 / (double) quantileNum; // 1.0 / 6 = 16666666666666666。quantileNum是分成几段,q1就是每一段的大小。如果分成6段,则每一段的大小是1/6
      this.roundMode = type;
   }

   public long genIndex(int k) {
      // 假设还是6段,则如果取第一段,则k=1,其index为 (1/6 * (1001 - 1) * 1) = 167
      return roundMode.calc(this.q1 * (this.totalCount - 1.0) * (double) k);
   }
}

0x05 输出模型

输出模型是通过 reduceGroup 调用 SerializeModel 来完成。

具体逻辑是:

  • 先构建分箱点元数据信息;
  • 然后序列化成模型;
// 序列化模型
quantile = quantile.reduceGroup(
      new SerializeModel(
         getParams(),BinTypes.BinDivideType.QUANTILE
      )
);

SerializeModel 的具体实现是:

public static class SerializeModel implements GroupReduceFunction<Row,Row> {
   private Params meta;
   private String[] colNames;
   private TypeInformation<?>[] colTypes;
   private BinTypes.BinDivideType binDivideType;

   @Override
   public void reduce(Iterable<Row> values,Collector<Row> out) throws Exception {
      Map<String,FeatureBorder> m = new HashMap<>();
      for (Row val : values) {
         int index = (int) val.getField(0);
         Number[] splits = (Number[]) val.getField(1);
         m.put(
            colNames[index],QuantileDiscretizerModelDataConverter.arraySplit2FeatureBorder(
               colNames[index],colTypes[index],splits,meta.get(QuantileDiscretizerTrainParams.LEFT_OPEN),binDivideType
            )
         );
      }

      for (int i = 0; i < colNames.length; ++i) {
         if (m.containsKey(colNames[i])) {
            continue;
         }

         m.put(
            colNames[i],QuantileDiscretizerModelDataConverter.arraySplit2FeatureBorder(
               colNames[i],colTypes[i],binDivideType
            )
         );
      }

      QuantileDiscretizerModelDataConverter model = new QuantileDiscretizerModelDataConverter(m,meta);

      model.save(model,out);
   }
}

这里用到了 FeatureBorder 类。

数据分箱是按照某种规则将数据进行分类。就像可以将水果按照大小进行分类,售卖不同的价格一样。

FeatureBorder 就是专门为了 Featureborder for binning,discrete Featureborder and continuous Featureborder。

我们能够看出来,该分箱对应的列名,index,各个分割点。

m = {HashMap@9380}  size = 1
 "col0" -> {FeatureBorder@9438} "{"binDivideType":"QUANTILE","featureName":"col0","bin":{"NORM":[{"index":0},{"index":1},{"index":2},{"index":3},{"index":4},{"index":5}],"NULL":{"index":6}},"featureType":"BIGINT","splitsArray":[1168,1334,1501,1667,1834],"isLeftOpen":true,"binCount":6}"

0x06 预测

预测是在 QuantileDiscretizerModelMapper 中完成的。

6.1 加载模型

模型数据是

model = {QuantileDiscretizerModelDataConverter@9582} 
 meta = {Params@9670} "Params {selectedCols=["col0"],version="v2",numBuckets=6}"
 data = {HashMap@9584}  size = 1
  "col0" -> {FeatureBorder@9676} "{"binDivideType":"QUANTILE","binCount":6}"

loadModel会完成加载。

@Override
public void loadModel(List<Row> modelRows) {
   QuantileDiscretizerModelDataConverter model = new QuantileDiscretizerModelDataConverter();
   model.load(modelRows);

   for (int i = 0; i < mapperBuilder.paramsBuilder.selectedCols.length; i++) {
      FeatureBorder border = model.data.get(mapperBuilder.paramsBuilder.selectedCols[i]);
      List<Bin.BaseBin> norm = border.bin.normBins;
      int size = norm.size();
      Long maxIndex = norm.get(0).getIndex();
      Long lastIndex = norm.get(size - 1).getIndex();
      for (int j = 0; j < norm.size(); ++j) {
         if (maxIndex < norm.get(j).getIndex()) {
            maxIndex = norm.get(j).getIndex();
         }
      }

      long maxIndexWithNull = Math.max(maxIndex,border.bin.nullBin.getIndex());

      switch (mapperBuilder.paramsBuilder.handleInvalidStrategy) {
         case KEEP:
            mapperBuilder.vectorSize.put(i,maxIndexWithNull + 1);
            break;
         case SKIP:
         case ERROR:
            mapperBuilder.vectorSize.put(i,maxIndex + 1);
            break;
         default:
            throw new UnsupportedOperationException("Unsupported now.");
      }

      if (mapperBuilder.paramsBuilder.dropLast) {
         mapperBuilder.dropIndex.put(i,lastIndex);
      }

      mapperBuilder.discretizers[i] = createQuantileDiscretizer(border,model.meta);
   }

   mapperBuilder.setAssembledVectorSize();
}

加载中,最后调用 createQuantileDiscretizer 生成 LongQuantileDiscretizer。这就是针对Long类型的离散器。

public static class LongQuantileDiscretizer implements NumericQuantileDiscretizer {
   long[] bounds;
   boolean isLeftOpen;
   int[] boundIndex;
   int nullIndex;
   boolean zeroAsMissing;

   @Override
   public int findIndex(Object number) {
      if (number == null) {
         return nullIndex;
      }

      long lVal = ((Number) number).longValue();

      if (isMissing(lVal,zeroAsMissing)) {
         return nullIndex;
      }

      int hit = Arrays.binarySearch(bounds,lVal);

      if (isLeftOpen) {
         hit = hit >= 0 ? hit - 1 : -hit - 2;
      } else {
         hit = hit >= 0 ? hit : -hit - 2;
      }

      return boundIndex[hit];
   }
}

其数值如下:

this = {QuantileDiscretizerModelMapper$LongQuantileDiscretizer@9768} 
 bounds = {long[7]@9757} 
  0 = -9223372036854775807
  1 = 1168
  2 = 1334
  3 = 1501
  4 = 1667
  5 = 1834
  6 = 9223372036854775807
 isLeftOpen = true
 boundIndex = {int[7]@9743} 
  0 = 0 // -9223372036854775807 ~ 1168 之间对应的最终分箱离散值是 0 
  1 = 1
  2 = 2
  3 = 3
  4 = 4
  5 = 5
  6 = 5 // 1834 ~ 9223372036854775807 之间对应的最终分箱离散值是 5 
 nullIndex = 6
 zeroAsMissing = false

6.2 预测

预测 QuantileDiscretizerModelMapper 的 DiscretizerMapperBuilder 完成。

Row map(Row row){
  
// 这里的 row 举例是: row = {Row@9743} "1003"
   for (int i = 0; i < paramsBuilder.selectedCols.length; i++) {
      int colIdxInData = selectedColIndicesInData[i];
      Object val = row.getField(colIdxInData);
      int foundIndex = discretizers[i].findIndex(val); // 找到 1003对应的index,就是调用Discretizer完成,这里找到 foundIndex 是0
      predictIndices[i] = (long) foundIndex;
   }

   return paramsBuilder.outputColsHelper.getResultRow(
      row,setResultRow(
         predictIndices,paramsBuilder.encode,dropIndex,vectorSize,paramsBuilder.dropLast,assembledVectorSize) // 最后返回离散值是0
   );
}

this = {QuantileDiscretizerModelMapper$DiscretizerMapperBuilder@9744} 
 paramsBuilder = {QuantileDiscretizerModelMapper$DiscretizerParamsBuilder@9752} 
 selectedColIndicesInData = {int[1]@9754} 
 vectorSize = {HashMap@9758}  size = 1
 dropIndex = {HashMap@9759}  size = 1
 assembledVectorSize = {Integer@9760} 6
 discretizers = {QuantileDiscretizerModelMapper$NumericQuantileDiscretizer[1]@9761} 
  0 = {QuantileDiscretizerModelMapper$LongQuantileDiscretizer@9768} 
   bounds = {long[7]@9776} 
   isLeftOpen = true
   boundIndex = {int[7]@9777} 
   nullIndex = 6
   zeroAsMissing = false
 predictIndices = {Long[1]@9763} 

0xFF 参考

QuantileDiscretizer的用法

Spark QuantileDiscretizer 分位数离散器

机器学习——数据离散化(时间离散,多值离散化,分位数,聚类法,频率区间,二值化)

如何通俗地理解分位数?

分位数通俗理解

Python解释数学系列——分位数Quantile

spark之QuantileDiscretizer源码解析

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


文章浏览阅读5.3k次,点赞10次,收藏39次。本章详细写了mysql的安装,环境的搭建以及安装时常见的问题和解决办法。_mysql安装及配置超详细教程
文章浏览阅读1.8k次,点赞50次,收藏31次。本篇文章讲解Spark编程基础这门课程的期末大作业,主要围绕Hadoop基本操作、RDD编程、SparkSQL和SparkStreaming编程展开。_直接将第4题的计算结果保存到/user/root/lisi目录中lisipi文件里。
文章浏览阅读7.8k次,点赞9次,收藏34次。ES查询常用语法目录1. ElasticSearch之查询返回结果各字段含义2. match 查询3. term查询4. terms 查询5. range 范围6. 布尔查询6.1 filter加快查询效率的原因7. boosting query(提高查询)8. dis_max(最佳匹配查询)9. 分页10. 聚合查询【内含实际的demo】_es查询语法
文章浏览阅读928次,点赞27次,收藏18次。
文章浏览阅读1.1k次,点赞24次,收藏24次。作用描述分布式协调和一致性协调多个节点的活动,确保一致性和顺序。实现一致性、领导选举、集群管理等功能,确保系统的稳定和可靠性。高可用性和容错性Zookeeper是高可用的分布式系统,通过多个节点提供服务,容忍节点故障并自动进行主从切换。作为其他分布式系统的高可用组件,提供稳定的分布式协调和管理服务,保证系统的连续可用性。配置管理和动态更新作为配置中心,集中管理和分发配置信息。通过订阅机制,实现对配置的动态更新,以适应系统的变化和需求的变化。分布式锁和并发控制。
文章浏览阅读1.5k次,点赞26次,收藏29次。为贯彻执行集团数字化转型的需要,该知识库将公示集团组织内各产研团队不同角色成员的职务“职级”岗位的评定标准;
文章浏览阅读1.2k次,点赞26次,收藏28次。在安装Hadoop之前,需要进行以下准备工作:确认操作系统:Hadoop可以运行在多种操作系统上,包括Linux、Windows和Mac OS等。选择适合你的操作系统,并确保操作系统版本符合Hadoop的要求。安装Java环境:Hadoop是基于Java开发的,因此需要先安装和配置Java环境。确保已经安装了符合Hadoop版本要求的Java Development Kit (JDK),并设置好JAVA_HOME环境变量。确认硬件要求:Hadoop是一个分布式系统,因此需要多台计算机组成集群。
文章浏览阅读974次,点赞19次,收藏24次。# 基于大数据的K-means广告效果分析毕业设计 基于大数据的K-means广告效果分析。
文章浏览阅读1.7k次,点赞6次,收藏10次。Hadoop入门理论
文章浏览阅读1.3w次,点赞28次,收藏232次。通过博客和文献调研整理的一些农业病虫害数据集与算法。_病虫害数据集
文章浏览阅读699次,点赞22次,收藏7次。ZooKeeper使用的是Zab(ZooKeeper Atomic Broadcast)协议,其选举过程基于一种名为Fast Leader Election(FLE)的算法进行。:每个参与选举的ZooKeeper服务器称为一个“Follower”或“Candidate”,它们都有一个唯一的标识ID(通常是一个整数),并且都知道集群中其他服务器的ID。总之,ZooKeeper的选举机制确保了在任何时刻集群中只有一个Leader存在,并通过过半原则保证了即使部分服务器宕机也能维持高可用性和一致性。
文章浏览阅读10w+次,点赞62次,收藏73次。informatica 9.x是一款好用且功能强大的数据集成平台,主要进行各类数据库的管理操作,是使用相当广泛的一款ETL工具(注: ETL就是用来描述将数据从源端经过抽取(extract)、转换(transform)、加载(load)到目的端的过程)。本文主要为大家图文详细介绍Windows10下informatica powercenter 9.6.1安装与配置步骤。文章到这里就结束了,本人是在虚拟机中装了一套win10然后在此基础上测试安装的这些软件,因为工作学习要分开嘛哈哈哈。!!!!!_informatica客户端安装教程
文章浏览阅读7.8w次,点赞245次,收藏2.9k次。111个Python数据分析实战项目,代码已跑通,数据可下载_python数据分析项目案例
文章浏览阅读1.9k次,点赞61次,收藏64次。TDH企业级一站式大数据基础平台致力于帮助企业更全面、更便捷、更智能、更安全的加速数字化转型。通过数年时间的打磨创新,已帮助数千家行业客户利用大数据平台构建核心商业系统,加速商业创新。为了让大数据技术得到更广泛的使用与应用从而创造更高的价值,依托于TDH强大的技术底座,星环科技推出TDH社区版(Transwarp Data Hub Community Edition)版本,致力于为企业用户、高校师生、科研机构以及其他专业开发人员提供更轻量、更简单、更易用的数据分析开发环境,轻松应对各类人员数据分析需求。_星环tdh没有hive
文章浏览阅读836次,点赞21次,收藏19次。
文章浏览阅读1k次,点赞21次,收藏15次。主要介绍ETL相关工作的一些概念和需求点
文章浏览阅读1.4k次。本文以Android、java为开发技术,实现了一个基于Android的博物馆线上导览系统 app。基于Android的博物馆线上导览系统 app的主要使用者分为管理员和用户,app端:首页、菜谱信息、甜品信息、交流论坛、我的,管理员:首页、个人中心、用户管理、菜谱信息管理、菜谱分类管理、甜品信息管理、甜品分类管理、宣传广告管理、交流论坛、系统管理等功能。通过这些功能模块的设计,基本上实现了整个博物馆线上导览的过程。
文章浏览阅读897次,点赞19次,收藏26次。1.背景介绍在当今的数字时代,数据已经成为企业和组织中最宝贵的资源之一。随着互联网、移动互联网和物联网等技术的发展,数据的产生和收集速度也急剧增加。这些数据包括结构化数据(如数据库、 spreadsheet 等)和非结构化数据(如文本、图像、音频、视频等)。这些数据为企业和组织提供了更多的信息和见解,从而帮助他们做出更明智的决策。业务智能(Business Intelligence,BI)...
文章浏览阅读932次,点赞22次,收藏16次。也就是说,一个类应该对自己需要耦合或调用的类知道的最少,类与类之间的关系越密切,耦合度越大,那么类的变化对其耦合的类的影响也会越大,这也是我们面向对象设计的核心原则:低耦合,高内聚。优秀的架构和产品都是一步一步迭代出来的,用户量的不断增大,业务的扩展进行不断地迭代升级,最终演化成优秀的架构。其根本思想是强调了类的松耦合,类之间的耦合越弱,越有利于复用,一个处在弱耦合的类被修改,不会波及有关系的类。缓存,从操作系统到浏览器,从数据库到消息队列,从应用软件到操作系统,从操作系统到CPU,无处不在。
文章浏览阅读937次,点赞22次,收藏23次。大数据可视化是关于数据视觉表现形式的科学技术研究[9],将数据转换为图形或图像在屏幕上显示出来,并进行各种交互处理的理论、方法和技术。将数据直观地展现出来,以帮助人们理解数据,同时找出包含在海量数据中的规律或者信息,更多的为态势监控和综合决策服务。数据可视化是大数据生态链的最后一公里,也是用户最直接感知数据的环节。数据可视化系统并不是为了展示用户的已知的数据之间的规律,而是为了帮助用户通过认知数据,有新的发现,发现这些数据所反映的实质。大数据可视化的实施是一系列数据的转换过程。