-
Notifications
You must be signed in to change notification settings - Fork 0
/
TreeNode.m
147 lines (135 loc) · 6.56 KB
/
TreeNode.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
classdef TreeNode < handle
%UNTITLED Summary of this class goes here
% Detailed explanation goes here
properties
data
metadata
split_attribute_no
split_attribute_name
children_labels
children TreeNode;
splits
gain = [];
is_leaf = false;
class_label
label_count
parent_label_count
splitting_value_index = [];
end
methods
function obj = TreeNode(data, metadata, parent_label_count)
%UNTITLED Construct an instance of this class
% Detailed explanation goes here
obj.data = data;
obj.metadata = metadata;
obj.parent_label_count = parent_label_count;
end
function determine_candidate_splits(obj)
obj.splits = determine_candidate_splits(obj.data, obj.metadata);
end
function count_class_labels(obj)
class_labels = obj.data(:, end);
unique_class_labels = obj.metadata.attribute_values{end};
obj.label_count = zeros(length(unique_class_labels), 1);
for i = 1:length(unique_class_labels)
obj.label_count(i) = sum(ismember(class_labels, unique_class_labels{i}));
end
end
function determine_class_label(obj)
class_labels = obj.data(:, end);
unique_class_labels = obj.metadata.attribute_values{end};
max_value_locations = (obj.label_count == max(obj.label_count));
if isempty(obj.data)
% no data points have reached leaf. Class label needs to be
% determined from parent node
[~, label_index] = max(obj.parent_label_count);
elseif sum(max_value_locations) > 1
% multiple labels appear the same (max) number of times
% the class label cannot be determined uniquely based on
% this. Parent node needs to be checked for class label
[~, label_index] = max(obj.parent_label_count);
else
[~, label_index] = max(obj.label_count);
end
obj.class_label = unique_class_labels{label_index};
end
function are_criteria_met = are_stopping_criteria_met(obj, m)
class_labels = obj.data(:, end);
unique_class_values = unique(class_labels);
if (length(class_labels) < m)
obj.is_leaf = true;
elseif (length(unique_class_values)<=1)
obj.is_leaf = true;
elseif (max(obj.gain) <= 0)
obj.is_leaf = true;
else
no_of_unique_labels = ones(1, length(obj.data(1,:)));
for i = 1:length(obj.data(1,:))
no_of_unique_labels(i) = length(unique(obj.data(:, i)));
end
% no_of_unique_labels(i) == 1 implies that either the
% original dataset only had one label for feature i or that
% feature i has already been used earlier in the tree
if(max(no_of_unique_labels(1:end-1)) == 1)
obj.is_leaf = true;
end
end
are_criteria_met = obj.is_leaf;
end
function find_best_split(obj)
%METHOD1 Summary of this method goes here
% Detailed explanation goes here
for attribute_number = 1:length(obj.metadata.attribute_names)
[obj.gain(attribute_number), obj.splitting_value_index(attribute_number)]= ...
info_gain(obj.data, obj.metadata, obj.splits, attribute_number);
end
[~, obj.split_attribute_no] = max(obj.gain(1:end - 1));
obj.split_attribute_name = obj.metadata.attribute_names{obj.split_attribute_no};
end
function populate_children(obj, m)
are_attributes_numeric = obj.metadata.is_attribute_numeric;
if ~are_attributes_numeric(obj.split_attribute_no) % Non numerical features
children_data_sets = split_data_sets(obj.data, obj.metadata,...
obj.splits, obj.split_attribute_no);
else
all_possible_child_data_sets = split_data_sets(obj.data, obj.metadata,...
obj.splits, obj.split_attribute_no);
% construct the correct 2 children data sets here
children_data_sets(1).data = repmat({''}, 1, length(obj.data(1, :)));
children_data_sets(2).data = repmat({''}, 1, length(obj.data(1, :)));
for j = 1:length(all_possible_child_data_sets)
if j <= obj.splitting_value_index(obj.split_attribute_no)
children_data_sets(1).data = ...
combine_data_sets(children_data_sets(1).data, all_possible_child_data_sets(j).data);
else
children_data_sets(2).data =...
combine_data_sets(children_data_sets(2).data, all_possible_child_data_sets(j).data);
end
end
end
obj.children_labels = obj.metadata.attribute_values{obj.split_attribute_no};
for i = 1:length(children_data_sets)
obj.children(i) = make_subtree(children_data_sets(i).data, obj.metadata, m, obj.label_count);
end
end
function label = find_correct_child(obj, test_data_point)
if obj.is_leaf
label = obj.class_label;
elseif ~obj.metadata.is_attribute_numeric(obj.split_attribute_no) % Non Numeric Attribute
child_number = find(strcmp(test_data_point{obj.split_attribute_no}, obj.children_labels));
label = obj.children(child_number).find_correct_child(test_data_point);
else % Numeric Attribute
possible_splits_for_attribute = obj.splits{obj.split_attribute_no};
split_value = possible_splits_for_attribute(obj.splitting_value_index(obj.split_attribute_no));
test_attribute_value = str2num(test_data_point{obj.split_attribute_no});
if (test_attribute_value <= split_value)
% here we just have 2 children - less than value and
% more than value
label = obj.children(1).find_correct_child(test_data_point);
else
label = obj.children(2).find_correct_child(test_data_point);
end
end
end
end
end