/*
 * Copyright (c) 2019-2024 Two Sigma Open Source, LLC
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

#define _GNU_SOURCE
#include <sys/mount.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <err.h>
#include <fcntl.h>
#include <sched.h>
#include <stdarg.h>
#include <stdio.h>
#include <unistd.h>

/*
 * verytmp - create a temporary directory as an ephemeral tmpfs
 *
 * verytmp() mounts a new tmpfs disconnected from the filesystem and
 * returns a file descriptor to the root of the tmpfs. Once all file
 * descriptors that refer to this tmpfs are closed, the kernel will
 * automatically clean up the tmpfs and free any memory it uses.
 *
 * The intended use is with the *at family of functions, e.g.,
 *
 *  int tmpdir = verytmp();
 *  int fd = openat(tmpdir, "somefile", O_RDWR | O_CREAT, 0644);
 *  FILE *f = fdopen(fd, "w+");
 *
 * However, if you do need a path to pass to functions expecting an
 * absolute path, you can use /proc/self/fd, e.g.,
 *
 *  char *path;
 *  asprintf(&path, "/proc/%d/fd/%d/", getpid(), fd);
 *
 * verytmp() relies on unprivileged user namespaces: it works by
 * creating a child process in a user + mount namespace, mounting a
 * tmpfs in that child process, and passing a file descriptor back to
 * the parent.
 */
int verytmp(void);

/* Sample program demonstrating use of verytmp. Run `free -m` during and
 * after the program to see the memory usage. Note that the cleanup even
 * happens if the program is killed abruptly with e.g. `kill -9`,
 * because there is no userspace code that needs to run to perform the
 * cleanup. */
__attribute__((weak))
int main(void) {
	int tmpdir = verytmp();
	if (tmpdir < 0)
		errx(1, "verytmp failed");
	int fd = openat(tmpdir, "myfile", O_WRONLY | O_CREAT, 0644);
	if (fd < 0)
		errx(1, "open tmpdir/myfile failed");
	if (fallocate(fd, 0, 0, 1L << 31 /* 2 GB */) != 0)
		errx(1, "fallocate failed");
	sleep(60);
	// no cleanup!
}

static int write_file(const char *filename, const char *format, ...);

int
verytmp(void)
{
	int sock[2], fd;
	pid_t pid;
	uid_t uid = geteuid();
	gid_t gid = getegid();
	char cmsg_buf[CMSG_SPACE(sizeof(int))] __attribute__((aligned(8)));
	char iov_buf[] = "x";
	struct iovec iov = {
		.iov_base = iov_buf,
		.iov_len = 1,
	};
	struct msghdr msg = {
		.msg_iov = &iov,
		.msg_iovlen = 1,
		.msg_control = cmsg_buf,
		.msg_controllen = sizeof(cmsg_buf),
	};
	struct cmsghdr *cmsg;

	if (socketpair(AF_UNIX, SOCK_STREAM, 0, sock) != 0) {
		warn("verytmp: socketpair");
		return -1;
	}
	pid = fork();
	if (pid < 0) {
		warn("verytmp: fork");
		return -1;
	} else if (pid == 0) {
		close(sock[0]);
		if (unshare(CLONE_NEWUSER | CLONE_NEWNS) != 0) {
			warn("verytmp: unshare");
			_exit(1);
		}
		if (write_file("/proc/self/setgroups", "deny") != 0) {
			_exit(1);
		}
		if (write_file("/proc/self/gid_map", "%1$d %1$d 1", gid) != 0) {
			_exit(1);
		}
		if (write_file("/proc/self/uid_map", "%1$d %1$d 1", uid) != 0) {
			_exit(1);
		}

#ifdef HAVE_FSOPEN
		// Cleaner mplementation using fsopen(), which requires
		// kernel 5.1 (2019) and glibc 2.36 (2022).
		int config_fd = fsopen("tmpfs", 0);
		if (config_fd < 0) {
			warn("verytmp: fsopen");
			_exit(1);
		}
		if (fsconfig(config_fd, FSCONFIG_SET_STRING, "source", "verytmp", 0) != 0) {
			warn("verytmp: fsconfig source");
			_exit(1);
		}
		if (fsconfig(config_fd, FSCONFIG_CMD_CREATE, NULL, NULL, 0)) {
			warn("verytmp: fsconfig create");
			_exit(1);
		}
		fd = fsmount(config_fd, 0, 0);
		if (fd < 0) {
			warn("verytmp: fsmount");
			_exit(1);
		}
#else
		// Implementation that abuses /proc/self as a way to get
		// a directory on which we can mount things. Since the
		// mount only exists in this child process, and this
		// process exits very soon, there's no real impact.
		if (mount("verytmp", "/proc/self/task", "tmpfs", 0, NULL) != 0) {
			warn("verytmp: mount");
			_exit(1);
		}
		fd = open("/proc/self/task", O_RDONLY | O_DIRECTORY);
		if (fd < 0) {
			warn("verytmp: open O_DIRECTORY");
			_exit(1);
		}
#endif

		cmsg = CMSG_FIRSTHDR(&msg);
		cmsg->cmsg_level = SOL_SOCKET;
		cmsg->cmsg_type = SCM_RIGHTS;
		cmsg->cmsg_len = CMSG_LEN(sizeof(int));
		*(int *)CMSG_DATA(cmsg) = fd;

		if (sendmsg(sock[1], &msg, 0) < 0) {
			warn("sendmsg");
			_exit(1);
		}
		_exit(0);
	} else {
		close(sock[1]);
		if (recvmsg(sock[0], &msg, MSG_CMSG_CLOEXEC) < 0) {
			warn("recvmsg");
			waitpid(pid, NULL, 0);
			return -1;
		}
		waitpid(pid, NULL, 0);
		cmsg = CMSG_FIRSTHDR(&msg);
		if (cmsg != NULL &&
		    cmsg->cmsg_level == SOL_SOCKET &&
		    cmsg->cmsg_type == SCM_RIGHTS &&
		    cmsg->cmsg_len == CMSG_LEN(sizeof(int))) {
			return *(int *)CMSG_DATA(cmsg);
		} else {
			warnx("unexpected cmsg");
			return -1;
		}
	}
}

static int
vwrite_file(const char *filename, const char *format, va_list ap)
{
	FILE *stream = fopen(filename, "we");
	if (!stream) {
		warn("verytmp: fopen %s", filename);
		return -1;
	}

	if (vfprintf(stream, format, ap) < 0) {
		warn("verytmp: writing to %s", filename);
		fclose(stream);
		return -1;
	}

	if (fclose(stream) != 0) {
		warn("verytmp: close %s", filename);
		return -1;
	}
	return 0;
}

static int
write_file(const char *filename, const char *format, ...)
{
	va_list ap;
	va_start(ap, format);
	int ret = vwrite_file(filename, format, ap);
	va_end(ap);
	return ret;
}
