-
Notifications
You must be signed in to change notification settings - Fork 0
/
infogain.py
81 lines (62 loc) · 2.04 KB
/
infogain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import types
from typing import Any, Protocol, runtime_checkable
import numpy as np
from ..cost_functions import NDArrayInt
from ._splitter_protocol import Splitter
from .tree import ChildNodes, TreeIndices
class InfoGainType(Protocol):
def __call__(self, y_node: NDArrayInt, leafs: tuple[NDArrayInt, ...]) -> float:
...
def calc_info_gain(
data: np.ndarray,
y: NDArrayInt,
node_indices: TreeIndices,
feature: int,
splitter: Splitter,
info_gainer: InfoGainType,
**kwargs: Any
) -> float:
y_node: NDArrayInt = y[node_indices]
# split node into n leafs, as determined by the splitter.
split_indices: ChildNodes = splitter(data, node_indices, feature, **kwargs)
leafs: tuple[NDArrayInt, ...] = tuple(y[ind] for ind in split_indices)
return info_gainer(y_node, leafs)
@runtime_checkable
class InfoGain(Protocol):
splitter: Splitter
info_gainer: InfoGainType
def __call__(
self,
data: np.ndarray,
y: NDArrayInt,
node_indices: TreeIndices,
feature: int,
**kwargs: Any
) -> float:
raise NotImplementedError("InfoGain is an abstract class.")
class BaseInfoGain:
splitter: Splitter
info_gainer: InfoGainType
def __call__(
self,
data: np.ndarray,
y: NDArrayInt,
node_indices: TreeIndices,
feature: int,
**kwargs: Any
) -> float:
return calc_info_gain(
data, y, node_indices, feature, self.splitter, self.info_gainer, **kwargs
)
def info_gain_factory(
name: str, splitter: Splitter, info_gainer: InfoGainType
) -> InfoGain:
# kwds = {"splitter": types.MethodType(splitter), "info_gainer": info_gainer},
# new_class = types(name, (InfoGain, BaseInfoGain,),)
new_info_gain_class: type[BaseInfoGain] = type(
name,
(BaseInfoGain,),
{"splitter": staticmethod(splitter), "info_gainer": staticmethod(info_gainer)},
)
initialized_info_gain_class: InfoGain = new_info_gain_class()
return initialized_info_gain_class