Skip to content

Commit

Permalink
initial db backend
Browse files Browse the repository at this point in the history
  • Loading branch information
rpiazza committed Nov 5, 2024
1 parent e508f99 commit f1af0dd
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 54 deletions.
2 changes: 1 addition & 1 deletion stix2/datastore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def add(self, *args, **kwargs):
"""
try:
return self.sink.add(*args, **kwargs)
except AttributeError:
except AttributeError as ex:
msg = "%s has no data sink to put objects in"
raise AttributeError(msg % self.__class__.__name__)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

from typing import Any
import os

from sqlalchemy import create_engine
from sqlalchemy_utils import create_database, database_exists, drop_database


class DatabaseBackend:
def __init__(self, database_connection_url, force_recreate=False, **kwargs: Any):
self.database_connection_url = database_connection_url
self.database_exists = database_exists(database_connection_url)

if force_recreate:
if self.database_exists:
drop_database(database_connection_url)
create_database(database_connection_url)
self.database_exists = database_exists(database_connection_url)

self.database_connection = create_engine(database_connection_url)

def _create_schemas(self):
pass

@staticmethod
def _determine_schema_name(stix_object):
return ""

def _create_database(self):
if self.database_exists:
drop_database(self.database_connection.url)
create_database(self.database_connection.url)
self.database_exists = database_exists(self.database_connection.url)


Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
from typing import Any
from sqlalchemy.schema import CreateSchema

from .database_backend_base import DatabaseBackend

from stix2.base import (
_DomainObject, _MetaObject, _Observable, _RelationshipObject, _STIXBase,
)


class PostgresBackend(DatabaseBackend):
default_database_connection_url = \
f"postgresql://{os.getenv('POSTGRES_USER', 'postgres')}:" + \
f"{os.getenv('POSTGRES_PASSWORD', 'postgres')}@" + \
f"{os.getenv('POSTGRES_IP_ADDRESS', '0.0.0.0')}:" + \
f"{os.getenv('POSTGRES_PORT', '5432')}/postgres"

def __init__(self, database_connection_url=default_database_connection_url, force_recreate=False, **kwargs: Any):
super().__init__(database_connection_url, force_recreate=False, **kwargs)

def _create_schemas(self):
with self.database_connection.begin() as trans:
trans.execute(CreateSchema("common", if_not_exists=True))
trans.execute(CreateSchema("sdo", if_not_exists=True))
trans.execute(CreateSchema("sco", if_not_exists=True))
trans.execute(CreateSchema("sro", if_not_exists=True))

@staticmethod
def _determine_schema_name(stix_object):
if isinstance(stix_object, _DomainObject):
return "sdo"
elif isinstance(stix_object, _Observable):
return "sco"
elif isinstance(stix_object, _RelationshipObject):
return "sro"
elif isinstance(stix_object, _MetaObject):
return "common"
98 changes: 48 additions & 50 deletions stix2/datastore/relational_db/relational_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def _add(store, stix_data, allow_custom=True, version="2.1"):

class RelationalDBStore(DataStoreMixin):
def __init__(
self, database_connection_url, allow_custom=True, version=None,
instantiate_database=True, force_recreate=False, *stix_object_classes,
self, db_backend, allow_custom=True, version=None,
instantiate_database=True, *stix_object_classes,
):
"""
Initialize this store.
Expand All @@ -80,32 +80,31 @@ def __init__(
auto-detect all classes and create table schemas for all of
them.
"""
database_connection = create_engine(database_connection_url)

self.metadata = MetaData()
create_table_objects(
self.metadata, stix_object_classes,
)

super().__init__(
source=RelationalDBSource(
database_connection,
db_backend,
metadata=self.metadata,
),
sink=RelationalDBSink(
database_connection,
db_backend,
allow_custom=allow_custom,
version=version,
instantiate_database=instantiate_database,
force_recreate=force_recreate,
metadata=self.metadata,
),
)


class RelationalDBSink(DataSink):
def __init__(
self, database_connection_or_url, allow_custom=True, version=None,
instantiate_database=True, force_recreate=False, *stix_object_classes, metadata=None,
self, db_backend, allow_custom=True, version=None,
instantiate_database=True, *stix_object_classes, metadata=None,
):
"""
Initialize this sink. Only one of stix_object_classes and metadata
Expand Down Expand Up @@ -135,14 +134,16 @@ def __init__(
"""
super(RelationalDBSink, self).__init__()

if isinstance(database_connection_or_url, str):
self.database_connection = create_engine(database_connection_or_url)
else:
self.database_connection = database_connection_or_url
self.db_backend = db_backend

self.database_exists = database_exists(self.database_connection.url)
if force_recreate:
self._create_database()
# if isinstance(database_connection_or_url, str):
# self.database_connection = create_engine(database_connection_or_url)
# else:
# self.database_connection = database_connection_or_url

# self.database_exists = database_exists(self.database_connection.url)
# if force_recreate:
# self._create_database()

if metadata:
self.metadata = metadata
Expand All @@ -160,49 +161,49 @@ def __init__(
self.tables_dictionary[canonicalize_table_name(t.name, t.schema)] = t

if instantiate_database:
if not self.database_exists:
self._create_database()
self._create_schemas()
if not self.db_backend.database_exists:
self.db_backend._create_database()
self.db_backend._create_schemas()
self._instantiate_database()

def _create_schemas(self):
with self.database_connection.begin() as trans:
trans.execute(CreateSchema("common", if_not_exists=True))
trans.execute(CreateSchema("sdo", if_not_exists=True))
trans.execute(CreateSchema("sco", if_not_exists=True))
trans.execute(CreateSchema("sro", if_not_exists=True))
# def _create_schemas(self):
# with self.database_connection.begin() as trans:
# trans.execute(CreateSchema("common", if_not_exists=True))
# trans.execute(CreateSchema("sdo", if_not_exists=True))
# trans.execute(CreateSchema("sco", if_not_exists=True))
# trans.execute(CreateSchema("sro", if_not_exists=True))

def _instantiate_database(self):
self.metadata.create_all(self.database_connection)
self.metadata.create_all(self.db_backend.database_connection)

def _create_database(self):
if self.database_exists:
drop_database(self.database_connection.url)
create_database(self.database_connection.url)
self.database_exists = database_exists(self.database_connection.url)
# def _create_database(self):
# if self.database_exists:
# drop_database(self.database_connection.url)
# create_database(self.database_connection.url)
# self.database_exists = database_exists(self.database_connection.url)

def generate_stix_schema(self):
for t in self.metadata.tables.values():
print(CreateTable(t).compile(self.database_connection))
print(CreateTable(t).compile(self.db_backend.database_connection))

def add(self, stix_data, version=None):
_add(self, stix_data)
add.__doc__ = _add.__doc__

@staticmethod
def _determine_schema_name(stix_object):
if isinstance(stix_object, _DomainObject):
return "sdo"
elif isinstance(stix_object, _Observable):
return "sco"
elif isinstance(stix_object, _RelationshipObject):
return "sro"
elif isinstance(stix_object, _MetaObject):
return "common"
# @staticmethod
# def _determine_schema_name(stix_object):
# if isinstance(stix_object, _DomainObject):
# return "sdo"
# elif isinstance(stix_object, _Observable):
# return "sco"
# elif isinstance(stix_object, _RelationshipObject):
# return "sro"
# elif isinstance(stix_object, _MetaObject):
# return "common"

def insert_object(self, stix_object):
schema_name = self._determine_schema_name(stix_object)
with self.database_connection.begin() as trans:
schema_name = self.db_backend._determine_schema_name(stix_object)
with self.db_backend.database_connection.begin() as trans:
statements = generate_insert_for_object(self, stix_object, schema_name)
for stmt in statements:
print("executing: ", stmt)
Expand All @@ -211,7 +212,7 @@ def insert_object(self, stix_object):

def clear_tables(self):
tables = list(reversed(self.metadata.sorted_tables))
with self.database_connection.begin() as trans:
with self.db_backend.database_connection.begin() as trans:
for table in tables:
delete_stmt = delete(table)
print(f'delete_stmt: {delete_stmt}')
Expand All @@ -220,7 +221,7 @@ def clear_tables(self):

class RelationalDBSource(DataSource):
def __init__(
self, database_connection_or_url, *stix_object_classes, metadata=None,
self, db_backend, *stix_object_classes, metadata=None,
):
"""
Initialize this source. Only one of stix_object_classes and metadata
Expand All @@ -243,10 +244,7 @@ def __init__(
"""
super().__init__()

if isinstance(database_connection_or_url, str):
self.database_connection = create_engine(database_connection_or_url)
else:
self.database_connection = database_connection_or_url
self.db_backend = db_backend

if metadata:
self.metadata = metadata
Expand All @@ -257,7 +255,7 @@ def __init__(
)

def get(self, stix_id, version=None, _composite_filters=None):
with self.database_connection.connect() as conn:
with self.db_backend.connect() as conn:
stix_obj = read_object(
stix_id,
self.metadata,
Expand Down
7 changes: 4 additions & 3 deletions stix2/datastore/relational_db/relational_db_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from stix2.datastore.relational_db.relational_db import RelationalDBStore
import stix2.properties

from database_backends.postgres_backend import PostgresBackend

directory_stix_object = stix2.Directory(
path="/foo/bar/a",
path_enc="latin1",
Expand Down Expand Up @@ -251,14 +253,13 @@ def test_dictionary():

def main():
store = RelationalDBStore(
"postgresql://localhost/stix-data-sink",
PostgresBackend("postgresql://localhost/stix-data-sink"),
False,
None,
True,
True,
)

if store.sink.database_exists:
if store.sink.db_backend.database_exists:
store.sink.generate_stix_schema()
store.sink.clear_tables()

Expand Down

0 comments on commit f1af0dd

Please sign in to comment.