-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature request] Support mask in Timeseries Forecast . #2708
Comments
Hi @eromoe, can you try inverting your mask: - mask = torch.isnan(test_data)
+ mask = ~torch.isnan(test_data) In Pyro's mask semantics, True means observed and False means missing. |
Oh, I see,
here: self.loc.shape
Out[12]: torch.Size([100, 44, 103, 657, 1])
batch_shape
Out[13]: torch.Size([100, 44, 103, 671, 1]) I have seen time length in noise_dist = noise_dist.expand(
noise_dist.batch_shape[:-2] + prediction.shape[-2:]) As to def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(MaskedDistribution, _instance)
batch_shape = torch.Size(batch_shape)
new.base_dist = self.base_dist.expand(batch_shape) # exception at here
new._mask = self._mask # but here mask size is old too
if isinstance(new._mask, torch.Tensor):
new._mask = new._mask.expand(batch_shape)
super(MaskedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new I also have a question about how pyro make prediction , I don't know why it need whole time length 671 here .
But I can't figure out how pyro do here , there is a foreward method , but I don't understand how it make prediction . Could you make some explanation or provide some resouces to make this clear ? |
@eromoe could you paste a little of your model code that triggers the above exception? Ideally could you provide a minimal model and fake data that triggers the error? It's hard for me to suggest a fix without being able to reproduce the error.
model = MyForeacastingModel(...)
forecaster = Forecaster(model, data, covariates_as_long_as_data) # fits a model
prediction = forecaster(data, covariates_that_are_longer_than_data, num_samples=100) # predicts If you want to do windowed prediction then you'll need to create multiple |
Sure, it is very easy to generate . import math
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.contrib.examples.bart import load_bart_od
from pyro.contrib.forecast import ForecastingModel, Forecaster, eval_crps
from pyro.infer.reparam import LocScaleReparam, SymmetricStableReparam
from pyro.ops.tensor_utils import periodic_repeat
from pyro.ops.stats import quantile
import pandas as pd
import numpy as np
assert pyro.__version__.startswith('1.5.1')
pyro.enable_validation(True)
pyro.set_rng_seed(20200221)
from pyro.contrib.forecast.util import reshape_batch
from pyro.contrib.forecast.util import prefix_condition
@reshape_batch.register(dist.MaskedDistribution)
def _(d, batch_shape):
mask = d._mask.reshape(batch_shape)
base_dist = reshape_batch(d.base_dist, batch_shape)
return base_dist.mask(mask)
@prefix_condition.register(dist.MaskedDistribution)
def _(d, data):
base_dist = prefix_condition(d.base_dist, data)
mask = d._mask[tuple(slice(-size, None) for size in base_dist.batch_shape)]
return base_dist.mask(mask)
class Model2(ForecastingModel):
def __init__(self, mask=None):
super().__init__()
self.mask = mask
def model(self, zero_data, covariates):
num_stores, num_products, duration, one = zero_data.shape
# We construct plates once so we can reuse them later. We ensure they don't collide by
# specifying different dim args for each: -3, -2, -1. Note the time_plate is dim=-1.
stores_plate = pyro.plate("stores", num_stores, dim=-3)
products_plate = pyro.plate("products", num_products, dim=-2)
day_of_week_plate = pyro.plate("day_of_week", 7, dim=-1)
# Let's model the time-dependent part with only O(num_stations * duration) 复杂度 many
# parameters, rather than the full possible O(num_stations ** 2 * duration) data size.
drift_stability = pyro.sample("drift_stability", dist.Uniform(1, 2))
drift_scale = pyro.sample("drift_scale", dist.LogNormal(-20, 5))
with stores_plate:
with day_of_week_plate:
stores_seasonal = pyro.sample("stores_seasonal", dist.Normal(0, 5))
with products_plate:
with day_of_week_plate:
product_seasonal = pyro.sample("product_seasonal", dist.Normal(0, 5))
with self.time_plate:
with poutine.reparam(config={"drift": LocScaleReparam()}):
with poutine.reparam(config={"drift": SymmetricStableReparam()}):
drift = pyro.sample("drift",
dist.Stable(drift_stability, 0, drift_scale))
with stores_plate, products_plate:
pairwise = pyro.sample("pairwise", dist.Normal(0, 1))
# Outside of the time plate we can now form the prediction.
seasonal = stores_seasonal + product_seasonal # Note this broadcasts.
seasonal = periodic_repeat(seasonal, duration, dim=-1) # 组装时间周期
motion = drift.cumsum(dim=-1) # A Levy stable motion to model shocks.
prediction = motion + seasonal + pairwise
# We will decompose the noise scale parameter into
# an origin-local and a destination-local component.
with stores_plate:
stores_scale = pyro.sample("stores_scale", dist.LogNormal(-5, 5))
with products_plate:
products_scale = pyro.sample("products_scale", dist.LogNormal(-5, 5))
scale = stores_scale + products_scale
# At this point our prediction and scale have shape (50, 50, duration) and (50, 50, 1)
# respectively, but we want them to have shape (50, 50, duration, 1) to satisfy the
# Forecaster requirements.
scale = scale.unsqueeze(-1)
prediction = prediction.unsqueeze(-1)
# Finally we construct a noise distribution and call the .predict() method.
# Note that predict must be called inside the origin and destination plates.
noise_dist = dist.Normal(0, scale)
if self.mask is not None:
noise_dist = noise_dist.mask(self.mask)
with stores_plate, products_plate:
self.predict(noise_dist, prediction)
import scipy.stats as ss
np.random.seed(20)
x = ss.norm.rvs(0, 100, (44, 103, 271, 1))
for i in range(int(x.size*.25)):
a = np.random.choice(44)
b = np.random.choice(103)
c = np.random.choice(271)
x[a,b,c,0] = np.nan
msc = x
test_data = torch.Tensor(msc)
T2 = test_data.size(-2) # end
T1 = T2 - 7 * 2 # train/test split
T0 = 0 # beginning: train on 90 days of data
covariates = torch.zeros(test_data.size(-2), 0) # empty covariates
pyro.set_rng_seed(1)
pyro.clear_param_store()
mask = ~torch.isnan(test_data)
# test_data = torch.Tensor(msc)
test_data = torch.Tensor(np.nan_to_num(msc))
covariates = torch.zeros(test_data.size(-2), 0)
forecaster = Forecaster(Model2(mask=mask[..., T0:T1, :]), test_data[..., T0:T1, :], covariates[T0:T1],
learning_rate=0.1, learning_rate_decay=1, num_steps=501, log_every=50)
samples = forecaster(test_data[..., T0:T1, :], covariates[T0:T2], num_samples=100)
samples.clamp_(min=0) # apply domain knowledge: the samples must be positive
p10, p50, p90 = quantile(samples[..., 0], (0.1, 0.5, 0.9))
crps = eval_crps(samples, test_data[..., T1:T2, :])
print(samples.shape, p10.shape)
trace = poutine.trace(Model2(mask=mask), test_data[..., T0:T1, :], covariates[T0:T1], ).get_trace()
trace.compute_log_prob() # optional, but allows printing of log_prob shapes
print(trace.format_shapes()) |
@fritzo Is there any thing I can help ? |
Hi @eromoe, I'll try to debug it this weekend. Thanks for providing a reproducible example! |
Hi @eromoe, I managed to get your model working with a small fix (see below). The problem is that the My proposed solution is to apply the mask only during training. This works because masking only affects the - if self.mask is not None:
+ if self.mask is not None and prediction.requires_grad:
noise_dist = noise_dist.mask(self.mask) I admit this is kind of a hack; if you find a more elegant solution please share on this thread. If you're feeling ambitious, we'd welcome an additional section in one of the tutorials demonstrating how to do masking 😄 Here's the full script: import math
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.contrib.examples.bart import load_bart_od
from pyro.contrib.forecast import ForecastingModel, Forecaster, eval_crps
from pyro.infer.reparam import LocScaleReparam, SymmetricStableReparam
from pyro.ops.tensor_utils import periodic_repeat
from pyro.ops.stats import quantile
import pandas as pd
import numpy as np
assert pyro.__version__.startswith('1.5.1')
pyro.enable_validation(True)
pyro.set_rng_seed(20200221)
from pyro.contrib.forecast.util import reshape_batch
from pyro.contrib.forecast.util import prefix_condition
@reshape_batch.register(dist.MaskedDistribution)
def _(d, batch_shape):
mask = d._mask.reshape(batch_shape)
base_dist = reshape_batch(d.base_dist, batch_shape)
return base_dist.mask(mask)
@prefix_condition.register(dist.MaskedDistribution)
def _(d, data):
base_dist = prefix_condition(d.base_dist, data)
mask = d._mask[tuple(slice(-size, None) for size in base_dist.batch_shape)]
return base_dist.mask(mask)
class Model2(ForecastingModel):
def __init__(self, mask=None):
super().__init__()
self.mask = mask
def model(self, zero_data, covariates):
num_stores, num_products, duration, one = zero_data.shape
# We construct plates once so we can reuse them later. We ensure they don't collide by
# specifying different dim args for each: -3, -2, -1. Note the time_plate is dim=-1.
stores_plate = pyro.plate("stores", num_stores, dim=-3)
products_plate = pyro.plate("products", num_products, dim=-2)
day_of_week_plate = pyro.plate("day_of_week", 7, dim=-1)
# Let's model the time-dependent part with only O(num_stations * duration) 复杂度 many
# parameters, rather than the full possible O(num_stations ** 2 * duration) data size.
drift_stability = pyro.sample("drift_stability", dist.Uniform(1, 2))
drift_scale = pyro.sample("drift_scale", dist.LogNormal(-20, 5))
with stores_plate:
with day_of_week_plate:
stores_seasonal = pyro.sample("stores_seasonal", dist.Normal(0, 5))
with products_plate:
with day_of_week_plate:
product_seasonal = pyro.sample("product_seasonal", dist.Normal(0, 5))
with self.time_plate:
with poutine.reparam(config={"drift": LocScaleReparam()}):
with poutine.reparam(config={"drift": SymmetricStableReparam()}):
drift = pyro.sample("drift",
dist.Stable(drift_stability, 0, drift_scale))
with stores_plate, products_plate:
pairwise = pyro.sample("pairwise", dist.Normal(0, 1))
# Outside of the time plate we can now form the prediction.
seasonal = stores_seasonal + product_seasonal # Note this broadcasts.
seasonal = periodic_repeat(seasonal, duration, dim=-1) # 组装时间周期
motion = drift.cumsum(dim=-1) # A Levy stable motion to model shocks.
prediction = motion + seasonal + pairwise
# We will decompose the noise scale parameter into
# an origin-local and a destination-local component.
with stores_plate:
stores_scale = pyro.sample("stores_scale", dist.LogNormal(-5, 5))
with products_plate:
products_scale = pyro.sample("products_scale", dist.LogNormal(-5, 5))
scale = stores_scale + products_scale
# At this point our prediction and scale have shape (50, 50, duration) and (50, 50, 1)
# respectively, but we want them to have shape (50, 50, duration, 1) to satisfy the
# Forecaster requirements.
scale = scale.unsqueeze(-1)
prediction = prediction.unsqueeze(-1)
# Finally we construct a noise distribution and call the .predict() method.
# Note that predict must be called inside the origin and destination plates.
noise_dist = dist.Normal(zero_data, scale) # We should be able to use either 0 or zero_data for loc.
# Mask only if in training mode, i.e. if prediction.requires_grad.
# This is kind of a hack, but works since prediction is called in a torch.no_grad() context.
if self.mask is not None and prediction.requires_grad:
noise_dist = noise_dist.mask(self.mask)
with stores_plate, products_plate:
self.predict(noise_dist, prediction)
import scipy.stats as ss
np.random.seed(20)
x = ss.norm.rvs(0, 100, (44, 103, 271, 1))
for i in range(int(x.size*.25)):
a = np.random.choice(44)
b = np.random.choice(103)
c = np.random.choice(271)
x[a,b,c,0] = np.nan
msc = x
test_data = torch.Tensor(msc)
T2 = test_data.size(-2) # end
T1 = T2 - 7 * 2 # train/test split
T0 = 0 # beginning: train on 90 days of data
covariates = torch.zeros(test_data.size(-2), 0) # empty covariates
pyro.set_rng_seed(1)
pyro.clear_param_store()
mask = ~torch.isnan(test_data)
# test_data = torch.Tensor(msc)
test_data = torch.Tensor(np.nan_to_num(msc))
covariates = torch.zeros(test_data.size(-2), 0)
forecaster = Forecaster(Model2(mask=mask[..., T0:T1, :]), test_data[..., T0:T1, :], covariates[T0:T1],
learning_rate=0.1, learning_rate_decay=1, num_steps=501, log_every=50)
samples = forecaster(test_data[..., T0:T1, :], covariates[T0:T2], num_samples=100)
samples.clamp_(min=0) # apply domain knowledge: the samples must be positive
p10, p50, p90 = quantile(samples[..., 0], (0.1, 0.5, 0.9))
crps = eval_crps(samples, test_data[..., T1:T2, :])
print(samples.shape, p10.shape)
with torch.no_grad(): # Run in a no_grad context to avoid shape error.
trace = poutine.trace(Model2(mask=mask)).get_trace(test_data[..., T0:T1, :], covariates[T0:T1])
trace.compute_log_prob() # optional, but allows printing of log_prob shapes
print(trace.format_shapes()) |
Thank you for the detailed explanation ! |
Missing observation is very common in real work , but I doesn't found a way to write a correct model .
And tutorial http://pyro.ai/examples/forecasting_iii.html do not cover this problem : nan in real data .
For example :
in 2010, there was only 48 station
in 2011, 3 station was closed and 5 new opened, 50 stations .
in 2012, 2 station was closed in 2011 reponed
…
For another example :
My data is salecount of various products in many stores .
I have reshape the szie to torch.Size([44, 103, 671, 1]) , means:
44 stores, 103 products , 671 days salecount . Some stores may be closed in different days, so as products would be off.shelf by many reason .
random 10 products history salecount :
Create matrix must have nan values , and
These are real cases .
According to https://forum.pyro.ai/t/how-to-ignore-nan-values-when-do-hierachical-forecast/2412
I alter forecast_iii code to below, structure is totally same , I just change some names .
I add a mask to noise distribution , but without luck , it doesn't work in prediction .
The text was updated successfully, but these errors were encountered: