Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rerun newest black on entire repo #3178

Merged
merged 1 commit into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions examples/air/air.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(
likelihood_sd=0.3,
use_cuda=False,
):

super().__init__()

self.num_steps = num_steps
Expand Down Expand Up @@ -127,7 +126,6 @@ def __init__(
self.cuda()

def prior(self, n, **kwargs):

state = ModelState(
x=torch.zeros(n, self.x_size, self.x_size, **self.options),
z_pres=torch.ones(n, self.z_pres_size, **self.options),
Expand All @@ -145,7 +143,6 @@ def prior(self, n, **kwargs):
return (z_where, z_pres), state.x

def prior_step(self, t, n, prev, z_pres_prior_p=default_z_pres_prior_p):

# Sample presence indicators.
z_pres = pyro.sample(
"z_pres_{}".format(t),
Expand Down Expand Up @@ -263,7 +260,6 @@ def guide(self, data, batch_size, **kwargs):
return z_where, z_pres

def guide_step(self, t, n, prev, inputs):

rnn_input = torch.cat(
(inputs["embed"], prev.z_where, prev.z_what, prev.z_pres), 1
)
Expand Down
2 changes: 0 additions & 2 deletions examples/air/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def load_data():


def main(**kwargs):

args = argparse.Namespace(**kwargs)

if "save" in args:
Expand Down Expand Up @@ -229,7 +228,6 @@ def per_param_optim_args(param_name):
examples_to_viz = X[5:10]

for i in range(1, args.num_steps + 1):

loss = svi.step(
X, batch_size=args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i)
)
Expand Down
2 changes: 1 addition & 1 deletion examples/air/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
layers = []
in_sizes = [in_size] + out_sizes[0:-1]
sizes = list(zip(in_sizes, out_sizes))
for (i, o) in sizes[0:-1]:
for i, o in sizes[0:-1]:
layers.append(nn.Linear(i, o))
layers.append(non_linear_layer())
layers.append(nn.Linear(sizes[-1][0], sizes[-1][1]))
Expand Down
1 change: 1 addition & 0 deletions examples/contrib/funsor/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def model_0(sequences, lengths, args, batch_size=None, include_prior=True):
# and randomly subsample data to size batch_size. To add jit support we
# silence some warnings and try to avoid dynamic program structure.


# Note that this is the "HMM" model in reference [1] (with the difference that
# in [1] the probabilities probs_x and probs_y are not MAP-regularized with
# Dirichlet and Beta distributions for any of the models)
Expand Down
1 change: 0 additions & 1 deletion examples/contrib/mue/FactorMuE.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def generate_data(small_test, include_stop, device):


def main(args):

# Load dataset.
if args.cpu_data or not args.cuda:
device = torch.device("cpu")
Expand Down
1 change: 0 additions & 1 deletion examples/contrib/mue/ProfileHMM.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def generate_data(small_test, include_stop, device):


def main(args):

pyro.set_rng_seed(args.rng_seed)

# Load dataset.
Expand Down
1 change: 0 additions & 1 deletion examples/contrib/oed/ab_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def true_ape(ns):


def main(num_vi_steps, num_bo_steps, seed):

pyro.set_rng_seed(seed)
pyro.clear_param_store()

Expand Down
1 change: 0 additions & 1 deletion examples/cvae/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def train(
early_stop_patience,
model_path,
):

# Train baseline
baseline_net = BaselineNet(500, 500)
baseline_net.to(device)
Expand Down
2 changes: 0 additions & 2 deletions examples/cvae/cvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def model(self, xs, ys=None):
pyro.module("generation_net", self)
batch_size = xs.shape[0]
with pyro.plate("data"):

# Prior network uses the baseline predictions as initial guess.
# This is the generative process with recurrent connection
with torch.no_grad():
Expand Down Expand Up @@ -130,7 +129,6 @@ def train(
model_path,
pre_trained_baseline_net,
):

# clear param store
pyro.clear_param_store()

Expand Down
2 changes: 0 additions & 2 deletions examples/cvae/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def visualize(
num_samples,
image_path=None,
):

# Load sample random data
datasets, _, dataset_sizes = get_data(
num_quadrant_inputs=num_quadrant_inputs, batch_size=num_images
Expand Down Expand Up @@ -121,7 +120,6 @@ def generate_table(
num_particles,
col_name,
):

# Load sample random data
datasets, dataloaders, dataset_sizes = get_data(
num_quadrant_inputs=num_quadrant_inputs, batch_size=32
Expand Down
2 changes: 0 additions & 2 deletions examples/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def model(
mini_batch_seq_lengths,
annealing_factor=1.0,
):

# this is the number of time steps we need to process in the mini-batch
T_max = mini_batch.size(1)

Expand Down Expand Up @@ -269,7 +268,6 @@ def guide(
mini_batch_seq_lengths,
annealing_factor=1.0,
):

# this is the number of time steps we need to process in the mini-batch
T_max = mini_batch.size(1)
# register all PyTorch (sub)modules with pyro
Expand Down
4 changes: 0 additions & 4 deletions examples/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def jit_prob(equation, *operands, **kwargs):
"""
key = "prob", equation, kwargs["plates"]
if key not in _CACHE:

# This simply wraps einsum for jit compilation.
def _einsum(*operands):
return einsum(equation, *operands, **kwargs)
Expand All @@ -61,7 +60,6 @@ def jit_logprob(equation, *operands, **kwargs):
"""
key = "logprob", equation, kwargs["plates"]
if key not in _CACHE:

# This simply wraps einsum for jit compilation.
def _einsum(*operands):
return einsum(
Expand All @@ -81,7 +79,6 @@ def jit_gradient(equation, *operands, **kwargs):
"""
key = "gradient", equation, kwargs["plates"]
if key not in _CACHE:

# This wraps einsum for jit compilation, but we will call backward on the result.
def _einsum(*operands):
return einsum(
Expand Down Expand Up @@ -114,7 +111,6 @@ def _jit_adjoint(equation, *operands, **kwargs):
backend = kwargs.pop("backend", "pyro.ops.einsum.torch_sample")
key = backend, equation, tuple(x.shape for x in operands), kwargs["plates"]
if key not in _CACHE:

# This wraps a complete adjoint algorithm call.
@ignore_jit_warnings()
def _forward_backward(*operands):
Expand Down
1 change: 1 addition & 0 deletions examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def model_0(sequences, lengths, args, batch_size=None, include_prior=True):
# and randomly subsample data to size batch_size. To add jit support we
# silence some warnings and try to avoid dynamic program structure.


# Note that this is the "HMM" model in reference [1] (with the difference that
# in [1] the probabilities probs_x and probs_y are not MAP-regularized with
# Dirichlet and Beta distributions for any of the models)
Expand Down
3 changes: 0 additions & 3 deletions examples/mixed_hmm/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def _size(tensor):


def run_expt(args):

data_dir = args["folder"]
dataset = "seal" # args["dataset"]
seed = args["seed"]
Expand Down Expand Up @@ -79,7 +78,6 @@ def run_expt(args):
schedule_step_loss = True

for t in range(timesteps):

optimizer.zero_grad()
loss = loss_fn(model, guide)
loss.backward()
Expand Down Expand Up @@ -166,7 +164,6 @@ def closure():


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("-g", "--group", default="none", type=str)
parser.add_argument("-i", "--individual", default="none", type=str)
Expand Down
4 changes: 0 additions & 4 deletions examples/mixed_hmm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def guide_generic(config):

N_c = config["sizes"]["group"]
with pyro.plate("group", N_c, dim=-1):

if config["group"]["random"] == "continuous":
pyro.sample(
"eps_g",
Expand All @@ -59,7 +58,6 @@ def guide_generic(config):
with pyro.plate("individual", N_s, dim=-2), poutine.mask(
mask=config["individual"]["mask"]
):

# individual-level random effects
if config["individual"]["random"] == "continuous":
pyro.sample(
Expand Down Expand Up @@ -158,7 +156,6 @@ def model_generic(config):

N_c = config["sizes"]["group"]
with pyro.plate("group", N_c, dim=-1):

# group-level random effects
if config["group"]["random"] == "discrete":
# group-level discrete effect
Expand All @@ -179,7 +176,6 @@ def model_generic(config):
with pyro.plate("individual", N_s, dim=-2), poutine.mask(
mask=config["individual"]["mask"]
):

# individual-level random effects
if config["individual"]["random"] == "discrete":
# individual-level discrete effect
Expand Down
1 change: 0 additions & 1 deletion examples/mixed_hmm/seal_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def download_seal_data(filename):


def prepare_seal(filename, random_effects):

if not os.path.exists(filename):
download_seal_data(filename)

Expand Down
1 change: 0 additions & 1 deletion examples/rsa/hyperbole.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def test_truth():


def main(args):

# test_truth()

pragmatic_marginal = pragmatic_listener(args.price)
Expand Down
1 change: 0 additions & 1 deletion examples/rsa/schelling_false.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def bob(preference, depth):


def main(args):

# Here Alice and Bob slightly prefer one location over the other a priori
shared_preference = torch.tensor([args.preference])

Expand Down
1 change: 0 additions & 1 deletion examples/rsa/search_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def sample_escape(tr, site):
)

def _fn(*args, **kwargs):

for i in range(int(1e6)):
assert (
not queue.empty()
Expand Down
1 change: 0 additions & 1 deletion examples/rsa/semantic_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ def literal_listener_raw(utterance, qud):


def main(args):

mll = Marginal(literal_listener_raw, num_samples=args.num_samples)

def is_any_qud(world):
Expand Down
8 changes: 1 addition & 7 deletions examples/vae/ss_vae_M2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __init__(
use_cuda=False,
aux_loss_multiplier=None,
):

super().__init__()

# initialize the class with all arguments provided to the constructor
Expand All @@ -68,7 +67,6 @@ def __init__(
self.setup_networks()

def setup_networks(self):

z_dim = self.z_dim
hidden_sizes = self.hidden_layers

Expand Down Expand Up @@ -127,7 +125,6 @@ def model(self, xs, ys=None):
batch_size = xs.size(0)
options = dict(dtype=xs.dtype, device=xs.device)
with pyro.plate("data"):

# sample the handwriting style from the constant prior distribution
prior_loc = torch.zeros(batch_size, self.z_dim, **options)
prior_scale = torch.ones(batch_size, self.z_dim, **options)
Expand Down Expand Up @@ -167,7 +164,6 @@ def guide(self, xs, ys=None):
"""
# inform Pyro that the variables in the batch of xs, ys are conditionally independent
with pyro.plate("data"):

# if the class label (the digit) is not supervised, sample
# (and score) the digit with the variational distribution
# q(y|x) = categorical(alpha(x))
Expand Down Expand Up @@ -245,7 +241,6 @@ def run_inference_for_epoch(data_loaders, losses, periodic_interval_batches):
# count the number of supervised batches seen in this epoch
ctr_sup = 0
for i in range(batches_per_epoch):

# whether this batch is supervised or not
is_supervised = (i % periodic_interval_batches == 1) and ctr_sup < sup_batches

Expand Down Expand Up @@ -277,7 +272,7 @@ def get_accuracy(data_loader, classifier_fn, batch_size):
predictions, actuals = [], []

# use the appropriate data loader
for (xs, ys) in data_loader:
for xs, ys in data_loader:
# use classification function to compute all predictions for each batch
predictions.append(classifier_fn(xs))
actuals.append(ys)
Expand Down Expand Up @@ -370,7 +365,6 @@ def main(args):

# run inference for a certain number of epochs
for i in range(0, args.num_epochs):

# get the losses for an epoch
epoch_losses_sup, epoch_losses_unsup = run_inference_for_epoch(
data_loaders, losses, periodic_interval_batches
Expand Down
2 changes: 0 additions & 2 deletions examples/vae/utils/custom_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,11 @@ def __init__(
else output_activation
)
else:

# we're going to have a bunch of separate layers we can spit out (a tuple of outputs)
out_layers = []

# multiple outputs? handle separately
for out_ix, out_size in enumerate(output_size):

# for a single output object, we create a linear layer and some weights
split_layer = []

Expand Down
1 change: 0 additions & 1 deletion examples/vae/utils/mnist_cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def target_transform(y):
], "invalid train/test option values"

if mode in ["sup", "unsup", "valid"]:

# transform the training data if transformations are provided
if transform is not None:
self.data = transform(self.data.float())
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/funsor/handlers/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,10 @@ def queue(
:param num_samples: optional number of extended traces for extend_fn to return
:returns: stochastic function decorated with poutine logic
"""

# TODO rewrite this to use purpose-built trace/replay handlers
def wrapper(wrapped):
def _fn(*args, **kwargs):

for i in range(max_tries):
assert (
not queue.empty()
Expand Down
2 changes: 0 additions & 2 deletions pyro/contrib/funsor/handlers/named_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __exit__(self, *args, **kwargs):

@staticmethod # only depends on the global _DIM_STACK state, not self
def _pyro_to_data(msg):

(funsor_value,) = msg["args"]
name_to_dim = msg["kwargs"].setdefault("name_to_dim", OrderedDict())
dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL)
Expand All @@ -82,7 +81,6 @@ def _pyro_to_data(msg):

@staticmethod # only depends on the global _DIM_STACK state, not self
def _pyro_to_funsor(msg):

if len(msg["args"]) == 2:
raw_value, output = msg["args"]
else:
Expand Down
1 change: 0 additions & 1 deletion pyro/contrib/funsor/handlers/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def _genvalue(self, key, value_request):
)

def allocate(self, key_to_value_request):

# step 1: split into fresh and non-fresh
key_to_value = OrderedDict()
for key, value_request in tuple(key_to_value_request.items()):
Expand Down
Loading