Skip to content

Commit

Permalink
fix,docs(tx.py)!:
Browse files Browse the repository at this point in the history
1. Make sequence errors near-impossible.
2. Use more type hints.
3. Improve wallet.Address initialization robustness.
  • Loading branch information
Unique-Divine committed Jul 8, 2023
1 parent 772b7b4 commit 1d73353
Showing 1 changed file with 163 additions and 77 deletions.
240 changes: 163 additions & 77 deletions nibiru/tx.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,50 @@
"""
Classes:
TxClient: A client for building, simulating, and broadcasting transactions.
Transaction: Transactions trigger state changes based on messages. Each message
must be cryptographically signed before being broadcasted to the network.
Transaction: Transactions trigger state changes based on messages. Each
message must be cryptographically signed before being broadcasted to
the network.
"""
import json
import logging
import pprint
from numbers import Number
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
from typing import Any, Iterable, List, Optional, Tuple, Union

from google.protobuf import any_pb2, message
from google.protobuf.json_format import MessageToDict
from nibiru_proto.cosmos.base.abci.v1beta1 import abci_pb2 as abci_type
from nibiru_proto.cosmos.base.v1beta1 import coin_pb2 as cosmos_base_coin_pb
from nibiru_proto.cosmos.base.v1beta1.coin_pb2 import Coin
from nibiru_proto.cosmos.tx.signing.v1beta1 import signing_pb2 as tx_sign
from nibiru_proto.cosmos.tx.v1beta1 import service_pb2 as tx_service
from nibiru_proto.cosmos.tx.v1beta1 import tx_pb2 as cosmos_tx_type

from nibiru import pytypes as pt
from nibiru import wallet
from nibiru import exceptions
from nibiru.exceptions import SimulationError, TxError
from nibiru.grpc_client import GrpcClient


class TxClient:
"""
A client for building, simulating, and broadcasting transactions.
Attributes:
address (Optional[wallet.Address])
client (GrpcClient)
network (pt.Network)
priv_key (wallet.PrivateKey)
tx_config (pt.TxConfig)
"""

address: Optional[wallet.Address]
client: GrpcClient
network: pt.Network
priv_key: wallet.PrivateKey
tx_config: pt.TxConfig

def __init__(
self,
Expand All @@ -42,72 +57,62 @@ def __init__(
self.network = network
self.client = client
self.address = None
self.config = config
self.tx_config = config

def execute_msgs(
self,
msgs: Union[pt.PythonMsg, List[pt.PythonMsg]],
get_sequence_from_node: bool = False,
sequence: Optional[int] = None,
tx_config: Optional[pt.TxConfig] = None,
try_decrease_seq: bool = False,
) -> pt.RawSyncTxResp:
"""
Broadcasts messages to a node in a single transaction. This function first
simulates the corresponding transaction to estimate the amount of gas needed.
Broadcasts messages to a node in a single transaction. This function
first simulates the corresponding transaction to estimate the amount of
gas needed.
If the transaction fails because of account sequence mismatch, we try to
query the sequence from the LCD endpoint and broadcast with the updated
sequence value.
If the transaction fails because of account sequence mismatch, we try
to query the sequence from the LCD endpoint and broadcast with the
updated sequence value.
Args:
get_sequence_from_node (bool, optional): Specifies whether the sequence
comes from the local value or the lcd endpoint. Defaults to False.
msgs (Union[pt.PythonMsg, List[pt.PythonMsg]]):
sequence (Optional[int]): Account sequence for the tx. Sequence
is used to enforce tx ordering and prevent double-spending.
Each time a tx is procesed and committed to the blockchain,
the account sequence number is incremented.
tx_config (Optional[pt.TxConfig] = None)
get_sequence_from_node (bool, optional): Specifies whether the
sequence comes from the local value or the lcd endpoint.
Defaults to False.
Raises:
SimulationError: If broadcasting fails during the simulation.
TxError: If the response code is nonzero, the 'TxError' includes the
raw error logs from the blockchain.
TxError: If the response code is nonzero, the 'TxError' includes
the raw error logs from the blockchain.
Returns:
Union[RawSyncTxResp, Dict[str, Any]]: The transaction response as a dict
in proto3 JSON format.
Union[RawSyncTxResp, Dict[str, Any]]: The transaction response as
a dict in proto3 JSON format.
"""

tx: Transaction
address: wallet.Address
address: wallet.Address = self.ensure_address_info()
tx, address = self.build_tx(
msgs=msgs, get_sequence_from_node=get_sequence_from_node
msgs=msgs,
)
print("Msgs", msgs)
if sequence is not None:
...
elif address:
address_seq = address.sequence
sequence = address_seq
else:
breakpoint()

# Validate account sequence
if sequence is None:
sequence = address.sequence
sequence_err: str = "sequence was not given or available on the wallet object."
assert address, sequence_err
assert sequence, sequence_err

tx = tx.with_sequence(sequence=sequence)

try:
sim_res = self.simulate(tx)
gas_estimate: float = sim_res.gas_info.gas_used

# breakpoint()
tx_resp: abci_type.TxResponse = self.execute_tx(
tx, gas_estimate, tx_config=tx_config
)
# Convert raw log into a dictionary
tx_resp: dict[str, Any] = MessageToDict(tx_resp)
tx_output = self.client.tx_by_hash(tx_hash=tx_resp["txhash"])

# breakpoint()
if tx_output.get("tx_response").get("code") != 0:
address.decrease_sequence()
raise TxError(tx_output.raw_log)

tx_output["rawLog"] = json.loads(tx_output.get("rawLog", "{}"))
return pt.RawSyncTxResp(tx_output)
except SimulationError as err:
if "account sequence mismatch, expected" in str(err):

Expand All @@ -121,17 +126,46 @@ def execute_msgs(
return self.execute_msgs(
msgs=msgs,
sequence=sequence,
get_sequence_from_node=get_sequence_from_node,
tx_config=tx_config,
)
if address:
address.decrease_sequence()
raise SimulationError(f"Failed to simulate transaction: {err}") from err

try:
tx_resp: abci_type.TxResponse = self.execute_tx(
tx, gas_estimate, tx_config=tx_config
)
tx_resp: dict[str, Any] = MessageToDict(tx_resp)
tx_hash: Union[str, None] = tx_resp.get("txhash")
assert tx_hash, f"null txhash on tx_resp: {tx_resp}"
tx_output: tx_service.GetTxResponse = self.client.tx_by_hash(
tx_hash=tx_hash
)

if tx_output.get("tx_response").get("code") != 0:
address.decrease_sequence()
raise TxError(tx_output.raw_log)
breakpoint()

tx_output["rawLog"] = json.loads(tx_output.get("rawLog", "{}"))
return pt.RawSyncTxResp(tx_output)
except exceptions.ErrorQueryTx as err:
logging.info("ErrorQueryTx")
logging.error(err)
raise err
except BaseException as err:
logging.info("BaseException")
logging.error(err)
raise err

def execute_tx(
self, tx: "Transaction", gas_estimate: float, tx_config: pt.TxConfig = None
self,
tx: "Transaction",
gas_estimate: float,
tx_config: pt.TxConfig = None,
) -> abci_type.TxResponse:
conf: pt.TxConfig = self.get_config(tx_config=tx_config)
conf: pt.TxConfig = self.ensure_tx_config(new_tx_config=tx_config)

def compute_gas_wanted() -> float:
# Related to https://github.com/cosmos/cosmos-sdk/issues/14405
Expand Down Expand Up @@ -163,34 +197,67 @@ def compute_gas_wanted() -> float:
)
tx_raw_bytes = tx.get_signed_tx_data()

return self._send_tx(tx_raw_bytes, conf.tx_type)
return self._broadcast_tx(tx_raw_bytes, conf.broadcast_mode)

def _broadcast_tx(
self,
tx_raw_bytes: bytes,
tx_type: pt.TxBroadcastMode = pt.TxBroadcastMode.SYNC,
) -> abci_type.TxResponse:
"""Broadcast the signed transaction to one or more nodes in the
network. The nodes in the network will receive the transaction
and validate its integrity by verifying the signature, checking
if the sender has sufficient funds or permissions, and running
the `ValidateBasic` check on each tx message.
Args:
tx_raw_bytes (bytes): Signed transaction.
tx_type (pt.TxBroadcastMode): Broadcast mode for the transaction
def _send_tx(self, tx_raw_bytes: bytes, tx_type: pt.TxType) -> abci_type.TxResponse:
broadcast_fn: Callable[[bytes], abci_type.TxResponse]
Returns:
(abci_type.TxResponse)
"""

if tx_type == pt.TxType.ASYNC:
broadcast_fn = self.client.send_tx_async_mode
broadcast_mode: tx_service.Broadcast
if tx_type == pt.TxBroadcastMode.ASYNC:
broadcast_mode = tx_service.BroadcastMode.BROADCAST_MODE_ASYNC
else:
broadcast_fn = self.client.send_tx_sync_mode
broadcast_mode = tx_service.BroadcastMode.BROADCAST_MODE_SYNC
return self.client.broadcast_tx(
tx_byte=tx_raw_bytes,
mode=broadcast_mode,
)

return broadcast_fn(tx_raw_bytes)
def build_tx_with_node_sequence(
self,
msgs: Union[pt.PythonMsg, List[pt.PythonMsg]],
):
address: wallet.Address = self.ensure_address_info()
sequence: int = address.get_sequence(
from_node=True,
lcd_endpoint=self.network.lcd_endpoint,
)
return self.build_tx(msgs=msgs, sequence=sequence)

def build_tx(
self,
msgs: Union[pt.PythonMsg, List[pt.PythonMsg]],
get_sequence_from_node: bool = False,
sequence: int = None,
) -> Tuple["Transaction", wallet.Address]:
if not isinstance(msgs, list):
msgs = [msgs]

pb_msgs = [msg.to_pb() for msg in msgs]

self.client.sync_timeout_height()
address: wallet.Address = self.get_address_info()
sequence: int = address.get_sequence(
from_node=get_sequence_from_node,
lcd_endpoint=self.network.lcd_endpoint,
)

address: wallet.Address = self.address
if self.address is None:
address = self.ensure_address_info()
self.address = address

if sequence is None:
sequence = self.address.sequence

tx = (
Transaction()
.with_messages(pb_msgs)
Expand All @@ -207,8 +274,8 @@ def simulate(self, tx: "Transaction") -> abci_type.SimulationResponse:
tx (Transaction): The transaction being simulated.
Returns:
SimulationResponse: SimulationResponse defines the response generated
when a transaction is simulated successfully.
SimulationResponse: SimulationResponse defines the response
generated when a transaction is simulated successfully.
Raises:
SimulationError
Expand All @@ -223,27 +290,46 @@ def simulate(self, tx: "Transaction") -> abci_type.SimulationResponse:

return sim_res

def get_address_info(self) -> wallet.Address:
def ensure_address_info(self) -> wallet.Address:
"""Guarantees that the TxClient.address has been set and returns it.
If the wallet address has not been set prior to this function call,
(1) the address is derived from the 'priv_key' and
(2) the sequence is derived from the 'network.lcd_endpoint'.
"""
if self.address is None:
pub_key: wallet.PublicKey = self.priv_key.to_public_key()
self.address = pub_key.to_address()
self.address = self.address.init_num_seq(self.network.lcd_endpoint)

return self.address

def get_config(self, tx_config: pt.TxConfig = None) -> pt.TxConfig:
"""
Properties in kwargs overwrite the self.config
def ensure_tx_config(
self,
new_tx_config: pt.TxConfig = None,
) -> pt.TxConfig:
"""Guarantees that the TxClient.tx_config has been set and returns it.
Args:
new_tx_config (Optional[pytypes.TxConfig]): Becomes the new value
for the tx config if given. Defaults to None.
Returns:
(pt.TxConfig): The new value for the TxClient.tx_config.
"""
config: pt.TxConfig
if tx_config is not None:
config = tx_config
tx_config: pt.TxConfig
if new_tx_config is not None:
tx_config = new_tx_config
elif self.tx_config is None:
# Set as the default if the TxConfig has not been initialized.
tx_config = pt.TxConfig()
else:
config = self.config
return config
pass
tx_config = self.tx_config
return tx_config


class Transaction:
# TODO: Refactor this into a dataclass for brevity.
class Transaction(pt.Jsonable):
"""
Transactions trigger state changes based on messages ('msgs'). Each message
must be signed before being broadcasted to the network, included in a block,
Expand Down Expand Up @@ -273,9 +359,9 @@ class Transaction:
convention, the signer from the first message is referred to as the
primary signer and pays the fee for the whole transaction. We refer
to this primary signer with 'priv_key'.
memo (str): Memo is a note or comment to be added to the transction.
timeout_height (int): Timeout height is the block height after which the
transaction will not be processed by the chain.
memo (str): Memo is a note or comment to be added to the transaction.
timeout_height (int): Timeout height is the block height after which
the transaction will not be processed by the chain.
"""

def __init__(
Expand Down Expand Up @@ -403,7 +489,7 @@ def get_sign_doc(
self, public_key: wallet.PublicKey = None
) -> cosmos_tx_type.SignDoc:
if len(self.msgs) == 0:
raise ValueError("message is empty")
raise ValueError("no messages in the tx body")

if self.account_num is None:
raise RuntimeError("account_num should be defined")
Expand Down Expand Up @@ -442,4 +528,4 @@ def get_signed_tx_data(self) -> bytes:
pub_key = self.priv_key.to_public_key()
sign_doc = self.get_sign_doc(pub_key)
sig = self.priv_key.sign(sign_doc.SerializeToString())
return self.get_tx_data(sig, pub_key)
return self.get_tx_data(signature=sig, public_key=pub_key)

0 comments on commit 1d73353

Please sign in to comment.