Skip to content

Commit

Permalink
🐛 fix seed context error
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed Jun 4, 2024
1 parent 2782182 commit faceb2b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
43 changes: 38 additions & 5 deletions modules/utils/SeedContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,31 @@ def deterministic(seed=0):
torch.backends.cudnn.benchmark = False


def is_numeric(obj):
if isinstance(obj, str):
try:
float(obj)
return True
except ValueError:
return False
elif isinstance(obj, (np.integer, np.signedinteger, np.unsignedinteger)):
return True
elif isinstance(obj, np.floating):
return True
elif isinstance(obj, (int, float)):
return True
else:
return False


class SeedContext:
def __init__(self, seed):
assert (
isinstance(seed, int)
or isinstance(seed, float)
or (isinstance(seed, str) and seed.isdigit())
), "Seed must be an integer or a float."
assert is_numeric(seed), "Seed must be an number."

try:
self.seed = int(np.clip(int(seed), -1, 2**32 - 1))
except Exception as e:
raise ValueError("Seed must be an integer.")

self.seed = seed
self.state = None
Expand All @@ -42,3 +60,18 @@ def __exit__(self, exc_type, exc_value, traceback):
torch.set_rng_state(self.state[0])
random.setstate(self.state[1])
np.random.set_state(self.state[2])


if __name__ == "__main__":
print(is_numeric("1234")) # True
print(is_numeric("12.34")) # True
print(is_numeric("-1234")) # True
print(is_numeric("abc123")) # False
print(is_numeric(np.int32(10))) # True
print(is_numeric(np.float64(10.5))) # True
print(is_numeric(10)) # True
print(is_numeric(10.5)) # True
print(is_numeric(np.int8(10))) # True
print(is_numeric(np.uint64(10))) # True
print(is_numeric(np.float16(10.5))) # True
print(is_numeric([1, 2, 3])) # False
1 change: 1 addition & 0 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def tts_generate(
prompt2 = prompt2 or params.get("prompt2", "")

infer_seed = clip(infer_seed, -1, 2**32 - 1)
infer_seed = int(infer_seed)

if not disable_normalize:
text = text_normalize(text)
Expand Down

0 comments on commit faceb2b

Please sign in to comment.