Skip to content

Commit

Permalink
Merge pull request BVLC#11 from longjon/master
Browse files Browse the repository at this point in the history
Python interface to blobs and blob data through boost::python
  • Loading branch information
shelhamer committed Jan 20, 2014
2 parents 2f80e1a + f03201f commit 58e7f39
Showing 1 changed file with 102 additions and 0 deletions.
102 changes: 102 additions & 0 deletions python/caffe/pycaffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION

#include <boost/python.hpp>
#include <boost/python/suite/indexing/vector_indexing_suite.hpp>
#include <numpy/arrayobject.h>
#include "caffe/caffe.hpp"

Expand All @@ -20,6 +21,78 @@ using boost::python::extract;
using boost::python::len;
using boost::python::list;
using boost::python::object;
using boost::python::handle;
using boost::python::vector_indexing_suite;


// wrap shared_ptr<Blob<float> > in a class that we construct in C++ and pass
// to Python
class CaffeBlob {
public:

CaffeBlob(const shared_ptr<Blob<float> > &blob)
: blob_(blob) {}

CaffeBlob()
{}

int num() const { return blob_->num(); }
int channels() const { return blob_->channels(); }
int height() const { return blob_->height(); }
int width() const { return blob_->width(); }
int count() const { return blob_->count(); }

bool operator == (const CaffeBlob &other)
{
return this->blob_ == other.blob_;
}

protected:
shared_ptr<Blob<float> > blob_;
};


// we need another wrapper (used as boost::python's HeldType) that receives a
// self PyObject * which we can use as ndarray.base, so that data/diff memory
// is not freed while still being used in Python
class CaffeBlobWrap : public CaffeBlob {
public:
CaffeBlobWrap(PyObject *p, shared_ptr<Blob<float> > &blob)
: CaffeBlob(blob), self_(p) {}

CaffeBlobWrap(PyObject *p, const CaffeBlob &blob)
: CaffeBlob(blob), self_(p) {}

object get_data()
{
npy_intp dims[] = {num(), channels(), height(), width()};

PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
blob_->mutable_cpu_data());
PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
Py_INCREF(self_);
handle<> h(obj);

return object(h);
}

object get_diff()
{
npy_intp dims[] = {num(), channels(), height(), width()};

PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
blob_->mutable_cpu_diff());
PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
Py_INCREF(self_);
handle<> h(obj);

return object(h);
}

private:
PyObject *self_;
};



// A simple wrapper over CaffeNet that runs the forward process.
Expand Down Expand Up @@ -143,14 +216,24 @@ struct CaffeNet
void set_phase_test() { Caffe::set_phase(Caffe::TEST); }
void set_device(int device_id) { Caffe::SetDevice(device_id); }

vector<CaffeBlob> blobs() {
return vector<CaffeBlob>(net_->blobs().begin(), net_->blobs().end());
}

vector<CaffeBlob> params() {
return vector<CaffeBlob>(net_->params().begin(), net_->params().end());
}

// The pointer to the internal caffe::Net instant.
shared_ptr<Net<float> > net_;
};



// The boost python module definition.
BOOST_PYTHON_MODULE(pycaffe)
{

boost::python::class_<CaffeNet>(
"CaffeNet", boost::python::init<string, string>())
.def("Forward", &CaffeNet::Forward)
Expand All @@ -160,5 +243,24 @@ BOOST_PYTHON_MODULE(pycaffe)
.def("set_phase_train", &CaffeNet::set_phase_train)
.def("set_phase_test", &CaffeNet::set_phase_test)
.def("set_device", &CaffeNet::set_device)
.def("blobs", &CaffeNet::blobs)
.def("params", &CaffeNet::params)
;

boost::python::class_<CaffeBlob, CaffeBlobWrap>(
"CaffeBlob", boost::python::no_init)
.add_property("num", &CaffeBlob::num)
.add_property("channels", &CaffeBlob::channels)
.add_property("height", &CaffeBlob::height)
.add_property("width", &CaffeBlob::width)
.add_property("count", &CaffeBlob::count)
.add_property("data", &CaffeBlobWrap::get_data)
.add_property("diff", &CaffeBlobWrap::get_diff)
;

boost::python::class_<vector<CaffeBlob> >("BlobVec")
.def(vector_indexing_suite<vector<CaffeBlob>, true>());

import_array();

}

0 comments on commit 58e7f39

Please sign in to comment.