-
Notifications
You must be signed in to change notification settings - Fork 0
/
builder.py
151 lines (129 loc) · 5.01 KB
/
builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Graph builder from pandas dataframes"""
from collections import namedtuple
import dgl
from pandas.api.types import (
is_categorical,
is_categorical_dtype,
is_numeric_dtype,
)
__all__ = ["PandasGraphBuilder"]
def _series_to_tensor(series):
if is_categorical(series):
return torch.LongTensor(series.cat.codes.values.astype("int64"))
else: # numeric
return torch.FloatTensor(series.values)
class PandasGraphBuilder(object):
"""Creates a heterogeneous graph from multiple pandas dataframes.
Examples
--------
Let's say we have the following three pandas dataframes:
User table ``users``:
=========== =========== =======
``user_id`` ``country`` ``age``
=========== =========== =======
XYZZY U.S. 25
FOO China 24
BAR China 23
=========== =========== =======
Game table ``games``:
=========== ========= ============== ==================
``game_id`` ``title`` ``is_sandbox`` ``is_multiplayer``
=========== ========= ============== ==================
1 Minecraft True True
2 Tetris 99 False True
=========== ========= ============== ==================
Play relationship table ``plays``:
=========== =========== =========
``user_id`` ``game_id`` ``hours``
=========== =========== =========
XYZZY 1 24
FOO 1 20
FOO 2 16
BAR 2 28
=========== =========== =========
One could then create a bidirectional bipartite graph as follows:
>>> builder = PandasGraphBuilder()
>>> builder.add_entities(users, 'user_id', 'user')
>>> builder.add_entities(games, 'game_id', 'game')
>>> builder.add_binary_relations(plays, 'user_id', 'game_id', 'plays')
>>> builder.add_binary_relations(plays, 'game_id', 'user_id', 'played-by')
>>> g = builder.build()
>>> g.num_nodes('user')
3
>>> g.num_edges('plays')
4
"""
def __init__(self):
self.entity_tables = {}
self.relation_tables = {}
self.entity_pk_to_name = (
{}
) # mapping from primary key name to entity name
self.entity_pk = {} # mapping from entity name to primary key
self.entity_key_map = (
{}
) # mapping from entity names to primary key values
self.num_nodes_per_type = {}
self.edges_per_relation = {}
self.relation_name_to_etype = {}
self.relation_src_key = {} # mapping from relation name to source key
self.relation_dst_key = (
{}
) # mapping from relation name to destination key
def add_entities(self, entity_table, primary_key, name):
entities = entity_table[primary_key].astype("category")
if not (entities.value_counts() == 1).all():
raise ValueError(
"Different entity with the same primary key detected."
)
# preserve the category order in the original entity table
entities = entities.cat.reorder_categories(
entity_table[primary_key].values
)
self.entity_pk_to_name[primary_key] = name
self.entity_pk[name] = primary_key
self.num_nodes_per_type[name] = entity_table.shape[0]
self.entity_key_map[name] = entities
self.entity_tables[name] = entity_table
def add_binary_relations(
self, relation_table, source_key, destination_key, name
):
src = relation_table[source_key].astype("category")
src = src.cat.set_categories(
self.entity_key_map[
self.entity_pk_to_name[source_key]
].cat.categories
)
dst = relation_table[destination_key].astype("category")
dst = dst.cat.set_categories(
self.entity_key_map[
self.entity_pk_to_name[destination_key]
].cat.categories
)
if src.isnull().any():
raise ValueError(
"Some source entities in relation %s do not exist in entity %s."
% (name, source_key)
)
if dst.isnull().any():
raise ValueError(
"Some destination entities in relation %s do not exist in entity %s."
% (name, destination_key)
)
srctype = self.entity_pk_to_name[source_key]
dsttype = self.entity_pk_to_name[destination_key]
etype = (srctype, name, dsttype)
self.relation_name_to_etype[name] = etype
self.edges_per_relation[etype] = (
src.cat.codes.values.astype("int64"),
dst.cat.codes.values.astype("int64"),
)
self.relation_tables[name] = relation_table
self.relation_src_key[name] = source_key
self.relation_dst_key[name] = destination_key
def build(self):
# Create heterograph
graph = dgl.heterograph(
self.edges_per_relation, self.num_nodes_per_type
)
return graph