Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auto update clients to prevent expiry #412

Merged
merged 10 commits into from
Feb 12, 2021
9 changes: 9 additions & 0 deletions cmd/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var (
flagTimeoutHeightOffset = "timeout-height-offset"
flagTimeoutTimeOffset = "timeout-time-offset"
flagMaxRetries = "max-retries"
flagThresholdTime = "time-threshold"
)

func ibcDenomFlags(cmd *cobra.Command) *cobra.Command {
Expand Down Expand Up @@ -259,3 +260,11 @@ func retryFlag(cmd *cobra.Command) *cobra.Command {
}
return cmd
}

func updateTimeFlags(cmd *cobra.Command) *cobra.Command {
cmd.Flags().Duration(flagThresholdTime, 6*time.Hour, "time before to expiry time to update client")
if err := viper.BindPFlag(flagThresholdTime, cmd.Flags().Lookup(flagThresholdTime)); err != nil {
panic(err)
}
return cmd
}
64 changes: 63 additions & 1 deletion cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@ package cmd

import (
"fmt"
"math"
"os"
"os/signal"
"strings"
"syscall"
"time"

retry "github.com/avast/retry-go"
"github.com/cosmos/relayer/relayer"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"golang.org/x/sync/errgroup"
)

// startCmd represents the start command
Expand Down Expand Up @@ -70,11 +75,33 @@ $ %s start demo-path2 --max-tx-size 10`, appName, appName)),
return err
}

thresholdTime := viper.GetDuration(flagThresholdTime)

eg := new(errgroup.Group)
eg.Go(func() error {
colin-axner marked this conversation as resolved.
Show resolved Hide resolved
for {
var timeToExpiry time.Duration
if err := retry.Do(func() error {
timeToExpiry, err = UpdateClientsFromChains(c[src], c[dst], thresholdTime)
if err != nil {
return err
}
return nil
}, retry.Attempts(5), retry.Delay(time.Millisecond*500), retry.LastErrorOnly(true)); err != nil {
return err
}
time.Sleep(timeToExpiry - thresholdTime)
}
})
if err = eg.Wait(); err != nil {
return err
}

trapSignal(done)
return nil
},
}
return strategyFlag(cmd)
return strategyFlag(updateTimeFlags(cmd))
}

// trap signal waits for a SIGINT or SIGTERM and then sends down the done channel
Expand All @@ -91,3 +118,38 @@ func trapSignal(done func()) {
// call the cleanup func
done()
}

// UpdateClientsFromChains takes src, dst chains, threshold time and update clients based on expiry time
func UpdateClientsFromChains(src, dst *relayer.Chain, thresholdTime time.Duration) (time.Duration, error) {
var (
srcTimeExpiry, dstTimeExpiry time.Duration
err error
)

eg := new(errgroup.Group)
eg.Go(func() error {
srcTimeExpiry, err = relayer.AutoUpdateClient(src, dst, thresholdTime)
return err
})
eg.Go(func() error {
dstTimeExpiry, err = relayer.AutoUpdateClient(dst, src, thresholdTime)
return err
})
if err := eg.Wait(); err != nil {
return 0, err
}

if srcTimeExpiry <= 0 {
return 0, fmt.Errorf("client (%s) of chain: %s is expired",
src.PathEnd.ClientID, src.ChainID)
}

if dstTimeExpiry <= 0 {
return 0, fmt.Errorf("client (%s) of chain: %s is expired",
dst.PathEnd.ClientID, dst.ChainID)
}

minTimeExpiry := math.Min(float64(srcTimeExpiry), float64(dstTimeExpiry))

return time.Duration(int64(minTimeExpiry)), nil
}
89 changes: 89 additions & 0 deletions relayer/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,92 @@ func IsMatchingClient(clientStateA, clientStateB ibctmtypes.ClientState) bool {
func IsMatchingConsensusState(consensusStateA, consensusStateB *ibctmtypes.ConsensusState) bool {
return reflect.DeepEqual(*consensusStateA, *consensusStateB)
}

// AutoUpdateClient update client automatically to prevent expiry
func AutoUpdateClient(src, dst *Chain, thresholdTime time.Duration) (time.Duration, error) {
height, err := src.QueryLatestHeight()
if err != nil {
return 0, err
}

clientStateRes, err := src.QueryClientState(height)
if err != nil {
return 0, err
}

// unpack any into ibc tendermint client state
clientStateExported, err := clienttypes.UnpackClientState(clientStateRes.ClientState)
if err != nil {
return 0, err
}

// cast from interface to concrete type
clientState, ok := clientStateExported.(*ibctmtypes.ClientState)
if !ok {
return 0, fmt.Errorf("error when casting exported clientstate with clientID %s on chain: %s",
src.PathEnd.ClientID, src.PathEnd.ChainID)
}

if clientState.TrustingPeriod <= thresholdTime {
return 0, fmt.Errorf("client (%s) trusting period time is less than or equal to threshold time",
src.PathEnd.ClientID)
}

// query the latest consensus state of the potential matching client
consensusStateResp, err := clientutils.QueryConsensusStateABCI(src.CLIContext(0),
src.PathEnd.ClientID, clientState.GetLatestHeight())
if err != nil {
return 0, err
}

exportedConsState, err := clienttypes.UnpackConsensusState(consensusStateResp.ConsensusState)
if err != nil {
return 0, err
}

consensusState, ok := exportedConsState.(*ibctmtypes.ConsensusState)
if !ok {
return 0, fmt.Errorf("consensus state with clientID %s from chain %s is not IBC tendermint type",
src.PathEnd.ClientID, src.PathEnd.ChainID)
}

expirationTime := consensusState.Timestamp.Add(clientState.TrustingPeriod)

timeToExpiry := time.Until(expirationTime)

if timeToExpiry > thresholdTime {
return timeToExpiry, nil
}

if clientState.IsExpired(consensusState.Timestamp, time.Now()) {
return 0, fmt.Errorf("client (%s) is already expired on chain: %s", src.PathEnd.ClientID, src.ChainID)
}

srcUpdateHeader, err := src.GetIBCUpdateHeader(dst)
if err != nil {
return 0, err
}

updateMsg, err := src.UpdateClient(dst)
if err != nil {
return 0, err
}

msgs := []sdk.Msg{updateMsg}

res, success, err := src.SendMsgs(msgs)
if err != nil {
return 0, err
}
if !success {
return 0, fmt.Errorf("tx failed: %s", res.RawLog)
}
src.Log(fmt.Sprintf("★ Client updated: [%s]client(%s) {%d}->{%d}",
src.ChainID,
src.PathEnd.ClientID,
MustGetHeight(srcUpdateHeader.TrustedHeight),
srcUpdateHeader.Header.Height,
))

return clientState.TrustingPeriod, nil
}