diff --git a/src/gretel_synthetics/actgan/actgan.py b/src/gretel_synthetics/actgan/actgan.py index baac69a0..41403a03 100644 --- a/src/gretel_synthetics/actgan/actgan.py +++ b/src/gretel_synthetics/actgan/actgan.py @@ -6,6 +6,19 @@ import pandas as pd import torch +from packaging import version +from torch import optim +from torch.nn import ( + BatchNorm1d, + Dropout, + functional, + LeakyReLU, + Linear, + Module, + ReLU, + Sequential, +) + from gretel_synthetics.actgan.base import BaseSynthesizer, random_state from gretel_synthetics.actgan.column_encodings import ( BinaryColumnEncoding, @@ -21,18 +34,6 @@ ) from gretel_synthetics.actgan.train_data import TrainData from gretel_synthetics.typing import DFLike -from packaging import version -from torch import optim -from torch.nn import ( - BatchNorm1d, - Dropout, - functional, - LeakyReLU, - Linear, - Module, - ReLU, - Sequential, -) logging.basicConfig() logger = logging.getLogger(__name__) diff --git a/src/gretel_synthetics/actgan/actgan_wrapper.py b/src/gretel_synthetics/actgan/actgan_wrapper.py index 30612074..a2961e80 100644 --- a/src/gretel_synthetics/actgan/actgan_wrapper.py +++ b/src/gretel_synthetics/actgan/actgan_wrapper.py @@ -10,21 +10,23 @@ import numpy as np import pandas as pd +from rdt.transformers import BaseTransformer +from sdv.tabular.base import BaseTabularModel + from gretel_synthetics.actgan.actgan import ACTGANSynthesizer from gretel_synthetics.actgan.columnar_df import ColumnarDF from gretel_synthetics.actgan.structures import ConditionalVectorType from gretel_synthetics.detectors.sdv import SDVTableMetadata from gretel_synthetics.utils import rdt_patches, torch_utils -from rdt.transformers import BaseTransformer -from sdv.tabular.base import BaseTabularModel if TYPE_CHECKING: - from gretel_synthetics.actgan.structures import EpochInfo from numpy.random import RandomState from sdv.constraints import Constraint from sdv.metadata import Metadata from torch import Generator + from gretel_synthetics.actgan.structures import EpochInfo + EPOCH_CALLBACK = "epoch_callback" logging.basicConfig() diff --git a/src/gretel_synthetics/actgan/data_transformer.py b/src/gretel_synthetics/actgan/data_transformer.py index eae2ad63..1e2e4bfd 100644 --- a/src/gretel_synthetics/actgan/data_transformer.py +++ b/src/gretel_synthetics/actgan/data_transformer.py @@ -6,6 +6,8 @@ import numpy as np import pandas as pd +from rdt.transformers import BinaryEncoder, OneHotEncoder + from gretel_synthetics.actgan.column_encodings import ( BinaryColumnEncoding, FloatColumnEncoding, @@ -22,7 +24,6 @@ ClusterBasedNormalizer, ) from gretel_synthetics.typing import DFLike -from rdt.transformers import BinaryEncoder, OneHotEncoder logging.basicConfig() logger = logging.getLogger(__name__) diff --git a/src/gretel_synthetics/actgan/structures.py b/src/gretel_synthetics/actgan/structures.py index 34a233db..a4760d3d 100644 --- a/src/gretel_synthetics/actgan/structures.py +++ b/src/gretel_synthetics/actgan/structures.py @@ -11,9 +11,10 @@ if TYPE_CHECKING: import numpy as np - from gretel_synthetics.actgan.column_encodings import ColumnEncoding from rdt.transformers.base import BaseTransformer + from gretel_synthetics.actgan.column_encodings import ColumnEncoding + class ColumnType(str, Enum): CONTINUOUS = "continuous" diff --git a/src/gretel_synthetics/actgan/transformers.py b/src/gretel_synthetics/actgan/transformers.py index fc87e1eb..f074437f 100644 --- a/src/gretel_synthetics/actgan/transformers.py +++ b/src/gretel_synthetics/actgan/transformers.py @@ -10,11 +10,12 @@ import pandas as pd from category_encoders import BaseNEncoder, BinaryEncoder -from gretel_synthetics.typing import ListOrSeriesOrDF, SeriesOrDFLike from rdt.transformers import BaseTransformer from rdt.transformers import ClusterBasedNormalizer as RDTClusterBasedNormalizer from rdt.transformers import FloatFormatter +from gretel_synthetics.typing import ListOrSeriesOrDF, SeriesOrDFLike + MODE = "mode" VALID_ROUNDING_MODES = (MODE,) diff --git a/src/gretel_synthetics/batch.py b/src/gretel_synthetics/batch.py index 6284aa70..ca58924b 100644 --- a/src/gretel_synthetics/batch.py +++ b/src/gretel_synthetics/batch.py @@ -30,10 +30,14 @@ from typing import List, Optional, Tuple, Type, Union import cloudpickle -import gretel_synthetics.const as const import numpy as np import pandas as pd +from pandas.errors import EmptyDataError +from tqdm.auto import tqdm + +import gretel_synthetics.const as const + from gretel_synthetics.config import ( BaseConfig, config_from_model_dir, @@ -44,8 +48,6 @@ from gretel_synthetics.generate import generate_text, GenText, SeedingGenerator from gretel_synthetics.tokenizers import BaseTokenizerTrainer from gretel_synthetics.train import train -from pandas.errors import EmptyDataError -from tqdm.auto import tqdm logging.basicConfig() logger = logging.getLogger(__name__) diff --git a/src/gretel_synthetics/config.py b/src/gretel_synthetics/config.py index 82d01a84..bd3dc4b5 100644 --- a/src/gretel_synthetics/config.py +++ b/src/gretel_synthetics/config.py @@ -13,9 +13,10 @@ from pathlib import Path from typing import Callable, Optional, TYPE_CHECKING -import gretel_synthetics.const as const import tensorflow as tf +import gretel_synthetics.const as const + from gretel_synthetics.tensorflow.generator import TensorFlowGenerator from gretel_synthetics.tensorflow.train import train_rnn diff --git a/src/gretel_synthetics/detectors/sdv.py b/src/gretel_synthetics/detectors/sdv.py index aa256b5a..c2c0cbd0 100644 --- a/src/gretel_synthetics/detectors/sdv.py +++ b/src/gretel_synthetics/detectors/sdv.py @@ -9,10 +9,11 @@ import numpy as np import pandas as pd -from gretel_synthetics.detectors.dates import detect_datetimes from rdt.transformers import BaseTransformer from rdt.transformers.datetime import UnixTimestampEncoder +from gretel_synthetics.detectors.dates import detect_datetimes + if TYPE_CHECKING: from gretel_synthetics.detectors.dates import DateTimeColumn diff --git a/src/gretel_synthetics/generate_utils.py b/src/gretel_synthetics/generate_utils.py index 76c0ed17..200b8e99 100644 --- a/src/gretel_synthetics/generate_utils.py +++ b/src/gretel_synthetics/generate_utils.py @@ -11,11 +11,12 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory from typing import Callable, Optional, Union +from smart_open import open as smart_open + from gretel_synthetics.batch import DataFrameBatch, MAX_INVALID from gretel_synthetics.config import config_from_model_dir from gretel_synthetics.generate import generate_text from gretel_synthetics.utils.tar_util import safe_extractall -from smart_open import open as smart_open logging.basicConfig() logger = logging.getLogger(__name__) diff --git a/src/gretel_synthetics/tensorflow/train.py b/src/gretel_synthetics/tensorflow/train.py index 102662af..4d5029ad 100644 --- a/src/gretel_synthetics/tensorflow/train.py +++ b/src/gretel_synthetics/tensorflow/train.py @@ -17,6 +17,8 @@ import pandas as pd import tensorflow as tf +from tqdm import tqdm + from gretel_synthetics.const import ( METRIC_ACCURACY, METRIC_DELTA, @@ -30,7 +32,6 @@ from gretel_synthetics.tensorflow.model import build_model, load_model from gretel_synthetics.tokenizers import BaseTokenizer from gretel_synthetics.train import EpochState -from tqdm import tqdm if TYPE_CHECKING: from gretel_synthetics.config import TensorFlowConfig diff --git a/src/gretel_synthetics/timeseries_dgan/dgan.py b/src/gretel_synthetics/timeseries_dgan/dgan.py index 04843fd1..95374e3c 100644 --- a/src/gretel_synthetics/timeseries_dgan/dgan.py +++ b/src/gretel_synthetics/timeseries_dgan/dgan.py @@ -56,6 +56,8 @@ import pandas as pd import torch +from torch.utils.data import DataLoader, Dataset, TensorDataset + from gretel_synthetics.errors import DataError, InternalError, ParameterError from gretel_synthetics.timeseries_dgan.config import DfStyle, DGANConfig, OutputType from gretel_synthetics.timeseries_dgan.structures import ProgressInfo @@ -67,7 +69,6 @@ Output, transform, ) -from torch.utils.data import DataLoader, Dataset, TensorDataset logger = logging.getLogger(__name__) diff --git a/src/gretel_synthetics/timeseries_dgan/transformations.py b/src/gretel_synthetics/timeseries_dgan/transformations.py index 0cd4968f..41d3c42b 100644 --- a/src/gretel_synthetics/timeseries_dgan/transformations.py +++ b/src/gretel_synthetics/timeseries_dgan/transformations.py @@ -8,9 +8,10 @@ import numpy as np from category_encoders import BinaryEncoder, OneHotEncoder -from gretel_synthetics.timeseries_dgan.config import Normalization, OutputType from scipy.stats import mode +from gretel_synthetics.timeseries_dgan.config import Normalization, OutputType + def _new_uuid() -> str: """Return a random uuid prefixed with 'gretel-'.""" diff --git a/src/gretel_synthetics/tokenizers.py b/src/gretel_synthetics/tokenizers.py index d30fe913..38838989 100644 --- a/src/gretel_synthetics/tokenizers.py +++ b/src/gretel_synthetics/tokenizers.py @@ -47,13 +47,15 @@ ) import cloudpickle -import gretel_synthetics.const as const import numpy as np import sentencepiece as spm -from gretel_synthetics.errors import ParameterError from smart_open import open as smart_open +import gretel_synthetics.const as const + +from gretel_synthetics.errors import ParameterError + if TYPE_CHECKING: from gretel_synthetics.config import BaseConfig else: diff --git a/tests-integration/test_generate.py b/tests-integration/test_generate.py index 7cabf163..30914bdf 100644 --- a/tests-integration/test_generate.py +++ b/tests-integration/test_generate.py @@ -31,10 +31,11 @@ import pandas as pd import pytest +from smart_open import open as smart_open + from gretel_synthetics.batch import DataFrameBatch, GenerationProgress from gretel_synthetics.generate_utils import DataFileGenerator from gretel_synthetics.utils.tar_util import safe_extractall -from smart_open import open as smart_open BATCH_MODELS = [ "https://gretel-public-website.s3-us-west-2.amazonaws.com/tests/synthetics/models/safecast-batch-sp-0-14.tar.gz", diff --git a/tests-integration/test_train.py b/tests-integration/test_train.py index cc1ab70a..f4d9cec4 100644 --- a/tests-integration/test_train.py +++ b/tests-integration/test_train.py @@ -7,10 +7,11 @@ from pathlib import Path -import gretel_synthetics.const as const import pandas as pd import pytest +import gretel_synthetics.const as const + from gretel_synthetics.batch import DataFrameBatch, PATH_HOLDER from gretel_synthetics.config import TensorFlowConfig from gretel_synthetics.errors import DataError diff --git a/tests/actgan/test_actgan.py b/tests/actgan/test_actgan.py index 3ad5b1ba..07d1c6b5 100644 --- a/tests/actgan/test_actgan.py +++ b/tests/actgan/test_actgan.py @@ -7,10 +7,11 @@ import pandas as pd import pytest +from pandas.api.types import is_number + from gretel_synthetics.actgan import ACTGAN from gretel_synthetics.actgan.data_transformer import BinaryEncodingTransformer from gretel_synthetics.actgan.structures import ConditionalVectorType -from pandas.api.types import is_number @pytest.fixture diff --git a/tests/detectors/test_sdv.py b/tests/detectors/test_sdv.py index 0d21449f..f7df623f 100644 --- a/tests/detectors/test_sdv.py +++ b/tests/detectors/test_sdv.py @@ -7,12 +7,13 @@ import pandas as pd import pytest -from gretel_synthetics.detectors.dates import DateTimeColumn, DateTimeColumns -from gretel_synthetics.detectors.sdv import EmptyFieldTransformer, SDVTableMetadata from rdt import HyperTransformer from rdt.transformers.datetime import UnixTimestampEncoder from sdv import Table +from gretel_synthetics.detectors.dates import DateTimeColumn, DateTimeColumns +from gretel_synthetics.detectors.sdv import EmptyFieldTransformer, SDVTableMetadata + def _create_info() -> DateTimeColumns: return DateTimeColumns(columns={"footime": DateTimeColumn("footime", "%Y-%m-%d")}) diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py index fea8c2df..26e287ab 100644 --- a/tests/test_tokenizers.py +++ b/tests/test_tokenizers.py @@ -1,9 +1,10 @@ from copy import deepcopy from pathlib import Path -import gretel_synthetics.tokenizers as tok import pytest +import gretel_synthetics.tokenizers as tok + from gretel_synthetics.config import BaseConfig from gretel_synthetics.tokenizers import VocabSizeTooSmall diff --git a/tests/timeseries_dgan/test_dgan.py b/tests/timeseries_dgan/test_dgan.py index 64fd856d..d9a41e8b 100644 --- a/tests/timeseries_dgan/test_dgan.py +++ b/tests/timeseries_dgan/test_dgan.py @@ -5,6 +5,9 @@ import pandas as pd import pytest +from pandas.api.types import is_numeric_dtype, is_object_dtype +from pandas.testing import assert_frame_equal + from gretel_synthetics.errors import DataError, ParameterError from gretel_synthetics.timeseries_dgan.config import ( DfStyle, @@ -27,8 +30,6 @@ ContinuousOutput, OneHotEncodedOutput, ) -from pandas.api.types import is_numeric_dtype, is_object_dtype -from pandas.testing import assert_frame_equal @pytest.fixture diff --git a/tests/utils/test_rdt_float_formatter_orig.py b/tests/utils/test_rdt_float_formatter_orig.py index 66d54dbd..c6793fc5 100644 --- a/tests/utils/test_rdt_float_formatter_orig.py +++ b/tests/utils/test_rdt_float_formatter_orig.py @@ -9,10 +9,11 @@ import pandas as pd import pytest -from gretel_synthetics.utils.rdt_patches import patch_float_formatter_rounding_bug from rdt.transformers.null import NullTransformer from rdt.transformers.numerical import FloatFormatter +from gretel_synthetics.utils.rdt_patches import patch_float_formatter_rounding_bug + with patch_float_formatter_rounding_bug(): # This is the original suite of tests for the FloatFormatter from rdt. # Source code is copied from diff --git a/tests/utils/test_rdt_patches.py b/tests/utils/test_rdt_patches.py index e0b6bf2c..a081ec99 100644 --- a/tests/utils/test_rdt_patches.py +++ b/tests/utils/test_rdt_patches.py @@ -3,11 +3,12 @@ import numpy as np import pandas as pd +from rdt.transformers.numerical import FloatFormatter + from gretel_synthetics.utils.rdt_patches import ( _patched_float_formatter_reverse_transform, patch_float_formatter_rounding_bug, ) -from rdt.transformers.numerical import FloatFormatter def test_original_rounding_bug_upstream():