Skip to content

Commit

Permalink
added restart probability
Browse files Browse the repository at this point in the history
  • Loading branch information
kerighan committed Nov 5, 2021
1 parent 176118c commit ee5491a
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 5 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ G = nx.random_partition_graph([1000] * 15, .01, .001)
# generate random walks
X = walker.random_walks(G, n_walks=50, walk_len=25)

# generate random walks with restart probability alpha
X = walker.random_walks(G, n_walks=50, walk_len=25, alpha=.1)

# you can generate random walks from specified starting nodes
X = walker.random_walks(G, n_walks=50, walk_len=25, start_nodes=[0, 1, 2])

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

setup(
name="graph-walker",
version="1.0.3",
version="1.0.4",
author="Maixent Chenebaux",
author_email="max.chbx@gmail.com",
description="Fastest library for random walks on graph",
Expand Down
1 change: 1 addition & 0 deletions src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace py = pybind11;

PYBIND11_MODULE(_walker, m) {
m.def("random_walks", &randomWalks, "random walks");
m.def("random_walks_with_restart", &randomWalksRestart, "random walks with restart");
m.def("node2vec_random_walks", &n2vRandomWalks, "node2vec random walks");
m.def("corrupt", &corruptWalks, "corrupt walks");
m.def("weighted_corrupt", &weightedCorruptWalks, "weighted corrupt walks");
Expand Down
94 changes: 94 additions & 0 deletions src/randomWalks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,100 @@ py::array_t<uint32_t> randomWalks(py::array_t<uint32_t> _indptr, py::array_t<uin

}

py::array_t<uint32_t> randomWalksRestart(
py::array_t<uint32_t> _indptr,
py::array_t<uint32_t> _indices,
py::array_t<float> _data,
py::array_t<uint32_t> _startNodes,
size_t nWalks,
size_t walkLen,
float alpha)
{
// get data buffers
py::buffer_info indptrBuf = _indptr.request();
uint32_t *indptr = (uint32_t *)indptrBuf.ptr;

py::buffer_info indicesBuf = _indices.request();
uint32_t *indices = (uint32_t *)indicesBuf.ptr;

py::buffer_info dataBuf = _data.request();
float *data = (float *)dataBuf.ptr;

py::buffer_info startNodesBuf = _startNodes.request();
uint32_t *startNodes = (uint32_t *)startNodesBuf.ptr;

// general variables
size_t nNodes = startNodesBuf.shape[0];
size_t shape = nWalks * nNodes;

// walk matrix
py::array_t<uint32_t> _walks({shape, walkLen});
py::buffer_info walksBuf = _walks.request();
uint32_t *walks = (uint32_t *)walksBuf.ptr;

// make random walks
PARALLEL_FOR_BEGIN(shape)
{
static thread_local std::random_device rd;
static thread_local std::mt19937 generator(rd());
std::uniform_real_distribution<> dist(0., 1.);
std::vector<float> draws;
draws.reserve(walkLen - 1);
for (size_t z = 0; z < walkLen - 1; z++)
{
draws[z] = dist(generator);
}

size_t step = startNodes[i % nNodes];
size_t startNode = step;
walks[i * walkLen] = step;

for (size_t k = 1; k < walkLen; k++)
{
uint32_t start = indptr[step];
uint32_t end = indptr[step + 1];

// if no neighbors, we fill in current node
if (start == end)
{
walks[i * walkLen + k] = step;
continue;
}

if (dist(generator) < alpha){
step = startNode;
} else
{
// searchsorted
float cumsum = 0;
size_t index = 0;
float draw = draws[k - 1];
for (size_t z = start; z < end; z++)
{
cumsum += data[z];
if (draw > cumsum)
{
continue;
}
else
{
index = z;
break;
}
}

// draw next index
step = indices[index];
}

// update walk
walks[i * walkLen + k] = step;
}
}
PARALLEL_FOR_END();

return _walks;
}

py::array_t<uint32_t> n2vRandomWalks(py::array_t<uint32_t> _indptr, py::array_t<uint32_t>_indices, py::array_t<float> _data, py::array_t<uint32_t> _startNodes, size_t nWalks, size_t walkLen, float p, float q){
// get data buffers
Expand Down
1 change: 1 addition & 0 deletions src/randomWalks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ namespace py = pybind11;


py::array_t<uint32_t> randomWalks(py::array_t<uint32_t> _indptr, py::array_t<uint32_t>_indices, py::array_t<float> _data, py::array_t<uint32_t> _startNodes, size_t nWalks, size_t walkLen);
py::array_t<uint32_t> randomWalksRestart(py::array_t<uint32_t> _indptr, py::array_t<uint32_t>_indices, py::array_t<float> _data, py::array_t<uint32_t> _startNodes, size_t nWalks, size_t walkLen, float alpha);
py::array_t<uint32_t> n2vRandomWalks(py::array_t<uint32_t> _indptr, py::array_t<uint32_t>_indices, py::array_t<float> _data, py::array_t<uint32_t> _startNodes, size_t nWalks, size_t walkLen, float p, float q);
py::array_t<bool> corruptWalks(py::array_t<uint32_t> _walks, size_t nNodes, float r);
py::array_t<bool> weightedCorruptWalks(py::array_t<uint32_t> _walks, py::array_t<uint32_t> _candidates, size_t nNodes, float r);
14 changes: 10 additions & 4 deletions walker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from _walker import random_walks as _random_walks
from _walker import random_walks_with_restart as _random_walks_with_restart
from _walker import node2vec_random_walks as _node2vec_random_walks
from _walker import weighted_corrupt as _corrupt
from .preprocessing import get_normalized_adjacency
Expand All @@ -12,7 +13,7 @@ def random_walks(
n_walks=10,
walk_len=10,
sub_sampling=0.,
p=1, q=1,
p=1, q=1, alpha=0,
start_nodes=None,
verbose=True
):
Expand All @@ -29,9 +30,14 @@ def random_walks(
start_nodes = np.array(start_nodes, dtype=np.uint32)

if p == 1 and q == 1:
walks = _random_walks(
indptr, indices, data, start_nodes,
n_walks, walk_len)
if alpha == 0:
walks = _random_walks(
indptr, indices, data, start_nodes,
n_walks, walk_len)
else:
walks = _random_walks_with_restart(
indptr, indices, data, start_nodes,
n_walks, walk_len, alpha)
else:
walks = _node2vec_random_walks(
indptr, indices, data, start_nodes,
Expand Down

0 comments on commit ee5491a

Please sign in to comment.