Skip to content

Commit

Permalink
Hash 3PID lookups (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
anoadragon453 authored Aug 20, 2019
1 parent 2ccded0 commit 81f0de7
Show file tree
Hide file tree
Showing 16 changed files with 586 additions and 26 deletions.
125 changes: 125 additions & 0 deletions sydent/db/hashing_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-

# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Actions on the hashing_metadata table which is defined in the migration process in
# sqlitedb.py

class HashingMetadataStore:
def __init__(self, sydent):
self.sydent = sydent

def get_lookup_pepper(self):
"""Return the value of the current lookup pepper from the db
:returns a pepper if it exists in the database, or None if one does
not exist
"""
cur = self.sydent.db.cursor()
res = cur.execute("select lookup_pepper from hashing_metadata")
row = res.fetchone()

if not row:
return None
return row[0]

def store_lookup_pepper(self, hashing_function, pepper):
"""Stores a new lookup pepper in the hashing_metadata db table and rehashes all 3PIDs
:param hashing_function: A function with single input and output strings
:type hashing_function func(str) -> str
:param pepper: The pepper to store in the database
:type pepper: str
"""
cur = self.sydent.db.cursor()

# Create or update lookup_pepper
sql = (
'INSERT OR REPLACE INTO hashing_metadata (id, lookup_pepper) '
'VALUES (0, ?)'
)
cur.execute(sql, (pepper,))

# Hand the cursor to each rehashing function
# Each function will queue some rehashing db transactions
self._rehash_threepids(cur, hashing_function, pepper, "local_threepid_associations")
self._rehash_threepids(cur, hashing_function, pepper, "global_threepid_associations")

# Commit the queued db transactions so that adding a new pepper and hashing is atomic
self.sydent.db.commit()

def _rehash_threepids(self, cur, hashing_function, pepper, table):
"""Rehash 3PIDs of a given table using a given hashing_function and pepper
A database cursor `cur` must be passed to this function. After this function completes,
the calling function should make sure to call self`self.sydent.db.commit()` to commit
the made changes to the database.
:param cur: Database cursor
:type cur:
:param hashing_function: A function with single input and output strings
:type hashing_function func(str) -> str
:param pepper: A pepper to append to the end of the 3PID (after a space) before hashing
:type pepper: str
:param table: The database table to perform the rehashing on
:type table: str
"""

# Get count of all 3PID records
# Medium/address combos are marked as UNIQUE in the database
sql = "SELECT COUNT(*) FROM %s" % table
res = cur.execute(sql)
row_count = res.fetchone()
row_count = row_count[0]

# Iterate through each medium, address combo, hash it,
# and store in the db
batch_size = 500
count = 0
while count < row_count:
sql = (
"SELECT medium, address FROM %s ORDER BY id LIMIT %s OFFSET %s" %
(table, batch_size, count)
)
res = cur.execute(sql)
rows = res.fetchall()

for medium, address in rows:
# Skip broken db entry
if not medium or not address:
continue

# Combine the medium, address and pepper together in the
# following form: "address medium pepper"
# According to MSC2134: https://github.com/matrix-org/matrix-doc/pull/2134
combo = "%s %s %s" % (address, medium, pepper)

# Hash the resulting string
result = hashing_function(combo)

# Save the result to the DB
sql = (
"UPDATE %s SET lookup_hash = ? "
"WHERE medium = ? AND address = ?"
% table
)
# Lines up the query to be executed on commit
cur.execute(sql, (result, medium, address))

count += len(rows)
36 changes: 36 additions & 0 deletions sydent/db/sqlitedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _upgradeSchema(self):
self.db.commit()
logger.info("v0 -> v1 schema migration complete")
self._setSchemaVersion(1)

if curVer < 2:
logger.info("Migrating schema from v1 to v2")
cur = self.db.cursor()
Expand All @@ -140,6 +141,41 @@ def _upgradeSchema(self):
logger.info("v1 -> v2 schema migration complete")
self._setSchemaVersion(2)

if curVer < 3:
cur = self.db.cursor()

# Add lookup_hash columns to threepid association tables
cur.execute(
"ALTER TABLE local_threepid_associations "
"ADD COLUMN lookup_hash VARCHAR(256)"
)
cur.execute(
"CREATE INDEX IF NOT EXISTS lookup_hash_medium "
"on local_threepid_associations "
"(lookup_hash, medium)"
)
cur.execute(
"ALTER TABLE global_threepid_associations "
"ADD COLUMN lookup_hash VARCHAR(256)"
)
cur.execute(
"CREATE INDEX IF NOT EXISTS lookup_hash_medium "
"on global_threepid_associations "
"(lookup_hash, medium)"
)

# Create hashing_metadata table to store the current lookup_pepper
cur.execute(
"CREATE TABLE IF NOT EXISTS hashing_metadata ("
"id integer primary key, "
"lookup_pepper varchar(256)"
")"
)

self.db.commit()
logger.info("v2 -> v3 schema migration complete")
self._setSchemaVersion(3)

def _getSchemaVersion(self):
cur = self.db.cursor()
res = cur.execute("PRAGMA user_version");
Expand Down
54 changes: 42 additions & 12 deletions sydent/db/threepid_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def addOrUpdateAssociation(self, assoc):

# sqlite's support for upserts is atrocious
cur.execute("insert or replace into local_threepid_associations "
"('medium', 'address', 'mxid', 'ts', 'notBefore', 'notAfter')"
" values (?, ?, ?, ?, ?, ?)",
(assoc.medium, assoc.address, assoc.mxid, assoc.ts, assoc.not_before, assoc.not_after))
"('medium', 'address', 'lookup_hash', 'mxid', 'ts', 'notBefore', 'notAfter')"
" values (?, ?, ?, ?, ?, ?, ?)",
(assoc.medium, assoc.address, assoc.lookup_hash, assoc.mxid, assoc.ts, assoc.not_before, assoc.not_after))
self.sydent.db.commit()

def getAssociationsAfterId(self, afterId, limit):
Expand All @@ -45,7 +45,8 @@ def getAssociationsAfterId(self, afterId, limit):
if afterId is None:
afterId = -1

q = "select id, medium, address, mxid, ts, notBefore, notAfter from local_threepid_associations " \
q = "select id, medium, address, lookup_hash, mxid, ts, notBefore, notAfter from " \
"local_threepid_associations " \
"where id > ? order by id asc"
if limit is not None:
q += " limit ?"
Expand All @@ -58,7 +59,7 @@ def getAssociationsAfterId(self, afterId, limit):

assocs = {}
for row in res.fetchall():
assoc = ThreepidAssociation(row[1], row[2], row[3], row[4], row[5], row[6])
assoc = ThreepidAssociation(row[1], row[2], row[3], row[4], row[5], row[6], row[7])
assocs[row[0]] = assoc
maxId = row[0]

Expand Down Expand Up @@ -139,10 +140,20 @@ def getMxid(self, medium, address):
return row[0]

def getMxids(self, threepid_tuples):
"""Given a list of threepid_tuples, return the same list but with
mxids appended to each tuple for which a match was found in the
database for. Output is ordered by medium, address, timestamp DESC
:param threepid_tuples: List containing (medium, address) tuples
:type threepid_tuples: [(str, str)]
:returns a list of (medium, address, mxid) tuples
:rtype [(str, str, str)]
"""
cur = self.sydent.db.cursor()

cur.execute("CREATE TEMPORARY TABLE tmp_getmxids (medium VARCHAR(16), address VARCHAR(256))");
cur.execute("CREATE INDEX tmp_getmxids_medium_lower_address ON tmp_getmxids (medium, lower(address))");
cur.execute("CREATE TEMPORARY TABLE tmp_getmxids (medium VARCHAR(16), address VARCHAR(256))")
cur.execute("CREATE INDEX tmp_getmxids_medium_lower_address ON tmp_getmxids (medium, lower(address))")

try:
inserted_cap = 0
Expand Down Expand Up @@ -181,14 +192,13 @@ def getMxids(self, threepid_tuples):
def addAssociation(self, assoc, rawSgAssoc, originServer, originId, commit=True):
"""
:param assoc: (sydent.threepid.GlobalThreepidAssociation) The association to add as a high level object
:param sgAssoc The original raw bytes of the signed association
:return:
:param sgAssoc: The original raw bytes of the signed association
"""
cur = self.sydent.db.cursor()
res = cur.execute("insert or ignore into global_threepid_associations "
"(medium, address, mxid, ts, notBefore, notAfter, originServer, originId, sgAssoc) values "
"(?, ?, ?, ?, ?, ?, ?, ?, ?)",
(assoc.medium, assoc.address, assoc.mxid, assoc.ts, assoc.not_before, assoc.not_after,
"(medium, address, lookup_hash, mxid, ts, notBefore, notAfter, originServer, originId, sgAssoc) values "
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(assoc.medium, assoc.address, assoc.lookup_hash, assoc.mxid, assoc.ts, assoc.not_before, assoc.not_after,
originServer, originId, rawSgAssoc))
if commit:
self.sydent.db.commit()
Expand Down Expand Up @@ -216,3 +226,23 @@ def removeAssociation(self, medium, address):
cur.rowcount, medium, address,
)
self.sydent.db.commit()

def retrieveMxidFromHash(self, lookup_hash):
"""Returns an mxid from a given lookup_hash value
:param input_hash: The lookup_hash value to lookup in the database
:type input_hash: str
:returns the mxid relating to the lookup_hash value if found,
otherwise None
:rtype: str|None
"""
cur = self.sydent.db.cursor()

res = cur.execute(
"SELECT mxid FROM global_threepid_associations WHERE lookup_hash = ?", (lookup_hash,)
)
row = res.fetchone()
if not row:
return None
return row[0]
27 changes: 25 additions & 2 deletions sydent/db/threepid_associations.sql
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,33 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

CREATE TABLE IF NOT EXISTS local_threepid_associations (id integer primary key, medium varchar(16) not null, address varchar(256) not null, mxid varchar(256) not null, ts integer not null, notBefore bigint not null, notAfter bigint not null);
CREATE TABLE IF NOT EXISTS local_threepid_associations (
id integer primary key,
medium varchar(16) not null,
address varchar(256) not null,
lookup_hash varchar,
mxid varchar(256) not null,
ts integer not null,
notBefore bigint not null,
notAfter bigint not null
);
CREATE INDEX IF NOT EXISTS lookup_hash_medium on local_threepid_associations (lookup_hash, medium);
CREATE UNIQUE INDEX IF NOT EXISTS medium_address on local_threepid_associations(medium, address);

CREATE TABLE IF NOT EXISTS global_threepid_associations (id integer primary key, medium varchar(16) not null, address varchar(256) not null, mxid varchar(256) not null, ts integer not null, notBefore bigint not null, notAfter integer not null, originServer varchar(255) not null, originId integer not null, sgAssoc text not null);
CREATE TABLE IF NOT EXISTS global_threepid_associations (
id integer primary key,
medium varchar(16) not null,
address varchar(256) not null,
lookup_hash varchar,
mxid varchar(256) not null,
ts integer not null,
notBefore bigint not null,
notAfter integer not null,
originServer varchar(255) not null,
originId integer not null,
sgAssoc text not null
);
CREATE INDEX IF NOT EXISTS lookup_hash_medium on global_threepid_associations (lookup_hash, medium);
CREATE INDEX IF NOT EXISTS medium_address on global_threepid_associations (medium, address);
CREATE INDEX IF NOT EXISTS medium_lower_address on global_threepid_associations (medium, lower(address));
CREATE UNIQUE INDEX IF NOT EXISTS originServer_originId on global_threepid_associations (originServer, originId);
8 changes: 8 additions & 0 deletions sydent/http/httpserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, sydent):
identity = Resource()
api = Resource()
v1 = self.sydent.servlets.v1
v2 = self.sydent.servlets.v2

validate = Resource()
email = Resource()
Expand All @@ -51,6 +52,9 @@ def __init__(self, sydent):
lookup = self.sydent.servlets.lookup
bulk_lookup = self.sydent.servlets.bulk_lookup

hash_details = self.sydent.servlets.hash_details
lookup_v2 = self.sydent.servlets.lookup_v2

threepid = Resource()
bind = self.sydent.servlets.threepidBind
unbind = self.sydent.servlets.threepidUnbind
Expand All @@ -63,6 +67,7 @@ def __init__(self, sydent):
root.putChild('_matrix', matrix)
matrix.putChild('identity', identity)
identity.putChild('api', api)
identity.putChild('v2', v2)
api.putChild('v1', v1)

v1.putChild('validate', validate)
Expand Down Expand Up @@ -93,6 +98,9 @@ def __init__(self, sydent):

v1.putChild('sign-ed25519', self.sydent.servlets.blindlySignStuffServlet)

v2.putChild('lookup', lookup_v2)
v2.putChild('hash_details', hash_details)

self.factory = Site(root)
self.factory.displayTracebacks = False

Expand Down
4 changes: 2 additions & 2 deletions sydent/http/servlets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

def get_args(request, required_args):
"""
Helper function to get arguments for an HTTP request
Helper function to get arguments for an HTTP request.
Currently takes args from the top level keys of a json object or
www-form-urlencoded for backwards compatability.
Returns a tuple (error, args) where if error is non-null,
the requesat is malformed. Otherwise, args contains the
the request is malformed. Otherwise, args contains the
parameters passed.
"""
args = None
Expand Down
Loading

0 comments on commit 81f0de7

Please sign in to comment.