Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] Support mask in Timeseries Forecast . #2708

Open
eromoe opened this issue Dec 7, 2020 · 8 comments
Open

[Feature request] Support mask in Timeseries Forecast . #2708

eromoe opened this issue Dec 7, 2020 · 8 comments
Assignees
Labels
documentation enhancement help wanted Issues suitable for, and inviting external contributions

Comments

@eromoe
Copy link

eromoe commented Dec 7, 2020

Missing observation is very common in real work , but I doesn't found a way to write a correct model .
And tutorial http://pyro.ai/examples/forecasting_iii.html do not cover this problem : nan in real data .
For example :
in 2010, there was only 48 station
in 2011, 3 station was closed and 5 new opened, 50 stations .
in 2012, 2 station was closed in 2011 reponed

For another example :
My data is salecount of various products in many stores .
I have reshape the szie to torch.Size([44, 103, 671, 1]) , means:
44 stores, 103 products , 671 days salecount . Some stores may be closed in different days, so as products would be off.shelf by many reason .

random 10 products history salecount :

image

Create matrix must have nan values , and

  • We can’t fill them by 0 because they are different to true 0 .
  • We should not take nan values into count.
  • We can’t drop the nan when trainning , because that break timeseries order .
    These are real cases .

According to https://forum.pyro.ai/t/how-to-ignore-nan-values-when-do-hierachical-forecast/2412

I alter forecast_iii code to below, structure is totally same , I just change some names .
I add a mask to noise distribution , but without luck , it doesn't work in prediction .

import math
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.contrib.examples.bart import load_bart_od
from pyro.contrib.forecast import ForecastingModel, Forecaster, eval_crps
from pyro.infer.reparam import LocScaleReparam, SymmetricStableReparam
from pyro.ops.tensor_utils import periodic_repeat
from pyro.ops.stats import quantile

from pyro.contrib.forecast.util import reshape_batch

@reshape_batch.register(dist.MaskedDistribution)
def _(d, batch_shape):
    base_dist = reshape_batch(d.base_dist, batch_shape)
    return dist.MaskedDistribution(base_dist, d._mask.reshape(base_dist.shape()) )


class Model2(ForecastingModel):
    def __init__(self, mask=None):
        super().__init__()
        self.mask = mask

    def model(self, zero_data, covariates):
        num_stores, num_products, duration, one = zero_data.shape

        # We construct plates once so we can reuse them later. We ensure they don't collide by
        # specifying different dim args for each: -3, -2, -1. Note the time_plate is dim=-1.
        stores_plate = pyro.plate("stores", num_stores, dim=-3)
        products_plate = pyro.plate("products", num_products, dim=-2)
        day_of_week_plate = pyro.plate("day_of_week", 7, dim=-1)

        # Let's model the time-dependent part with only O(num_stations * duration) 复杂度 many
        # parameters, rather than the full possible O(num_stations ** 2 * duration) data size.
        drift_stability = pyro.sample("drift_stability", dist.Uniform(1, 2))
        drift_scale = pyro.sample("drift_scale", dist.LogNormal(-20, 5))
        with stores_plate:
            with day_of_week_plate:
                stores_seasonal = pyro.sample("stores_seasonal", dist.Normal(0, 5))
        with products_plate:
            with day_of_week_plate:
                product_seasonal = pyro.sample("product_seasonal", dist.Normal(0, 5))
            with self.time_plate:
                with poutine.reparam(config={"drift": LocScaleReparam()}):
                    with poutine.reparam(config={"drift": SymmetricStableReparam()}):
                        drift = pyro.sample("drift",
                                            dist.Stable(drift_stability, 0, drift_scale))

        with stores_plate, products_plate:
            pairwise = pyro.sample("pairwise", dist.Normal(0, 1))

        # Outside of the time plate we can now form the prediction.
        seasonal = stores_seasonal + product_seasonal  # Note this broadcasts.
        seasonal = periodic_repeat(seasonal, duration, dim=-1)  # 组装时间周期
        motion = drift.cumsum(dim=-1)  # A Levy stable motion to model shocks.
        prediction = motion + seasonal + pairwise

        # We will decompose the noise scale parameter into
        # an origin-local and a destination-local component.
        with stores_plate:
            stores_scale = pyro.sample("stores_scale", dist.LogNormal(-5, 5))
        with products_plate:
            products_scale = pyro.sample("products_scale", dist.LogNormal(-5, 5))
        scale = stores_scale + products_scale

        # At this point our prediction and scale have shape (50, 50, duration) and (50, 50, 1)
        # respectively, but we want them to have shape (50, 50, duration, 1) to satisfy the
        # Forecaster requirements.
        scale = scale.unsqueeze(-1)
        prediction = prediction.unsqueeze(-1)

        # Finally we construct a noise distribution and call the .predict() method.
        # Note that predict must be called inside the origin and destination plates.
        noise_dist = dist.Normal(0, scale)

        if self.mask is not None:
            noise_dist = noise_dist.mask(self.mask)

        with stores_plate, products_plate:
            self.predict(noise_dist, prediction)


test_data = torch.Tensor(msc)

T2 = test_data.size(-2)    # end
T1 = T2 - 7 * 2  # train/test split
T0 = 0    # beginning: train on 90 days of data
covariates = torch.zeros(test_data.size(-2), 0)  # empty covariates

pyro.set_rng_seed(1)
pyro.clear_param_store()
mask = torch.isnan(test_data)

# test_data = torch.Tensor(msc)
test_data = torch.Tensor(np.nan_to_num(msc))

covariates = torch.zeros(test_data.size(-2), 0)  
forecaster = Forecaster(Model2(mask=mask[..., T0:T1, :]), test_data[..., T0:T1, :], covariates[T0:T1],
                        learning_rate=0.1, learning_rate_decay=1, num_steps=501, log_every=50)

samples = forecaster(test_data[..., T0:T1, :], covariates[T0:T2], num_samples=100)

###################################
# here tensor([], size=(44, 103, 0, 1))
# then below code would encounter error 
###################################

samples.clamp_(min=0)  # apply domain knowledge: the samples must be positive
p10, p50, p90 = quantile(samples[..., 0], (0.1, 0.5, 0.9))
crps = eval_crps(samples, test_data[..., T1:T2, :])
print(samples.shape, p10.shape)
@fritzo
Copy link
Member

fritzo commented Dec 7, 2020

Hi @eromoe, can you try inverting your mask:

- mask = torch.isnan(test_data)
+ mask = ~torch.isnan(test_data)

In Pyro's mask semantics, True means observed and False means missing.

@eromoe
Copy link
Author

eromoe commented Dec 8, 2020

@fritzo

Oh, I see, prefix_condition is used for forecaster . Here is another problem come

Traceback (most recent call last):
  File "/home/ufo/.pycharm_helpers/pydev/pydevd.py", line 1741, in <module>
    main()
  File "/home/ufo/.pycharm_helpers/pydev/pydevd.py", line 1735, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "/home/ufo/.pycharm_helpers/pydev/pydevd.py", line 1135, in run
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/home/ufo/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/tmp/pycharm_project_837/src/pyro_test.py", line 139, in <module>
    samples = forecaster(test_data[..., T0:T1, :], covariates[T0:T2], num_samples=100)
  File "/home/ufo/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py", line 333, in __call__
    return super().__call__(data, covariates, num_samples, batch_size)
  File "/home/ufo/anaconda3/envs/dl/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/ufo/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py", line 361, in forward
    return self.model(data, covariates)
  File "/home/ufo/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/nn/module.py", line 413, in __call__
    return super().__call__(*args, **kwargs)
  File "/home/ufo/anaconda3/envs/dl/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/ufo/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py", line 174, in forward
    self.model(zero_data, covariates)
  File "/tmp/pycharm_project_837/src/pyro_test.py", line 117, in model
    self.predict(noise_dist, prediction)
  File "/home/ufo/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py", line 104, in predict
    noise_dist.batch_shape[:-2] + prediction.shape[-2:])
  File "/home/ufo/anaconda3/envs/dl/lib/python3.7/site-packages/pyro/distributions/torch_distribution.py", line 270, in expand
    new.base_dist = self.base_dist.expand(batch_shape)
  File "/home/ufo/anaconda3/envs/dl/lib/python3.7/site-packages/torch/distributions/normal.py", line 54, in expand
    new.loc = self.loc.expand(batch_shape)
RuntimeError: The expanded size of the tensor (671) must match the existing size (657) at non-singleton dimension 3.  Target sizes: [100, 44, 103, 671, 1].  Tensor sizes: [100, 44, 103, 657, 1]

here:

self.loc.shape
Out[12]: torch.Size([100, 44, 103, 657, 1])
batch_shape
Out[13]: torch.Size([100, 44, 103, 671, 1])

I have seen time length in prediction = time length in zero_data + time length in covariance .
So this expand would fail

noise_dist = noise_dist.expand(
                    noise_dist.batch_shape[:-2] + prediction.shape[-2:])

As to MaskedDistribution.expand , mask size also doesn't fit to prediction size too (mask was fixed after model init )

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(MaskedDistribution, _instance)
        batch_shape = torch.Size(batch_shape)
        new.base_dist = self.base_dist.expand(batch_shape)   # exception at here 
        new._mask = self._mask    # but here mask size is old too 
        if isinstance(new._mask, torch.Tensor):
            new._mask = new._mask.expand(batch_shape)
        super(MaskedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False)
        new._validate_args = self._validate_args
        return new

I also have a question about how pyro make prediction , I don't know why it need whole time length 671 here .
For example . there are usual two kinds of timeseries forecast model,

  1. foreward , predict one step , reuse this step , predict next , until hit target prediction length .
  2. training by multi-steps , predict multi-steps at once

But I can't figure out how pyro do here , there is a foreward method , but I don't understand how it make prediction . Could you make some explanation or provide some resouces to make this clear ?

@fritzo
Copy link
Member

fritzo commented Dec 8, 2020

@eromoe could you paste a little of your model code that triggers the above exception? Ideally could you provide a minimal model and fake data that triggers the error? It's hard for me to suggest a fix without being able to reproduce the error.

I also have a question about how pyro make prediction

pyro.contrib.forecast does batch multi-step prediction (2). To predict you create a Forecaster instance and call it on covariates longer than your data:

model = MyForeacastingModel(...)
forecaster = Forecaster(model, data, covariates_as_long_as_data)  # fits a model 
prediction = forecaster(data, covariates_that_are_longer_than_data, num_samples=100)  # predicts

If you want to do windowed prediction then you'll need to create multiple Forecaster instances, one per window that is a truncated form of data, covariates.

@eromoe
Copy link
Author

eromoe commented Dec 9, 2020

Sure, it is very easy to generate .

import math
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.contrib.examples.bart import load_bart_od
from pyro.contrib.forecast import ForecastingModel, Forecaster, eval_crps
from pyro.infer.reparam import LocScaleReparam, SymmetricStableReparam
from pyro.ops.tensor_utils import periodic_repeat
from pyro.ops.stats import quantile

import pandas as pd
import numpy as np


assert pyro.__version__.startswith('1.5.1')
pyro.enable_validation(True)
pyro.set_rng_seed(20200221)


from pyro.contrib.forecast.util import reshape_batch
from pyro.contrib.forecast.util import prefix_condition

@reshape_batch.register(dist.MaskedDistribution)
def _(d, batch_shape):
    mask = d._mask.reshape(batch_shape)
    base_dist = reshape_batch(d.base_dist, batch_shape)
    return base_dist.mask(mask)

@prefix_condition.register(dist.MaskedDistribution)
def _(d, data):
    base_dist = prefix_condition(d.base_dist, data)
    mask = d._mask[tuple(slice(-size, None) for size in base_dist.batch_shape)]
    return base_dist.mask(mask)



class Model2(ForecastingModel):
    def __init__(self, mask=None):
        super().__init__()
        self.mask = mask

    def model(self, zero_data, covariates):
        num_stores, num_products, duration, one = zero_data.shape

        # We construct plates once so we can reuse them later. We ensure they don't collide by
        # specifying different dim args for each: -3, -2, -1. Note the time_plate is dim=-1.
        stores_plate = pyro.plate("stores", num_stores, dim=-3)
        products_plate = pyro.plate("products", num_products, dim=-2)
        day_of_week_plate = pyro.plate("day_of_week", 7, dim=-1)

        # Let's model the time-dependent part with only O(num_stations * duration) 复杂度 many
        # parameters, rather than the full possible O(num_stations ** 2 * duration) data size.
        drift_stability = pyro.sample("drift_stability", dist.Uniform(1, 2))
        drift_scale = pyro.sample("drift_scale", dist.LogNormal(-20, 5))
        with stores_plate:
            with day_of_week_plate:
                stores_seasonal = pyro.sample("stores_seasonal", dist.Normal(0, 5))
        with products_plate:
            with day_of_week_plate:
                product_seasonal = pyro.sample("product_seasonal", dist.Normal(0, 5))
            with self.time_plate:
                with poutine.reparam(config={"drift": LocScaleReparam()}):
                    with poutine.reparam(config={"drift": SymmetricStableReparam()}):
                        drift = pyro.sample("drift",
                                            dist.Stable(drift_stability, 0, drift_scale))

        with stores_plate, products_plate:
            pairwise = pyro.sample("pairwise", dist.Normal(0, 1))

        # Outside of the time plate we can now form the prediction.
        seasonal = stores_seasonal + product_seasonal  # Note this broadcasts.
        seasonal = periodic_repeat(seasonal, duration, dim=-1)  # 组装时间周期
        motion = drift.cumsum(dim=-1)  # A Levy stable motion to model shocks.
        prediction = motion + seasonal + pairwise

        # We will decompose the noise scale parameter into
        # an origin-local and a destination-local component.
        with stores_plate:
            stores_scale = pyro.sample("stores_scale", dist.LogNormal(-5, 5))
        with products_plate:
            products_scale = pyro.sample("products_scale", dist.LogNormal(-5, 5))
        scale = stores_scale + products_scale

        # At this point our prediction and scale have shape (50, 50, duration) and (50, 50, 1)
        # respectively, but we want them to have shape (50, 50, duration, 1) to satisfy the
        # Forecaster requirements.
        scale = scale.unsqueeze(-1)
        prediction = prediction.unsqueeze(-1)

        # Finally we construct a noise distribution and call the .predict() method.
        # Note that predict must be called inside the origin and destination plates.
        noise_dist = dist.Normal(0, scale)

        if self.mask is not None:
            noise_dist = noise_dist.mask(self.mask)

        with stores_plate, products_plate:
            self.predict(noise_dist, prediction)


import scipy.stats as ss
np.random.seed(20)
x = ss.norm.rvs(0, 100, (44, 103, 271, 1))
for i in range(int(x.size*.25)):
    a = np.random.choice(44)
    b = np.random.choice(103)
    c = np.random.choice(271)
    x[a,b,c,0] = np.nan
msc = x

test_data = torch.Tensor(msc)

T2 = test_data.size(-2)    # end
T1 = T2 - 7 * 2  # train/test split
T0 = 0    # beginning: train on 90 days of data
covariates = torch.zeros(test_data.size(-2), 0)  # empty covariates

pyro.set_rng_seed(1)
pyro.clear_param_store()
mask = ~torch.isnan(test_data)

# test_data = torch.Tensor(msc)
test_data = torch.Tensor(np.nan_to_num(msc))

covariates = torch.zeros(test_data.size(-2), 0)  
forecaster = Forecaster(Model2(mask=mask[..., T0:T1, :]), test_data[..., T0:T1, :], covariates[T0:T1],
                        learning_rate=0.1, learning_rate_decay=1, num_steps=501, log_every=50)

samples = forecaster(test_data[..., T0:T1, :], covariates[T0:T2], num_samples=100)
samples.clamp_(min=0)  # apply domain knowledge: the samples must be positive
p10, p50, p90 = quantile(samples[..., 0], (0.1, 0.5, 0.9))
crps = eval_crps(samples, test_data[..., T1:T2, :])
print(samples.shape, p10.shape)


trace = poutine.trace(Model2(mask=mask), test_data[..., T0:T1, :], covariates[T0:T1], ).get_trace()
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())

@eromoe
Copy link
Author

eromoe commented Dec 11, 2020

@fritzo Is there any thing I can help ?
I am aslo trying to understand the source code , if you can provide some reading suggestion would be very nice : )

@fritzo
Copy link
Member

fritzo commented Dec 12, 2020

Hi @eromoe, I'll try to debug it this weekend. Thanks for providing a reproducible example!

@fritzo
Copy link
Member

fritzo commented Dec 13, 2020

Hi @eromoe, I managed to get your model working with a small fix (see below). The problem is that the mask should have shape broadcastable to zero_data.shape. But your self.mask is stored once in .__init__() whereas zero_data changes depending on whether the model is being trained (which works in your case) or being used for prediction (which fails in your case because the expanded size of zero_data no longer matches the static size of self.mask).

My proposed solution is to apply the mask only during training. This works because masking only affects the log_prob and log_prob is used only during training, not during prediction. We can distinguish training from prediction because the prediction tensor requires gradients only during training (the prediction method forecaster.__call__() is called in a torch.no_grad() context). To apply this solution you can simply change

- if self.mask is not None:
+ if self.mask is not None and prediction.requires_grad:
      noise_dist = noise_dist.mask(self.mask)

I admit this is kind of a hack; if you find a more elegant solution please share on this thread. If you're feeling ambitious, we'd welcome an additional section in one of the tutorials demonstrating how to do masking 😄

Here's the full script:

import math
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.contrib.examples.bart import load_bart_od
from pyro.contrib.forecast import ForecastingModel, Forecaster, eval_crps
from pyro.infer.reparam import LocScaleReparam, SymmetricStableReparam
from pyro.ops.tensor_utils import periodic_repeat
from pyro.ops.stats import quantile

import pandas as pd
import numpy as np

assert pyro.__version__.startswith('1.5.1')
pyro.enable_validation(True)
pyro.set_rng_seed(20200221)

from pyro.contrib.forecast.util import reshape_batch
from pyro.contrib.forecast.util import prefix_condition

@reshape_batch.register(dist.MaskedDistribution)
def _(d, batch_shape):
    mask = d._mask.reshape(batch_shape)
    base_dist = reshape_batch(d.base_dist, batch_shape)
    return base_dist.mask(mask)

@prefix_condition.register(dist.MaskedDistribution)
def _(d, data):
    base_dist = prefix_condition(d.base_dist, data)
    mask = d._mask[tuple(slice(-size, None) for size in base_dist.batch_shape)]
    return base_dist.mask(mask)

class Model2(ForecastingModel):
    def __init__(self, mask=None):
        super().__init__()
        self.mask = mask

    def model(self, zero_data, covariates):
        num_stores, num_products, duration, one = zero_data.shape

        # We construct plates once so we can reuse them later. We ensure they don't collide by
        # specifying different dim args for each: -3, -2, -1. Note the time_plate is dim=-1.
        stores_plate = pyro.plate("stores", num_stores, dim=-3)
        products_plate = pyro.plate("products", num_products, dim=-2)
        day_of_week_plate = pyro.plate("day_of_week", 7, dim=-1)

        # Let's model the time-dependent part with only O(num_stations * duration) 复杂度 many
        # parameters, rather than the full possible O(num_stations ** 2 * duration) data size.
        drift_stability = pyro.sample("drift_stability", dist.Uniform(1, 2))
        drift_scale = pyro.sample("drift_scale", dist.LogNormal(-20, 5))
        with stores_plate:
            with day_of_week_plate:
                stores_seasonal = pyro.sample("stores_seasonal", dist.Normal(0, 5))
        with products_plate:
            with day_of_week_plate:
                product_seasonal = pyro.sample("product_seasonal", dist.Normal(0, 5))
            with self.time_plate:
                with poutine.reparam(config={"drift": LocScaleReparam()}):
                    with poutine.reparam(config={"drift": SymmetricStableReparam()}):
                        drift = pyro.sample("drift",
                                            dist.Stable(drift_stability, 0, drift_scale))

        with stores_plate, products_plate:
            pairwise = pyro.sample("pairwise", dist.Normal(0, 1))

        # Outside of the time plate we can now form the prediction.
        seasonal = stores_seasonal + product_seasonal  # Note this broadcasts.
        seasonal = periodic_repeat(seasonal, duration, dim=-1)  # 组装时间周期
        motion = drift.cumsum(dim=-1)  # A Levy stable motion to model shocks.
        prediction = motion + seasonal + pairwise

        # We will decompose the noise scale parameter into
        # an origin-local and a destination-local component.
        with stores_plate:
            stores_scale = pyro.sample("stores_scale", dist.LogNormal(-5, 5))
        with products_plate:
            products_scale = pyro.sample("products_scale", dist.LogNormal(-5, 5))
        scale = stores_scale + products_scale

        # At this point our prediction and scale have shape (50, 50, duration) and (50, 50, 1)
        # respectively, but we want them to have shape (50, 50, duration, 1) to satisfy the
        # Forecaster requirements.
        scale = scale.unsqueeze(-1)
        prediction = prediction.unsqueeze(-1)

        # Finally we construct a noise distribution and call the .predict() method.
        # Note that predict must be called inside the origin and destination plates.
        noise_dist = dist.Normal(zero_data, scale)  # We should be able to use either 0 or zero_data for loc.

        # Mask only if in training mode, i.e. if prediction.requires_grad.
        # This is kind of a hack, but works since prediction is called in a torch.no_grad() context.
        if self.mask is not None and prediction.requires_grad:
            noise_dist = noise_dist.mask(self.mask)

        with stores_plate, products_plate:
            self.predict(noise_dist, prediction)


import scipy.stats as ss
np.random.seed(20)
x = ss.norm.rvs(0, 100, (44, 103, 271, 1))
for i in range(int(x.size*.25)):
    a = np.random.choice(44)
    b = np.random.choice(103)
    c = np.random.choice(271)
    x[a,b,c,0] = np.nan
msc = x

test_data = torch.Tensor(msc)

T2 = test_data.size(-2)    # end
T1 = T2 - 7 * 2  # train/test split
T0 = 0    # beginning: train on 90 days of data
covariates = torch.zeros(test_data.size(-2), 0)  # empty covariates

pyro.set_rng_seed(1)
pyro.clear_param_store()
mask = ~torch.isnan(test_data)

# test_data = torch.Tensor(msc)
test_data = torch.Tensor(np.nan_to_num(msc))

covariates = torch.zeros(test_data.size(-2), 0)
forecaster = Forecaster(Model2(mask=mask[..., T0:T1, :]), test_data[..., T0:T1, :], covariates[T0:T1],
                        learning_rate=0.1, learning_rate_decay=1, num_steps=501, log_every=50)

samples = forecaster(test_data[..., T0:T1, :], covariates[T0:T2], num_samples=100)
samples.clamp_(min=0)  # apply domain knowledge: the samples must be positive
p10, p50, p90 = quantile(samples[..., 0], (0.1, 0.5, 0.9))
crps = eval_crps(samples, test_data[..., T1:T2, :])
print(samples.shape, p10.shape)

with torch.no_grad():  # Run in a no_grad context to avoid shape error.
    trace = poutine.trace(Model2(mask=mask)).get_trace(test_data[..., T0:T1, :], covariates[T0:T1])
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())

@eromoe
Copy link
Author

eromoe commented Dec 14, 2020

Thank you for the detailed explanation !
I would like to help the new section , just need some times 😄

@fritzo fritzo added the help wanted Issues suitable for, and inviting external contributions label Jun 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation enhancement help wanted Issues suitable for, and inviting external contributions
Projects
None yet
Development

No branches or pull requests

2 participants