Skip to content

Commit

Permalink
(SD) Cleanup after each test. (nod-ai#744)
Browse files Browse the repository at this point in the history
Segfaults otherwise on certain runners.
  • Loading branch information
monorimet authored Jun 26, 2024
1 parent e328124 commit 7cc1be5
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions models/turbine_models/tests/sd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def testExportT5Model(self):
new_blob_name = new_blob_name[0] + "-pass.mlir"
turbine_tank.changeBlobName(blob_name, new_blob_name)
del current_args
del turbine

def testExportClipVitLarge14(self):
current_args = copy.deepcopy(default_arguments)
Expand Down Expand Up @@ -152,6 +153,8 @@ def testExportClipVitLarge14(self):
if platform.system() != "Windows":
os.remove(current_args["external_weight_path"])
os.remove(current_args["vmfb_path"])
del current_args
del turbine

def testExportClipModel(self):
current_args = copy.deepcopy(default_arguments)
Expand Down Expand Up @@ -190,7 +193,10 @@ def testExportClipModel(self):
if platform.system() != "Windows":
os.remove(current_args["external_weight_path"])
os.remove(current_args["vmfb_path"])
del current_args
del turbine

@unittest.expectedFailure
def testExportUnetModel(self):
current_args = copy.deepcopy(default_arguments)
blob_name = unet.export_unet_model(
Expand Down Expand Up @@ -301,6 +307,7 @@ def testExportVaeModelDecode(self):
new_blob_name = blob_name.split(".")
new_blob_name = new_blob_name[0] + "-pass.mlir"
turbine_tank.changeBlobName(blob_name, new_blob_name)
del current_args
del torch_output
del turbine
os.remove("stable_diffusion_v1_4_vae.safetensors")
Expand Down Expand Up @@ -352,6 +359,8 @@ def testExportVaeModelEncode(self):
turbine_tank.changeBlobName(blob_name, new_blob_name)
os.remove("stable_diffusion_v1_4_vae.safetensors")
os.remove("stable_diffusion_v1_4_vae.vmfb")
del current_args
del turbine

@unittest.expectedFailure
def testExportPNDMScheduler(self):
Expand Down Expand Up @@ -405,6 +414,7 @@ def testExportPNDMScheduler(self):
turbine_tank.changeBlobName(blob_name, new_blob_name)
os.remove("stable_diffusion_v1_4_scheduler.safetensors")
os.remove("stable_diffusion_v1_4_scheduler.vmfb")
del current_args
del torch_output
del turbine

Expand Down

0 comments on commit 7cc1be5

Please sign in to comment.