Skip to content

Commit

Permalink
atproto_firehose.subscribe bug fix: allow follows of protocol bot use…
Browse files Browse the repository at this point in the history
  • Loading branch information
snarfed committed Sep 23, 2024
1 parent 3ebcaac commit f83e372
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
25 changes: 21 additions & 4 deletions atproto_firehose.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@
global_cache,
global_cache_policy,
global_cache_timeout_policy,
PROTOCOL_DOMAINS,
report_error,
report_exception,
USER_AGENT,
)
from models import Object, reset_protocol_properties
from web import Web

logger = logging.getLogger(__name__)

Expand All @@ -62,13 +64,24 @@
atproto_loaded_at = datetime(1900, 1, 1)
bridged_dids = set()
bridged_loaded_at = datetime(1900, 1, 1)
protocol_bot_dids = None
dids_initialized = Event()


def load_dids():
# run in a separate thread since it needs to make its own NDB
# context when it runs in the timer thread
Thread(target=_load_dids).start()

global protocol_bot_dids
protocol_bot_dids = set()
bot_keys = [Web(id=domain).key for domain in PROTOCOL_DOMAINS]
for bot in ndb.get_multi(bot_keys):
if bot:
if did := bot.get_copy(ATProto):
logger.info(f'Loaded protocol bot user {bot.key.id()} {did}')
protocol_bot_dids.add(did)

dids_initialized.wait()
dids_initialized.clear()

Expand Down Expand Up @@ -188,9 +201,6 @@ def subscribe():
# when running locally, comment out put above and uncomment this
# cursor.updated = util.now().replace(tzinfo=None)

if payload['repo'] not in atproto_dids:
continue

blocks = {} # maps base32 str CID to dict block
if block_bytes := payload.get('blocks'):
_, blocks = libipld.decode_car(block_bytes)
Expand All @@ -204,7 +214,7 @@ def subscribe():
f'bad payload! seq {op.seq} action {op.action} path {op.path}!')
continue

if op.action == 'delete':
if op.repo in atproto_dids and op.action == 'delete':
logger.info(f'Got delete from our ATProto user: {op}')
# TODO: also detect deletes of records that *reference* our bridged
# users, eg a delete of a follow or like or repost of them.
Expand All @@ -228,6 +238,13 @@ def subscribe():
elif type not in ATProto.SUPPORTED_RECORD_TYPES:
continue

# generally we only want records from bridged Bluesky users. the one
# exception is follows of protocol bot users.
if (op.repo not in atproto_dids
and not (type == 'app.bsky.graph.follow'
and op.record['subject'] in protocol_bot_dids)):
continue

def is_ours(ref, also_atproto_users=False):
"""Returns True if the arg is a bridge user."""
if match := AT_URI_PATTERN.match(ref['uri']):
Expand Down
19 changes: 16 additions & 3 deletions tests/test_atproto_firehose.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
import atproto_firehose
from atproto_firehose import commits, handle, Op, STORE_CURSOR_FREQ
import common
from models import Object
from models import Object, Target
import protocol
from .testutil import TestCase
from .test_atproto import DID_DOC
from web import Web

A_CID = CID.decode('bafkreicqpqncshdd27sgztqgzocd3zhhqnnsv6slvzhs5uz6f57cq6lmtq')

Expand Down Expand Up @@ -109,7 +110,7 @@ def setUp(self):
atproto_firehose.bridged_loaded_at = datetime(1900, 1, 1)
atproto_firehose.dids_initialized.clear()

self.make_bridged_atproto_user()
self.user = self.make_bridged_atproto_user()
AtpRepo(id='did:alice', head='', signing_key_pem=b'').put()
self.store_object(id='did:plc:bob', raw=DID_DOC)
ATProto(id='did:plc:bob').put()
Expand Down Expand Up @@ -295,6 +296,19 @@ def test_follow_of_other(self):
'subject': 'did:eve',
})

def test_follow_of_protocol_bot_account_by_unbridged_user(self):
self.user.enabled_protocols = []
self.user.put()

self.make_user('fa.brid.gy', cls=Web, enabled_protocols=['atproto'],
copies=[Target(protocol='atproto', uri='did:fa')])
AtpRepo(id='did:fa', head='', signing_key_pem=b'').put()

self.assert_enqueues({
'$type': 'app.bsky.graph.follow',
'subject': 'did:fa',
})

def test_block_of_our_user(self):
self.assert_enqueues({
'$type': 'app.bsky.graph.block',
Expand Down Expand Up @@ -373,7 +387,6 @@ def test_load_dids_updated_atproto_user(self):
self.assertIn('did:plc:eve', atproto_firehose.atproto_dids)

def test_load_dids_atprepo(self):

FakeWebsocketClient.to_receive = [({'op': 1, 't': '#info'}, {})]
self.subscribe()

Expand Down

0 comments on commit f83e372

Please sign in to comment.