-
Notifications
You must be signed in to change notification settings - Fork 0
/
dependencies.py
352 lines (293 loc) · 12.5 KB
/
dependencies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
import time
from fastapi import FastAPI, Depends, HTTPException, status, Request, Response
from fastapi.security import OAuth2PasswordBearer
from fastapi.routing import APIRoute
from typing import Callable
from google.auth.transport import requests
from google.oauth2 import id_token
from google.auth import exceptions
from pydantic import BaseModel
from pydantic.typing import List, Set, Dict, Any, Mapping, Optional
from config import *
from stores import firestore
import jwt
# stores reference to global APP
app = FastAPI()
# handle CORS preflight requests
@app.options('/{rest_of_path:path}', include_in_schema=False)
async def preflight_handler(request: Request, rest_of_path: str) -> Response:
response = Response()
response.headers['Access-Control-Allow-Origin'] = ALLOWED_ORIGINS
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Authorization, Content-Type, Range'
return response
# set CORS headers
@app.middleware("http")
async def add_CORS_header(request: Request, call_next):
response = await call_next(request)
response.headers['Access-Control-Allow-Origin'] = ALLOWED_ORIGINS
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Authorization, Content-Type, Range'
return response
def version_str_to_int(version_str: str) -> int:
"""Returns a version integer given a semantic versioning string."""
parts = version_str.split('.')
if len(parts) > 3:
raise HTTPException(status_code=400, detail=f'version tag "{version_str}" should only have 3 parts (major, minor, patch numbers)')
elif len(parts) == 0:
raise HTTPException(status_code=400, detail=f'version tag "{version_str}" should have at least major number')
major = 0
minor = 0
patch = 0
try:
if parts[0][0] == 'v':
major = int(parts[0][1:])
else:
major = int(parts[0])
if len(parts) > 1:
minor = int(parts[1])
if len(parts) > 2:
patch = int(parts[2])
except Exception as e:
raise HTTPException(status_code=400, detail=f'unable to parse version tag "{version_str}": {e}')
return major * 1000 * 1000 + minor * 1000 + patch
# reloads User and Dataset info from DB after this many seconds
USER_REFRESH_SECS = 600.0
MEMBERSHIPS_REFRESH_SECS = 600.0
DATASET_REFRESH_SECS = 600.0
class NeuprintServer(BaseModel):
dataset: str # What the dataset is called in the neuprint server
server: str # name.domain.org
class Dataset(BaseModel):
title: Optional[str]
description: str
tag: Optional[str]
uuid: Optional[str]
dvid: Optional[str] # The base URL including http or https for dvid server.
mainLayer: Optional[str]
neuroglancer: Optional[dict]
versions: Optional[list]
typing: Optional[dict]
neuprintHTTP: Optional[NeuprintServer]
bodyAnnotationSchema: Optional[dict]
orderedLayers: Optional[list]
# legacy -- will be removed after UI accomodates new schema
public: Optional[bool] = False
layers: Optional[List[dict]] = [] # segmentation refs
dimensions: Optional[dict]
position: Optional[List[float]]
crossSectionScale: Optional[float]
projectionScale: Optional[float]
location: Optional[str] # legacy grayscale image ref that will be moved to layers with type=image.
class DatasetCache(BaseModel):
collection: Any
cache: Dict[str, Dataset] = {}
public: Set[str] = set()
updated: float = time.time() # update time for all dataset
def refresh_cache(self):
datasets = self.collection.get()
for dataset_ref in datasets:
dataset_dict = dataset_ref.to_dict()
dataset_obj = Dataset(**dataset_dict)
self.cache[dataset_ref.id] = dataset_obj
if dataset_obj.public:
self.public.add(dataset_ref.id)
self.updated = time.time()
print(f"Cached {len(self.cache)} dataset metadata.")
def get_dataset(self, dataset_id: str) -> Dataset:
"""Returns dataset information."""
age = time.time() - self.updated
if age > DATASET_REFRESH_SECS:
print(f"dataset cache last checked {age} secs ago... refreshing")
self.refresh_cache()
if dataset_id not in self.cache:
raise HTTPException(status_code=404, detail=f"dataset {dataset_id} not found")
return self.cache[dataset_id]
def is_public(self, dataset_id: str) -> bool:
"""Returns True if dataset is public."""
age = time.time() - self.updated
if age > DATASET_REFRESH_SECS:
print(f"dataset cache last checked {age} secs ago... refreshing")
self.refresh_cache()
return dataset_id in self.public
def public_dataset(dataset_id: str) -> bool:
"""Returns True if the given dataset is public"""
return datasets.is_public(dataset_id)
def get_dataset(dataset_id: str) -> Dataset:
"""Returns dataset given the dataset id"""
return datasets.get_dataset(dataset_id)
# cache everything initially on startup of service
datasets = DatasetCache(collection = firestore.get_collection([CLIO_DATASETS]))
datasets.refresh_cache()
def get_dataset(dataset_id: str) -> Dataset:
return datasets.get_dataset(dataset_id)
class User(BaseModel):
email: str # Used for Google authentication
# Possible additions to user fields:
#
# userid: str # Janelia userid if available (None for external users)
# email_verified: bool = False
name: Optional[str] # full name
org: Optional[str] # affiliated organization
disabled: Optional[bool] = False
global_roles: Optional[Set[str]] = set()
datasets: Optional[Dict[str, Set[str]]] = {}
groups: Optional[Set[str]] = set()
google_idinfo: Optional[Mapping[str, Any]] = None
def has_role(self, role: str, dataset: str = "") -> bool:
if role in self.global_roles:
return True
if dataset == "":
return False
if dataset in self.datasets and role in self.datasets[dataset]:
return True
if role == "clio_general" and dataset in datasets.public:
return True
return False
def can_read(self, dataset: str = "") -> bool:
if "clio_general" in self.global_roles:
return True
if dataset in datasets.public:
return True
dataset_roles = self.datasets.get(dataset, set())
read_roles = set(["clio_read", "clio_general", "clio_write"])
return read_roles & dataset_roles
def can_write_own(self, dataset: str = "") -> bool:
if "clio_general" in self.global_roles:
return True
if dataset in datasets.public:
return True
dataset_roles = self.datasets.get(dataset, set())
write_roles = set(["clio_general", "clio_write"])
return write_roles & dataset_roles
def can_write_others(self, dataset: str = "") -> bool:
if "clio_write" in self.global_roles:
return True
return "clio_write" in self.datasets.get(dataset, set())
def is_dataset_admin(self, dataset: str = "") -> bool:
if "admin" in self.global_roles:
return True
dataset_roles = self.datasets.get(dataset, set())
return set(["dataset_admin"]) & dataset_roles
def is_admin(self) -> bool:
return "admin" in self.global_roles
class UserCache(BaseModel):
collection: Any # users collection
cache: Dict[str, User] = {}
user_updated: Dict[str, float] = {} # update time per user
memberships: Dict[str, Set[str]] = {} # set of user emails per group names
memberships_updated: float = 0.0 # last full update of memberships
def cache_user(self, user: User):
self.user_updated[user.email] = time.time()
for group in user.groups:
if group in self.memberships:
self.memberships[group].add(user.email)
else:
self.memberships[group] = set([user.email])
if user.email == OWNER:
user.global_roles.add("admin")
self.cache[user.email] = user
def uncache_user(self, email: str):
if email in self.cache:
user = self.cache[email]
for group in user.groups:
if group in self.memberships:
self.memberships[group].discard(email)
del self.cache[email]
def refresh_user(self, user_ref) -> User:
user_dict = user_ref.to_dict()
user_dict["email"] = user_ref.id
user_obj = User(**user_dict)
self.cache_user(user_obj)
return user_obj
def refresh_cache(self) -> Dict[str, User]:
users = {}
t0 = time.time()
for user_ref in self.collection.get():
users[user_ref.id] = self.refresh_user(user_ref)
self.memberships_updated == time.time()
print(f"Cached {len(self.cache)} user metadata and {len(self.memberships)} groups in {time.time() - t0} secs.")
return users
def get_user(self, email: str, google_idinfo: Mapping[str, Any] = None) -> User:
user = self.cache.get(email)
if user is not None:
age = time.time() - self.user_updated.get(email, 0)
if age > USER_REFRESH_SECS:
user = None
if user is None:
t0 = time.time()
user_ref = self.collection.document(email).get()
print(f"get_user {email} took {time.time() - t0} secs")
if user_ref.exists:
user = self.refresh_user(user_ref)
else:
user = User(email=email)
if google_idinfo and user is not None:
user.google_idinfo = google_idinfo
return user
def group_members(self, user: User, groups: Set[str]) -> Set[str]:
"""Returns set of emails for groups given user belongs unless user is admin"""
if not user.is_admin():
groups.intersection_update(user.groups)
if len(groups) == 0:
return set()
age = time.time() - self.memberships_updated
if age > MEMBERSHIPS_REFRESH_SECS:
self.refresh_cache()
members = set()
for group in groups:
if group in self.memberships:
members.update(self.memberships[group])
return members
users = UserCache(collection = firestore.get_collection([CLIO_USERS]))
users.refresh_cache()
def group_members(user: User, groups: Set[str]) -> Set[str]:
"""
Return set of email addresses of members who are within the given groups
of the given user. Only groups to which the user belongs are added to
the returned set.
"""
return users.group_members(user, groups)
# handle OAuth2
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
def get_user_from_token(token: str = Depends(oauth2_scheme)) -> User:
"""Check token (either FlyEM or Google identity) and return user roles and data."""
email = None
idinfo = None # Google ID token if supplied
# Check if token is a "FlyEM token"
if FLYEM_SECRET:
try:
decoded = jwt.decode(token, FLYEM_SECRET, algorithms="HS256")
exp = decoded.get('exp', 0)
if time.time() <= exp:
email = decoded.get('email', None)
except:
pass
# Consider case when token passed is not a "FlyEM" token (with shared secret)
if not email:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
idinfo = id_token.verify_oauth2_token(token, requests.Request())
email = idinfo["email"].lower()
except exceptions.GoogleAuthError:
print(f"Non-FlyEM token is also not a Google identity token: {token}")
raise credentials_exception
except:
print(f"no user token so using TEST_USER {TEST_USER}")
if TEST_USER is not None:
email = TEST_USER
else:
raise credentials_exception
user = users.get_user(email, idinfo)
if user is None:
print(f"Valid token for user {email} but not associated with a valid user from Clio Firestore")
raise credentials_exception
return user
def get_user(current_user: User = Depends(get_user_from_token)):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user