Skip to content

Commit

Permalink
TRT Facility API: various improvements (skyportal#5261)
Browse files Browse the repository at this point in the history
Various minor improvements to the TRT facility API:
* accounts for some of the submission failures with 200 status_codes
* updates to SQL alchemy 2.0 syntax
* adds more error handling + frontend notifications
* use the run_async decorator to call the download_observations method.
  • Loading branch information
Theodlz authored Sep 11, 2024
1 parent 755d7fb commit 0865251
Showing 1 changed file with 183 additions and 108 deletions.
291 changes: 183 additions & 108 deletions skyportal/facility_apis/trt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import base64
import functools
import json
import os
from datetime import datetime, timedelta
import sqlalchemy as sa
from sqlalchemy.orm import sessionmaker, scoped_session
from tornado.ioloop import IOLoop

import requests
from astropy.time import Time, TimeDelta
Expand Down Expand Up @@ -116,44 +113,48 @@ def download_observations(request_id, urls):
Group,
)

Session = scoped_session(sessionmaker())
if Session.registry.has():
session = Session()
else:
session = Session(bind=DBSession.session_factory.kw["bind"])

try:
req = session.scalars(
sa.select(FollowupRequest).where(FollowupRequest.id == request_id)
).first()

group_ids = [g.id for g in req.requester.accessible_groups]
groups = session.scalars(
Group.select(req.requester).where(Group.id.in_(group_ids))
).all()
for url in urls:
url_parse = urlparse(url)
attachment_name = os.path.basename(url_parse.path)
with urllib.request.urlopen(url) as f:
attachment_bytes = base64.b64encode(f.read())
comment = Comment(
text=f'TRT: {attachment_name}',
obj_id=req.obj.id,
attachment_bytes=attachment_bytes,
attachment_name=attachment_name,
author=req.requester,
groups=groups,
bot=True,
with DBSession() as session:
try:
req = session.scalar(
sa.select(FollowupRequest).where(FollowupRequest.id == request_id)
)
session.add(comment)
req.status = f'{len(urls)} images posted as comment'
session.commit()
except Exception as e:
session.rollback()
log(f"Unable to post data for {request_id}: {e}")
finally:
session.close()
Session.remove()

group_ids = [g.id for g in req.requester.accessible_groups]
groups = session.scalars(
Group.select(req.requester).where(Group.id.in_(group_ids))
).all()
for url in urls:
url_parse = urlparse(url)
attachment_name = os.path.basename(url_parse.path)
try:
with urllib.request.urlopen(url) as f:
attachment_bytes = base64.b64encode(f.read())
comment = Comment(
text=f'TRT: {attachment_name}',
obj_id=req.obj.id,
attachment_bytes=attachment_bytes,
attachment_name=attachment_name,
author=req.requester,
groups=groups,
bot=True,
)
except Exception as e:
log(
f"TRT API Retrieve: unable to download data for {request_id}: {e}"
)
comment = Comment(
text=f'TRT: {attachment_name}, **failed to download** data [at this url]({url})',
obj_id=req.obj.id,
author=req.requester,
groups=groups,
bot=True,
)
session.add(comment)
req.status = f'{len(urls)} images posted as comment'
session.commit()
except Exception as e:
session.rollback()
log(f"Unable to post data for {request_id}: {e}")


class TRTAPI(FollowUpAPI):
Expand Down Expand Up @@ -195,7 +196,11 @@ def submit(request, session, **kwargs):
headers=headers,
)

if r.status_code == 200:
if r.status_code == 200 and 'token expired' in r.content:
request.status = (
'rejected: API token specified in the allocation is expired.'
)
elif r.status_code == 200:
request.status = 'submitted'
else:
request.status = f'rejected: {r.content}'
Expand All @@ -218,19 +223,31 @@ def submit(request, session, **kwargs):

session.add(transaction)

if kwargs.get('refresh_source', False):
try:
flow = Flow()
flow.push(
'*',
'skyportal/REFRESH_SOURCE',
payload={'obj_key': request.obj.internal_key},
)
if kwargs.get('refresh_requests', False):
flow = Flow()
flow.push(
request.last_modified_by_id,
'skyportal/REFRESH_FOLLOWUP_REQUESTS',
)
if kwargs.get('refresh_source', False):
flow.push(
'*',
'skyportal/REFRESH_SOURCE',
payload={'obj_key': request.obj.internal_key},
)
if kwargs.get('refresh_requests', False):
flow.push(
request.last_modified_by_id,
'skyportal/REFRESH_FOLLOWUP_REQUESTS',
)
if request.status != 'submitted':
flow.push(
request.last_modified_by_id,
'baselayer/SHOW_NOTIFICATION',
payload={
'message': f'Failed to submit TRT request: "{request.status}"',
'type': 'error',
},
)
except Exception as e:
log(f'Failed to send notification: {e}')
pass

@staticmethod
def get(request, session, **kwargs):
Expand All @@ -244,28 +261,38 @@ def get(request, session, **kwargs):
Database session for this transaction
"""

from ..models import DBSession, FacilityTransaction, FollowupRequest
from ..models import FacilityTransaction, FollowupRequest
from ..utils.asynchronous import run_async

if cfg['app.trt_endpoint'] is not None:
altdata = request.allocation.altdata

req = (
DBSession()
.query(FollowupRequest)
.filter(FollowupRequest.id == request.id)
.one()
)

altdata = request.allocation.altdata

if not altdata:
raise ValueError('Missing allocation information.')

req = session.scalar(
sa.select(FollowupRequest).where(FollowupRequest.id == request.id)
)

url = f"{cfg['app.trt_endpoint']}/getfilepath"

content = req.transactions[-1].response["content"]
content = json.loads(content)

if 'token expired' in content:
raise ValueError(
'Token expired, the request might have not been submitted correctly, or cannot be retrieved.'
)

try:
content = json.loads(content)
except json.JSONDecodeError:
raise ValueError(
f'Unable to parse submission response from TRT: {content}'
)

uid = content[0]
if not uid:
raise ValueError('Unable to find observation ID in response from TRT.')

payload = json.dumps({"obs_id": uid})

Expand All @@ -278,25 +305,32 @@ def get(request, session, **kwargs):
r.raise_for_status()

if r.status_code == 200:
try:
data = r.json()
except json.JSONDecodeError:
raise ValueError(
f'Unable to parse retrieval response from TRT: {r.content}'
)

urls = []
for file_path in r.json()['file_path']:

if not isinstance(data.get('file_path', []), list):
raise ValueError(
f'Unexpected response from TRT, expected list of file paths, got {data.get("file_path", [])}'
)
for file_path in data.get('file_path', []):
for key in file_path.keys():
calibrated = file_path[key].get('calibrated', '')
if calibrated:
urls.append(calibrated)

if len(urls) > 0:
request.status = "complete"
download_obs = functools.partial(
download_observations,
request.id,
urls,
)
IOLoop.current().run_in_executor(None, download_obs)
run_async(download_observations, request.id, urls)
else:
request.status = "pending"
else:
request.status = r.content.decode()
request.status = f'failed to retrieve: {r.content.decode()}'

transaction = FacilityTransaction(
request=http.serialize_requests_request(r.request),
Expand All @@ -308,19 +342,49 @@ def get(request, session, **kwargs):
session.add(transaction)
session.commit()

if kwargs.get('refresh_source', False):
flow = Flow()
flow.push(
'*',
'skyportal/REFRESH_SOURCE',
payload={'obj_key': request.obj.internal_key},
)
if kwargs.get('refresh_requests', False):
try:
flow = Flow()
flow.push(
request.last_modified_by_id,
'skyportal/REFRESH_FOLLOWUP_REQUESTS',
)
if kwargs.get('refresh_source', False):
flow.push(
'*',
'skyportal/REFRESH_SOURCE',
payload={'obj_key': request.obj.internal_key},
)
if kwargs.get('refresh_requests', False):
flow.push(
request.last_modified_by_id,
'skyportal/REFRESH_FOLLOWUP_REQUESTS',
)
if request.status == 'pending':
flow.push(
request.last_modified_by_id,
'baselayer/SHOW_NOTIFICATION',
payload={
'message': 'TRT request is still pending.',
'type': 'warning',
},
)
elif request.status.startswith('complete'):
flow.push(
request.last_modified_by_id,
'baselayer/SHOW_NOTIFICATION',
payload={
'message': 'TRT request is complete, observations will be downloaded shortly.',
'type': 'info',
},
)
else:
flow.push(
request.last_modified_by_id,
'baselayer/SHOW_NOTIFICATION',
payload={
'message': f'Failed to retrieve TRT request: "{request.status}"',
'type': 'error',
},
)
except Exception as e:
log(f'Failed to send notification: {e}')
pass

@staticmethod
def delete(request, session, **kwargs):
Expand All @@ -334,47 +398,58 @@ def delete(request, session, **kwargs):
Database session for this transaction
"""

from ..models import DBSession, FacilityTransaction, FollowupRequest
from ..models import FacilityTransaction, FollowupRequest

last_modified_by_id = request.last_modified_by_id
obj_internal_key = request.obj.internal_key

if cfg['app.trt_endpoint'] is not None:
req = (
DBSession()
.query(FollowupRequest)
.filter(FollowupRequest.id == request.id)
.one()
)

altdata = request.allocation.altdata

if not altdata:
raise ValueError('Missing allocation information.')

req = session.scalar(
sa.select(FollowupRequest).where(FollowupRequest.id == request.id)
)

url = f"{cfg['app.trt_endpoint']}/cancelobservation"

content = req.transactions[-1].response["content"]
content = json.loads(content)
uid = content[0]
if 'token expired' in content:
request.status = 'failed to delete: API token specified in the allocation is expired.'
session.commit()
else:
try:
content = json.loads(content)
except json.JSONDecodeError:
raise ValueError(
f'Unable to parse submission response from TRT: {content}'
)

payload = json.dumps({"obs_id": [uid]})
uid = content[0]
if not uid:
raise ValueError(
'Unable to find observation ID in response from TRT.'
)

headers = {
'Content-Type': 'application/json',
'TRT': altdata['token'],
}
r = requests.request("POST", url, headers=headers, data=payload)
payload = json.dumps({"obs_id": [uid]})

r.raise_for_status()
request.status = "deleted"
headers = {
'Content-Type': 'application/json',
'TRT': altdata['token'],
}
r = requests.request("POST", url, headers=headers, data=payload)

transaction = FacilityTransaction(
request=http.serialize_requests_request(r.request),
response=http.serialize_requests_response(r),
followup_request=request,
initiator_id=request.last_modified_by_id,
)
r.raise_for_status()
request.status = "deleted"

transaction = FacilityTransaction(
request=http.serialize_requests_request(r.request),
response=http.serialize_requests_response(r),
followup_request=request,
initiator_id=request.last_modified_by_id,
)
else:
request.status = 'deleted'

Expand Down

0 comments on commit 0865251

Please sign in to comment.