`
fighting_2013
  • 浏览: 14800 次
  • 性别: Icon_minigender_1
社区版块
存档分类
最新评论

数据挖掘笔记-分类-决策树-2

阅读更多

接着上面说下决策树的一些其他算法:SLIQ、SPRINT、CART。这些算法则是根据Gini指标来计算的。

SLIQ

SLIQ(Supervised Learning In Quest)利用三中数据结构来构造树,分别是属性表、类表和类直方图。

SLIQ算法在建树阶段,对连续属性采取预先排序技术与广度优先相结合的策略生成树,对离散属性采取快速求子集算法确定划分条件。

具体步骤如下:

step1:建立类表和各个属性表,并且进行预先排序,即对每个连续属性的属性表进行独立的排序,以避免在每个节点上都要给连续属性值重新排序;

step2:如果每个叶子节点中的样本都能归为一类,则算法停止;否则转step3;

step3:利用属性表计算gini值,选择最小gini值的属性和分割点作为最佳划分;

step4:根据step3得到的最佳划分节点,判断为真的样本划分为左孩子节点,否则划分为右孩子节点.这样就构成了广度优先的生成树策略;

step5:更新类表中的第二项,使之指向样本划分后所在的叶子节点;

step6:跳转到step2

 
SLIQ优异性能:
可伸缩性良好:缩短学习时间、处理常驻磁盘的数据集能力、处理结果的准确性
 
SPRINT
SPRINT(Scalable Parallelizable Induction of Classification Tree)算法是一种可扩展的、可并行的归纳决策树,它完全不受内存限制,运行速度快,且允许多个处理器协同创建一个决策树模型.
SPRINT算法是对SLIQ算法的改进,其目的有两个:一是为了能够更好的并行建立决策树,二是为了使得决策树适合更大的数据集.
SPRINT算法定义了两种数据结构,分别是属性表与直方图.属性表由一组三元组<属性值、类别属性、样本号>组成,它随节点的扩张而划分,并归附于相应的子节点.
与SLIQ算法不同,SPRINT算法采取传统的深度优先生成树策略,具体步骤如下:
step1:生成根节点,并为所有属性建立属性表,同时预先排序连续属性的属性表;
step2:如果节点中的样本都能归为一类,则算法停止;否则转step3;
step3:利用属性表寻找拥有最小gini值的划分作为最佳划分方案.算法依次扫描该节点上的每张属性表;
step4:根据划分方案,生成该节点的两个子节点;
step5:划分该节点上的各属性表;
step6:跳转到step2
 
SPRINT算法的优点是在寻找每个结点的最优分裂标准时变得更简单。其缺点是对非分裂属性的属性列表进行分裂变得很困难。解决的办法是对分裂属性进行分裂时用哈希表记录下每个记录属于哪个孩子结点,若内存能够容纳下整个哈希表,其他属性列表的分裂只需参照该哈希表即可。由于哈希表的大小与训练集的大小成正比,当训练集很大时,哈希表可能无法在内存容纳,此时分裂只能分批执行,这使得SPRINT算法的可伸缩性仍然不是很好。
 
CART

分类回归树算法:CART(Classification And Regression Tree)算法采用一种二分递归分割的技术,将当前的样本集分为两个子样本集,使得生成的的每个非叶子节点都有两个分支。因此,CART算法生成的决策树是结构简洁的二叉树。

分类树两个基本思想:第一个是将训练样本进行递归地划分自变量空间进行建树的想法,第二个想法是用验证数据进行剪枝。

 

这里只对SPRINT算法用Java进行了简单实现

@Override
public Object build(Data data) {
	//对数据集预先判断,特征属性为空时候选取最多数量的类型,数据集全部为统一类型时候直接返回类型
	Object preHandleResult = preHandle(data);
	if (null != preHandleResult) return preHandleResult;
	//创建属性表
	Map<String, List<Attribute>> attributeTableMap = 
			new HashMap<String, List<Attribute>>();
	for (Instance instance : data.getInstances()) {
		String category = String.valueOf(instance.getCategory());
		Map<String, Object> attrs = instance.getAttributes();
		for (Map.Entry<String, Object> entry : attrs.entrySet()) {
			String attrName = entry.getKey();
			List<Attribute> attributeTable = attributeTableMap.get(attrName);
			if (null == attributeTable) {
				attributeTable = new ArrayList<Attribute>();
				attributeTableMap.put(attrName, attributeTable);
			}
			attributeTable.add(new Attribute(instance.getId(), 
					attrName, String.valueOf(entry.getValue()), category));
		}
	}
	//计算属性表的基尼指数
	Set<String> attributes = data.getAttributeSet();
	String splitAttribute = null;
	String minSplitPoint = null;
	double minSplitPointGini = 1.0;
	for (Map.Entry<String, List<Attribute>> entry : attributeTableMap.entrySet()) {
		String attribute = entry.getKey();
		if (!attributes.contains(attribute)) {
			continue;
		}
		List<Attribute> attributeTable = entry.getValue();
		Object[] result = calculateMinGini(attributeTable);
		double splitPointGini = Double.parseDouble(String.valueOf(result[1]));
		if (minSplitPointGini > splitPointGini) {
			minSplitPointGini = splitPointGini;
			minSplitPoint = String.valueOf(result[0]);
			splitAttribute = attribute;
		}
	}
	System.out.println("splitAttribute: " + splitAttribute);
	TreeNode treeNode = new TreeNode(splitAttribute);
		
	//根据分割属性和分割点分割数据集
	attributes.remove(splitAttribute);
	Set<String> attributeValues = new HashSet<String>();
	List<List<Instance>> splitInstancess = new ArrayList<List<Instance>>();
	List<Instance> splitInstances1 = new ArrayList<Instance>();
	List<Instance> splitInstances2 = new ArrayList<Instance>();
	splitInstancess.add(splitInstances1);
	splitInstancess.add(splitInstances2);
	for (Instance instance : data.getInstances()) {
		Object value = instance.getAttribute(splitAttribute);
		attributeValues.add(String.valueOf(value));
		if (value.equals(minSplitPoint)) {
			splitInstances1.add(instance);
		} else {
			splitInstances2.add(instance);
		}
	}
	attributeValues.remove(minSplitPoint);
	StringBuilder sb = new StringBuilder();
	for (String attributeValue : attributeValues) {
		sb.append(attributeValue).append(",");
	}
	if (sb.length() > 0) sb.deleteCharAt(sb.length() - 1);
	String[] names = new String[]{minSplitPoint, sb.toString()};
	for (int i = 0; i < 2; i++) {
		List<Instance> splitInstances = splitInstancess.get(i);
		if (splitInstances.size() == 0) continue;
		Data subData = new Data(attributes.toArray(new String[0]),
				splitInstances);
		treeNode.setChild(names[i], build(subData));
	}
	return treeNode;
}

/** 计算基尼指数*/
public Object[] calculateMinGini(List<Attribute> attributeTable) {
	double totalNum = 0.0;
	Map<String, Map<String, Integer>> attrValueSplits = 
			new HashMap<String, Map<String, Integer>>();
	Set<String> splitPoints = new HashSet<String>();
	Iterator<Attribute> iterator = attributeTable.iterator();
	while (iterator.hasNext()) {
		Attribute attribute = iterator.next();
		String attributeValue = attribute.getValue();
		splitPoints.add(attributeValue);
		Map<String, Integer> attrValueSplit = attrValueSplits.get(attributeValue);
		if (null == attrValueSplit) {
			attrValueSplit = new HashMap<String, Integer>();
			attrValueSplits.put(attributeValue, attrValueSplit);
		}
		String category = attribute.getCategory();
		Integer categoryNum = attrValueSplit.get(category);
		attrValueSplit.put(category, null == categoryNum ? 1 : categoryNum + 1);
		totalNum++;
	}
	String minSplitPoint = null;
	double minSplitPointGini = 1.0;
	for (String splitPoint : splitPoints) {
		double splitPointGini = 0.0;
		double splitAboveNum = 0.0;
		double splitBelowNum = 0.0;
		Map<String, Integer> attrBelowSplit = new HashMap<String, Integer>();
		for (Map.Entry<String, Map<String, Integer>> entry : attrValueSplits.entrySet()){
			String attrValue = entry.getKey();
			Map<String, Integer> attrValueSplit = entry.getValue();
			if (splitPoint.equals(attrValue)) {
				for (Integer v : attrValueSplit.values()) {
					splitAboveNum += v;
				}
				double aboveGini = 1.0;
				for (Integer v : attrValueSplit.values()) {
					aboveGini -= Math.pow((v / splitAboveNum), 2);
				}
				splitPointGini += (splitAboveNum / totalNum) * aboveGini;
			} else {
				for (Map.Entry<String, Integer> e : attrValueSplit.entrySet()) {
					String k = e.getKey();
					Integer v = e.getValue();
					Integer count = attrBelowSplit.get(k);
					attrBelowSplit.put(k, null == count ? v : v + count);
					splitBelowNum += e.getValue();
				}
			}
		}
		double belowGini = 1.0;
		for (Integer v : attrBelowSplit.values()) {
			belowGini -= Math.pow((v / splitBelowNum), 2);
		}
		splitPointGini += (splitBelowNum / totalNum) * belowGini;
		if (minSplitPointGini > splitPointGini) {
			minSplitPointGini = splitPointGini;
			minSplitPoint = splitPoint;
		}
	}
	return new Object[]{minSplitPoint, minSplitPointGini};
}
分享到:
评论

相关推荐

    基于C4.5决策树的大学生笔记本电脑购买行为的数据挖掘.pdf

    基于C4.5决策树的大学生笔记本电脑购买行为的数据挖掘.pdf

    数据挖掘十大算法详解.zip

    数据挖掘十大算法详解,数据挖掘学习笔记--决策树C4.5 、数据挖掘十大算法--K-均值聚类算法 、机器学习与数据挖掘-支持向量机(SVM)、拉格朗日对偶、支持向量机(SVM)(三)-- 最优间隔分类器 (optimal margin ...

    决策树DTC数据分析及鸢尾数据集分析.doc

    那么很自然一共就只可能有2棵决策树,如下图所示: 示例3: 第三个例子,推荐这篇文章:决策树学习笔记整理 - bourneli 决策树构建的基本步骤如下: 1. 开始,所有记录看作一个节点; 2. 遍历每个变量的每一种分割...

    20200401零基础入门数据挖掘 – 二手车交易价格预测笔记(4)

    决策树; 4.4 模型对比: 常用线性模型; 常用非线性模型; 4.5 模型调参: 贪心调参方法; 网格调参方法; 贝叶斯调参方法; 下面节选一些我学习比较多的地方进行记录: 4.1.1 线性回归 建立线性模型 from sklearn....

    机器学习&深度学习资料笔记&基本算法实现&资源整理.zip

    决策树 - Adaboost kNN - 朴素贝叶斯 EM - HMM - 条件随机场 kMeans - PCA ROC曲线&AUC值 Stacking(demo) 计算IOU 参考:《机器学习》周志华 《统计学习方法》李航 1.机器学习&深度学习 工具 | 书籍 | 课程 | ...

    数据挖掘学习笔记(三)

    典型方法:决策树、朴素贝叶斯分类、支持向量机、神经网络、规则分类器、基于模式的分类、逻辑回归… 3.聚类分析 聚类就是把一些对象划分为多个组或者“聚簇”,从而使得同组内对象间比较相似而不同组对象间的差异较

    机器学习课程笔记完整版

    作为人工智能领域(数据挖掘/机器学习方向)的提升课程,掌握更深更有效的解决问题技能 目标 应用Scikit-learn实现数据集的特征工程 掌握机器学习常见算法原理 应用Scikit-learn实现机器学习算法的应用,结合...

    案例系列:泰坦尼克号-预测幸存者-TensorFlow决策森林.ipynb jupyter 代码示例

    TensorFlow决策森林在表格数据上表现较好。本笔记将带您完成使用TensorFlow决策森林训练基线梯度提升树模型并在泰坦尼克号竞赛中提交的步骤。

    java笔试题算法-rapaio:统计、数据挖掘和机器学习工具箱

    已实现算法和功能的不完整列表包括:核心统计工具、常见分布和假设检验、朴素贝叶斯、二元逻辑回归、决策树(回归和分类)、随机森林(回归和分类)、AdaBoost、梯度提升树(回归)和分类)、BinarySMO SVM、相关...

    MachineLearningNote

    Python机器学习笔记:使用sklearn做特征工程和数据挖掘 地址: Python机器学习笔记:Grid SearchCV(网格搜索) 地址: 1,logistic Regression 关于逻辑回归文件夹中的数据和代码,详情请参考博客: Python机器学习...

    财经数据分析(第一周笔记整理)

    数据挖掘任务主要分为两种:描述性任务和预测性任务 1.描述性任务 定义:描述性任务将发掘数据中潜在的规律,找出目前可以理解与描述数据集中数据之间的联系,并刻画数据集中数据的一般特性。 描述性任务一般采用的...

    大数据预处理技术.pdf

    对于正常的数据分布⽽⾔可以使⽤均值,⽽倾斜数据分布应使⽤中位数 最可能的值填充:使⽤回归、基于推理的⼯具或者决策树归纳确定。 2.噪声数据与离群点: 噪声:被测量的变量的随机误差或者⽅差(⼀般指错误的数据...

    java版飞机大战源码-open-src:开源项目资料整理

    本课程提供了一个广泛的介绍机器学习、数据挖掘、统计模式识别的课程 莫烦大神 机器学习相关教程 B站: 个人页: AiLearning: 机器学习 AiLearning: 机器学习 - MachineLearning - ML、深度学习 - DeepLearning - DL...

Global site tag (gtag.js) - Google Analytics