forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
THGenerator.hpp
39 lines (34 loc) · 1.08 KB
/
THGenerator.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#pragma once
#include <ATen/core/MT19937RNGEngine.h>
/**
* THGeneratorState is a POD class needed for memcpys
* in torch.get_rng_state() and torch.set_rng_state().
* It is a legacy class and even though it is replaced with
* at::CPUGeneratorImpl, we need this class and some of its fields
* to support backward compatibility on loading checkpoints.
*/
struct THGeneratorState {
/* The initial seed. */
uint64_t the_initial_seed;
int left; /* = 1; */
int seeded; /* = 0; */
uint64_t next;
uint64_t state[at::MERSENNE_STATE_N]; /* the array for the state vector */
/********************************/
/* For normal distribution */
double normal_x;
double normal_y;
double normal_rho;
int normal_is_valid; /* = 0; */
};
/**
* THGeneratorStateNew is a POD class containing
* new data introduced in at::CPUGeneratorImpl and the legacy state. It is used
* as a helper for torch.get_rng_state() and torch.set_rng_state()
* functions.
*/
struct THGeneratorStateNew {
THGeneratorState legacy_pod;
float next_float_normal_sample;
bool is_next_float_normal_sample_valid;
};