Skip to content

Commit

Permalink
Support foreign implementations of trait interfaces (mozilla#1578)
Browse files Browse the repository at this point in the history
Scaffolding:
* Generate a struct that implements the trait using a callback interface callback
* Make `try_lift` input a callback interface handle and create one of those structs.
* Don't use `try_lift` in the trait interface method scaffolding.
  `try_lift` expects to lift a callback handle, but scaffolding
  methods are called with a leaked object pointer.
* Removed the unused RustCallStatus param from the callback initialization function

Kotlin/Python/Swift:
* Factored out the callback interface impl and interface/protocol
  templates so it can also be used for trait interfaces.
* Changed the callback interface handle map code so that it doesn't
  try to re-use the handles. If an object is lowered twice, we now
  generate two different handles. This is required for trait
  interfaces, and I think it's also would be the right thing for
  callback interfaces if they could be passed back into the foreign
  language from Rust.
* Make `lower()` return a callback interface handle.
* Added some code to clarify how we generate the protocol and the
  implementation of that protocol for an object

Other:
* Trait interfaces are still not supported on Ruby.
* Updated the coverall bindings tests to test this.
* Updated the traits example, although there's definitely more room for improvement.

TODO:

I think a better handle solution (mozilla#1730) could help with a few things:

* We're currently wrapping the object every time it's passed across the
  FFI.  If the foreign code receives a trait object, then passes it back
  to Rust.  Rust now has a handle to the foreign impl and that foreign
  impl just calls back into Rust.  This can lead to some extremely
  inefficent FFI calls if an object is passed around enough.
* The way we're coercing between pointers, usize, and uint64 is
  probably wrong and at the very least extremely brittle.

There should be better tests for reference counts, but I'm waiting until
we address the handle issue to implement them.
  • Loading branch information
bendk committed Oct 12, 2023
1 parent 484f73a commit 7291c25
Show file tree
Hide file tree
Showing 42 changed files with 1,001 additions and 533 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
- Error types must now implement `Error + Send + Sync + 'static`.
- Proc-macros: The `handle_unknown_callback_error` attribute is no longer needed for callback
interface errors
- Foreign types can now implement trait interfaces

### What's Fixed

Expand Down
18 changes: 17 additions & 1 deletion docs/manual/src/udl/interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,28 @@ fn get_buttons() -> Vec<Arc<dyn Button>> { ... }
fn press(button: Arc<dyn Button>) -> Arc<dyn Button> { ... }
```

See the ["traits" example](https://github.com/mozilla/uniffi-rs/tree/main/examples/traits) for more.
### Foreign implementations

Traits can also be implemented on the foreign side passed into Rust, for example:

```python
class PyButton(uniffi_module.Button):
def name(self):
return "PyButton"

uniffi_module.press(PyButton())
```

Note: This is currently supported on Python, Kotlin, and Swift.

### Traits construction

Because any number of `struct`s may implement a trait, they don't have constructors.

### Traits example

See the ["traits" example](https://github.com/mozilla/uniffi-rs/tree/main/examples/traits) for more.

## Alternate Named Constructors

In addition to the default constructor connected to the `::new()` method, you can specify
Expand Down
16 changes: 16 additions & 0 deletions examples/traits/tests/bindings/test_traits.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import uniffi.traits.*

for (button in getButtons()) {
val name = button.name()
// Check that the name is one of the expected values
assert(name in listOf("go", "stop"))
// Check that we can round-trip the button through Rust
assert(press(button).name() == name)
}

// Test a button implemented in Kotlin
class KtButton : Button {
override fun name() = "KtButton"
}

assert(press(KtButton()).name() == "KtButton")
16 changes: 12 additions & 4 deletions examples/traits/tests/bindings/test_traits.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from traits import *

for button in get_buttons():
if button.name() in ["go", "stop"]:
press(button)
else:
print("unknown button", button)
name = button.name()
# Check that the name is one of the expected values
assert(name in ["go", "stop"])
# Check that we can round-trip the button through Rust
assert(press(button).name() == name)

# Test a button implemented in Python
class PyButton(Button):
def name(self):
return "PyButton"

assert(press(PyButton()).name() == "PyButton")
18 changes: 18 additions & 0 deletions examples/traits/tests/bindings/test_traits.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import traits

for button in getButtons() {
let name = button.name()
// Check that the name is one of the expected values
assert(["go", "stop"].contains(name))
// Check that we can round-trip the button through Rust
assert(press(button: button).name() == name)
}

// Test a Button implemented in Swift
class SwiftButton: Button {
func name() -> String {
return "SwiftButton"
}
}

assert(press(button: SwiftButton()).name() == "SwiftButton")
6 changes: 5 additions & 1 deletion examples/traits/tests/test_generated_bindings.rs
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
uniffi::build_foreign_language_testcases!("tests/bindings/test_traits.py",);
uniffi::build_foreign_language_testcases!(
"tests/bindings/test_traits.py",
"tests/bindings/test_traits.kts",
"tests/bindings/test_traits.swift",
);
97 changes: 92 additions & 5 deletions fixtures/coverall/tests/bindings/test_coverall.kts
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,67 @@ Coveralls("test_interfaces_in_dicts").use { coveralls ->
assert(coveralls.getRepairs().size == 2)
}

Coveralls("test_regressions").use { coveralls ->
assert(coveralls.getStatus("success") == "status: success")
}

class KotlinGetters : Getters {
override fun getBool(v: Boolean, arg2: Boolean) : Boolean {
return v != arg2
}

override fun getString(v: String, arg2: Boolean) : String {
if (v == "too-many-holes") {
throw CoverallException.TooManyHoles("too many holes")
} else if (v == "unexpected-error") {
throw RuntimeException("unexpected error")
} else if (arg2) {
return v.uppercase()
} else {
return v
}
}

override fun getOption(v: String, arg2: Boolean) : String? {
if (v == "os-error") {
throw ComplexException.OsException(100, 200)
} else if (v == "unknown-error") {
throw ComplexException.UnknownException()
} else if (arg2) {
if (!v.isEmpty()) {
return v.uppercase()
} else {
return null
}
} else {
return v
}
}

override fun getList(v: List<Int>, arg2: Boolean) : List<Int> {
if (arg2) {
return v
} else {
return listOf()
}
}

@Suppress("UNUSED_PARAMETER")
override fun getNothing(v: String) = Unit
}

// Test traits implemented in Rust
makeRustGetters().let { rustGetters ->
testGetters(rustGetters)
testGettersFromKotlin(rustGetters)
}

// Test traits implemented in Kotlin
KotlinGetters().let { kotlinGetters ->
testGetters(kotlinGetters)
testGettersFromKotlin(kotlinGetters)
}

fun testGettersFromKotlin(getters: Getters) {
assert(getters.getBool(true, true) == false);
assert(getters.getBool(true, false) == true);
Expand Down Expand Up @@ -258,11 +313,27 @@ fun testGettersFromKotlin(getters: Getters) {

try {
getters.getString("unexpected-error", true)
} catch(e: InternalException) {
} catch(e: Exception) {
// Expected
}
}

class KotlinNode() : NodeTrait {
var currentParent: NodeTrait? = null

override fun name() = "node-kt"

override fun setParent(parent: NodeTrait?) {
currentParent = parent
}

override fun getParent() = currentParent

override fun strongCount() : ULong {
return 0.toULong() // TODO
}
}

// Test NodeTrait
getTraits().let { traits ->
assert(traits[0].name() == "node-1")
Expand All @@ -273,16 +344,32 @@ getTraits().let { traits ->
assert(traits[1].name() == "node-2")
assert(traits[1].strongCount() == 2UL)

// Note: this doesn't increase the Rust strong count, since we wrap the Rust impl with a
// Swift impl before passing it to `setParent()`
traits[0].setParent(traits[1])
assert(ancestorNames(traits[0]) == listOf("node-2"))
assert(ancestorNames(traits[1]).isEmpty())
assert(traits[1].strongCount() == 3UL)
assert(traits[1].strongCount() == 2UL)
assert(traits[0].getParent()!!.name() == "node-2")

val ktNode = KotlinNode()
traits[1].setParent(ktNode)
assert(ancestorNames(traits[0]) == listOf("node-2", "node-kt"))
assert(ancestorNames(traits[1]) == listOf("node-kt"))
assert(ancestorNames(ktNode) == listOf<String>())

traits[1].setParent(null)
ktNode.setParent(traits[0])
assert(ancestorNames(ktNode) == listOf("node-1", "node-2"))
assert(ancestorNames(traits[0]) == listOf("node-2"))
assert(ancestorNames(traits[1]) == listOf<String>())

// Unset everything and check that we don't get a memory error
ktNode.setParent(null)
traits[0].setParent(null)

Coveralls("test_regressions").use { coveralls ->
assert(coveralls.getStatus("success") == "status: success")
}
// FIXME: We should be calling `NodeTraitImpl.close()` to release the Rust pointer, however that's
// not possible through the `NodeTrait` interface (see #1787).
}

// This tests that the UniFFI-generated scaffolding doesn't introduce any unexpected locking.
Expand Down
90 changes: 85 additions & 5 deletions fixtures/coverall/tests/bindings/test_coverall.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,68 @@ def test_bytes(self):
coveralls = Coveralls("test_bytes")
self.assertEqual(coveralls.reverse(b"123"), b"321")

class PyGetters:
def get_bool(self, v, arg2):
return v ^ arg2

def get_string(self, v, arg2):
if v == "too-many-holes":
raise CoverallError.TooManyHoles
elif v == "unexpected-error":
raise RuntimeError("unexpected error")
elif arg2:
return v.upper()
else:
return v

def get_option(self, v, arg2):
if v == "os-error":
raise ComplexError.OsError(100, 200)
elif v == "unknown-error":
raise ComplexError.UnknownError
elif arg2:
if v:
return v.upper()
else:
return None
else:
return v

def get_list(self, v, arg2):
if arg2:
return v
else:
return []

def get_nothing(self, _v):
return None

class PyNode:
def __init__(self):
self.parent = None

def name(self):
return "node-py"

def set_parent(self, parent):
self.parent = parent

def get_parent(self):
return self.parent

def strong_count(self):
return 0 # TODO

class TraitsTest(unittest.TestCase):
# Test traits implemented in Rust
def test_rust_getters(self):
test_getters(make_rust_getters())
self.check_getters_from_python(make_rust_getters())
# def test_rust_getters(self):
# test_getters(None)
# self.check_getters_from_python(make_rust_getters())

# Test traits implemented in Rust
def test_python_getters(self):
test_getters(PyGetters())
#self.check_getters_from_python(PyGetters())

def check_getters_from_python(self, getters):
self.assertEqual(getters.get_bool(True, True), False);
Expand Down Expand Up @@ -316,7 +373,8 @@ def check_getters_from_python(self, getters):
with self.assertRaises(InternalError):
getters.get_string("unexpected-error", True)

def test_node(self):
def test_path(self):
# Get traits creates 2 objects that implement the trait
traits = get_traits()
self.assertEqual(traits[0].name(), "node-1")
# Note: strong counts are 1 more than you might expect, because the strong_count() method
Expand All @@ -326,11 +384,33 @@ def test_node(self):
self.assertEqual(traits[1].name(), "node-2")
self.assertEqual(traits[1].strong_count(), 2)

# Let's try connecting them together
traits[0].set_parent(traits[1])
# Note: this doesn't increase the Rust strong count, since we wrap the Rust impl with a
# python impl before passing it to `set_parent()`
self.assertEqual(traits[1].strong_count(), 2)
self.assertEqual(ancestor_names(traits[0]), ["node-2"])
self.assertEqual(ancestor_names(traits[1]), [])
self.assertEqual(traits[1].strong_count(), 3)
self.assertEqual(traits[0].get_parent().name(), "node-2")

# Throw in a Python implementation of the trait
# The ancestry chain now goes traits[0] -> traits[1] -> py_node
py_node = PyNode()
traits[1].set_parent(py_node)
self.assertEqual(ancestor_names(traits[0]), ["node-2", "node-py"])
self.assertEqual(ancestor_names(traits[1]), ["node-py"])
self.assertEqual(ancestor_names(py_node), [])

# Rotating things.
# The ancestry chain now goes py_node -> traits[0] -> traits[1]
traits[1].set_parent(None)
py_node.set_parent(traits[0])
self.assertEqual(ancestor_names(py_node), ["node-1", "node-2"])
self.assertEqual(ancestor_names(traits[0]), ["node-2"])
self.assertEqual(ancestor_names(traits[1]), [])

# Make sure we don't crash when undoing it all
py_node.set_parent(None)
traits[0].set_parent(None)

if __name__=='__main__':
Expand Down
Loading

0 comments on commit 7291c25

Please sign in to comment.