Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update GNN stellargraph with DGL #1032

Merged
merged 116 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
116 commits
Select commit Hold shift + click to select a range
425d047
inference, training script and README added
tzemicheal Jul 11, 2023
3555b0b
Merge remote-tracking branch 'upstream/branch-23.07' into fsi_dgl
tzemicheal Jul 11, 2023
04f0e0f
add dgl model
tzemicheal Jul 11, 2023
26ebe10
formatting fix
tzemicheal Jul 11, 2023
ed9d713
fix pylint formatting
tzemicheal Jul 11, 2023
62aaa67
Merge remote-tracking branch 'upstream/branch-23.11' into fsi_dgl
tzemicheal Jul 19, 2023
b208bf3
HinSAGE added & model file updated
tzemicheal Jul 19, 2023
0b4fddf
inference code updated to match training modification
tzemicheal Jul 19, 2023
0625e8f
copyright updated
tzemicheal Jul 19, 2023
cc4bece
gnn pipeline(stages, models) added
tzemicheal Jul 19, 2023
901a19c
Merge remote-tracking branch 'upstream/branch-23.11' into fsi_dgl
tzemicheal Jul 19, 2023
221666b
formatting cli parameters
tzemicheal Jul 19, 2023
4d7288b
Merge remote-tracking branch 'upstream/branch-23.11' into fsi_dgl
tzemicheal Jul 24, 2023
9966f7c
notebook cleaned/added
tzemicheal Jul 24, 2023
c127a7e
modified: morpheus.code-workspace
tzemicheal Jul 24, 2023
086a246
Merge remote-tracking branch 'upstream/branch-23.11' into fsi_dgl
tzemicheal Jul 27, 2023
092c4c5
Merge remote-tracking branch 'upstream/branch-23.11' into fsi_dgl
tzemicheal Jul 27, 2023
856f15d
docstring fixed, conda req added
tzemicheal Jul 28, 2023
400eed2
new file: models/training-tuning-scripts/fraud-detection-models/re…
tzemicheal Jul 28, 2023
3d84ec4
converted pandas ref to cudf & model restructure
tzemicheal Jul 29, 2023
086edc7
inference, training updated
tzemicheal Jul 29, 2023
c41a53c
model files added
tzemicheal Jul 29, 2023
45486e1
models lfs tracking added
tzemicheal Jul 29, 2023
e64a470
Remove extra debugging info
tzemicheal Jul 29, 2023
23edcc9
docstring fixed
tzemicheal Aug 1, 2023
e4265c3
Merge branch 'branch-23.11' into fsi_dgl
tzemicheal Aug 3, 2023
6bb671a
pylint fixes
tzemicheal Aug 3, 2023
9c6c7b3
Remove extra requirements & doc fix
tzemicheal Aug 4, 2023
224d5d5
pin pytorch-cuda to 11.8
dagardner-nv Aug 4, 2023
ee8e796
Merge branch 'fsi_dgl' of github.com:tzemicheal/Morpheus into david-f…
dagardner-nv Aug 4, 2023
d26d4b5
Merge pull request #1 from dagardner-nv/david-fsi_dgl-patch
tzemicheal Aug 4, 2023
f345358
Remove tensorflow requirement from gnn example
tzemicheal Aug 4, 2023
699d887
dask no longer needed for this example
dagardner-nv Aug 4, 2023
a1b4022
Merge remote-tracking branch 'origin/fsi_dgl' into fsi_dgl
tzemicheal Aug 4, 2023
6a47203
Example now only needs dgl and cuml
dagardner-nv Aug 4, 2023
45dfe32
Cleanup tests, drop tests for methods that no longer exist, process_m…
dagardner-nv Aug 4, 2023
7d6a1b4
Update examples/gnn_fraud_detection_pipeline/stages/graph_sage_stage.py
tzemicheal Aug 4, 2023
732e63e
WIP
dagardner-nv Aug 4, 2023
7e421b8
wip
dagardner-nv Aug 4, 2023
21c440c
Update test_graph_sage_stage, still failing, but failing better
dagardner-nv Aug 4, 2023
09ee198
Update tests so that they can actually execute, they fail, but still
dagardner-nv Aug 4, 2023
917b49f
Merge branch 'fsi_dgl' of github.com:tzemicheal/Morpheus into david-f…
dagardner-nv Aug 4, 2023
cba2b09
Lint fixes
dagardner-nv Aug 4, 2023
d3796f8
Revert unintentional model changes
dagardner-nv Aug 4, 2023
9d61da2
Merge pull request #2 from dagardner-nv/david-fsi_dgl-patch
tzemicheal Aug 4, 2023
5b193a8
???
dagardner-nv Aug 22, 2023
3c08ba5
Merge remote-tracking branch 'upstream/branch-23.11' into fsi_dgl
tzemicheal Aug 22, 2023
c447089
Merge branch 'branch-23.11' into david-fsi_dgl-patch3
dagardner-nv Aug 22, 2023
5cab36a
Merge pull request #3 from dagardner-nv/david-fsi_dgl-patch3
tzemicheal Aug 22, 2023
cb87fc9
Merge remote-tracking branch 'origin/fsi_dgl' into fsi_dgl
tzemicheal Aug 22, 2023
7f6fdc9
Remove modules from examples after running tests, since multiple exam…
dagardner-nv Aug 22, 2023
8ff6c53
Ensure float data
dagardner-nv Aug 22, 2023
842910e
Seed dgl as well as the others
dagardner-nv Aug 22, 2023
cfadf9f
Update expected data to match current results
dagardner-nv Aug 22, 2023
255ea98
Run tests with a manual seed to get deterministic results
dagardner-nv Aug 22, 2023
5a23bf8
WIP - passing bug incomplete
dagardner-nv Aug 22, 2023
ccc6c95
Merge pull request #4 from dagardner-nv/david-fsi_dgl-patch3
tzemicheal Aug 23, 2023
673ceee
Merge remote-tracking branch 'origin/fsi_dgl' into fsi_dgl
tzemicheal Aug 23, 2023
798bc22
Merge remote-tracking branch 'upstream/branch-23.11' into fsi_dgl
tzemicheal Aug 23, 2023
29b12bd
Merge branch 'branch-23.11' of github.com:nv-morpheus/Morpheus into d…
dagardner-nv Aug 23, 2023
9b5554e
Resolve merge conflicts
dagardner-nv Aug 23, 2023
7139f5a
Insert the path in the beginning, it seems other things have been app…
dagardner-nv Aug 23, 2023
97b2586
Merge branch 'branch-23.11' of github.com:nv-morpheus/Morpheus into d…
dagardner-nv Aug 23, 2023
5b6a4fa
Merge pull request #5 from dagardner-nv/david-fsi_dgl-patch4
tzemicheal Aug 24, 2023
cfd3c86
Merge remote-tracking branch 'origin/fsi_dgl' into fsi_dgl
tzemicheal Aug 24, 2023
485617d
unit test for graph construction added
tzemicheal Aug 24, 2023
3d27f6e
Fixed test set variable
tzemicheal Aug 24, 2023
e27611a
fix nodeId order
tzemicheal Aug 24, 2023
4c747bd
Merge remote-tracking branch 'upstream/branch-23.11' into fsi_dgl
tzemicheal Aug 24, 2023
870fe3a
unit test for graph construction added
tzemicheal Aug 25, 2023
eb69143
Fixed test set variable
tzemicheal Aug 25, 2023
e79199a
fix nodeId order
tzemicheal Aug 25, 2023
93f232b
Adding HTTP sources & sinks (#977)
tzemicheal Aug 25, 2023
7d4f749
update tests
tzemicheal Aug 25, 2023
de532c6
Merge branch 'fsi_dgl' of github.com:tzemicheal/Morpheus into fsi_dgl
tzemicheal Aug 25, 2023
e66dfb8
Fix linting
tzemicheal Aug 25, 2023
be8a8a9
modified: examples/gnn_fraud_detection_pipeline/stages/model.py
tzemicheal Aug 25, 2023
d5809e4
Merge branch 'fsi_dgl' of github.com:tzemicheal/Morpheus into fsi_dgl
tzemicheal Aug 25, 2023
e1b443e
Make imports from stages explicitly from current package since the gn…
dagardner-nv Aug 25, 2023
fbdb30b
Fix sorting
dagardner-nv Aug 25, 2023
b1b2517
Add dgl to generated-members
dagardner-nv Aug 26, 2023
74047f4
Ignore no-name-in-module linting errors
dagardner-nv Aug 26, 2023
10724b8
Merge pull request #6 from dagardner-nv/david-fsi_dgl-patch6
tzemicheal Aug 26, 2023
86f1055
modified: tests/examples/gnn_fraud_detection_pipeline/conftest.py
tzemicheal Aug 26, 2023
6216906
test updated
tzemicheal Aug 26, 2023
1b98a4a
Update test data, re-enable checking expected df, use manual_seed onl…
dagardner-nv Aug 28, 2023
6c8fc27
Merge pull request #7 from dagardner-nv/david-fsi_dgl-patch7
tzemicheal Aug 28, 2023
7280b52
Fix relative imports
dagardner-nv Aug 28, 2023
f7b5efc
Merge pull request #8 from dagardner-nv/david-fsi_dgl-patch7
tzemicheal Aug 28, 2023
aa7f39d
README run command flags fixed
tzemicheal Aug 28, 2023
4c0e2f1
remove stellargraph from ci runner
dagardner-nv Aug 28, 2023
e50dced
Update container ver
dagardner-nv Aug 28, 2023
f146989
Remove additional constraint on wrapt, remove un-needed pip
dagardner-nv Aug 28, 2023
806f7cb
First pass at consolidating common sections of HinSAGE & HeteroRGCN i…
dagardner-nv Aug 28, 2023
1e9f8da
Revert naming for layers and hetro_embedding for compat with saved mo…
dagardner-nv Aug 28, 2023
ac05828
Merge pull request #9 from dagardner-nv/david-fsi_dgl-patch8
tzemicheal Aug 28, 2023
fea1f1d
More type hints
dagardner-nv Aug 28, 2023
3965f43
Document manual_seed fixture
dagardner-nv Aug 28, 2023
f5c5149
Merge branch 'fsi_dgl' of github.com:tzemicheal/Morpheus into david-f…
dagardner-nv Aug 28, 2023
fbc86a9
Enforce dir for model_dir arg
dagardner-nv Aug 28, 2023
e38d24a
Add docstring for remove_module
dagardner-nv Aug 28, 2023
bb20434
Merge pull request #10 from dagardner-nv/david-fsi_dgl-patch9
tzemicheal Aug 28, 2023
7567042
workspace update
tzemicheal Aug 28, 2023
e64ee7c
Merge branch 'fsi_dgl' of https://github.com/tzemicheal/Morpheus into…
tzemicheal Aug 28, 2023
d91c860
Track lfs val/train dataset
tzemicheal Aug 28, 2023
4118b70
modified: examples/gnn_fraud_detection_pipeline/run.py
tzemicheal Aug 28, 2023
62f73aa
Merge branch 'fsi_dgl' of github.com:tzemicheal/Morpheus into fsi_dgl
tzemicheal Aug 29, 2023
42b68ff
Merge branch 'branch-23.11' into david-fsi-dgl-lfs
dagardner-nv Aug 29, 2023
3798562
Revert "Track lfs val/train dataset"
dagardner-nv Aug 29, 2023
f3a5acd
Merge branch 'david-fsi-dgl-lfs' of github.com:dagardner-nv/Morpheus …
dagardner-nv Aug 29, 2023
525b0c6
Merge pull request #11 from dagardner-nv/david-fsi-dgl-lfs
tzemicheal Aug 29, 2023
58c6d46
inference and evaluate are now methods on the base model class
dagardner-nv Aug 29, 2023
5f7b098
lint fixes
dagardner-nv Aug 29, 2023
020d2b5
Merge pull request #12 from dagardner-nv/david-fsi_dgl-patch10
tzemicheal Aug 29, 2023
843ed38
Fix embedding_size
dagardner-nv Aug 29, 2023
062481e
Merge pull request #13 from dagardner-nv/david-fsi_dgl-patch10
tzemicheal Aug 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ tests/mock_triton_server/payloads/** filter=lfs diff=lfs merge=lfs -text
tests/tests_data/** filter=lfs diff=lfs merge=lfs -text
examples/basic_usage/img/** filter=lfs diff=lfs merge=lfs -text
docs/source/img/* filter=lfs diff=lfs merge=lfs -text
git filter=lfs diff=lfs merge=lfs -text
status filter=lfs diff=lfs merge=lfs -text
4 changes: 2 additions & 2 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
uses: ./.github/workflows/ci_pipe.yml
with:
run_check: ${{ startsWith(github.ref_name, 'pull-request/') }}
container: nvcr.io/ea-nvidia-morpheus/morpheus:morpheus-ci-build-230801
test_container: nvcr.io/ea-nvidia-morpheus/morpheus:morpheus-ci-test-230801
container: nvcr.io/ea-nvidia-morpheus/morpheus:morpheus-ci-build-230828
test_container: nvcr.io/ea-nvidia-morpheus/morpheus:morpheus-ci-test-230828
secrets:
NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
7 changes: 0 additions & 7 deletions ci/runner/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,6 @@ RUN apt update && \
COPY ./docker/conda/environments/cuda${CUDA_SHORT_VER}_examples.yml /tmp/conda/cuda${CUDA_SHORT_VER}_examples.yml
COPY ./ci/scripts/download_kafka.py /tmp/scripts/download_kafka.py

# Install extra deps needed for gnn_fraud_detection_pipeline & ransomware_detection examples
RUN CONDA_ALWAYS_YES=true /opt/conda/bin/mamba env update -n ${PROJ_NAME} -q --file /tmp/conda/cuda${CUDA_SHORT_VER}_examples.yml && \
conda clean -afy && \
source activate ${PROJ_NAME} && \
pip install --ignore-requires-python stellargraph==1.2.1 && \
rm -rf /tmp/conda

# Install camouflage needed for unittests to mock a triton server
RUN source activate ${PROJ_NAME} && \
npm install -g camouflage-server@0.9 && \
Expand Down
2 changes: 1 addition & 1 deletion ci/scripts/run_ci_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ GIT_BRANCH=$(git branch --show-current)
GIT_COMMIT=$(git log -n 1 --pretty=format:%H)

LOCAL_CI_TMP=${LOCAL_CI_TMP:-${MORPHEUS_ROOT}/.tmp/local_ci_tmp}
CONTAINER_VER=${CONTAINER_VER:-230801}
CONTAINER_VER=${CONTAINER_VER:-230828}
CUDA_VER=${CUDA_VER:-11.8}
DOCKER_EXTRA_ARGS=${DOCKER_EXTRA_ARGS:-""}

Expand Down
8 changes: 2 additions & 6 deletions docker/conda/environments/cuda11.8_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,14 @@ channels:
- rapidsai
- nvidia
- conda-forge
- dglteam/label/cu118
dependencies:
- boto3
- chardet=5.0.0
- cuml=23.06
- dask>=2023.1.1
- dgl=1.0.2
- dill=0.3.6
- distributed>=2023.1.1
- mlflow>=2.2.1,<3
- papermill=2.3.4
- s3fs>=2023.6
- pip
- wrapt=1.14.1 # ver 1.15 breaks the keras model used by the gnn_fraud_detection_pipeline
- pip:
# tensorflow exists in conda-forge but is tied to CUDA-11.3
- tensorflow==2.12.0
71 changes: 36 additions & 35 deletions examples/gnn_fraud_detection_pipeline/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,8 @@ Prior to running the GNN fraud detection pipeline, additional requirements must

```bash
mamba env update -n ${CONDA_DEFAULT_ENV} -f examples/gnn_fraud_detection_pipeline/requirements.yml
pip install --ignore-requires-python stellargraph==1.2.1
```

> **Note**: The `--ignore-requires-python` is needed because Stellargraph only officially supports Python versions prior to 3.9 ([stellargraph/stellargraph#1960](https://github.com/stellargraph/stellargraph/issues/1960)).

## Running

##### Setup Env Variable
Expand All @@ -44,23 +41,22 @@ python run.py --help
Usage: run.py [OPTIONS]

Options:
--num_threads INTEGER RANGE Number of internal pipeline threads to use
--num_threads INTEGER RANGE Number of internal pipeline threads to use.
[x>=1]
--pipeline_batch_size INTEGER RANGE
Internal batch size for the pipeline. Can be
much larger than the model batch size. Also
used for Kafka consumers

used for Kafka consumers. [x>=1]
--model_max_batch_size INTEGER RANGE
Max batch size to use for the model
--input_file PATH Input filepath [required]
Max batch size to use for the model. [x>=1]
--model_fea_length INTEGER RANGE
Features length to use for the model.
[x>=1]
--input_file PATH Input data filepath. [required]
--training_file PATH Training data filepath. [required]
--model_dir PATH Trained model directory path [required]
--output_file TEXT The path to the file where the inference
output will be saved.
--training_file PATH Training data file [required]
--model_fea_length INTEGER RANGE
Features length to use for the model
--model-xgb-file PATH The name of the XGB model that is deployed
--model-hinsage-file PATH The name of the trained HinSAGE model file path

--help Show this message and exit.
```

Expand All @@ -71,35 +67,41 @@ cd ${MORPHEUS_ROOT}/examples/gnn_fraud_detection_pipeline
python run.py
```
```
====Registering Pipeline====
====Building Pipeline====
Added source: <from-file-0; FileSourceStage(filename=validation.csv, iterative=None, file_type=auto, repeat=1, filter_null=False)>
Graph construction rate: 0 messages [00:00, ? me====Building Pipeline Complete!====
Inference rate: 0 messages [00:00, ? messages/s]====Registering Pipeline Complete!====
====Starting Pipeline====
====Pipeline Started==== 0 messages [00:00, ? messages/s]
====Building Segment: linear_segment_0====ges/s]
Added source: <from-file-0; FileSourceStage(filename=validation.csv, iterative=False, file_type=FileTypes.Auto, repeat=1, filter_null=False)>
└─> morpheus.MessageMeta
Added stage: <deserialize-1; DeserializeStage()>
Added stage: <deserialize-1; DeserializeStage(ensure_sliceable_index=True)>
└─ morpheus.MessageMeta -> morpheus.MultiMessage
Added stage: <fraud-graph-construction-2; FraudGraphConstructionStage(training_file=training.csv)>
└─ morpheus.MultiMessage -> stages.FraudGraphMultiMessage
Added stage: <monitor-3; MonitorStage(description=Graph construction rate, smoothing=0.05, unit=messages, delayed_start=False, determine_count_fn=None)>
Added stage: <monitor-3; MonitorStage(description=Graph construction rate, smoothing=0.05, unit=messages, delayed_start=False, determine_count_fn=None, log_level=LogLevels.INFO)>
└─ stages.FraudGraphMultiMessage -> stages.FraudGraphMultiMessage
Added stage: <gnn-fraud-sage-4; GraphSAGEStage(model_hinsage_file=model/hinsage-model.pt, batch_size=5, sample_size=[2, 32], record_id=index, target_node=transaction)>
Added stage: <gnn-fraud-sage-4; GraphSAGEStage(model_dir=model, batch_size=100, record_id=index, target_node=transaction)>
└─ stages.FraudGraphMultiMessage -> stages.GraphSAGEMultiMessage
Added stage: <monitor-5; MonitorStage(description=Inference rate, smoothing=0.05, unit=messages, delayed_start=False, determine_count_fn=None)>
Added stage: <monitor-5; MonitorStage(description=Inference rate, smoothing=0.05, unit=messages, delayed_start=False, determine_count_fn=None, log_level=LogLevels.INFO)>
└─ stages.GraphSAGEMultiMessage -> stages.GraphSAGEMultiMessage
Added stage: <gnn-fraud-classification-6; ClassificationStage(model_xgb_file=model/xgb-model.pt)>
Added stage: <gnn-fraud-classification-6; ClassificationStage(model_xgb_file=model/xgb.pt)>
└─ stages.GraphSAGEMultiMessage -> morpheus.MultiMessage
Added stage: <monitor-7; MonitorStage(description=Add classification rate, smoothing=0.05, unit=messages, delayed_start=False, determine_count_fn=None)>
Added stage: <monitor-7; MonitorStage(description=Add classification rate, smoothing=0.05, unit=messages, delayed_start=False, determine_count_fn=None, log_level=LogLevels.INFO)>
└─ morpheus.MultiMessage -> morpheus.MultiMessage
Added stage: <serialize-8; SerializeStage(include=None, exclude=['^ID$', '^_ts_'], output_type=pandas)>
└─ morpheus.MultiMessage -> pandas.DataFrame
Added stage: <monitor-9; MonitorStage(description=Serialize rate, smoothing=0.05, unit=messages, delayed_start=False, determine_count_fn=None)>
└─ pandas.DataFrame -> pandas.DataFrame
Added stage: <to-file-10; WriteToFileStage(filename=result.csv, overwrite=True, file_type=auto)>
└─ pandas.DataFrame -> pandas.DataFrame
====Building Pipeline Complete!====
====Pipeline Started====
Graph construction rate[Complete]: 265messages [00:00, 1590.22messages/s]
Inference rate[Complete]: 265messages [00:01, 150.23messages/s]
Add classification rate[Complete]: 265messages [00:01, 147.11messages/s]
Serialize rate[Complete]: 265messages [00:01, 142.31messages/s]
Added stage: <serialize-8; SerializeStage(include=[], exclude=['^ID$', '^_ts_'], fixed_columns=True)>
└─ morpheus.MultiMessage -> morpheus.MessageMeta
Added stage: <monitor-9; MonitorStage(description=Serialize rate, smoothing=0.05, unit=messages, delayed_start=False, determine_count_fn=None, log_level=LogLevels.INFO)>
└─ morpheus.MessageMeta -> morpheus.MessageMeta
Added stage: <to-file-10; WriteToFileStage(filename=output.csv, overwrite=True, file_type=FileTypes.Auto, include_index_col=True, flush=False)>
└─ morpheus.MessageMeta -> morpheus.MessageMeta
====Building Segment Complete!====
Graph construction rate[Complete]: 265 messages [00:00, 1218.88 messages/s]
Inference rate[Complete]: 265 messages [00:01, 174.04 messages/s]
Add classification rate[Complete]: 265 messages [00:01, 170.69 messages/s]
Serialize rate[Complete]: 265 messages [00:01, 166.36 messages/s]
====Pipeline Complete====
```

### CLI Example
Expand All @@ -118,9 +120,8 @@ morpheus --log_level INFO \
deserialize \
fraud-graph-construction --training_file examples/gnn_fraud_detection_pipeline/training.csv \
monitor --description "Graph construction rate" \
gnn-fraud-sage --model_hinsage_file examples/gnn_fraud_detection_pipeline/model/hinsage-model.pt \
gnn-fraud-sage --model_dir examples/gnn_fraud_detection_pipeline/model/ \
monitor --description "Inference rate" \
gnn-fraud-classification --model_xgb_file examples/gnn_fraud_detection_pipeline/model/xgb-model.pt \
monitor --description "Add classification rate" \
serialize \
to-file --filename "output.csv" --overwrite
Expand Down
Binary file not shown.

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
9 changes: 2 additions & 7 deletions examples/gnn_fraud_detection_pipeline/requirements.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@ channels:
- rapidsai
- nvidia
- conda-forge
- dglteam/label/cu118
dependencies:
- chardet=5.0.0
- cuml=23.06
- dask>=2023.1.1
- distributed>=2023.1.1
- pip
- pip:
# tensorflow exists in conda-forge but is tied to CUDA-11.3
- tensorflow==2.12.0
- dgl=1.0.2
44 changes: 18 additions & 26 deletions examples/gnn_fraud_detection_pipeline/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from morpheus.stages.postprocess.serialize_stage import SerializeStage
from morpheus.stages.preprocess.deserialize_stage import DeserializeStage
from morpheus.utils.logger import configure_logging
# pylint: disable=no-name-in-module
from stages.classification_stage import ClassificationStage
from stages.graph_construction_stage import FraudGraphConstructionStage
from stages.graph_sage_stage import GraphSAGEStage
Expand Down Expand Up @@ -60,48 +61,39 @@
)
@click.option(
"--input_file",
type=click.Path(exists=True, readable=True),
type=click.Path(exists=True, readable=True, dir_okay=False),
default="validation.csv",
required=True,
help="Input data filepath.",
)
@click.option(
"--training_file",
type=click.Path(exists=True, readable=True),
type=click.Path(exists=True, readable=True, dir_okay=False),
default="training.csv",
required=True,
help="Training data filepath.",
)
@click.option(
"--model-hinsage-file",
type=click.Path(exists=True, readable=True),
default="model/hinsage-model.pt",
"--model_dir",
type=click.Path(exists=True, readable=True, file_okay=False, dir_okay=True),
default="model",
required=True,
help="Trained hinsage model filepath.",
)
@click.option(
"--model-xgb-file",
type=click.Path(exists=True, readable=True),
default="model/xgb-model.pt",
required=True,
help="Trained xgb model filepath.",
help="Path to trained Hinsage & XGB models.",
)
@click.option(
"--output_file",
type=click.Path(dir_okay=False),
default="output.csv",
help="The path to the file where the inference output will be saved.",
)
def run_pipeline(
num_threads,
pipeline_batch_size,
model_max_batch_size,
model_fea_length,
input_file,
training_file,
model_hinsage_file,
model_xgb_file,
output_file,
):
def run_pipeline(num_threads,
pipeline_batch_size,
model_max_batch_size,
model_fea_length,
input_file,
training_file,
model_dir,
output_file):
# Enable the default logger.
configure_logging(log_level=logging.INFO)

Expand Down Expand Up @@ -140,12 +132,12 @@ def run_pipeline(
pipeline.add_stage(MonitorStage(config, description="Graph construction rate"))

# Add a sage inference stage.
pipeline.add_stage(GraphSAGEStage(config, model_hinsage_file))
pipeline.add_stage(GraphSAGEStage(config, model_dir))
pipeline.add_stage(MonitorStage(config, description="Inference rate"))

# Add classification stage.
# This stage adds detected classifications to each message.
pipeline.add_stage(ClassificationStage(config, model_xgb_file))
pipeline.add_stage(ClassificationStage(config, os.path.join(model_dir, "xgb.pt")))
pipeline.add_stage(MonitorStage(config, description="Add classification rate"))

# Add a serialize stage.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import typing

import mrc
from mrc.core import operators as ops

Expand Down Expand Up @@ -55,13 +53,13 @@ def __init__(self, c: Config, model_xgb_file: str):
def name(self) -> str:
return "gnn-fraud-classification"

def accepted_types(self) -> typing.Tuple:
def accepted_types(self) -> (GraphSAGEMultiMessage, ):
return (GraphSAGEMultiMessage, )

def supports_cpp_node(self):
def supports_cpp_node(self) -> bool:
return False

def _process_message(self, message: GraphSAGEMultiMessage):
def _process_message(self, message: GraphSAGEMultiMessage) -> GraphSAGEMultiMessage:
ind_emb_columns = message.get_meta(message.inductive_embedding_column_names)
message.set_meta("node_id", message.node_identifiers)

Expand Down
Loading