CART (classification and regression trees) – это аббревиатура, обозначающая методы классификации и регрессии с использованием дерева решений. Это методика обучения, основанная на деревьях решений, которая возвращает классификационные или регрессионные деревья. Как было в случае с C4.5, CART – это классификатор.
Дерево классификации выглядит так же как дерево решений? Дерево классификаций – это подвид дерева решений. Результатом работы дерева классификаций является класс.
Например, снова возьмем набор данных о пациенте. Вы можете попытаться предсказать, будет ли у пациента рак. Здесь возможно использование двух классов: «заболеет раком» и «не заболеет раком».
Что такое дерево регрессий? Дерево классификаций на выходе имеет класс, а дерево регрессий… числовую или непрерывную величину, например, время госпитализации или цену смартфона.
Деревья классификаций выводят классы, деревья регрессий – числа.
Поскольку мы уже разбирали, как деревья решений применяются для классификации данных, давайте сразу перейдем к сути алгоритма:
Как CART соотносится с C4.5?
C4.5:
Cart:
Требует ли этот метод обучения или он самообучающийся? CART требует обучения, поскольку для построения дерева классификаций и дерева регрессий необходим размеченный набор данных.
Почему именно CART? Причины, по которым вы бы использовали C4.5, применимы и к CART, поскольку оба метода – это техники обучения на основании дерева решений. Также к достоинствам CART можно отнести легкую интерпретируемость.
Как и C4.5, CART довольно быстрый, пользуется популярностью и обладает удобно читаемым выводом.
Где он используется? Реализации CART встречаются в scikit-learn. R использует CART в своем пакете работы с деревьями. CART есть в Weka и MATLAB.
CART (видео)
Реализация CART в R
Пример 1 (via rpart)
Let’s use the data frame kyphosis to predict a type of deformation (kyphosis) after surgery, from age in months (Age), number of vertebrae involved (Number), and the highest vertebrae operated on (Start).
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
# Classification Tree with rpart library(rpart) # grow tree fit <- rpart(Kyphosis ~ Age + Number + Start, method="class", data=kyphosis) printcp(fit) # display the results plotcp(fit) # visualize cross-validation results summary(fit) # detailed summary of splits # plot tree plot(fit, uniform=TRUE, main="Classification Tree for Kyphosis") text(fit, use.n=TRUE, all=TRUE, cex=.8) # create attractive postscript plot of tree post(fit, file = "c:/tree.ps", title = "Classification Tree for Kyphosis") # prune the tree pfit<- prune(fit, cp= fit$cptable[which.min(fit$cptable[,"xerror"]),"CP"]) # plot the pruned tree plot(pfit, uniform=TRUE, main="Pruned Classification Tree for Kyphosis") text(pfit, use.n=TRUE, all=TRUE, cex=.8) post(pfit, file = "c:/ptree.ps", title = "Pruned Classification Tree for Kyphosis") |
Пример 2.
In this example we will predict car mileage from price, country, reliability, and car type. The data frame is cu.summary.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
# Regression Tree Example library(rpart) # grow tree fit <- rpart(Mileage~Price + Country + Reliability + Type, method="anova", data=cu.summary) printcp(fit) # display the results plotcp(fit) # visualize cross-validation results summary(fit) # detailed summary of splits # create additional plots par(mfrow=c(1,2)) # two plots on one page rsq.rpart(fit) # visualize cross-validation results # plot tree plot(fit, uniform=TRUE, main="Regression Tree for Mileage ") text(fit, use.n=TRUE, all=TRUE, cex=.8) # create attractive postcript plot of tree post(fit, file = "c:/tree2.ps", title = "Regression Tree for Mileage ") # prune the tree pfit<- prune(fit, cp=0.01160389) # from cptable # plot the pruned tree plot(pfit, uniform=TRUE, main="Pruned Regression Tree for Mileage") text(pfit, use.n=TRUE, all=TRUE, cex=.8) post(pfit, file = "c:/ptree2.ps", title = "Pruned Regression Tree for Mileage") |
Реализация CART в Python
В примере используется Bank Note датасет
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# CART on the Bank Note dataset from random import seed from random import randrange from csv import reader # Load a CSV file def load_csv(filename): file = open(filename, "rb") lines = reader(file) dataset = list(lines) return dataset # Convert string column to float def str_column_to_float(dataset, column): for row in dataset: row[column] = float(row[column].strip()) # Split a dataset into k folds def cross_validation_split(dataset, n_folds): dataset_split = list() dataset_copy = list(dataset) fold_size = int(len(dataset) / n_folds) for i in range(n_folds): fold = list() while len(fold) < fold_size: index = randrange(len(dataset_copy)) fold.append(dataset_copy.pop(index)) dataset_split.append(fold) return dataset_split # Calculate accuracy percentage def accuracy_metric(actual, predicted): correct = 0 for i in range(len(actual)): if actual[i] == predicted[i]: correct += 1 return correct / float(len(actual)) * 100.0 # Evaluate an algorithm using a cross validation split def evaluate_algorithm(dataset, algorithm, n_folds, *args): folds = cross_validation_split(dataset, n_folds) scores = list() for fold in folds: train_set = list(folds) train_set.remove(fold) train_set = sum(train_set, []) test_set = list() for row in fold: row_copy = list(row) test_set.append(row_copy) row_copy[-1] = None predicted = algorithm(train_set, test_set, *args) actual = [row[-1] for row in fold] accuracy = accuracy_metric(actual, predicted) scores.append(accuracy) return scores # Split a dataset based on an attribute and an attribute value def test_split(index, value, dataset): left, right = list(), list() for row in dataset: if row[index] < value: left.append(row) else: right.append(row) return left, right # Calculate the Gini index for a split dataset def gini_index(groups, class_values): gini = 0.0 for class_value in class_values: for group in groups: size = len(group) if size == 0: continue proportion = [row[-1] for row in group].count(class_value) / float(size) gini += (proportion * (1.0 - proportion)) return gini # Select the best split point for a dataset def get_split(dataset): class_values = list(set(row[-1] for row in dataset)) b_index, b_value, b_score, b_groups = 999, 999, 999, None for index in range(len(dataset[0])-1): for row in dataset: groups = test_split(index, row[index], dataset) gini = gini_index(groups, class_values) if gini < b_score: b_index, b_value, b_score, b_groups = index, row[index], gini, groups return {'index':b_index, 'value':b_value, 'groups':b_groups} # Create a terminal node value def to_terminal(group): outcomes = [row[-1] for row in group] return max(set(outcomes), key=outcomes.count) # Create child splits for a node or make terminal def split(node, max_depth, min_size, depth): left, right = node['groups'] del(node['groups']) # check for a no split if not left or not right: node['left'] = node['right'] = to_terminal(left + right) return # check for max depth if depth >= max_depth: node['left'], node['right'] = to_terminal(left), to_terminal(right) return # process left child if len(left) <= min_size: node['left'] = to_terminal(left) else: node['left'] = get_split(left) split(node['left'], max_depth, min_size, depth+1) # process right child if len(right) <= min_size: node['right'] = to_terminal(right) else: node['right'] = get_split(right) split(node['right'], max_depth, min_size, depth+1) # Build a decision tree def build_tree(train, max_depth, min_size): root = get_split(train) split(root, max_depth, min_size, 1) return root # Make a prediction with a decision tree def predict(node, row): if row[node['index']] < node['value']: if isinstance(node['left'], dict): return predict(node['left'], row) else: return node['left'] else: if isinstance(node['right'], dict): return predict(node['right'], row) else: return node['right'] # Classification and Regression Tree Algorithm def decision_tree(train, test, max_depth, min_size): tree = build_tree(train, max_depth, min_size) predictions = list() for row in test: prediction = predict(tree, row) predictions.append(prediction) return(predictions) # Test CART on Bank Note dataset seed(1) # load and prepare data filename = 'data_banknote_authentication.csv' dataset = load_csv(filename) # convert string attributes to integers for i in range(len(dataset[0])): str_column_to_float(dataset, i) # evaluate algorithm n_folds = 5 max_depth = 5 min_size = 10 scores = evaluate_algorithm(dataset, decision_tree, n_folds, max_depth, min_size) print('Scores: %s' % scores) print('Mean Accuracy: %.3f%%' % (sum(scores)/float(len(scores)))) |
Источник примера (тьюториал с пояснениями)