diff --git a/contracts/tfi-pair/src/contract.rs b/contracts/tfi-pair/src/contract.rs index 53769e0..deea568 100644 --- a/contracts/tfi-pair/src/contract.rs +++ b/contracts/tfi-pair/src/contract.rs @@ -184,26 +184,41 @@ pub fn receive_cw20( // bytes data = 2; // } // Let's do this by hand to avoid whole protobuf libs -fn parse_init_addr(init_result: &[u8]) -> StdResult<&str> { +fn parse_init_addr(init_result: &[u8]) -> Result<&str, ContractError> { + if init_result.len() < 2 { + return Err(ContractError::InvalidAddressLength(init_result.len())); + } + // ensure the first byte (field 1, type 2 = 1 << 3 + 2 = 10) if init_result[0] != 10 { - return Err(StdError::generic_err("Unexpected field, must be 10")); + return Err(StdError::generic_err("Unexpected field, must be 10").into()); } // parse the length (this will always be less than 127 in our case) let length = init_result[1] as usize; + + if init_result.len() < 2 + length { + return Err(ContractError::InvalidAddressLength(init_result.len())); + } + let addr_bytes = &init_result[2..][..length]; - Ok(std::str::from_utf8(addr_bytes)?) + + Ok(std::str::from_utf8(addr_bytes).map_err(StdError::from)?) } /// This just stores the result for future query #[cfg_attr(not(feature = "library"), entry_point)] -pub fn reply(deps: DepsMut, _env: Env, msg: Reply) -> StdResult { +pub fn reply(deps: DepsMut, _env: Env, msg: Reply) -> Result { // this is the only expected one from init if msg.id != 1 { - return Err(StdError::generic_err("Unsupported reply id")); + return Err(StdError::generic_err("Unsupported reply id").into()); } - let data = msg.result.unwrap().data.unwrap(); + let data = msg + .result + .into_result() + .map_err(ContractError::MessageFailure)? + .data + .ok_or(ContractError::MissingData {})?; let contract_addr = parse_init_addr(&data)?; let liquidity_token = deps.api.addr_validate(contract_addr)?; @@ -237,13 +252,13 @@ pub fn provide_liquidity( assets .iter() .find(|a| a.info.equal(&pools[0].info)) - .map(|a| a.amount) - .expect("Wrong asset info is given"), + .ok_or_else(|| ContractError::AssetMismatch(pools[0].info.to_string()))? + .amount, assets .iter() .find(|a| a.info.equal(&pools[1].info)) - .map(|a| a.amount) - .expect("Wrong asset info is given"), + .ok_or_else(|| ContractError::AssetMismatch(pools[1].info.to_string()))? + .amount, ]; let mut messages: Vec = vec![]; @@ -389,7 +404,7 @@ pub fn swap( }; ask_pool = pools[0].clone(); } else { - return Err(ContractError::AssetMismatch {}); + return Err(ContractError::AssetMismatch(offer_asset.info.to_string())); } let offer_amount = offer_asset.amount; @@ -480,7 +495,7 @@ pub fn query_simulation( offer_pool = pools[1].clone(); ask_pool = pools[0].clone(); } else { - return Err(ContractError::AssetMismatch {}); + return Err(ContractError::AssetMismatch(offer_asset.info.to_string())); } let (return_amount, spread_amount, commission_amount) = @@ -511,7 +526,7 @@ pub fn query_reverse_simulation( ask_pool = pools[1].clone(); offer_pool = pools[0].clone(); } else { - return Err(ContractError::AssetMismatch {}); + return Err(ContractError::AssetMismatch(ask_asset.info.to_string())); } let (offer_amount, spread_amount, commission_amount) = @@ -549,7 +564,7 @@ fn compute_swap( let commission_amount: Uint128 = return_amount * Decimal::from_str(&COMMISSION_RATE).unwrap(); // commission will be absorbed to pool - let return_amount: Uint128 = return_amount.checked_sub(commission_amount).unwrap(); + let return_amount: Uint128 = return_amount.checked_sub(commission_amount)?; Ok((return_amount, spread_amount, commission_amount)) } diff --git a/contracts/tfi-pair/src/error.rs b/contracts/tfi-pair/src/error.rs index bee0dea..6d62550 100644 --- a/contracts/tfi-pair/src/error.rs +++ b/contracts/tfi-pair/src/error.rs @@ -21,6 +21,15 @@ pub enum ContractError { #[error("Max slippage assertion")] MaxSlippageAssertion {}, - #[error("Asset mismatch")] - AssetMismatch {}, + #[error("Asset mismatch: {0}")] + AssetMismatch(String), + + #[error("Explicit failure in message: {0}")] + MessageFailure(String), + + #[error("Missing required data")] + MissingData {}, + + #[error("Invalid address length: {0}")] + InvalidAddressLength(usize), }