Skip to content

Commit

Permalink
Merge pull request #501 from njtierney/fix-check-v-test-500
Browse files Browse the repository at this point in the history
Address `check()` vs `test()` failure modes

Resolves #500
  • Loading branch information
njtierney authored Mar 14, 2022
2 parents 15262c1 + 79a0c77 commit f8c67d5
Show file tree
Hide file tree
Showing 37 changed files with 694 additions and 687 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Type: Package
Package: greta
Title: Simple and Scalable Statistical Modelling in R
Version: 0.4.0
Date: 2021-11-26
Date: 2022-02-21
Authors@R: c(
person("Nick", "Golding", , "nick.golding.research@gmail.com", role = "aut",
comment = c(ORCID = "0000-0001-8916-5570")),
Expand Down Expand Up @@ -121,5 +121,6 @@ Collate:
'reinstallers.R'
'checkers.R'
'test_if_forked_cluster.R'
'testthat-helpers.R'
'zzz.R'
'internals.R'
123 changes: 123 additions & 0 deletions R/testthat-helpers.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# an array of random standard normals with the specificed dims
# e.g. randn(3, 2, 1)
randn <- function(...) {
dim <- c(...)
array(stats::rnorm(prod(dim)), dim = dim)
}

# ditto for standard uniforms
randu <- function(...) {
dim <- c(...)
array(stats::runif(prod(dim)), dim = dim)
}

# create a variable with the same dimensions as as_data(x)
as_variable <- function(x) {
x <- as_2d_array(x)
variable(dim = dim(x))
}


# check a greta operation and the equivalent R operation give the same output
# e.g. check_op(sum, randn(100, 3))
check_op <- function(op, a, b, greta_op = NULL,
other_args = list(),
tolerance = 1e-3,
only = c("data", "variable", "batched"),
relative_error = FALSE) {
if (is.null(greta_op)) {
greta_op <- op
}

r_out <- run_r_op(op, a, b, other_args)

for (type in only) {
# compare with ops on data greta arrays
greta_out <- run_greta_op(greta_op, a, b, other_args, type)
compare_op(r_out, greta_out, tolerance, relative_error = relative_error)
}
}

compare_op <- function(r_out, greta_out, tolerance = 1e-4, relative_error = FALSE) {
if (relative_error){
difference <- as.vector(abs(r_out - greta_out) / abs(r_out))
} else if (!relative_error){
difference <- as.vector(abs(r_out - greta_out))
}
difference_lt_tolerance <- difference < tolerance
are_all_true <- all(difference_lt_tolerance)
are_all_true
testthat::expect_true(are_all_true)
}

run_r_op <- function(op, a, b, other_args) {
arg_list <- list(a)
if (!missing(b)) {
arg_list <- c(arg_list, list(b))
}
arg_list <- c(arg_list, other_args)
do.call(op, arg_list)
}

run_greta_op <- function(greta_op, a, b, other_args,
type = c("data", "variable", "batched")) {
type <- match.arg(type)

converter <- switch(type,
data = as_data,
variable = as_variable,
batched = as_variable
)

g_a <- converter(a)

arg_list <- list(g_a)
values <- list(g_a = a)

if (!missing(b)) {
g_b <- converter(b)
arg_list <- c(arg_list, list(g_b))
values <- c(values, list(g_b = b))
}

arg_list <- c(arg_list, other_args)
out <- do.call(greta_op, arg_list)

if (type == "data") {
# data greta arrays should provide their own values
result <- calculate(out, values = list())[[1]]
} else if (type == "variable") {
result <- grab_via_free_state(out, values)
} else if (type == "batched") {
result <- grab_via_free_state(out, values, batches = 3)
} else {
result <- calculate(out, values = values)[[1]]
}

result
}

# get the value of the target greta array, by passing values for the named
# variable greta arrays via the free state parameter, optionally with batches
grab_via_free_state <- function(target, values, batches = 1) {
dag <- dag_class$new(list(target))
dag$define_tf()
inits <- do.call(initials, values)
inits_flat <- prep_initials(inits, 1, dag)[[1]]
if (batches > 1) {
inits_list <- replicate(batches, inits_flat, simplify = FALSE)
inits_flat <- do.call(rbind, inits_list)
vals <- dag$trace_values(inits_flat)[1, ]
} else {
vals <- dag$trace_values(inits_flat)
}
array(vals, dim = dim(target))
}

expect_ok <- function(expr) {
testthat::expect_error(expr, NA)
}

is.greta_array <- function(x) { # nolint
inherits(x, "greta_array")
}
2 changes: 1 addition & 1 deletion tests/testthat.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
options(testthat.progress.max_fails = 100)
library(testthat)
library(greta)

test_check("greta")
File renamed without changes.
File renamed without changes.
22 changes: 17 additions & 5 deletions tests/testthat/_snaps/tensorflow-rpkg-stability.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
# Tensor behaves as we expect

[1] 1

---

integer(0)

---

Tensor("Reshape:0", shape=(1, 1, 1, 1, 1), dtype=int32)

# shape returns right thing

TensorShape([])
Expand Down Expand Up @@ -26,6 +38,10 @@

TensorShape([Dimension(None), Dimension(4)])

---

TensorShape([Dimension(1), Dimension(1), Dimension(1)])

# placeholder and friends behave the same way

[1] 2 NA
Expand Down Expand Up @@ -89,7 +105,7 @@

---

Tensor("Const_6:0", shape=(2,), dtype=int32)
Tensor("Const_8:0", shape=(2,), dtype=int32)

---

Expand Down Expand Up @@ -167,10 +183,6 @@

[1] TRUE

# tf$reshape behaves as expected

Tensor("Reshape:0", shape=(2, 4), dtype=float32)

# [, [[, and assignment returns right object

TensorShape([Dimension(1)])
Expand Down
28 changes: 0 additions & 28 deletions tests/testthat/_snaps/tf_rpkg.md

This file was deleted.

Loading

0 comments on commit f8c67d5

Please sign in to comment.