Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements #451 #468

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 36 additions & 30 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,24 +122,28 @@ def before_request_checks():
g.motomo_info = settings.motomo_info
g.settings = settings

def check_auth():
if settings.debug_oidc_token:
oidc_blueprint.session.token = {'access_token': settings.debug_oidc_token}
else:
try:
if not oidc_blueprint.session.authorized or 'username' not in session:
return logout(next_url=request.full_path)

if oidc_blueprint.session.token['expires_in'] < 20:
app.logger.debug("Force refresh token")
oidc_blueprint.session.get(settings.oidcUserInfoPath)
except (InvalidTokenError, TokenExpiredError, InvalidGrantError):
flash("Token expired.", 'warning')
return logout(next_url=request.full_path)
return None

def authorized_with_valid_token(f):
@wraps(f)
def decorated_function(*args, **kwargs):

if settings.debug_oidc_token:
oidc_blueprint.session.token = {'access_token': settings.debug_oidc_token}
else:
try:
if not oidc_blueprint.session.authorized or 'username' not in session:
return logout(next_url=request.full_path)

if oidc_blueprint.session.token['expires_in'] < 20:
app.logger.debug("Force refresh token")
oidc_blueprint.session.get(settings.oidcUserInfoPath)
except (InvalidTokenError, TokenExpiredError, InvalidGrantError):
flash("Token expired.", 'warning')
return logout(next_url=request.full_path)

check_auth_res = check_auth()
if check_auth_res:
return check_auth_res # redirect to login
return f(*args, **kwargs)

return decorated_function
Expand Down Expand Up @@ -189,11 +193,9 @@ def home():

if settings.debug_oidc_token:
oidc_blueprint.session.token = {'access_token': settings.debug_oidc_token}
else:
if not oidc_blueprint.session.authorized:
return redirect(url_for('login'))

if 'userid' not in session or not session['userid']:
if ((oidc_blueprint.session.authorized or settings.debug_oidc_token) and
('userid' not in session or not session['userid'])):
# Only contact userinfo endpoint first time in session
try:
account_info = oidc_blueprint.session.get(settings.oidcUserInfoPath)
Expand Down Expand Up @@ -665,7 +667,6 @@ def infoutputs(infid=None):
return render_template('outputs.html', infid=infid, outputs=outputs)

@app.route('/configure')
@authorized_with_valid_token
def configure():
selected_tosca = None
inf_id = request.args.get('inf_id', None)
Expand All @@ -676,6 +677,9 @@ def configure():
inputs = {}
infra_name = ""
if inf_id:
check_auth_res = check_auth()
if check_auth_res:
return check_auth_res # redirect to login
access_token = oidc_blueprint.session.token['access_token']
auth_data = utils.getIMUserAuthData(access_token, cred, get_cred_id())
try:
Expand Down Expand Up @@ -706,7 +710,7 @@ def configure():
flash("Invalid TOSCA template name: %s" % selected_tosca, "error")
return redirect(url_for('home'))

if not utils.valid_template_vos(session['vos'], toscaInfo[selected_tosca]["metadata"]):
if not utils.valid_template_vos(session, toscaInfo[selected_tosca]["metadata"]):
flash("Invalid TOSCA template name: %s" % selected_tosca, "error")
return redirect(url_for('home'))

Expand All @@ -715,26 +719,27 @@ def configure():
if "childs" in toscaInfo[selected_tosca]["metadata"]:
if childs is not None:
for child in childs:
if child in toscaInfo and utils.valid_template_vos(session['vos'], toscaInfo[child]["metadata"]):
if child in toscaInfo and utils.valid_template_vos(session, toscaInfo[child]["metadata"]):
child_templates[child] = toscaInfo[child]
if "inputs" in toscaInfo[child]:
selected_template["inputs"].update(toscaInfo[child]["inputs"])
if "tabs" in toscaInfo[child]:
selected_template["tabs"].extend(toscaInfo[child]["tabs"])
else:
for child in toscaInfo[selected_tosca]["metadata"]["childs"]:
if child in toscaInfo and utils.valid_template_vos(session['vos'], toscaInfo[child]["metadata"]):
if child in toscaInfo and utils.valid_template_vos(session, toscaInfo[child]["metadata"]):
child_templates[child] = toscaInfo[child]
return render_template('portfolio.html', templates=child_templates, parent=selected_tosca)
else:
app.logger.debug("Template: " + json.dumps(toscaInfo[selected_tosca]))

try:
creds = cred.get_creds(get_cred_id(), 1)
except Exception as ex:
flash("Error getting user credentials: %s" % ex, "error")
creds = []
utils.get_project_ids(creds)
creds = []
if check_auth() is None:
try:
creds = cred.get_creds(get_cred_id(), 1)
except Exception as ex:
flash("Error getting user credentials: %s" % ex, "error")
utils.get_project_ids(creds)

# Enable to get input values from URL parameters
for input_name, input_value in selected_template["inputs"].items():
Expand Down Expand Up @@ -1467,9 +1472,10 @@ def internal_server_error(error):
def reload_sites():
scheduler.modify_job('reload_sites', trigger='interval', seconds=settings.appdb_cache_timeout - 30)
with app.app_context():
app.logger.debug('Reload Site List.')
app.logger.debug('Reload Site/VO List.')
g.settings = settings
utils.getCachedSiteList(True)
utils.getCachedVOList(True)

# Reload internally the TOSCA tamplates
@scheduler.task('interval', id='reload_templates', seconds=settings.checkToscaChangesTime)
Expand Down
4 changes: 2 additions & 2 deletions app/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def login(self, avatar):
def test_index_with_no_login(self):
self.oauth.session.authorized = False
res = self.client.get('/')
self.assertEqual(302, res.status_code)
self.assertIn('/login', res.headers['location'])
self.assertEqual(200, res.status_code)
self.assertIn(b'Portfolio', res.data)

@patch("app.utils.avatar")
def test_index(self, avatar):
Expand Down
10 changes: 6 additions & 4 deletions app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,12 +756,12 @@ def get_project_ids(creds):
return creds


def getCachedVOList():
def getCachedVOList(force=False):
global VO_LIST
global VO_LAST_UPDATE

now = int(time.time())
if not VO_LIST or now - VO_LAST_UPDATE > g.settings.appdb_cache_timeout:
if force or not VO_LIST or now - VO_LAST_UPDATE > g.settings.appdb_cache_timeout:
try:
VO_LIST = appdb.get_vo_list()
# in case of error do not update time
Expand Down Expand Up @@ -845,9 +845,11 @@ def discover_oidc_urls(base_url):
return res


def valid_template_vos(user_vos, template_metadata):
def valid_template_vos(session, template_metadata):
if 'vos' in template_metadata:
return [vo for vo in user_vos if vo in template_metadata['vos']]
if 'vos' not in session or not session['vos']:
return []
return [vo for vo in session['vos'] if vo in template_metadata['vos']]
else:
return ['all']

Expand Down