Skip to content

Commit

Permalink
[PAY-747] Handle solana-nft-gated premium content (#4480)
Browse files Browse the repository at this point in the history
* Update track schema

* Add sol nft collection ownership check

* Remove unused imports

* Check for None

* Clean up

* Add comment

* Use asyncio

* Increase the max number of pool executors

* Return track ids as signature map keys

* Remove Any

Co-authored-by: Saliou Diallo <saliou@audius.co>
  • Loading branch information
sddioulde and Saliou Diallo authored Dec 22, 2022
1 parent 44fa496 commit 8c3013e
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 11 deletions.
206 changes: 200 additions & 6 deletions discovery-provider/src/queries/get_premium_track_signatures.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import asyncio
import base64
import concurrent.futures
import json
import logging
import pathlib
import struct
from collections import defaultdict
from typing import Dict, List, Set

import base58
from eth_typing import ChecksumAddress
from solana.publickey import PublicKey
from sqlalchemy.orm.session import Session
from src.models.tracks.track import Track
from src.models.users.user import User
Expand All @@ -14,14 +19,19 @@
)
from src.premium_content.signature import get_premium_content_signature_for_user
from src.queries.get_associated_user_wallet import get_associated_user_wallet
from src.solana.solana_client_manager import SolanaClientManager
from src.solana.solana_helpers import METADATA_PROGRAM_ID_PK
from src.utils import db_session, web3_provider
from src.utils.config import shared_config
from web3 import Web3

logger = logging.getLogger(__name__)

erc721_abi = None
erc1155_abi = None

solana_client_manager = None

eth_web3 = web3_provider.get_eth_web3()


Expand Down Expand Up @@ -91,6 +101,7 @@ def _get_nft_gated_tracks(track_ids: List[int], session: Session):
return list(
filter(
lambda track: track.is_premium # type: ignore
and track.premium_conditions != None # type: ignore
and "nft_collection" in track.premium_conditions, # type: ignore
_get_tracks(track_ids, session),
)
Expand All @@ -104,6 +115,7 @@ def _get_eth_nft_gated_track_signatures(
track_token_id_map: Dict[int, List[str]],
):
track_signature_map = {}
track_cid_to_id_map = {}

user_eth_wallets = list(
map(Web3.toChecksumAddress, eth_associated_wallets + [user_wallet])
Expand All @@ -126,6 +138,7 @@ def _get_eth_nft_gated_track_signatures(
track.premium_conditions["nft_collection"]["address"] # type: ignore
)
erc721_collection_track_map[contract_address].append(track.track_cid)
track_cid_to_id_map[track.track_cid] = track.track_id

erc1155_gated_tracks = list(
filter(
Expand All @@ -152,6 +165,7 @@ def _get_eth_nft_gated_track_signatures(
contract_address_token_id_map[contract_address] = contract_address_token_id_map[
contract_address
].union(track_token_id_set)
track_cid_to_id_map[track.track_cid] = track.track_id

with concurrent.futures.ThreadPoolExecutor() as executor:
# Check ownership of nfts from erc721 collections from given contract addresses,
Expand All @@ -171,8 +185,9 @@ def _get_eth_nft_gated_track_signatures(
# nft collection is owned by the user.
if future.result():
for track_cid in erc721_collection_track_map[contract_address]:
track_id = track_cid_to_id_map[track_cid]
track_signature_map[
track_cid
track_id
] = get_premium_content_signature_for_user(
{
"id": track_cid,
Expand All @@ -183,7 +198,7 @@ def _get_eth_nft_gated_track_signatures(
)
except Exception as e:
logger.error(
f"Could not future result for erc721 contract_address {contract_address}. Error: {e}"
f"Could not get future result for erc721 contract_address {contract_address}. Error: {e}"
)

# Check ownership of nfts from erc1155 collections from given contract addresses,
Expand All @@ -206,8 +221,9 @@ def _get_eth_nft_gated_track_signatures(
# nft collection is owned by the user.
if future.result():
for track_cid in erc1155_collection_track_map[contract_address]:
track_id = track_cid_to_id_map[track_cid]
track_signature_map[
track_cid
track_id
] = get_premium_content_signature_for_user(
{
"id": track_cid,
Expand All @@ -218,19 +234,191 @@ def _get_eth_nft_gated_track_signatures(
)
except Exception as e:
logger.error(
f"Could not future result for erc1155 contract_address {contract_address}. Error: {e}"
f"Could not get future result for erc1155 contract_address {contract_address}. Error: {e}"
)

return track_signature_map


# todo: this will be implemented later
# Extended and simplified based on the reference links below
# https://docs.metaplex.com/programs/token-metadata/accounts#metadata
# https://github.com/metaplex-foundation/python-api/blob/441c2ba9be76962d234d7700405358c72ee1b35b/metaplex/metadata.py#L123
def _unpack_metadata_account_for_metaplex_nft(data):
assert data[0] == 4
i = 1 # key
i += 32 # update authority
i += 32 # mint
i += 36 # name
i += 14 # symbol
i += 204 # uri
i += 2 # seller fee basis points
has_creator = data[i]
i += 1 # whether has creators
if has_creator:
creator_len = struct.unpack("<I", data[i : i + 4])[0]
i += 4 # num creators
for _ in range(creator_len):
i += 32 # creator address
i += 1 # creator verified
i += 1 # creator share
i += 1 # primary sale happened
i += 1 # is mutable
i += 2 # edition nonce
i += 2 # token standard
has_collection = data[i]
if not has_collection:
return {"collection": None}

i += 1 # whether has collection
collection_verified = bool(data[i])
i += 1 # collection verified
collection_key = base58.b58encode(
bytes(struct.unpack("<" + "B" * 32, data[i : i + 32]))
)
return {"collection": {"verified": collection_verified, "key": collection_key}}


def _get_metadata_account(mint_address: str):
return PublicKey.find_program_address(
[
b"metadata",
bytes(METADATA_PROGRAM_ID_PK),
bytes(PublicKey(mint_address)), # type: ignore
],
METADATA_PROGRAM_ID_PK,
)[0]


def _get_token_account_info(token_account):
return token_account["account"]["data"]["parsed"]["info"]


def _decode_metadata_account(metadata_account):
return base64.b64decode(
solana_client_manager.get_account_info(metadata_account)["value"]["data"][0]
)


async def _wrap_decode_metadata_account(metadata_account):
loop = asyncio.get_running_loop()
with concurrent.futures.ThreadPoolExecutor(max_workers=100) as pool:
result = await loop.run_in_executor(
pool, _decode_metadata_account, metadata_account
)
return result


async def _decode_metadata_accounts_async(metadata_accounts):
datas = await asyncio.gather(*map(_wrap_decode_metadata_account, metadata_accounts))
return datas


# - Fet and parse token accounts from given wallets to get the mint addresses
# - Filter out token accounts with positive amounts and whose decimal places are not 0
# - Find the metadata PDAs for the mint addresses
# - Get the account infos for the PDAs if they exist
# - Unpack the chain metadatas from the account infos
# - Verify that the nft is from a verified collection whose mint address is the same as that passed into the function
# - If so, then user owns nft from that collection
def _does_user_own_sol_nft_collection(
collection_mint_address: str, user_sol_wallets: List[str]
):
if not solana_client_manager:
return False

for wallet in user_sol_wallets:
try:
result = solana_client_manager.get_token_accounts_by_owner(wallet)
token_accounts = result["value"]
nft_token_accounts = list(
filter(
lambda item: _get_token_account_info(item)["tokenAmount"]["amount"]
!= "0"
and _get_token_account_info(item)["tokenAmount"]["decimals"] == 0,
token_accounts,
)
)
nft_mints = list(
map(
lambda item: _get_token_account_info(item)["mint"],
nft_token_accounts,
)
)
metadata_accounts = list(map(_get_metadata_account, nft_mints))
datas = asyncio.run(_decode_metadata_accounts_async(metadata_accounts))
metadatas = list(map(_unpack_metadata_account_for_metaplex_nft, datas))
collections = list(map(lambda metadata: metadata["collection"], metadatas))
has_collection_mint_address = list(
filter(
lambda collection: collection
and collection["verified"]
and collection["key"].decode() == collection_mint_address,
collections,
)
)
if has_collection_mint_address:
return True
except Exception as e:
logger.error(
f"Could not get nft balance for nft collection {collection_mint_address} and user wallet {wallet}. Error: {e}"
)
return False


def _get_sol_nft_gated_track_signatures(
user_wallet: str,
sol_associated_wallets: List[str],
tracks: List[Track],
):
return {}
track_signature_map = {}
track_cid_to_id_map = {}

# Build a map of collection mint address -> track ids
# so that only one chain call will be made for premium tracks
# that share the same nft collection gate.
collection_track_map = defaultdict(list)
for track in tracks:
collection_mint_address = track.premium_conditions["nft_collection"]["address"] # type: ignore
collection_track_map[collection_mint_address].append(track.track_cid)
track_cid_to_id_map[track.track_cid] = track.track_id

with concurrent.futures.ThreadPoolExecutor() as executor:
# Check ownership of nfts from collections from given collection mint addresses,
# using all user sol wallets, and generate signatures for corresponding tracks.
future_to_collection_mint_address_map = {
executor.submit(
_does_user_own_sol_nft_collection,
collection_mint_address,
sol_associated_wallets,
): collection_mint_address
for collection_mint_address in list(collection_track_map.keys())
}
for future in concurrent.futures.as_completed(
future_to_collection_mint_address_map
):
collection_mint_address = future_to_collection_mint_address_map[future]
try:
# Generate premium content signatures for tracks whose
# nft collection is owned by the user.
if future.result():
for track_cid in collection_track_map[collection_mint_address]:
track_id = track_cid_to_id_map[track_cid]
track_signature_map[
track_id
] = get_premium_content_signature_for_user(
{
"id": track_cid,
"type": "track",
"user_wallet": user_wallet,
"is_premium": True,
}
)
except Exception as e:
logger.error(
f"Could not get future result for collection_mint_address {collection_mint_address}. Error: {e}"
)

return track_signature_map


# Generates a premium content signature for each of the nft-gated tracks.
Expand Down Expand Up @@ -353,4 +541,10 @@ def _load_abis():
erc1155_abi = json.dumps(json.load(f1155))


def _init_solana_client_manager():
global solana_client_manager
solana_client_manager = SolanaClientManager(shared_config["solana"]["endpoint"])


_load_abis()
_init_solana_client_manager()
4 changes: 2 additions & 2 deletions discovery-provider/src/schemas/track_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -409,13 +409,13 @@
"type": "string",
"const": "sol"
},
"name": {
"address": {
"type": "string"
}
},
"required": [
"chain",
"name"
"address"
],
"title": "PremiumConditionsSolNFTCollection"
},
Expand Down
Loading

0 comments on commit 8c3013e

Please sign in to comment.