Skip to content

Commit

Permalink
Merge pull request NVIDIA#199 from crohkohl/master
Browse files Browse the repository at this point in the history
Windows Compatibility
  • Loading branch information
lukeyeager committed Aug 13, 2015
2 parents 7f9d50a + d75a972 commit 50eee51
Show file tree
Hide file tree
Showing 21 changed files with 130 additions and 63 deletions.
9 changes: 8 additions & 1 deletion digits/config/caffe_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def validate(cls, value):
if value == '<PATHS>':
# Find the executable
executable = cls.find_executable('caffe')
if not executable:
executable = cls.find_executable('caffe.exe')
if not executable:
raise config_option.BadValue('caffe binary not found in PATH')
cls.validate_version(executable)
Expand Down Expand Up @@ -187,6 +189,9 @@ def get_version(executable):
elif platform.system() == 'Darwin':
# XXX: guess and let the user figure out errors later
return (0,11,0)
elif platform.system() == 'Windows':
# XXX: guess and let the user figure out errors later
return (0,11,0)
else:
print 'WARNING: platform "%s" not supported' % platform.system()
return None
Expand All @@ -197,6 +202,8 @@ def _set_config_dict_value(self, value):
else:
if value == '<PATHS>':
executable = self.find_executable('caffe')
if not executable:
executable = self.find_executable('caffe.exe')
else:
executable = os.path.join(value, 'build', 'tools', 'caffe')

Expand Down Expand Up @@ -238,7 +245,7 @@ def apply(self):
print 'Did you forget to "make pycaffe"?'
raise

if platform.system() == 'Darwin':
if platform.system() == 'Darwin' or platform.system() == 'Windows':
# Strange issue with protocol buffers and pickle - see issue #32
sys.path.insert(0, os.path.join(
os.path.dirname(caffe.__file__), 'proto'))
Expand Down
4 changes: 2 additions & 2 deletions digits/dataset/images/classification/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __setstate__(self, state):
import numpy as np

old_blob = caffe_pb2.BlobProto()
with open(task.path(task.mean_file)) as infile:
with open(task.path(task.mean_file),'rb') as infile:
old_blob.ParseFromString(infile.read())
data = np.array(old_blob.data).reshape(
old_blob.channels,
Expand All @@ -48,7 +48,7 @@ def __setstate__(self, state):
new_blob.num = 1
new_blob.channels, new_blob.height, new_blob.width = data.shape
new_blob.data.extend(data.astype(float).flat)
with open(task.path(task.mean_file), 'w') as outfile:
with open(task.path(task.mean_file), 'wb') as outfile:
outfile.write(new_blob.SerializeToString())
else:
print '\tSetting "%s" status to ERROR because it was created with RGB channels' % self.name()
Expand Down
2 changes: 1 addition & 1 deletion digits/dataset/images/generic/test_lmdb_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _save_mean(mean, filename):
blob.channels = 1
blob.height, blob.width = mean.shape
blob.data.extend(mean.astype(float).flat)
with open(filename, 'w') as outfile:
with open(filename, 'wb') as outfile:
outfile.write(blob.SerializeToString())

elif filename.endswith(('.jpg', '.jpeg', '.png')):
Expand Down
18 changes: 16 additions & 2 deletions digits/device_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,24 @@ def get_library(name):
return ctypes.cdll.LoadLibrary('%s.so' % name)
elif platform.system() == 'Darwin':
return ctypes.cdll.LoadLibrary('%s.dylib' % name)
elif platform.system() == 'Windows':
return ctypes.windll.LoadLibrary('%s.dll' % name)
except OSError:
pass
return None

devices = None

def get_cudart():
if not platform.system() == 'Windows':
return get_library('libcudart')

arch = platform.architecture()[0]
for ver in range(90,50,-5):
cudart = get_library('cudart%s_%d' % (arch[:2], ver))
if not cudart is None:
return cudart

def get_devices(force_reload=False):
"""
Returns a list of c_cudaDeviceProp's
Expand All @@ -131,7 +143,7 @@ def get_devices(force_reload=False):
return devices
devices = []

cudart = get_library('libcudart')
cudart = get_cudart()
if cudart is None:
return []

Expand Down Expand Up @@ -180,7 +192,9 @@ def get_nvml_info(device_id):

nvml = get_library('libnvidia-ml')
if nvml is None:
return None
nvml = get_library('nvml')
if nvml is None:
return None

rc = nvml.nvmlInit()
if rc != 0:
Expand Down
2 changes: 1 addition & 1 deletion digits/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def path(self, filename, relative=False):
path = os.path.join(self._dir, filename)
if relative:
path = os.path.relpath(path, config_value('jobs_dir'))
return str(path)
return str(path).replace("\\","/")

def path_is_local(self, path):
"""assert that a path is local to _dir"""
Expand Down
4 changes: 2 additions & 2 deletions digits/model/images/classification/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def test_classify_one(self):
category = self.imageset_paths.keys()[0]
image_path = self.imageset_paths[category][0]
image_path = os.path.join(self.imageset_folder, image_path)
with open(image_path) as infile:
with open(image_path,'rb') as infile:
# StringIO wrapping is needed to simulate POST file upload.
image_upload = (StringIO(infile.read()), 'image.png')

Expand All @@ -360,7 +360,7 @@ def test_classify_one_json(self):
category = self.imageset_paths.keys()[0]
image_path = self.imageset_paths[category][0]
image_path = os.path.join(self.imageset_folder, image_path)
with open(image_path) as infile:
with open(image_path,'rb') as infile:
# StringIO wrapping is needed to simulate POST file upload.
image_upload = (StringIO(infile.read()), 'image.png')

Expand Down
9 changes: 6 additions & 3 deletions digits/model/images/classification/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from forms import ImageClassificationModelForm
from job import ImageClassificationModelJob
from digits.status import Status
import platform

NAMESPACE = '/models/images/classification'

Expand Down Expand Up @@ -246,9 +247,11 @@ def image_classification_model_classify_one():
if 'image_url' in flask.request.form and flask.request.form['image_url']:
image = utils.image.load_image(flask.request.form['image_url'])
elif 'image_file' in flask.request.files and flask.request.files['image_file']:
with tempfile.NamedTemporaryFile() as outfile:
flask.request.files['image_file'].save(outfile.name)
image = utils.image.load_image(outfile.name)
outfile = tempfile.mkstemp(suffix='.bin')
flask.request.files['image_file'].save(outfile[1])
image = utils.image.load_image(outfile[1])
os.close(outfile[0])
os.remove(outfile[1])
else:
raise werkzeug.exceptions.BadRequest('must provide image_url or image_file')

Expand Down
4 changes: 2 additions & 2 deletions digits/model/images/generic/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def test_model_json(self):

def test_infer_one(self):
image_path = os.path.join(self.imageset_folder, self.test_image)
with open(image_path) as infile:
with open(image_path,'rb') as infile:
# StringIO wrapping is needed to simulate POST file upload.
image_upload = (StringIO(infile.read()), 'image.png')

Expand All @@ -355,7 +355,7 @@ def test_infer_one(self):

def test_infer_one_json(self):
image_path = os.path.join(self.imageset_folder, self.test_image)
with open(image_path) as infile:
with open(image_path,'rb') as infile:
# StringIO wrapping is needed to simulate POST file upload.
image_upload = (StringIO(infile.read()), 'image.png')

Expand Down
9 changes: 6 additions & 3 deletions digits/model/images/generic/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from forms import GenericImageModelForm
from job import GenericImageModelJob
from digits.status import Status
import platform

NAMESPACE = '/models/images/generic'

Expand Down Expand Up @@ -222,9 +223,11 @@ def generic_image_model_infer_one():
if 'image_url' in flask.request.form and flask.request.form['image_url']:
image = utils.image.load_image(flask.request.form['image_url'])
elif 'image_file' in flask.request.files and flask.request.files['image_file']:
with tempfile.NamedTemporaryFile() as outfile:
flask.request.files['image_file'].save(outfile.name)
image = utils.image.load_image(outfile.name)
outfile = tempfile.mkstemp(suffix='.bin')
flask.request.files['image_file'].save(outfile[1])
image = utils.image.load_image(outfile[1])
os.close(outfile[0])
os.remove(outfile[1])
else:
raise werkzeug.exceptions.BadRequest('must provide image_url or image_file')

Expand Down
6 changes: 3 additions & 3 deletions digits/model/tasks/caffe_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def save_files_classification(self):
val_data_layer.data_param.backend = caffe_pb2.DataParameter.LMDB
if self.use_mean:
mean_pixel = None
with open(self.dataset.path(self.dataset.train_db_task().mean_file)) as f:
with open(self.dataset.path(self.dataset.train_db_task().mean_file),'rb') as f:
blob = caffe_pb2.BlobProto()
blob.MergeFromString(f.read())
mean = np.reshape(blob.data,
Expand Down Expand Up @@ -1312,7 +1312,7 @@ def get_transformer(self):
channel_swap = (2,1,0)

if self.use_mean:
with open(self.dataset.path(self.dataset.train_db_task().mean_file)) as infile:
with open(self.dataset.path(self.dataset.train_db_task().mean_file),'rb') as infile:
blob = caffe_pb2.BlobProto()
blob.MergeFromString(infile.read())
mean_pixel = np.reshape(blob.data,
Expand All @@ -1331,7 +1331,7 @@ def get_transformer(self):
channel_swap = (2,1,0)

if self.dataset.mean_file:
with open(self.dataset.path(self.dataset.mean_file)) as infile:
with open(self.dataset.path(self.dataset.mean_file),'rb') as infile:
blob = caffe_pb2.BlobProto()
blob.MergeFromString(infile.read())
mean_pixel = np.reshape(blob.data,
Expand Down
6 changes: 4 additions & 2 deletions digits/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from config import config_value
from status import Status, StatusCls

import platform

# NOTE: Increment this everytime the pickled version changes
PICKLE_VERSION = 1

Expand Down Expand Up @@ -128,7 +130,7 @@ def path(self, filename, relative=False):
path = os.path.join(self.job_dir, filename)
if relative:
path = os.path.relpath(path, config_value('jobs_dir'))
return str(path)
return str(path).replace("\\","/")

def ready_to_queue(self):
"""
Expand Down Expand Up @@ -193,7 +195,7 @@ def run(self, resources):
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
cwd=self.job_dir,
close_fds=True,
close_fds=False if platform.system() == 'Windows' else True,
)

try:
Expand Down
21 changes: 15 additions & 6 deletions digits/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@

import os
import math
import fcntl
import locale
from random import uniform
from urlparse import urlparse
from io import BlockingIOError
import inspect
import platform


if not platform.system() == 'Windows':
import fcntl
else:
import gevent.os

HTTP_TIMEOUT = 6.05

def is_url(url):
return url is not None and urlparse(url).scheme != ""
return url is not None and urlparse(url).scheme != "" and not os.path.exists(url)

def wait_time():
"""Wait a random number of seconds"""
Expand All @@ -27,22 +33,25 @@ def nonblocking_readlines(f):
Newlines are normalized to the Unix standard.
"""
fd = f.fileno()
fl = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
if not platform.system() == 'Windows':
fl = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
enc = locale.getpreferredencoding(False)

buf = bytearray()
while True:
try:
block = os.read(fd, 8192)
if not platform.system() == 'Windows':
block = os.read(fd, 8192)
else:
block = gevent.os.tp_read(fd, 8192)
except (BlockingIOError, OSError):
yield ""
continue

if not block:
if buf:
yield buf.decode(enc)
buf.clear()
break

buf.extend(block)
Expand Down
27 changes: 22 additions & 5 deletions digits/utils/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import mock
import PIL.Image
import numpy as np
import os
import platform

from . import image as _, errors

Expand Down Expand Up @@ -52,9 +54,19 @@ def check_good_file(self, args):
orig_mode, suffix, pixel, new_mode = args

orig = PIL.Image.new(orig_mode, (10,10), pixel)
with tempfile.NamedTemporaryFile(suffix='.' + suffix) as tmp:
orig.save(tmp.name)
new = _.load_image(tmp.name)

# temp files cause permission errors so just generate the name
tmp = tempfile.mkstemp(suffix='.' + suffix)
orig.save(tmp[1])
new = _.load_image(tmp[1])
try:
# sometimes on windows the file is not closed yet
# which can cause an exception
os.close(tmp[0])
os.remove(tmp[1])
except:
pass

assert new is not None, 'load_image should never return None'
assert new.mode == new_mode, 'Image mode should be "%s", not "%s\nargs - %s' % (new_mode, new.mode, args)

Expand Down Expand Up @@ -94,16 +106,21 @@ def test_corrupted_file(self):
corrupted = encoded[:size/2] + encoded[size/2:][::-1]

# Save the corrupted image to a temporary file.
f = tempfile.NamedTemporaryFile(delete=False)
fname = tempfile.mkstemp(suffix='.bin')
f = os.fdopen(fname[0],'wb')
fname = fname[1]

f.write(corrupted)
f.close()

assert_raises(
errors.LoadImageError,
_.load_image,
f.name,
fname,
)

os.remove(fname)

class TestResizeImage():

@classmethod
Expand Down
Loading

0 comments on commit 50eee51

Please sign in to comment.