Skip to content

Commit

Permalink
feature layers
Browse files Browse the repository at this point in the history
  • Loading branch information
fuhailin committed Dec 17, 2023
1 parent 253b5dc commit 0db00ad
Show file tree
Hide file tree
Showing 18 changed files with 2,242 additions and 56 deletions.
5 changes: 2 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ tf_configure(

http_archive(
name = "org_tensorflow",
sha256 = "59d09bd00eef6f07477eea2f50778582edd4b7b2850a396f1fd0c646b357a573",
strip_prefix = "tensorflow-2.9.3",
strip_prefix = "tensorflow-2.9.1",
urls = [
"https://github.com/tensorflow/tensorflow/archive/refs/tags/v2.9.3.tar.gz",
"https://github.com/tensorflow/tensorflow/archive/refs/tags/v2.9.1.tar.gz",
],
)

Expand Down
1 change: 1 addition & 0 deletions deepray/custom_ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ py_library(
srcs = glob(["**/*.py"]),
deps = [
"//deepray/custom_ops/correlation_cost",
"//deepray/custom_ops/ffm_ops",
"//deepray/custom_ops/multiplex_1:multiplex_1_op",
"//deepray/custom_ops/multiplex_2:multiplex_2_op",
"//deepray/custom_ops/multiplex_3:multiplex_3_op",
Expand Down
46 changes: 46 additions & 0 deletions deepray/custom_ops/ffm_ops/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
load("//deepray:deepray.bzl", "custom_op_library")
load("@local_config_tf//:build_defs.bzl", "CPLUSPLUS_VERSION")

licenses(["notice"]) # Apache 2.0

custom_op_library(
name = "_ffm_ops.so",
copts = [CPLUSPLUS_VERSION],
srcs = [
"cc/kernels/ffm_kernels.cc",
"cc/kernels/ffm_kernels.h",
"cc/ops/ffm_ops.cc",
],
cuda_srcs = [
"cc/kernels/ffm_kernels.h",
"cc/kernels/ffm_kernels.cu.cc",
],
)

py_library(
name = "ffm_ops",
srcs = glob(
[
"python/*.py",
"python/**/*.py",
"*.py",
],
),
data = [
":_ffm_ops.so",
],
visibility = ["//visibility:public"],
deps = [
"//deepray/utils",
],
)

py_test(
name = "ffm_ops_test",
size = "small",
srcs = glob(["python/tests/*"]),
main = "python/tests/run_all_test.py",
deps = [
":ffm_ops",
],
)
17 changes: 17 additions & 0 deletions deepray/custom_ops/ffm_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

from .python.ffm_ops import ffm
114 changes: 114 additions & 0 deletions deepray/custom_ops/ffm_ops/cc/kernels/ffm_kernels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright 2022 ByteDance and/or its affiliates.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "ffm_kernels.h"
#include <string>
#include <vector>

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"

namespace tensorflow {
namespace monolith_tf {

using CPUDevice = Eigen::ThreadPoolDevice;

template <>
struct FFMImpl<CPUDevice> {
static void Compute(OpKernelContext *ctx, const std::string &int_type,
TTypes<float>::ConstMatrix left_matrix, int left_feat_num,
TTypes<float>::ConstMatrix right_matrix,
int right_feat_num, int batch_size, int dim_size,
TTypes<float>::Matrix output) {
output.setZero();

for (int l = 0; l < left_feat_num; ++l) {
int l_idx = l * dim_size;
for (int r = 0; r < right_feat_num; ++r) {
int r_idx = r * dim_size;
if (int_type == "dot") {
int o_idx = l * right_feat_num + r;
for (int b = 0; b < batch_size; ++b) {
for (int k = 0; k < dim_size; ++k) {
output(b, o_idx) +=
left_matrix(b, l_idx + k) * right_matrix(b, r_idx + k);
}
}
} else {
int o_idx = (l * right_feat_num + r) * dim_size;
for (int b = 0; b < batch_size; ++b) {
for (int k = 0; k < dim_size; ++k) {
output(b, o_idx + k) =
left_matrix(b, l_idx + k) * right_matrix(b, r_idx + k);
}
}
}
}
}
}
};

template <>
struct FFMGradImpl<CPUDevice> {
static void Compute(OpKernelContext *ctx, const std::string &int_type,
TTypes<float>::ConstMatrix grad_matrix, int grad_feat_num,
TTypes<float>::ConstMatrix left_matrix, int left_feat_num,
TTypes<float>::ConstMatrix right_matrix,
int right_feat_num, int batch_size, int dim_size,
TTypes<float>::Matrix left_grad_matrix,
TTypes<float>::Matrix right_grad_matrix) {
left_grad_matrix.setZero();
right_grad_matrix.setZero();

for (int g = 0; g < grad_feat_num; ++g) {
int l_idx = (g / right_feat_num) * dim_size;
int r_idx = (g % right_feat_num) * dim_size;

if (int_type == "dot") {
for (int b = 0; b < batch_size; ++b) {
for (int k = 0; k < dim_size; ++k) {
left_grad_matrix(b, l_idx + k) +=
grad_matrix(b, g) * right_matrix(b, r_idx + k);

right_grad_matrix(b, r_idx + k) +=
grad_matrix(b, g) * left_matrix(b, l_idx + k);
}
}
} else {
int g_idx = g * dim_size;
for (int b = 0; b < batch_size; ++b) {
for (int k = 0; k < dim_size; ++k) {
left_grad_matrix(b, l_idx + k) +=
grad_matrix(b, g_idx + k) * right_matrix(b, r_idx + k);

right_grad_matrix(b, r_idx + k) +=
grad_matrix(b, g_idx + k) * left_matrix(b, l_idx + k);
}
}
}
}
}
};

namespace {

REGISTER_KERNEL_BUILDER(Name("FFM").Device(DEVICE_CPU), FFMOp<CPUDevice>)

REGISTER_KERNEL_BUILDER(Name("FFMGrad").Device(DEVICE_CPU),
FFMGradOp<CPUDevice>)

} // namespace
} // namespace monolith_tf
} // namespace tensorflow
Loading

0 comments on commit 0db00ad

Please sign in to comment.