diff --git a/drivers/net/tun.c b/drivers/net/tun.c index 9254bca2813d..9b24861464bc 100644 --- a/drivers/net/tun.c +++ b/drivers/net/tun.c @@ -1661,6 +1661,7 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun, int len, int *skb_xdp) { struct page_frag *alloc_frag = ¤t->task_frag; + struct bpf_net_context __bpf_net_ctx, *bpf_net_ctx; struct bpf_prog *xdp_prog; int buflen = SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); char *buf; @@ -1700,6 +1701,7 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun, local_bh_disable(); rcu_read_lock(); + bpf_net_ctx = bpf_net_ctx_set(&__bpf_net_ctx); xdp_prog = rcu_dereference(tun->xdp_prog); if (xdp_prog) { struct xdp_buff xdp; @@ -1728,12 +1730,14 @@ static struct sk_buff *tun_build_skb(struct tun_struct *tun, pad = xdp.data - xdp.data_hard_start; len = xdp.data_end - xdp.data; } + bpf_net_ctx_clear(bpf_net_ctx); rcu_read_unlock(); local_bh_enable(); return __tun_build_skb(tfile, alloc_frag, buf, buflen, len, pad); out: + bpf_net_ctx_clear(bpf_net_ctx); rcu_read_unlock(); local_bh_enable(); return NULL; @@ -2566,6 +2570,7 @@ static int tun_sendmsg(struct socket *sock, struct msghdr *m, size_t total_len) if (m->msg_controllen == sizeof(struct tun_msg_ctl) && ctl && ctl->type == TUN_MSG_PTR) { + struct bpf_net_context __bpf_net_ctx, *bpf_net_ctx; struct tun_page tpage; int n = ctl->num; int flush = 0, queued = 0; @@ -2574,6 +2579,7 @@ static int tun_sendmsg(struct socket *sock, struct msghdr *m, size_t total_len) local_bh_disable(); rcu_read_lock(); + bpf_net_ctx = bpf_net_ctx_set(&__bpf_net_ctx); for (i = 0; i < n; i++) { xdp = &((struct xdp_buff *)ctl->ptr)[i]; @@ -2588,6 +2594,7 @@ static int tun_sendmsg(struct socket *sock, struct msghdr *m, size_t total_len) if (tfile->napi_enabled && queued > 0) napi_schedule(&tfile->napi); + bpf_net_ctx_clear(bpf_net_ctx); rcu_read_unlock(); local_bh_enable(); diff --git a/net/core/dev.c b/net/core/dev.c index 385c4091aa77..73e5af6943c3 100644 --- a/net/core/dev.c +++ b/net/core/dev.c @@ -5126,11 +5126,14 @@ static DEFINE_STATIC_KEY_FALSE(generic_xdp_needed_key); int do_xdp_generic(struct bpf_prog *xdp_prog, struct sk_buff **pskb) { + struct bpf_net_context __bpf_net_ctx, *bpf_net_ctx; + if (xdp_prog) { struct xdp_buff xdp; u32 act; int err; + bpf_net_ctx = bpf_net_ctx_set(&__bpf_net_ctx); act = netif_receive_generic_xdp(pskb, &xdp, xdp_prog); if (act != XDP_PASS) { switch (act) { @@ -5144,11 +5147,13 @@ int do_xdp_generic(struct bpf_prog *xdp_prog, struct sk_buff **pskb) generic_xdp_tx(*pskb, xdp_prog); break; } + bpf_net_ctx_clear(bpf_net_ctx); return XDP_DROP; } } return XDP_PASS; out_redir: + bpf_net_ctx_clear(bpf_net_ctx); kfree_skb_reason(*pskb, SKB_DROP_REASON_XDP); return XDP_DROP; }