Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RPC] Synchronous, typed interface for Python client #1272

Merged
merged 14 commits into from
Sep 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cryptol-remote-api/python/cryptol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@
from .bitvector import BV
from .commands import *
from .connection import *
# import everything from `.synchronous` except `connect` and `connect_stdio`
# (since functions with those names were already imported from `.connection`)
from .synchronous import Qed, Safe, Counterexample, Satisfiable, Unsatisfiable, CryptolSyncConnection
# and add an alias `sync` for the `synchronous` module
from . import synchronous
sync = synchronous


__all__ = ['bitvector', 'commands', 'connections', 'cryptoltypes', 'solver']
__all__ = ['bitvector', 'commands', 'connection', 'cryptoltypes', 'opaque', 'solver', 'synchronous']


def fail_with(x : Exception) -> NoReturn:
Expand Down
111 changes: 76 additions & 35 deletions cryptol-remote-api/python/cryptol/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from __future__ import annotations

import base64
from abc import ABC
from enum import Enum
from dataclasses import dataclass
from typing import Any, List, Optional, Union
from typing import Any, Tuple, List, Dict, Optional, Union
from typing_extensions import Literal

import argo_client.interaction as argo
Expand All @@ -20,7 +21,9 @@ def extend_hex(string : str) -> str:
else:
return string

def from_cryptol_arg(val : Any) -> Any:
CryptolValue = Union[bool, int, BV, Tuple, List, Dict, OpaqueValue]

def from_cryptol_arg(val : Any) -> CryptolValue:
"""Return the canonical Python value for a Cryptol JSON value."""
if isinstance(val, bool):
return val
Expand Down Expand Up @@ -80,27 +83,36 @@ def process_result(self, res : Any) -> Any:
return res


class CryptolEvalExpr(argo.Command):
class CryptolEvalExprRaw(argo.Command):
def __init__(self, connection : HasProtocolState, expr : Any) -> None:
super(CryptolEvalExpr, self).__init__(
super(CryptolEvalExprRaw, self).__init__(
'evaluate expression',
{'expression': expr},
connection
)

def process_result(self, res : Any) -> Any:
return from_cryptol_arg(res['value'])
return res['value']

class CryptolEvalExpr(CryptolEvalExprRaw):
def process_result(self, res : Any) -> Any:
return from_cryptol_arg(super(CryptolEvalExpr, self).process_result(res))

class CryptolCall(argo.Command):

class CryptolCallRaw(argo.Command):
def __init__(self, connection : HasProtocolState, fun : str, args : List[Any]) -> None:
super(CryptolCall, self).__init__(
super(CryptolCallRaw, self).__init__(
'call',
{'function': fun, 'arguments': args},
connection
)

def process_result(self, res : Any) -> Any:
return from_cryptol_arg(res['value'])
return res['value']

class CryptolCall(CryptolCallRaw):
def process_result(self, res : Any) -> Any:
return from_cryptol_arg(super(CryptolCall, self).process_result(res))


@dataclass
Expand All @@ -111,43 +123,53 @@ class CheckReport:
error_msg: Optional[str]
tests_run: int
tests_possible: Optional[int]

def __bool__(self) -> bool:
return self.success

def to_check_report(res : Any) -> CheckReport:
if res['result'] == 'pass':
return CheckReport(
success=True,
args=[],
error_msg = None,
tests_run=res['tests run'],
tests_possible=res['tests possible'])
elif res['result'] == 'fail':
return CheckReport(
success=False,
args=[from_cryptol_arg(arg['expr']) for arg in res['arguments']],
error_msg = None,
tests_run=res['tests run'],
tests_possible=res['tests possible'])
elif res['result'] == 'error':
return CheckReport(
success=False,
args=[from_cryptol_arg(arg['expr']) for arg in res['arguments']],
error_msg = res['error message'],
tests_run=res['tests run'],
tests_possible=res['tests possible'])
else:
raise ValueError("Unknown check result " + str(res))

class CryptolCheck(argo.Command):
class CryptolCheckRaw(argo.Command):
def __init__(self, connection : HasProtocolState, expr : Any, num_tests : Union[Literal['all'],int, None]) -> None:
if num_tests:
args = {'expression': expr, 'number of tests':num_tests}
else:
args = {'expression': expr}
super(CryptolCheck, self).__init__(
super(CryptolCheckRaw, self).__init__(
'check',
args,
connection
)

def process_result(self, res : Any) -> Any:
return res

class CryptolCheck(CryptolCheckRaw):
def process_result(self, res : Any) -> 'CheckReport':
if res['result'] == 'pass':
return CheckReport(
success=True,
args=[],
error_msg = None,
tests_run=res['tests run'],
tests_possible=res['tests possible'])
elif res['result'] == 'fail':
return CheckReport(
success=False,
args=[from_cryptol_arg(arg['expr']) for arg in res['arguments']],
error_msg = None,
tests_run=res['tests run'],
tests_possible=res['tests possible'])
elif res['result'] == 'error':
return CheckReport(
success=False,
args=[from_cryptol_arg(arg['expr']) for arg in res['arguments']],
error_msg = res['error message'],
tests_run=res['tests run'],
tests_possible=res['tests possible'])
else:
raise ValueError("Unknown check result " + str(res))
return to_check_report(super(CryptolCheck, self).process_result(res))


class CryptolCheckType(argo.Command):
Expand All @@ -161,14 +183,15 @@ def __init__(self, connection : HasProtocolState, expr : Any) -> None:
def process_result(self, res : Any) -> Any:
return res['type schema']


class SmtQueryType(str, Enum):
PROVE = 'prove'
SAFE = 'safe'
SAT = 'sat'

class CryptolProveSat(argo.Command):
class CryptolProveSatRaw(argo.Command):
def __init__(self, connection : HasProtocolState, qtype : SmtQueryType, expr : Any, solver : Solver, count : Optional[int]) -> None:
super(CryptolProveSat, self).__init__(
super(CryptolProveSatRaw, self).__init__(
'prove or satisfy',
{'query type': qtype,
'expression': expr,
Expand All @@ -180,6 +203,11 @@ def __init__(self, connection : HasProtocolState, qtype : SmtQueryType, expr : A
self.qtype = qtype

def process_result(self, res : Any) -> Any:
return res

class CryptolProveSat(CryptolProveSatRaw):
def process_result(self, res : Any) -> Any:
res = super(CryptolProveSat, self).process_result(res)
if res['result'] == 'unsatisfiable':
if self.qtype == SmtQueryType.SAT:
return False
Expand All @@ -199,25 +227,36 @@ def process_result(self, res : Any) -> Any:
else:
raise ValueError("Unknown SMT result: " + str(res))

class CryptolProveRaw(CryptolProveSatRaw):
def __init__(self, connection : HasProtocolState, expr : Any, solver : Solver) -> None:
super(CryptolProveRaw, self).__init__(connection, SmtQueryType.PROVE, expr, solver, 1)
class CryptolProve(CryptolProveSat):
def __init__(self, connection : HasProtocolState, expr : Any, solver : Solver) -> None:
super(CryptolProve, self).__init__(connection, SmtQueryType.PROVE, expr, solver, 1)

class CryptolSatRaw(CryptolProveSatRaw):
def __init__(self, connection : HasProtocolState, expr : Any, solver : Solver, count : int) -> None:
super(CryptolSatRaw, self).__init__(connection, SmtQueryType.SAT, expr, solver, count)
class CryptolSat(CryptolProveSat):
def __init__(self, connection : HasProtocolState, expr : Any, solver : Solver, count : int) -> None:
super(CryptolSat, self).__init__(connection, SmtQueryType.SAT, expr, solver, count)

class CryptolSafeRaw(CryptolProveSatRaw):
def __init__(self, connection : HasProtocolState, expr : Any, solver : Solver) -> None:
super(CryptolSafeRaw, self).__init__(connection, SmtQueryType.SAFE, expr, solver, 1)
class CryptolSafe(CryptolProveSat):
def __init__(self, connection : HasProtocolState, expr : Any, solver : Solver) -> None:
super(CryptolSafe, self).__init__(connection, SmtQueryType.SAFE, expr, solver, 1)


class CryptolNames(argo.Command):
def __init__(self, connection : HasProtocolState) -> None:
super(CryptolNames, self).__init__('visible names', {}, connection)

def process_result(self, res : Any) -> Any:
return res


class CryptolFocusedModule(argo.Command):
def __init__(self, connection : HasProtocolState) -> None:
super(CryptolFocusedModule, self).__init__(
Expand All @@ -229,6 +268,7 @@ def __init__(self, connection : HasProtocolState) -> None:
def process_result(self, res : Any) -> Any:
return res


class CryptolReset(argo.Notification):
def __init__(self, connection : HasProtocolState) -> None:
super(CryptolReset, self).__init__(
Expand All @@ -237,6 +277,7 @@ def __init__(self, connection : HasProtocolState) -> None:
connection
)


class CryptolResetServer(argo.Notification):
def __init__(self, connection : HasProtocolState) -> None:
super(CryptolResetServer, self).__init__(
Expand Down
49 changes: 47 additions & 2 deletions cryptol-remote-api/python/cryptol/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ def load_module(self, module_name : str) -> argo.Command:
self.most_recent_result = CryptolLoadModule(self, module_name)
return self.most_recent_result

def eval_raw(self, expression : Any) -> argo.Command:
"""Like the member method ``eval``, but does not call
``from_cryptol_arg`` on the ``.result()``.
"""
self.most_recent_result = CryptolEvalExprRaw(self, expression)
return self.most_recent_result

def eval(self, expression : Any) -> argo.Command:
"""Evaluate a Cryptol expression, represented according to
:ref:`cryptol-json-expression`, with Python datatypes standing
Expand All @@ -189,15 +196,33 @@ def evaluate_expression(self, expression : Any) -> argo.Command:
return self.eval(expression)

def extend_search_path(self, *dir : str) -> argo.Command:
"""Load a Cryptol module, like ``:module`` at the Cryptol REPL."""
"""Extend the search path for loading Cryptol modules."""
self.most_recent_result = CryptolExtendSearchPath(self, list(dir))
return self.most_recent_result

def call_raw(self, fun : str, *args : List[Any]) -> argo.Command:
"""Like the member method ``call``, but does not call
``from_cryptol_arg`` on the ``.result()``.
"""
encoded_args = [cryptoltypes.CryptolType().from_python(a) for a in args]
self.most_recent_result = CryptolCallRaw(self, fun, encoded_args)
return self.most_recent_result

def call(self, fun : str, *args : List[Any]) -> argo.Command:
encoded_args = [cryptoltypes.CryptolType().from_python(a) for a in args]
self.most_recent_result = CryptolCall(self, fun, encoded_args)
return self.most_recent_result

def check_raw(self, expr : Any, *, num_tests : Union[Literal['all'], int, None] = None) -> argo.Command:
"""Like the member method ``check``, but does not call
`to_check_report` on the ``.result()``.
"""
if num_tests == "all" or isinstance(num_tests, int) or num_tests is None:
self.most_recent_result = CryptolCheckRaw(self, expr, num_tests)
return self.most_recent_result
else:
raise ValueError('``num_tests`` must be an integer, ``None``, or the string literall ``"all"``')

def check(self, expr : Any, *, num_tests : Union[Literal['all'], int, None] = None) -> argo.Command:
"""Tests the validity of a Cryptol expression with random inputs. The expression must be a function with
return type ``Bit``.
Expand All @@ -212,7 +237,6 @@ def check(self, expr : Any, *, num_tests : Union[Literal['all'], int, None] = No
else:
raise ValueError('``num_tests`` must be an integer, ``None``, or the string literall ``"all"``')


def check_type(self, code : Any) -> argo.Command:
"""Check the type of a Cryptol expression, represented according to
:ref:`cryptol-json-expression`, with Python datatypes standing for
Expand All @@ -221,6 +245,13 @@ def check_type(self, code : Any) -> argo.Command:
self.most_recent_result = CryptolCheckType(self, code)
return self.most_recent_result

def sat_raw(self, expr : Any, solver : solver.Solver = solver.Z3, count : int = 1) -> argo.Command:
"""Like the member method ``sat``, but does not call
`to_smt_query_result` on the ``.result()``.
"""
self.most_recent_result = CryptolSatRaw(self, expr, solver, count)
return self.most_recent_result

def sat(self, expr : Any, solver : solver.Solver = solver.Z3, count : int = 1) -> argo.Command:
"""Check the satisfiability of a Cryptol expression, represented according to
:ref:`cryptol-json-expression`, with Python datatypes standing for
Expand All @@ -230,6 +261,13 @@ def sat(self, expr : Any, solver : solver.Solver = solver.Z3, count : int = 1) -
self.most_recent_result = CryptolSat(self, expr, solver, count)
return self.most_recent_result

def prove_raw(self, expr : Any, solver : solver.Solver = solver.Z3) -> argo.Command:
"""Like the member method ``prove``, but does not call
`to_smt_query_result` on the ``.result()``.
"""
self.most_recent_result = CryptolProveRaw(self, expr, solver)
return self.most_recent_result

def prove(self, expr : Any, solver : solver.Solver = solver.Z3) -> argo.Command:
"""Check the validity of a Cryptol expression, represented according to
:ref:`cryptol-json-expression`, with Python datatypes standing for
Expand All @@ -238,6 +276,13 @@ def prove(self, expr : Any, solver : solver.Solver = solver.Z3) -> argo.Command:
self.most_recent_result = CryptolProve(self, expr, solver)
return self.most_recent_result

def safe_raw(self, expr : Any, solver : solver.Solver = solver.Z3) -> argo.Command:
"""Like the member method ``safe``, but does not call
`to_smt_query_result` on the ``.result()``.
"""
self.most_recent_result = CryptolSafeRaw(self, expr, solver)
return self.most_recent_result

def safe(self, expr : Any, solver : solver.Solver = solver.Z3) -> argo.Command:
"""Check via an external SMT solver that the given term is safe for all inputs,
which means it cannot encounter a run-time error.
Expand Down
Loading