-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7b04c1b
commit 421e429
Showing
2 changed files
with
75 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
apiVersion: crd.chenshaowen.com/v1 | ||
kind: Task | ||
metadata: | ||
name: drain-error-gpu | ||
namespace: ops-system | ||
spec: | ||
desc: Drain GPUs with error states by checking PCI bus IDs and UUIDs. | ||
host: "alert-gpu=enabled" | ||
steps: | ||
- name: drain-error-gpu | ||
content: | | ||
#!/usr/bin/python | ||
import os | ||
import subprocess | ||
import re | ||
import requests | ||
import json | ||
from datetime import datetime | ||
def get_pci_info(): | ||
try: | ||
output = subprocess.run(['nvidia-smi', '-L'], | ||
universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).stdout | ||
print(f"nvidia-smi output: {output}") | ||
pci_bus_ids = re.findall(r'gpu (\S+):', output) | ||
if pci_bus_ids: | ||
return pci_bus_ids | ||
else: | ||
return [] | ||
except Exception as e: | ||
raise Exception(f"Failed to extract PCI bus IDs: {str(e)}") | ||
def disable_gpu(pci_bus_id): | ||
try: | ||
subprocess.run(['nvidia-smi', 'drain', '-p', pci_bus_id, '-m', '1'], check=True) | ||
except Exception as e: | ||
raise Exception(f"Failed to disable GPU {pci_bus_id}: {str(e)}") | ||
def send_alert(status, message): | ||
payload = { | ||
'kind': '${TASKRUN}', | ||
'type': 'TaskRunReport', | ||
'status': status, | ||
'message': message, | ||
'host': '${HOSTNAME}', | ||
'operator': '>' | ||
} | ||
headers = { | ||
'Content-Type': 'application/json' | ||
} | ||
response = requests.post('${OPSSERVER_ENDPOINT}/api/v1/namespaces/${NAMESPACE}/events/taskruns.${TASKRUN}.reports.' + host, | ||
headers=headers, data=json.dumps(payload)) | ||
print(response.text) | ||
def main(): | ||
try: | ||
if not os.path.exists('/usr/bin/nvidia-smi'): | ||
raise Exception("nvidia-smi not found") | ||
gpu_info = get_pci_info() | ||
print(f"Found GPUs: {gpu_info}") | ||
for pci_bus_id in gpu_info: | ||
disable_gpu(pci_bus_id) | ||
message = f"dissabled GPU {pci_bus_id}" | ||
send_alert('alerting', message) | ||
except Exception as e: | ||
send_alert('alerting', str(e)) | ||
if __name__ == "__main__": | ||
main() |