Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow setting defaults through a config file. #279

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions splash/browser_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from twisted.internet import defer
from twisted.python import log

from splash import defaults
from splash.config import settings
from splash.har.qt import cookies2har
from splash.qtrender_image import QtImageRenderer
from splash.qtutils import OPERATION_QT_CONSTANTS, WrappedSignal, qt2py, qurl2ascii
Expand Down Expand Up @@ -84,10 +84,10 @@ def _init_webpage(self, verbosity, network_manager, splash_proxy_factory, render
self.web_view.move(0, 0)
self.web_view.show()

self.set_viewport(defaults.VIEWPORT_SIZE)
self.set_viewport(settings.VIEWPORT_SIZE)
# XXX: hack to ensure that default window size is not 640x480.
self.web_view.resize(
QSize(*map(int, defaults.VIEWPORT_SIZE.split('x'))))
QSize(*map(int, settings.VIEWPORT_SIZE.split('x'))))

def set_js_enabled(self, val):
settings = self.web_page.settings()
Expand Down Expand Up @@ -184,7 +184,7 @@ def set_viewport(self, size, raise_if_empty=False):
if raise_if_empty:
raise RuntimeError("Cannot detect viewport size")
else:
size = defaults.VIEWPORT_SIZE
size = settings.VIEWPORT_SIZE
self.logger.log("Viewport is empty, falling back to: %s" %
size)

Expand Down
4 changes: 2 additions & 2 deletions splash/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from __future__ import absolute_import
from PyQt4.QtNetwork import QNetworkDiskCache
from twisted.python import log
from splash import defaults
from splash.config import settings


def construct(path=defaults.CACHE_PATH, size=defaults.CACHE_SIZE):
def construct(path=settings.CACHE_PATH, size=settings.CACHE_SIZE):
log.msg("Initializing cache on %s (maxsize: %d Mb)" % (path, size))
cache = QNetworkDiskCache()
cache.setCacheDirectory(path)
Expand Down
67 changes: 67 additions & 0 deletions splash/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-

import __builtin__
import ast
import ConfigParser
import os

from . import defaults


class ConfigError(Exception):
pass

global CONFIG_PATH
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks brittle.. It will require changing more code, but maybe we can make this path a Settings constructor parameter and don't create a global settings variable at module level? We can create it in server startup code (when the path is known) and pass it down to all components via function/method/constructor arguments instead.



class Settings(object):
"""Handles config files and default values of config settings."""

NO_CONFIG_FILE_MSG = "Config file doesn't exist at %s"

def __init__(self):
try:
self.config_path = CONFIG_PATH
except NameError:
# CONFIG_PATH is not defined. User hasn't passed in a config file.
self.config_path = None
self.defaults = {}
for name in dir(defaults):
if name.isupper():
self.defaults[name] = getattr(defaults, name)
parser = ConfigParser.SafeConfigParser()
# don't convert keys to lowercase.
parser.optionxform = str
if parser.read(self._get_configfile_paths()):
# Safely evaluate configuration values.
self.cfg = {key: ast.literal_eval(val) for (key, val) in parser.items('settings')}
else:
self.cfg = {}

def _get_configfile_paths(self):
"""Returns a list of config file paths."""
if self.config_path:
config_dir_path = os.path.abspath(os.path.expanduser(self.config_path))
configfile_path = os.path.abspath(os.path.join(config_dir_path, 'splash.cfg'))
if not os.path.isfile(configfile_path):
# file doesn't exist
raise ConfigError(self.NO_CONFIG_FILE_MSG % configfile_path)
else:
return configfile_path
else:
xdg_config_home = os.environ.get('XDG_CONFIG_HOME') or \
os.path.expanduser('~/.config')
return ['/etc/splash.cfg',
'C:\\splash\splash.cfg',
os.path.join(xdg_config_home, 'splash.cfg'),
os.path.expanduser('~/.splash.cfg')]

def __getattr__(self, item):
val = self.cfg.get(item, None)
if val is None:
val = self.defaults.get(item, None)
if val is None:
raise AttributeError("There is no settings named %s" % item)
return val

settings = Settings()
6 changes: 3 additions & 3 deletions splash/network_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
AdblockRulesRegistry,
ResourceTimeoutMiddleware)
from splash.response_middleware import ContentTypeMiddleware
from splash import defaults
from splash.config import settings


def create_default(filters_path=None, verbosity=None, allowed_schemes=None):
verbosity = defaults.VERBOSITY if verbosity is None else verbosity
verbosity = settings.VERBOSITY if verbosity is None else verbosity
if allowed_schemes is None:
allowed_schemes = defaults.ALLOWED_SCHEMES
allowed_schemes = settings.ALLOWED_SCHEMES
else:
allowed_schemes = allowed_schemes.split(',')
manager = SplashQNetworkAccessManager(
Expand Down
6 changes: 3 additions & 3 deletions splash/qtrender.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import functools
import pprint
from splash import defaults
from splash.config import settings
from splash.browser_tab import BrowserTab
from splash.exceptions import RenderError

Expand Down Expand Up @@ -80,11 +80,11 @@ def start(self, url, baseurl=None, wait=None, viewport=None,
render_all=False, resource_timeout=None):

self.url = url
self.wait_time = defaults.WAIT_TIME if wait is None else wait
self.wait_time = settings.WAIT_TIME if wait is None else wait
self.js_source = js_source
self.js_profile = js_profile
self.console = console
self.viewport = defaults.VIEWPORT_SIZE if viewport is None else viewport
self.viewport = settings.VIEWPORT_SIZE if viewport is None else viewport
self.render_all = render_all or viewport == 'full'

if resource_timeout:
Expand Down
14 changes: 7 additions & 7 deletions splash/qtrender_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from PyQt4.QtCore import QBuffer, QPoint, QRect, QSize, Qt
from PyQt4.QtGui import QImage, QPainter, QRegion

from splash import defaults
from splash.config import settings


class QtImageRenderer(object):
Expand All @@ -36,7 +36,7 @@ def __init__(self, web_page, logger=None, image_format=None,
self.width = width
self.height = height
if scale_method is None:
scale_method = defaults.IMAGE_SCALE_METHOD
scale_method = settings.IMAGE_SCALE_METHOD
self.scale_method = scale_method
self.image_format = image_format.upper()
if not (self.is_png() or self.is_jpeg()):
Expand Down Expand Up @@ -319,7 +319,7 @@ def _calculate_image_parameters(self, web_viewport, img_width, img_height):
return image_viewport, image_size

def _calculate_tiling(self, to_paint):
tile_maxsize = defaults.TILE_MAXSIZE
tile_maxsize = settings.TILE_MAXSIZE
tile_hsize = min(tile_maxsize, to_paint.width())
tile_vsize = min(tile_maxsize, to_paint.height())
htiles = 1 + (to_paint.width() - 1) // tile_hsize
Expand Down Expand Up @@ -420,15 +420,15 @@ def crop(self, rect):
assert isinstance(rect, QRect)
self.img = self.img.copy(rect)

def to_png(self, complevel=defaults.PNG_COMPRESSION_LEVEL):
def to_png(self, complevel=settings.PNG_COMPRESSION_LEVEL):
quality = 90 - (complevel * 10)
buf = QBuffer()
self.img.save(buf, 'png', quality)
return bytes(buf.data())

def to_jpeg(self, quality=None):
if quality is None:
quality = defaults.JPEG_QUALITY
quality = settings.JPEG_QUALITY
buf = QBuffer()
self.img.save(buf, 'jpeg', quality)
return bytes(buf.data())
Expand All @@ -454,14 +454,14 @@ def crop(self, rect):
top, bottom = rect.top(), rect.top() + rect.height()
self.img = self.img.crop((left, top, right, bottom))

def to_png(self, complevel=defaults.PNG_COMPRESSION_LEVEL):
def to_png(self, complevel=settings.PNG_COMPRESSION_LEVEL):
buf = StringIO()
self.img.save(buf, 'png', compress_level=complevel)
return buf.getvalue()

def to_jpeg(self, quality=None):
if quality is None:
quality = defaults.JPEG_QUALITY
quality = settings.JPEG_QUALITY
buf = StringIO()
self.img.save(buf, 'jpeg', quality=quality)
return buf.getvalue()
44 changes: 22 additions & 22 deletions splash/render_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import absolute_import
import os
import json
from splash import defaults
from splash.config import settings
from splash.utils import path_join_secure
from splash.exceptions import BadOption

Expand Down Expand Up @@ -115,19 +115,19 @@ def get_baseurl(self):
return self._get_url("baseurl", default=None)

def get_wait(self):
return self.get("wait", defaults.WAIT_TIME,
type=float, range=(0, defaults.MAX_WAIT_TIME))
return self.get("wait", settings.WAIT_TIME,
type=float, range=(0, settings.MAX_WAIT_TIME))

def get_timeout(self):
default = min(self.max_timeout, defaults.TIMEOUT)
default = min(self.max_timeout, settings.TIMEOUT)
return self.get("timeout", default, type=float, range=(0, self.max_timeout))

def get_resource_timeout(self):
return self.get("resource_timeout", defaults.RESOURCE_TIMEOUT,
return self.get("resource_timeout", settings.RESOURCE_TIMEOUT,
type=float, range=(0, 1e6))

def get_images(self):
return self._get_bool("images", defaults.AUTOLOAD_IMAGES)
return self._get_bool("images", settings.AUTOLOAD_IMAGES)

def get_proxy(self):
return self.get("proxy", default=None)
Expand All @@ -136,13 +136,13 @@ def get_js_source(self):
return self.get("js_source", default=None)

def get_width(self):
return self.get("width", None, type=int, range=(1, defaults.MAX_WIDTH))
return self.get("width", None, type=int, range=(1, settings.MAX_WIDTH))

def get_height(self):
return self.get("height", None, type=int, range=(1, defaults.MAX_HEIGTH))
return self.get("height", None, type=int, range=(1, settings.MAX_HEIGTH))

def get_scale_method(self):
scale_method = self.get("scale_method", defaults.IMAGE_SCALE_METHOD)
scale_method = self.get("scale_method", settings.IMAGE_SCALE_METHOD)
allowed_scale_methods = ['raster', 'vector']
if scale_method not in allowed_scale_methods:
self.raise_error(
Expand All @@ -155,7 +155,7 @@ def get_scale_method(self):
return scale_method

def get_quality(self):
return self.get("quality", defaults.JPEG_QUALITY, type=int, range=(0, 100))
return self.get("quality", settings.JPEG_QUALITY, type=int, range=(0, 100))

def get_http_method(self):
method = self.get("http_method", "GET")
Expand Down Expand Up @@ -226,7 +226,7 @@ def get_headers(self):
return headers

def get_viewport(self, wait=None):
viewport = self.get("viewport", defaults.VIEWPORT_SIZE)
viewport = self.get("viewport", settings.VIEWPORT_SIZE)

if viewport == 'full':
if wait == 0:
Expand Down Expand Up @@ -323,14 +323,14 @@ def get_jpeg_params(self):

def get_include_params(self):
return dict(
html=self._get_bool("html", defaults.DO_HTML),
iframes=self._get_bool("iframes", defaults.DO_IFRAMES),
png=self._get_bool("png", defaults.DO_PNG),
jpeg=self._get_bool("jpeg", defaults.DO_JPEG),
script=self._get_bool("script", defaults.SHOW_SCRIPT),
console=self._get_bool("console", defaults.SHOW_CONSOLE),
history=self._get_bool("history", defaults.SHOW_HISTORY),
har=self._get_bool("har", defaults.SHOW_HAR),
html=self._get_bool("html", settings.DO_HTML),
iframes=self._get_bool("iframes", settings.DO_IFRAMES),
png=self._get_bool("png", settings.DO_PNG),
jpeg=self._get_bool("jpeg", settings.DO_JPEG),
script=self._get_bool("script", settings.SHOW_SCRIPT),
console=self._get_bool("console", settings.SHOW_CONSOLE),
history=self._get_bool("history", settings.SHOW_HISTORY),
har=self._get_bool("har", settings.SHOW_HAR),
)


Expand All @@ -345,9 +345,9 @@ def validate_size_str(size_str):
:param size_str: string to validate

"""
max_width = defaults.VIEWPORT_MAX_WIDTH
max_heigth = defaults.VIEWPORT_MAX_HEIGTH
max_area = defaults.VIEWPORT_MAX_AREA
max_width = settings.VIEWPORT_MAX_WIDTH
max_heigth = settings.VIEWPORT_MAX_HEIGTH
max_area = settings.VIEWPORT_MAX_AREA
try:
w, h = map(int, size_str.split('x'))
except ValueError:
Expand Down
Loading