diff --git a/src/repository.c b/src/repository.c index 0b12053a4..bd048c73d 100644 --- a/src/repository.c +++ b/src/repository.c @@ -179,21 +179,31 @@ Repository_head__get__(Repository *self) } int -Repository_head__set__(Repository *self, PyObject *py_refname) +Repository_head__set__(Repository *self, PyObject *py_val) { int err; - const char *refname; - PyObject *trefname; + if (PyObject_TypeCheck(py_val, &OidType)) { + git_oid oid; + py_oid_to_git_oid(py_val, &oid); + err = git_repository_set_head_detached(self->repo, &oid, NULL, NULL); + if (err < 0) { + Error_set(err); + return -1; + } + } else { + const char *refname; + PyObject *trefname; - refname = py_str_borrow_c_str(&trefname, py_refname, NULL); - if (refname == NULL) - return -1; + refname = py_str_borrow_c_str(&trefname, py_val, NULL); + if (refname == NULL) + return -1; - err = git_repository_set_head(self->repo, refname, NULL, NULL); - Py_DECREF(trefname); - if (err < 0) { - Error_set_str(err, refname); - return -1; + err = git_repository_set_head(self->repo, refname, NULL, NULL); + Py_DECREF(trefname); + if (err < 0) { + Error_set_str(err, refname); + return -1; + } } return 0; @@ -512,6 +522,27 @@ Repository_workdir__get__(Repository *self, void *closure) return to_path(c_path); } +int +Repository_workdir__set__(Repository *self, PyObject *py_workdir) +{ + int err; + const char *workdir; + PyObject *tworkdir; + + workdir = py_str_borrow_c_str(&tworkdir, py_workdir, NULL); + if (workdir == NULL) + return -1; + + err = git_repository_set_workdir(self->repo, workdir, 0 /* update_gitlink */); + Py_DECREF(tworkdir); + if (err < 0) { + Error_set_str(err, workdir); + return -1; + } + + return 0; +} + PyDoc_STRVAR(Repository_merge_base__doc__, "merge_base(oid, oid) -> Oid\n" "\n" @@ -1597,7 +1628,7 @@ PyGetSetDef Repository_getseters[] = { GETTER(Repository, head_is_unborn), GETTER(Repository, is_empty), GETTER(Repository, is_bare), - GETTER(Repository, workdir), + GETSET(Repository, workdir), GETTER(Repository, default_signature), GETTER(Repository, _pointer), {NULL} diff --git a/test/test_repository.py b/test/test_repository.py index 4581f4774..4d6472ec1 100644 --- a/test/test_repository.py +++ b/test/test_repository.py @@ -70,6 +70,15 @@ def test_head(self): self.assertFalse(self.repo.head_is_unborn) self.assertFalse(self.repo.head_is_detached) + def test_set_head(self): + # Test setting a detatched HEAD. + self.repo.head = Oid(hex=PARENT_SHA) + self.assertEqual(self.repo.head.target.hex, PARENT_SHA) + # And test setting a normal HEAD. + self.repo.head = "refs/heads/master" + self.assertEqual(self.repo.head.name, "refs/heads/master") + self.assertEqual(self.repo.head.target.hex, HEAD_SHA) + def test_read(self): self.assertRaises(TypeError, self.repo.read, 123) self.assertRaisesWithArg(KeyError, '1' * 40, self.repo.read, '1' * 40) @@ -190,6 +199,11 @@ def test_get_workdir(self): expected = realpath(self.repo_path) self.assertEqual(directory, expected) + def test_set_workdir(self): + directory = tempfile.mkdtemp() + self.repo.workdir = directory + self.assertEqual(realpath(self.repo.workdir), realpath(directory)) + def test_checkout_ref(self): ref_i18n = self.repo.lookup_reference('refs/heads/i18n')