Skip to content

Commit

Permalink
Merge branch 'master' into homophily
Browse files Browse the repository at this point in the history
  • Loading branch information
mufeili authored Mar 23, 2023
2 parents 66c0291 + 170203a commit 25a2aa8
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 6 deletions.
4 changes: 2 additions & 2 deletions python/dgl/_sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target="u", rhs_target="v"):
out_shp = (gidx.number_of_edges(0),) + infer_broadcast_shape(
op, lhs_shp[1:], rhs_shp[1:]
)
out = F.zeros(out_shp, dtype, ctx)
out = F.empty(out_shp, dtype, ctx)
if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSDDMM(
gidx,
Expand Down Expand Up @@ -615,7 +615,7 @@ def _gsddmm_hetero(
out_shp = (gidx.number_of_edges(etid),) + infer_broadcast_shape(
op, lhs_shp[1:], rhs_shp[1:]
)
out_list[etid] = F.zeros(out_shp, dtype, ctx)
out_list[etid] = F.empty(out_shp, dtype, ctx)
if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSDDMMHetero(
gidx,
Expand Down
20 changes: 20 additions & 0 deletions python/dgl/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,26 @@ def swapaxes(input, axis1, axis2):
pass


def empty(shape, dtype, ctx):
"""Create a tensor filled with uninitialized data.
Parameters
----------
shape : tuple of int
The tensor shape.
dtype : data type
It should be one of the values in the data type dict.
ctx : context
The device of the result tensor.
Returns
-------
Tensor
The emtpy tensor.
"""
pass


def zeros(shape, dtype, ctx):
"""Create a zero tensor.
Expand Down
4 changes: 4 additions & 0 deletions python/dgl/backend/mxnet/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def swapaxes(input, axis1, axis2):
return nd.swapaxes(input, axis1, axis2)


def empty(shape, dtype, ctx):
return nd.empty(shape, dtype=dtype, ctx=ctx)


def zeros(shape, dtype, ctx):
return nd.zeros(shape, dtype=dtype, ctx=ctx)

Expand Down
6 changes: 3 additions & 3 deletions python/dgl/backend/pytorch/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ class SEGMENTMM(th.autograd.Function):
def forward(ctx, A, B, seglen_A):
if B.dim() != 3:
raise ValueError("segment_mm expects B to be a 3D tensor.")
C = th.zeros((A.shape[0], B.shape[2]), device=A.device, dtype=A.dtype)
C = th.empty((A.shape[0], B.shape[2]), device=A.device, dtype=A.dtype)
C = _segment_mm(A, B, C, seglen_A)
ctx.backward_cache = A, B, seglen_A
return C
Expand All @@ -981,11 +981,11 @@ def backward(ctx, dZ):
A_grad = B_grad = None
if ctx.needs_input_grad[0]:
# Compute A_grad = Out_grad * B^T
A_grad = th.zeros(A.shape, device=A.device, dtype=A.dtype)
A_grad = th.empty(A.shape, device=A.device, dtype=A.dtype)
A_grad = _segment_mm(dZ, B, A_grad, seglen_A, b_trans=True)
if ctx.needs_input_grad[1]:
# Compute B_grad = A^T * Out_grad
B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype)
B_grad = th.empty(B.shape, device=B.device, dtype=B.dtype)
B_grad = _segment_mm_backward_B(A, dZ, B_grad, seglen_A)
return A_grad, B_grad, None

Expand Down
4 changes: 4 additions & 0 deletions python/dgl/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def swapaxes(input, axis1, axis2):
return th.transpose(input, axis1, axis2)


def empty(shape, dtype, ctx):
return th.empty(shape, dtype=dtype, device=ctx)


def zeros(shape, dtype, ctx):
return th.zeros(shape, dtype=dtype, device=ctx)

Expand Down
5 changes: 5 additions & 0 deletions python/dgl/backend/tensorflow/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,11 @@ def swapaxes(input, axis1, axis2):
return tf.transpose(input, perm=t)


def empty(shape, dtype, ctx):
# tf doesn't have tf.empty(), use zeros() as a workaround
return zeros(shape, dtype, ctx)


def zeros(shape, dtype, ctx):
with tf.device(ctx):
t = tf.zeros(shape, dtype=dtype)
Expand Down
22 changes: 22 additions & 0 deletions python/dgl/graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,27 @@ def readonly(self, readonly_state=True):
state = (n_nodes, readonly_state, src, dst)
self.__setstate__(state)

def num_nodes(self):
"""Return the number of nodes.
Returns
-------
int
The number of nodes.
"""
return _CAPI_DGLGraphNumVertices(self)

def num_edges(self):
"""Return the number of edges.
Returns
-------
int
The number of edges.
"""
return _CAPI_DGLGraphNumEdges(self)

# TODO(#5485): remove this method.
def number_of_nodes(self):
"""Return the number of nodes.
Expand All @@ -171,6 +192,7 @@ def number_of_nodes(self):
"""
return _CAPI_DGLGraphNumVertices(self)

# TODO(#5485): remove this method.
def number_of_edges(self):
"""Return the number of edges.
Expand Down
32 changes: 32 additions & 0 deletions python/dgl/heterograph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,37 @@ def is_readonly(self):
"""
return bool(_CAPI_DGLHeteroIsReadonly(self))

def num_nodes(self, ntype):
"""Return the number of nodes.
Parameters
----------
ntype : int
Node type.
Returns
-------
int
The number of nodes.
"""
return _CAPI_DGLHeteroNumVertices(self, int(ntype))

def num_edges(self, etype):
"""Return the number of edges.
Parameters
----------
etype : int
Edge type.
Returns
-------
int
The number of edges.
"""
return _CAPI_DGLHeteroNumEdges(self, int(etype))

# TODO(#5485): remove this method.
def number_of_nodes(self, ntype):
"""Return the number of nodes.
Expand All @@ -374,6 +405,7 @@ def number_of_nodes(self, ntype):
"""
return _CAPI_DGLHeteroNumVertices(self, int(ntype))

# TODO(#5485): remove this method.
def number_of_edges(self, etype):
"""Return the number of edges.
Expand Down
2 changes: 1 addition & 1 deletion src/array/cuda/gather_mm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ void SegmentMMBackwardB(
int64_t A_offset = 0, dC_offset = 0, dB_offset = 0;
int64_t m, n, k;
int64_t num_rel = seglen.NumElements();
DType alpha = 1., beta = 1.;
DType alpha = 1., beta = 0.;

auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
if (!thr_entry->cublas_handle)
Expand Down

0 comments on commit 25a2aa8

Please sign in to comment.