/*
 * NPF socket User/group id tests.
 *
 * Public Domain.
 */

#ifdef _KERNEL
#include <sys/types.h>
#include <sys/cdefs.h>
__KERNEL_RCSID(0, "$NetBSD: npf_rid_test.c,v 1.3 2025/07/01 20:19:30 joe Exp $");
#endif

#include "npf_impl.h"
#include "npf_test.h"

#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/kauth.h>
#include <sys/socketvar.h>
#include <sys/lwp.h>
#include <sys/cpu.h>

#define	RESULT_PASS	0
#define	RESULT_BLOCK	ENETUNREACH

/* this port number suitable for testing */
#define REMOTE_PORT	65500
#define LOCAL_PORT	65000
#define LOCAL_IP	"127.0.0.1"
#define REMOTE_IP	LOCAL_IP

static const struct test_case {
	int		af;
	const char *	src;
	uint16_t	sport;
	const char *	dst;
	uint16_t	dport;
	uint32_t	uid;
	uint32_t	gid;
	const char *	ifname;
	int		di;
	int 	ret;
	int 	stateful_ret;
} test_cases[] = {
	{
		/* pass in final from $local_ip4 user $Kojo = 1001 group $wheel = 20 */
		.af = AF_INET,
		.src = "10.1.1.4",			.sport = 9000,
		.dst = LOCAL_IP,			.dport = LOCAL_PORT,
		.ifname = IFNAME_EXT,			.di = PFIL_IN,
		.uid = 1001,				.gid = 20, /* matches so pass it */
		.ret = RESULT_PASS,			.stateful_ret = RESULT_PASS
	},
	{
		/* connect on different UID and block */
		.af = AF_INET,
		.src = "10.1.1.4",			.sport = 9000,
		.dst = LOCAL_IP,			.dport = LOCAL_PORT,
		.ifname = IFNAME_EXT,			.di = PFIL_IN,
		.uid = 1001,				.gid = 10, /* mismatch gid so block it */
		.ret = RESULT_BLOCK,			.stateful_ret = RESULT_BLOCK
	},
	{
		.af = AF_INET,
		.src = "10.1.1.4",			.sport = 9000,
		.dst = LOCAL_IP,			.dport = LOCAL_PORT,
		.ifname = IFNAME_EXT,			.di = PFIL_IN,
		.uid = 100,				.gid = 20, /* mismatch uid so block it */
		.ret = RESULT_BLOCK,			.stateful_ret = RESULT_BLOCK
	},


	/* block out final to 127.0.0.1 user > $Kojo( > 1001) group 1 >< $wheel( IRG 1 >< 20) */
	{
		.af = AF_INET,
		.src = LOCAL_IP,			.sport = LOCAL_PORT,
		.dst = REMOTE_IP,			.dport = REMOTE_PORT,
		.ifname = IFNAME_EXT,			.di = PFIL_OUT,
		.uid = 1005,				.gid = 14, /* matches so blocks it */
		.ret = RESULT_BLOCK,			.stateful_ret = RESULT_BLOCK
	},
	{
		.af = AF_INET,
		.src = LOCAL_IP,			.sport = LOCAL_PORT,
		.dst = REMOTE_IP,			.dport = REMOTE_PORT,
		.ifname = IFNAME_EXT,			.di = PFIL_OUT,
		.uid = 1005,				.gid = 30, /* mismatch gid so pass it */
		.ret = RESULT_PASS,			.stateful_ret = RESULT_PASS
	},
	{
		.af = AF_INET,
		.src = LOCAL_IP,			.sport = LOCAL_PORT,
		.dst = REMOTE_IP,			.dport = REMOTE_PORT,
		.ifname = IFNAME_EXT,			.di = PFIL_OUT,
		.uid = 100,				.gid = 15, /* mismatch uid so pass it */
		.ret = RESULT_PASS,			.stateful_ret = RESULT_PASS
	},
	{
		.af = AF_INET,
		.src = LOCAL_IP,			.sport = LOCAL_PORT,
		.dst = REMOTE_IP,			.dport = REMOTE_PORT,
		.ifname = IFNAME_EXT,			.di = PFIL_OUT,
		.uid = 1010,				.gid = 11, /* matches so blocks it */
		.ret = RESULT_BLOCK,			.stateful_ret = RESULT_BLOCK
	},
};

static int
run_raw_testcase(unsigned i, bool verbose)
{
	const struct test_case *t = &test_cases[i];
	npf_t *npf = npf_getkernctx();
	npf_cache_t *npc;
	struct mbuf *m;
	npf_rule_t *rl;
	int slock, error;

	m = mbuf_get_pkt(t->af, IPPROTO_UDP, t->src, t->dst, t->sport, t->dport);
	npc = get_cached_pkt(m, t->ifname,  NPF_RULE_LAYER_3);

	slock = npf_config_read_enter(npf);
	rl = npf_ruleset_inspect(npc, npf_config_ruleset(npf), t->di,  NPF_RULE_LAYER_3);
	if (rl) {
		npf_match_info_t mi;
		int id_match;

		id_match = npf_rule_match_rid(rl, npc, t->di);
		error = npf_rule_conclude(rl, &mi);
		if (verbose)
			printf("id match is ...%d\n", id_match);
		if (id_match != -1 && !id_match) {
			error = npf_rule_reverse(npc, &mi, error);
		}

	} else {
		error = ENOENT;
	}
	npf_config_read_exit(npf, slock);

	put_cached_pkt(npc);
	return error;
}

static int
run_handler_testcase(unsigned i)
{
	const struct test_case *t = &test_cases[i];
	ifnet_t *ifp = npf_test_getif(t->ifname);
	npf_t *npf = npf_getkernctx();
	struct mbuf *m;
	int error;

	m = mbuf_get_pkt(t->af, IPPROTO_UDP, t->src, t->dst, t->sport, t->dport);
	error = npfk_packet_handler(npf, &m, ifp, t->di);
	if (m) {
		m_freem(m);
	}
	return error;
}

/*
 * we create our specific server socket here which listens on
 * loopback address and port 65000. easier to test pcb lookup here since
 * it will be loaded into the protocol table.
 */
static struct socket *
test_socket(int dir, uid_t uid, gid_t gid)
{
	struct sockaddr_in server;
	struct lwp *cur = curlwp;
	void *p, *rp;

	memset(&server, 0, sizeof(server));

	server.sin_len = sizeof(server);
	server.sin_family = AF_INET;
	p = &server.sin_addr.s_addr;
	npf_inet_pton(AF_INET, LOCAL_IP, p); /* we bind to 127.0.0.1 */
	server.sin_port = htons(LOCAL_PORT);

	struct socket *so;
	int error = socreate(AF_INET, &so, SOCK_DGRAM, 0, cur, NULL);
	if (error) {
		printf("socket creation failed: error is %d\n", error);
		return NULL;
	}

	solock(so);

	kauth_cred_t cred = kauth_cred_alloc();
	kauth_cred_seteuid(cred, uid);
	kauth_cred_setegid(cred, gid);

	kauth_cred_t old = so->so_cred;
	so->so_cred = kauth_cred_dup(cred);
	kauth_cred_free(old);

	sounlock(so);

	if ((error = sobind(so, (struct sockaddr *)&server, cur)) != 0) {
		printf("bind failed %d\n", error);
		return NULL;
	}

	if (dir == PFIL_OUT) {
		/* connect to an additional remote address to set the 4 tuple addr-port state */
		struct sockaddr_in remote;
		memset(&remote, 0, sizeof(remote));

		remote.sin_len = sizeof(remote);
		remote.sin_family = AF_INET;
		rp = &remote.sin_addr.s_addr;
		npf_inet_pton(AF_INET, REMOTE_IP, rp); /* we connect to 127.0.0.1 */
		remote.sin_port = htons(REMOTE_PORT);

		solock(so);
		if ((error = soconnect(so, (struct sockaddr *)&remote, cur)) != 0) {
			printf("connect failed :%d\n", error);
			return NULL;
		}
		sounlock(so);
	}

	return so;
}

static bool
test_static(bool verbose)
{
	for (size_t i = 0; i < __arraycount(test_cases); i++) {
		const struct test_case *t = &test_cases[i];
		int error, serror;
		struct socket *so;

		so = test_socket(t->di, t->uid, t->gid);
		if (so == NULL) {
			printf("socket:\n");
			return false;
		}

		if (npf_test_getif(t->ifname) == NULL) {
			printf("Interface %s is not configured.\n", t->ifname);
			return false;
		}

		error = run_raw_testcase(i, verbose);
		serror = run_handler_testcase(i);

		if (verbose) {
			printf("rule test %zu:\texpected %d (stateful) and %d\n"
			    "\t\t-> returned %d and %d\n",
			    i + 1, t->stateful_ret, t->ret, serror, error);
		}
		CHECK_TRUE(error == t->ret);
		CHECK_TRUE(serror == t->stateful_ret)

		soclose(so);
	}
	return true;
}

bool
npf_guid_test(bool verbose)
{
	soinit1();

	bool ok;

	ok = test_static(verbose);
	CHECK_TRUE(ok);

	return true;
}
