Skip to content

Commit

Permalink
remove all forget() (#2)
Browse files Browse the repository at this point in the history
remove all `forget()`. 

Signed-off-by: 闹钟大魔王 <1348651580@qq.com>
  • Loading branch information
anti-entropy123 authored Sep 26, 2023
1 parent 48d5a5d commit 261e103
Show file tree
Hide file tree
Showing 14 changed files with 193 additions and 67 deletions.
86 changes: 67 additions & 19 deletions crates/libcontainer/src/channel.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use nix::{
sys::socket::{self, UnixAddr},
unistd::{self},
};
use nix::sys::socket::{self, UnixAddr};
use serde::{Deserialize, Serialize};
use std::{
io::{IoSlice, IoSliceMut},
Expand All @@ -10,6 +7,7 @@ use std::{
fd::{AsRawFd, OwnedFd},
unix::prelude::RawFd,
},
sync::Arc,
};

#[derive(Debug, thiserror::Error)]
Expand All @@ -20,16 +18,18 @@ pub enum ChannelError {
Serde(#[from] serde_json::Error),
#[error("channel connection broken")]
BrokenChannel,
#[error("Unable to be closed")]
Unclosed,
}
#[derive(Clone)]
pub struct Receiver<T> {
receiver: RawFd,
receiver: Option<Arc<OwnedFd>>,
phantom: PhantomData<T>,
}

#[derive(Clone)]
pub struct Sender<T> {
sender: RawFd,
sender: Option<Arc<OwnedFd>>,
phantom: PhantomData<T>,
}

Expand All @@ -47,8 +47,14 @@ where
} else {
vec![]
};
socket::sendmsg::<UnixAddr>(self.sender, iov, &cmsgs, socket::MsgFlags::empty(), None)
.map_err(|e| e.into())
socket::sendmsg::<UnixAddr>(
self.sender.as_ref().unwrap().as_raw_fd(),
iov,
&cmsgs,
socket::MsgFlags::empty(),
None,
)
.map_err(|e| e.into())
}

fn send_slice_with_len(
Expand Down Expand Up @@ -84,8 +90,30 @@ where
Ok(())
}

pub fn close(&self) -> Result<(), ChannelError> {
Ok(unistd::close(self.sender)?)
pub fn close(&mut self) -> Result<(), ChannelError> {
// must ensure that the fd is closed immediately.
let count = Arc::strong_count(self.sender.as_ref().unwrap());
if count != 1 {
tracing::trace!(?count, "incorrect reference count value");
return Err(ChannelError::Unclosed)?;
};
self.sender = None;

Ok(())
}

/// Enforce a decrement of the inner reference counter by 1.
///
/// # Safety
/// The reason for `unsafe` is the caller must ensure that it's only called
/// when absolutely necessary. For instance, in the current implementation,
/// `clone()` can cause a leak of references residing on the stack in the
/// childprocess. This function allows for manual adjustment of the counter
/// to correct such situations.
pub unsafe fn decrement_count(&self) {
let rc = Arc::into_raw(Arc::clone(self.sender.as_ref().unwrap()));
Arc::decrement_strong_count(rc);
Arc::from_raw(rc);
}
}

Expand All @@ -101,8 +129,12 @@ where
std::mem::size_of::<u64>(),
)
})];
let _ =
socket::recvmsg::<UnixAddr>(self.receiver, &mut iov, None, socket::MsgFlags::MSG_PEEK)?;
let _ = socket::recvmsg::<UnixAddr>(
self.receiver.as_ref().unwrap().as_raw_fd(),
&mut iov,
None,
socket::MsgFlags::MSG_PEEK,
)?;
match len {
0 => Err(ChannelError::BrokenChannel),
_ => Ok(len),
Expand All @@ -118,7 +150,7 @@ where
{
let mut cmsgspace = nix::cmsg_space!(F);
let msg = socket::recvmsg::<UnixAddr>(
self.receiver,
self.receiver.as_ref().unwrap().as_raw_fd(),
iov,
Some(&mut cmsgspace),
socket::MsgFlags::MSG_CMSG_CLOEXEC,
Expand Down Expand Up @@ -190,8 +222,26 @@ where
Ok((serde_json::from_slice(&buf[..])?, fds))
}

pub fn close(&self) -> Result<(), ChannelError> {
Ok(unistd::close(self.receiver)?)
pub fn close(&mut self) -> Result<(), ChannelError> {
// must ensure that the fd is closed immediately.
let count = Arc::strong_count(self.receiver.as_ref().unwrap());
if count != 1 {
tracing::trace!(?count, "incorrect reference count value");
return Err(ChannelError::Unclosed)?;
};
self.receiver = None;

Ok(())
}

/// Enforce a decrement of the inner reference counter by 1.
///
/// # Safety
/// The reason for `unsafe` is same as `Sender::decrement_count()`.
pub unsafe fn decrement_count(&self) {
let rc = Arc::into_raw(Arc::clone(self.receiver.as_ref().unwrap()));
Arc::decrement_strong_count(rc);
Arc::from_raw(rc);
}
}

Expand All @@ -201,15 +251,13 @@ where
{
let (os_sender, os_receiver) = unix_channel()?;
let receiver = Receiver {
receiver: os_receiver.as_raw_fd(),
receiver: Some(Arc::from(os_receiver)),
phantom: PhantomData,
};
let sender = Sender {
sender: os_sender.as_raw_fd(),
sender: Some(Arc::from(os_sender)),
phantom: PhantomData,
};
std::mem::forget(os_sender);
std::mem::forget(os_receiver);
Ok((sender, receiver))
}

Expand Down
6 changes: 3 additions & 3 deletions crates/libcontainer/src/container/builder_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
use libcgroups::common::CgroupManager;
use nix::unistd::Pid;
use oci_spec::runtime::Spec;
use std::{fs, io::Write, os::unix::prelude::RawFd, path::PathBuf, rc::Rc};
use std::{fs, io::Write, os::fd::OwnedFd, path::PathBuf, rc::Rc};

pub(super) struct ContainerBuilderImpl {
/// Flag indicating if an init or a tenant container should be created
Expand All @@ -35,7 +35,7 @@ pub(super) struct ContainerBuilderImpl {
/// container process to the higher level runtime
pub pid_file: Option<PathBuf>,
/// Socket to communicate the file descriptor of the ptty
pub console_socket: Option<RawFd>,
pub console_socket: Option<OwnedFd>,
/// Options for new user namespace
pub user_ns_config: Option<UserNamespaceConfig>,
/// Path to the Unix Domain Socket to communicate container start
Expand Down Expand Up @@ -140,7 +140,7 @@ impl ContainerBuilderImpl {
syscall: self.syscall,
spec: Rc::clone(&self.spec),
rootfs: self.rootfs.to_owned(),
console_socket: self.console_socket,
console_socket: self.console_socket.take().map(Rc::from),
notify_listener,
preserve_fds: self.preserve_fds,
container: self.container.to_owned(),
Expand Down
16 changes: 1 addition & 15 deletions crates/libcontainer/src/container/init_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ use nix::unistd;
use oci_spec::runtime::Spec;
use std::{
fs,
mem::forget,
os::fd::AsRawFd,
path::{Path, PathBuf},
rc::Rc,
};
Expand Down Expand Up @@ -79,22 +77,10 @@ impl InitContainerBuilder {
// if socket file path is given in commandline options,
// get file descriptors of console socket
let csocketfd = if let Some(console_socket) = &self.base.console_socket {
Some(tty::setup_console_socket(
&container_dir,
console_socket,
"console-socket",
)?)
tty::setup_console_socket(&container_dir, console_socket, "console-socket")?
} else {
None
};
let csocketfd = csocketfd.map(|sockfd| match sockfd {
Some(sockfd) => {
let fd = sockfd.as_raw_fd();
forget(sockfd);
fd
}
None => -1,
});

let user_ns_config = UserNamespaceConfig::new(&spec)?;

Expand Down
8 changes: 1 addition & 7 deletions crates/libcontainer/src/container/tenant_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ use oci_spec::runtime::{
};
use procfs::process::Namespace;

use std::mem::forget;
use std::os::fd::{AsRawFd, OwnedFd};
use std::os::fd::OwnedFd;
use std::rc::Rc;
use std::{
collections::HashMap,
Expand Down Expand Up @@ -118,11 +117,6 @@ impl TenantContainerBuilder {
// if socket file path is given in commandline options,
// get file descriptors of console socket
let csocketfd = self.setup_tty_socket(&container_dir)?;
let csocketfd = csocketfd.map(|sockfd| {
let fd = sockfd.as_raw_fd();
forget(sockfd);
fd
});

let use_systemd = self.should_use_systemd(&container);
let user_ns_config = UserNamespaceConfig::new(&spec)?;
Expand Down
1 change: 1 addition & 0 deletions crates/libcontainer/src/namespaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ impl Namespaces {
tracing::debug!("unshare or setns: {:?}", namespace);
match namespace.path() {
Some(path) => {
// Note that the fd passed to `set_ns()` will be closed, and should not close it again.
let fd = fcntl::open(path, fcntl::OFlag::empty(), stat::Mode::empty()).map_err(
|err| {
tracing::error!(?err, ?namespace, "failed to open namespace file");
Expand Down
23 changes: 22 additions & 1 deletion crates/libcontainer/src/process/args.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use libcgroups::common::CgroupConfig;
use oci_spec::runtime::Spec;
use std::os::fd::OwnedFd;
use std::os::unix::prelude::RawFd;
use std::path::PathBuf;
use std::rc::Rc;
Expand All @@ -26,7 +27,7 @@ pub struct ContainerArgs {
/// Root filesystem of the container
pub rootfs: PathBuf,
/// Socket to communicate the file descriptor of the ptty
pub console_socket: Option<RawFd>,
pub console_socket: Option<Rc<OwnedFd>>,
/// The Unix Domain Socket to communicate container start
pub notify_listener: NotifyListener,
/// File descriptors preserved/passed to the container init process.
Expand All @@ -42,3 +43,23 @@ pub struct ContainerArgs {
/// Manage the functions that actually run on the container
pub executor: Box<dyn Executor>,
}

impl ContainerArgs {
/// Enforce a decrement of the inner reference counter by 1.
///
/// # Safety
/// The reason for `unsafe` is the caller must ensure that it's only called
/// when absolutely necessary. Please refer to `Sender::decrement_count()`
/// for more details.
pub unsafe fn decrement_count(&self) {
let rc = Rc::into_raw(Rc::clone(&self.spec));
Rc::decrement_strong_count(rc);
Rc::from_raw(rc);

if let Some(socket) = &self.console_socket {
let socket = Rc::into_raw(Rc::clone(socket));
Rc::decrement_strong_count(socket);
Rc::from_raw(socket);
}
}
}
Loading

0 comments on commit 261e103

Please sign in to comment.