Skip to content

Commit

Permalink
Change get_src_ip_continent and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
thinkst-tom committed Jul 24, 2024
1 parent 7ec9d57 commit c365413
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 23 deletions.
5 changes: 3 additions & 2 deletions canarytokens/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,8 +1326,9 @@ def get_additional_data_for_notification(self) -> Dict[str, Any]:
if self.src_data and key in self.src_data:
self.src_data[replacement] = self.src_data[key]

continent = get_src_ip_continent(additional_data)
additional_data["geo_info"]["continent"] = continent
if additional_data.get("geo_info") is not None:
continent = get_src_ip_continent(additional_data["geo_info"])
additional_data["geo_info"]["continent"] = continent

time = datetime.utcnow()
additional_data["time_hm"] = time.strftime("%H:%M")
Expand Down
6 changes: 3 additions & 3 deletions canarytokens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,16 @@ def get_deployed_commit_sha(commit_sha_file: Path = Path("/COMMIT_SHA")):
# return inner


def get_src_ip_continent(additional_data: dict) -> str:
def get_src_ip_continent(geo_data: dict) -> str:
"""Helper function that returns the continent of country given it's ISO 3166-2 code.
Args:
additional_data (dict): The "country" key contains an ISO 3166-2 code
geo_data (dict): The "country" key contains an ISO 3166-2 code
Returns:
str: A two character code representing a continent
"""
country = additional_data.get("geo_info", {}).get("country")
country = geo_data.get("country")
if country is not None:
# AQ is the ISO 3166-2 code for Antarctica, and is returned from IPinfo,
# but it's not included in pycountry_convert.
Expand Down
25 changes: 20 additions & 5 deletions tests/units/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
DNSTokenRequest,
DownloadContentTypes,
DownloadMSWordResponse,
GeoIPBogonInfo,
GeoIPInfo,
LegacyTokenHistory,
LegacyTokenHit,
Expand Down Expand Up @@ -476,11 +475,19 @@ def test_all_requests_have_a_response():
WebBugTokenHit,
{
"useragent": "python 3.10",
"geo_info": GeoIPBogonInfo(ip="127.0.0.1", bogon=True),
"geo_info": {
"ip": "127.0.0.1",
"bogon": True,
"continent": "NO_CONTINENT",
},
},
{
"useragent": "python 3.10",
"geo_info": GeoIPBogonInfo(ip="127.0.0.1", bogon=True),
"geo_info": {
"ip": "127.0.0.1",
"bogon": True,
"continent": "NO_CONTINENT",
},
},
),
(
Expand Down Expand Up @@ -584,11 +591,19 @@ def test_get_additional_data_for_email(history_type, hit_type, seed_data):
(
{
"useragent": "python 3.10",
"geo_info": GeoIPBogonInfo(ip="127.0.0.1", bogon=True),
"geo_info": {
"ip": "127.0.0.1",
"bogon": True,
"continent": "NO_CONTINENT",
},
},
{
"useragent": "python 3.10",
"geo_info": GeoIPBogonInfo(ip="127.0.0.1", bogon=True),
"geo_info": {
"ip": "127.0.0.1",
"bogon": True,
"continent": "NO_CONTINENT",
},
},
),
(
Expand Down
26 changes: 13 additions & 13 deletions tests/units/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ def test_coerce_to_float():


@pytest.mark.parametrize(
"additional_data, continent",
"geo_info, continent",
[
({"geo_info": {"country": "ZA"}}, "AF"),
({"geo_info": {"country": "AQ"}}, "AN"),
({"geo_info": {"country": "CN"}}, "AS"),
({"geo_info": {"country": "GB"}}, "EU"),
({"geo_info": {"country": "US"}}, "NA"),
({"geo_info": {"country": "AU"}}, "OC"),
({"geo_info": {"country": "AR"}}, "SA"),
({"geo_info": {"country": "Mordor"}}, "NO_CONTINENT"),
({"geo_info": {"bogon": True}}, "NO_CONTINENT"),
({"geo_info": {}}, "NO_CONTINENT"),
({"country": "ZA"}, "AF"),
({"country": "AQ"}, "AN"),
({"country": "CN"}, "AS"),
({"country": "GB"}, "EU"),
({"country": "US"}, "NA"),
({"country": "AU"}, "OC"),
({"country": "AR"}, "SA"),
({"country": "Mordor"}, "NO_CONTINENT"),
({"bogon": True}, "NO_CONTINENT"),
({}, "NO_CONTINENT"),
],
)
def test_get_src_ip_continent(additional_data, continent):
assert continent == get_src_ip_continent(additional_data)
def test_get_src_ip_continent(geo_info, continent):
assert continent == get_src_ip_continent(geo_info)

0 comments on commit c365413

Please sign in to comment.