Skip to content

Commit

Permalink
Merge pull request #232 from ninoseki/renew-factories
Browse files Browse the repository at this point in the history
refactor: renew factories (inject dependencies via __init__)
  • Loading branch information
ninoseki authored May 25, 2024
2 parents 157bce0 + b50ccfb commit 7295e94
Show file tree
Hide file tree
Showing 16 changed files with 129 additions and 102 deletions.
6 changes: 2 additions & 4 deletions backend/factories/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@


class AbstractFactory(ABC):
@classmethod
@abstractmethod
def call(cls, *args: typing.Any, **kwargs: typing.Any):
def call(self, *args: typing.Any, **kwargs: typing.Any):
raise NotImplementedError()


class AbstractAsyncFactory(ABC):
@classmethod
@abstractmethod
async def call(cls, *args: typing.Any, **kwargs: typing.Any):
async def call(self, *args: typing.Any, **kwargs: typing.Any):
raise NotImplementedError()
22 changes: 11 additions & 11 deletions backend/factories/emailrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@

from .abstract import AbstractAsyncFactory

NAME_OR_KEY = "EmailRep"


@future_safe
async def lookup(email: str, *, client: clients.EmailRep) -> schemas.EmailRepLookup:
return await client.lookup(email)


@future_safe
async def transform(lookup: schemas.EmailRepLookup, *, name_or_key: str = NAME_OR_KEY):
async def transform(lookup: schemas.EmailRepLookup, *, key_or_name: str):
details: list[schemas.VerdictDetail] = []
malicious = False

Expand All @@ -28,18 +26,20 @@ async def transform(lookup: schemas.EmailRepLookup, *, name_or_key: str = NAME_O
malicious = True
description = f"{lookup.email} is suspicious. See https://emailrep.io/{lookup.email} for details."

details.append(schemas.VerdictDetail(key=name_or_key, description=description))
return schemas.Verdict(name=name_or_key, malicious=malicious, details=details)
details.append(schemas.VerdictDetail(key=key_or_name, description=description))
return schemas.Verdict(name=key_or_name, malicious=malicious, details=details)


class EmailRepVerdictFactory(AbstractAsyncFactory):
@classmethod
async def call(
cls, email: str, *, client: clients.EmailRep, name_or_key: str = NAME_OR_KEY
) -> schemas.Verdict:
def __init__(self, client: clients.EmailRep, *, name: str = "EmailRep"):
self.client = client
self.name = name

async def call(self, email: str, key: str | None = None) -> schemas.Verdict:
key_or_name: str = key or self.name
f_result: FutureResultE[schemas.Verdict] = flow(
lookup(email, client=client),
bind(partial(transform, name_or_key=name_or_key)),
lookup(email, client=self.client),
bind(partial(transform, key_or_name=key_or_name)),
)
result = await f_result.awaitable()
return unsafe_perform_io(result.alt(raise_exception).unwrap())
3 changes: 1 addition & 2 deletions backend/factories/eml.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ def transform(parsed: dict) -> schemas.Eml:


class EmlFactory(AbstractFactory):
@classmethod
def call(cls, data: bytes) -> schemas.Eml:
def call(self, data: bytes) -> schemas.Eml:
result: ResultE[schemas.Eml] = flow(
to_eml(data),
bind(parse),
Expand Down
17 changes: 8 additions & 9 deletions backend/factories/inquest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

from backend import clients, schemas, settings, types

NAME = "InQuest"


@future_safe
async def lookup(sha256: str, *, client: clients.InQuest) -> schemas.InQuestLookup:
Expand All @@ -36,7 +34,7 @@ async def bulk_lookup(


@future_safe
async def transform(lookups: list[schemas.InQuestLookup], *, name: str = NAME):
async def transform(lookups: list[schemas.InQuestLookup], *, name: str):
malicious_lookups = [lookup for lookup in lookups if lookup.malicious]

if len(malicious_lookups) == 0:
Expand Down Expand Up @@ -67,24 +65,25 @@ async def transform(lookups: list[schemas.InQuestLookup], *, name: str = NAME):


class InQuestVerdictFactory:
@classmethod
def __init__(self, client: clients.InQuest, *, name: str = "InQuest"):
self.client = client
self.name = name

async def call(
cls,
self,
sha256s: types.ListSet[str],
*,
client: clients.InQuest,
name: str = NAME,
max_per_second: float | None = settings.ASYNC_MAX_PER_SECOND,
max_at_once: int | None = settings.ASYNC_MAX_AT_ONCE,
) -> schemas.Verdict:
f_result: FutureResultE[schemas.Verdict] = flow(
bulk_lookup(
sha256s,
client=client,
client=self.client,
max_at_once=max_at_once,
max_per_second=max_per_second,
),
bind(partial(transform, name=name)),
bind(partial(transform, name=self.name)),
)
result = await f_result.awaitable()
return unsafe_perform_io(result.alt(raise_exception).unwrap())
11 changes: 6 additions & 5 deletions backend/factories/oldid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

from .abstract import AbstractFactory

NAME = "oleid"


@safe
def parse(attachment: schemas.Attachment) -> OleID:
Expand Down Expand Up @@ -74,9 +72,12 @@ def inner(oleid: OleID):


class OleIDVerdictFactory(AbstractFactory):
@classmethod
def __init__(self, name: str = "oleid"):
self.name = name

def call(
cls, attachments: list[schemas.Attachment], *, name: str = NAME
self,
attachments: list[schemas.Attachment],
) -> schemas.Verdict:
details = list(
itertools.chain.from_iterable(
Expand All @@ -95,4 +96,4 @@ def call(
description="There is no suspicious OLE file in attachments.",
)
)
return schemas.Verdict(name=name, malicious=malicious, details=details)
return schemas.Verdict(name=self.name, malicious=malicious, details=details)
18 changes: 8 additions & 10 deletions backend/factories/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,46 +28,46 @@ def log_exception(exception: Exception):
@future_safe
async def parse(eml_file: bytes) -> schemas.Response:
return schemas.Response(
eml=EmlFactory.call(eml_file), id=hashlib.sha256(eml_file).hexdigest()
eml=EmlFactory().call(eml_file), id=hashlib.sha256(eml_file).hexdigest()
)


@future_safe
async def get_spam_assassin_verdict(
eml_file: bytes, *, client: clients.SpamAssassin
) -> schemas.Verdict:
return await SpamAssassinVerdictFactory.call(eml_file, client=client)
return await SpamAssassinVerdictFactory(client).call(eml_file)


@future_safe
async def get_oleid_verdict(attachments: list[schemas.Attachment]) -> schemas.Verdict:
return OleIDVerdictFactory.call(attachments)
return OleIDVerdictFactory().call(attachments)


@future_safe
async def get_email_rep_verdicts(from_, *, client: clients.EmailRep) -> schemas.Verdict:
return await EmailRepVerdictFactory.call(from_, client=client)
return await EmailRepVerdictFactory(client).call(from_)


@future_safe
async def get_urlscan_verdict(
urls: types.ListSet[str], *, client: clients.UrlScan
) -> schemas.Verdict:
return await UrlScanVerdictFactory.call(urls, client=client)
return await UrlScanVerdictFactory(client).call(urls)


@future_safe
async def get_inquest_verdict(
sha256s: types.ListSet[str], *, client: clients.InQuest
) -> schemas.Verdict:
return await InQuestVerdictFactory.call(sha256s, client=client)
return await InQuestVerdictFactory(client).call(sha256s)


@future_safe
async def get_vt_verdict(
sha256s: types.ListSet[str], *, client: clients.VirusTotal
) -> schemas.Verdict:
return await VirusTotalVerdictFactory.call(sha256s, client=client)
return await VirusTotalVerdictFactory(client).call(sha256s)


@future_safe
Expand Down Expand Up @@ -109,9 +109,7 @@ async def set_verdicts(
return response


class ResponseFactory(
AbstractAsyncFactory,
):
class ResponseFactory(AbstractAsyncFactory):
@classmethod
async def call(
cls,
Expand Down
17 changes: 11 additions & 6 deletions backend/factories/spamassassin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

from returns.functions import raise_exception
from returns.future import FutureResultE, future_safe
from returns.pipeline import flow
Expand All @@ -8,8 +10,6 @@

from .abstract import AbstractAsyncFactory

NAME = "SpamAssassin"


@future_safe
async def report(
Expand All @@ -20,7 +20,7 @@ async def report(

@future_safe
async def transform(
report: schemas.SpamAssassinReport, *, name: str = NAME
report: schemas.SpamAssassinReport, *, name: str
) -> schemas.Verdict:
details = [
schemas.VerdictDetail(
Expand All @@ -39,12 +39,17 @@ async def transform(


class SpamAssassinVerdictFactory(AbstractAsyncFactory):
@classmethod
def __init__(self, client: clients.SpamAssassin, *, name: str = "SpamAssassin"):
self.client = client
self.name = name

async def call(
cls, eml_file: bytes, *, client: clients.SpamAssassin
self,
eml_file: bytes,
) -> schemas.Verdict:
f_result: FutureResultE[schemas.Verdict] = flow(
report(eml_file, client=client), bind(transform)
report(eml_file, client=self.client),
bind(partial(transform, name=self.name)),
)
result = await f_result.awaitable()
return unsafe_perform_io(result.alt(raise_exception).unwrap())
17 changes: 8 additions & 9 deletions backend/factories/urlscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

from .abstract import AbstractAsyncFactory

NAME = "urlscan.io"


@future_safe
async def lookup(url: str, *, client: clients.UrlScan) -> schemas.UrlScanLookup:
Expand All @@ -39,7 +37,7 @@ async def bulk_lookup(


@future_safe
async def transform(lookups: list[schemas.UrlScanLookup], *, name: str = NAME):
async def transform(lookups: list[schemas.UrlScanLookup], *, name: str):
results = itertools.chain.from_iterable([lookup.results for lookup in lookups])
malicious_results = [result for result in results if result.verdicts.malicious]

Expand Down Expand Up @@ -71,24 +69,25 @@ async def transform(lookups: list[schemas.UrlScanLookup], *, name: str = NAME):


class UrlScanVerdictFactory(AbstractAsyncFactory):
@classmethod
def __init__(self, client: clients.UrlScan, *, name: str = "urlscan.io"):
self.client = client
self.name = name

async def call(
cls,
self,
urls: types.ListSet[str],
*,
client: clients.UrlScan,
name: str = NAME,
max_per_second: float | None = settings.ASYNC_MAX_PER_SECOND,
max_at_once: int | None = settings.ASYNC_MAX_AT_ONCE,
):
f_result: FutureResultE[schemas.Verdict] = flow(
bulk_lookup(
urls,
client=client,
client=self.client,
max_at_once=max_at_once,
max_per_second=max_per_second,
),
bind(partial(transform, name=name)),
bind(partial(transform, name=self.name)),
)
result = await f_result.awaitable()
return unsafe_perform_io(result.alt(raise_exception).unwrap())
13 changes: 7 additions & 6 deletions backend/factories/virustotal.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,25 @@ async def transform(objects: list[vt.Object], *, name: str = NAME) -> schemas.Ve


class VirusTotalVerdictFactory(AbstractAsyncFactory):
@classmethod
def __init__(self, client: clients.VirusTotal, *, name: str = "VirusTotal"):
self.client = client
self.name = name

async def call(
cls,
self,
sha256s: types.ListSet[str],
*,
client: clients.VirusTotal,
name: str = NAME,
max_per_second: float | None = settings.ASYNC_MAX_PER_SECOND,
max_at_once: int | None = settings.ASYNC_MAX_AT_ONCE,
) -> schemas.Verdict:
f_result: FutureResultE[schemas.Verdict] = flow(
bulk_get_file_objects(
sha256s,
client=client,
client=self.client,
max_at_once=max_at_once,
max_per_second=max_per_second,
),
bind(partial(transform, name=name)),
bind(partial(transform, name=self.name)),
)
result = await f_result.awaitable()
return unsafe_perform_io(result.alt(raise_exception).unwrap())
17 changes: 8 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,20 @@ async def is_spam_assassin_responsive(port: int) -> bool:
return False


if not ci.is_ci():
if ci.is_ci():

@pytest.fixture(scope="session", autouse=True)
def docker_compose(docker_ip: str, docker_services: Services): # type: ignore
def docker_compose(): # type: ignore
return
else:

@pytest.fixture(scope="session", autouse=True)
def docker_compose(docker_ip: str, docker_services: Services):
port = docker_services.port_for("spamassassin", 783)
docker_services.wait_until_responsive(
timeout=60.0, pause=0.1, check=lambda: is_spam_assassin_responsive(port)
)

else:

@pytest.fixture
def docker_compose():
return


@pytest.fixture
def spam_assassin() -> clients.SpamAssassin:
Expand Down Expand Up @@ -119,7 +118,7 @@ def test_html() -> str:

@pytest.fixture
def docx_attachment(encrypted_docx_eml: bytes) -> schemas.Attachment:
eml = factories.EmlFactory.call(encrypted_docx_eml)
eml = factories.EmlFactory().call(encrypted_docx_eml)
return eml.attachments[0]


Expand Down
Loading

0 comments on commit 7295e94

Please sign in to comment.