Skip to content

Commit ee40fb2

Browse files
Guillaume Naultdavem330
Guillaume Nault
authored andcommitted
l2tp: protect sock pointer of struct pppol2tp_session with RCU
pppol2tp_session_create() registers sessions that can't have their corresponding socket initialised. This socket has to be created by userspace, then connected to the session by pppol2tp_connect(). Therefore, we need to protect the pppol2tp socket pointer of L2TP sessions, so that it can safely be updated when userspace is connecting or closing the socket. This will eventually allow pppol2tp_connect() to avoid generating transient states while initialising its parts of the session. To this end, this patch protects the pppol2tp socket pointer using RCU. The pppol2tp socket pointer is still set in pppol2tp_connect(), but only once we know the function isn't going to fail. It's eventually reset by pppol2tp_release(), which now has to wait for a grace period to elapse before it can drop the last reference on the socket. This ensures that pppol2tp_session_get_sock() can safely grab a reference on the socket, even after ps->sk is reset to NULL but before this operation actually gets visible from pppol2tp_session_get_sock(). The rest is standard RCU conversion: pppol2tp_recv(), which already runs in atomic context, is simply enclosed by rcu_read_lock() and rcu_read_unlock(), while other functions are converted to use pppol2tp_session_get_sock() followed by sock_put(). pppol2tp_session_setsockopt() is a special case. It used to retrieve the pppol2tp socket from the L2TP session, which itself was retrieved from the pppol2tp socket. Therefore we can just avoid dereferencing ps->sk and directly use the original socket pointer instead. With all users of ps->sk now handling NULL and concurrent updates, the L2TP ->ref() and ->deref() callbacks aren't needed anymore. Therefore, rather than converting pppol2tp_session_sock_hold() and pppol2tp_session_sock_put(), we can just drop them. Signed-off-by: Guillaume Nault <[email protected]> Signed-off-by: David S. Miller <[email protected]>
1 parent ee28de6 commit ee40fb2

File tree

1 file changed

+101
-53
lines changed

1 file changed

+101
-53
lines changed

net/l2tp/l2tp_ppp.c

Lines changed: 101 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,11 @@
122122
struct pppol2tp_session {
123123
int owner; /* pid that opened the socket */
124124

125-
struct sock *sock; /* Pointer to the session
125+
struct mutex sk_lock; /* Protects .sk */
126+
struct sock __rcu *sk; /* Pointer to the session
126127
* PPPoX socket */
128+
struct sock *__sk; /* Copy of .sk, for cleanup */
129+
struct rcu_head rcu; /* For asynchronous release */
127130
struct sock *tunnel_sock; /* Pointer to the tunnel UDP
128131
* socket */
129132
int flags; /* accessed by PPPIOCGFLAGS.
@@ -138,6 +141,24 @@ static const struct ppp_channel_ops pppol2tp_chan_ops = {
138141

139142
static const struct proto_ops pppol2tp_ops;
140143

144+
/* Retrieves the pppol2tp socket associated to a session.
145+
* A reference is held on the returned socket, so this function must be paired
146+
* with sock_put().
147+
*/
148+
static struct sock *pppol2tp_session_get_sock(struct l2tp_session *session)
149+
{
150+
struct pppol2tp_session *ps = l2tp_session_priv(session);
151+
struct sock *sk;
152+
153+
rcu_read_lock();
154+
sk = rcu_dereference(ps->sk);
155+
if (sk)
156+
sock_hold(sk);
157+
rcu_read_unlock();
158+
159+
return sk;
160+
}
161+
141162
/* Helpers to obtain tunnel/session contexts from sockets.
142163
*/
143164
static inline struct l2tp_session *pppol2tp_sock_to_session(struct sock *sk)
@@ -224,7 +245,8 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int
224245
/* If the socket is bound, send it in to PPP's input queue. Otherwise
225246
* queue it on the session socket.
226247
*/
227-
sk = ps->sock;
248+
rcu_read_lock();
249+
sk = rcu_dereference(ps->sk);
228250
if (sk == NULL)
229251
goto no_sock;
230252

@@ -247,30 +269,16 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int
247269
kfree_skb(skb);
248270
}
249271
}
272+
rcu_read_unlock();
250273

251274
return;
252275

253276
no_sock:
277+
rcu_read_unlock();
254278
l2tp_info(session, L2TP_MSG_DATA, "%s: no socket\n", session->name);
255279
kfree_skb(skb);
256280
}
257281

258-
static void pppol2tp_session_sock_hold(struct l2tp_session *session)
259-
{
260-
struct pppol2tp_session *ps = l2tp_session_priv(session);
261-
262-
if (ps->sock)
263-
sock_hold(ps->sock);
264-
}
265-
266-
static void pppol2tp_session_sock_put(struct l2tp_session *session)
267-
{
268-
struct pppol2tp_session *ps = l2tp_session_priv(session);
269-
270-
if (ps->sock)
271-
sock_put(ps->sock);
272-
}
273-
274282
/************************************************************************
275283
* Transmit handling
276284
***********************************************************************/
@@ -431,14 +439,16 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
431439
*/
432440
static void pppol2tp_session_close(struct l2tp_session *session)
433441
{
434-
struct pppol2tp_session *ps = l2tp_session_priv(session);
435-
struct sock *sk = ps->sock;
436-
struct socket *sock = sk->sk_socket;
442+
struct sock *sk;
437443

438444
BUG_ON(session->magic != L2TP_SESSION_MAGIC);
439445

440-
if (sock)
441-
inet_shutdown(sock, SEND_SHUTDOWN);
446+
sk = pppol2tp_session_get_sock(session);
447+
if (sk) {
448+
if (sk->sk_socket)
449+
inet_shutdown(sk->sk_socket, SEND_SHUTDOWN);
450+
sock_put(sk);
451+
}
442452

443453
/* Don't let the session go away before our socket does */
444454
l2tp_session_inc_refcount(session);
@@ -461,6 +471,14 @@ static void pppol2tp_session_destruct(struct sock *sk)
461471
}
462472
}
463473

474+
static void pppol2tp_put_sk(struct rcu_head *head)
475+
{
476+
struct pppol2tp_session *ps;
477+
478+
ps = container_of(head, typeof(*ps), rcu);
479+
sock_put(ps->__sk);
480+
}
481+
464482
/* Called when the PPPoX socket (session) is closed.
465483
*/
466484
static int pppol2tp_release(struct socket *sock)
@@ -486,11 +504,24 @@ static int pppol2tp_release(struct socket *sock)
486504

487505
session = pppol2tp_sock_to_session(sk);
488506

489-
/* Purge any queued data */
490507
if (session != NULL) {
508+
struct pppol2tp_session *ps;
509+
491510
__l2tp_session_unhash(session);
492511
l2tp_session_queue_purge(session);
493-
sock_put(sk);
512+
513+
ps = l2tp_session_priv(session);
514+
mutex_lock(&ps->sk_lock);
515+
ps->__sk = rcu_dereference_protected(ps->sk,
516+
lockdep_is_held(&ps->sk_lock));
517+
RCU_INIT_POINTER(ps->sk, NULL);
518+
mutex_unlock(&ps->sk_lock);
519+
call_rcu(&ps->rcu, pppol2tp_put_sk);
520+
521+
/* Rely on the sock_put() call at the end of the function for
522+
* dropping the reference held by pppol2tp_sock_to_session().
523+
* The last reference will be dropped by pppol2tp_put_sk().
524+
*/
494525
}
495526
release_sock(sk);
496527

@@ -557,12 +588,14 @@ static int pppol2tp_create(struct net *net, struct socket *sock, int kern)
557588
static void pppol2tp_show(struct seq_file *m, void *arg)
558589
{
559590
struct l2tp_session *session = arg;
560-
struct pppol2tp_session *ps = l2tp_session_priv(session);
591+
struct sock *sk;
592+
593+
sk = pppol2tp_session_get_sock(session);
594+
if (sk) {
595+
struct pppox_sock *po = pppox_sk(sk);
561596

562-
if (ps) {
563-
struct pppox_sock *po = pppox_sk(ps->sock);
564-
if (po)
565-
seq_printf(m, " interface %s\n", ppp_dev_name(&po->chan));
597+
seq_printf(m, " interface %s\n", ppp_dev_name(&po->chan));
598+
sock_put(sk);
566599
}
567600
}
568601
#endif
@@ -693,13 +726,17 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
693726
/* Using a pre-existing session is fine as long as it hasn't
694727
* been connected yet.
695728
*/
696-
if (ps->sock) {
729+
mutex_lock(&ps->sk_lock);
730+
if (rcu_dereference_protected(ps->sk,
731+
lockdep_is_held(&ps->sk_lock))) {
732+
mutex_unlock(&ps->sk_lock);
697733
error = -EEXIST;
698734
goto end;
699735
}
700736

701737
/* consistency checks */
702738
if (ps->tunnel_sock != tunnel->sock) {
739+
mutex_unlock(&ps->sk_lock);
703740
error = -EEXIST;
704741
goto end;
705742
}
@@ -716,19 +753,21 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
716753
goto end;
717754
}
718755

756+
ps = l2tp_session_priv(session);
757+
mutex_init(&ps->sk_lock);
719758
l2tp_session_inc_refcount(session);
759+
760+
mutex_lock(&ps->sk_lock);
720761
error = l2tp_session_register(session, tunnel);
721762
if (error < 0) {
763+
mutex_unlock(&ps->sk_lock);
722764
kfree(session);
723765
goto end;
724766
}
725767
drop_refcnt = true;
726768
}
727769

728-
/* Associate session with its PPPoL2TP socket */
729-
ps = l2tp_session_priv(session);
730770
ps->owner = current->pid;
731-
ps->sock = sk;
732771
ps->tunnel_sock = tunnel->sock;
733772

734773
session->recv_skb = pppol2tp_recv;
@@ -737,12 +776,6 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
737776
session->show = pppol2tp_show;
738777
#endif
739778

740-
/* We need to know each time a skb is dropped from the reorder
741-
* queue.
742-
*/
743-
session->ref = pppol2tp_session_sock_hold;
744-
session->deref = pppol2tp_session_sock_put;
745-
746779
/* If PMTU discovery was enabled, use the MTU that was discovered */
747780
dst = sk_dst_get(tunnel->sock);
748781
if (dst != NULL) {
@@ -776,12 +809,17 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
776809
po->chan.mtu = session->mtu;
777810

778811
error = ppp_register_net_channel(sock_net(sk), &po->chan);
779-
if (error)
812+
if (error) {
813+
mutex_unlock(&ps->sk_lock);
780814
goto end;
815+
}
781816

782817
out_no_ppp:
783818
/* This is how we get the session context from the socket. */
784819
sk->sk_user_data = session;
820+
rcu_assign_pointer(ps->sk, sk);
821+
mutex_unlock(&ps->sk_lock);
822+
785823
sk->sk_state = PPPOX_CONNECTED;
786824
l2tp_info(session, L2TP_MSG_CONTROL, "%s: created\n",
787825
session->name);
@@ -827,6 +865,7 @@ static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel,
827865
}
828866

829867
ps = l2tp_session_priv(session);
868+
mutex_init(&ps->sk_lock);
830869
ps->tunnel_sock = tunnel->sock;
831870

832871
error = l2tp_session_register(session, tunnel);
@@ -998,12 +1037,10 @@ static int pppol2tp_session_ioctl(struct l2tp_session *session,
9981037
"%s: pppol2tp_session_ioctl(cmd=%#x, arg=%#lx)\n",
9991038
session->name, cmd, arg);
10001039

1001-
sk = ps->sock;
1040+
sk = pppol2tp_session_get_sock(session);
10021041
if (!sk)
10031042
return -EBADR;
10041043

1005-
sock_hold(sk);
1006-
10071044
switch (cmd) {
10081045
case SIOCGIFMTU:
10091046
err = -ENXIO;
@@ -1279,7 +1316,6 @@ static int pppol2tp_session_setsockopt(struct sock *sk,
12791316
int optname, int val)
12801317
{
12811318
int err = 0;
1282-
struct pppol2tp_session *ps = l2tp_session_priv(session);
12831319

12841320
switch (optname) {
12851321
case PPPOL2TP_SO_RECVSEQ:
@@ -1300,8 +1336,8 @@ static int pppol2tp_session_setsockopt(struct sock *sk,
13001336
}
13011337
session->send_seq = !!val;
13021338
{
1303-
struct sock *ssk = ps->sock;
1304-
struct pppox_sock *po = pppox_sk(ssk);
1339+
struct pppox_sock *po = pppox_sk(sk);
1340+
13051341
po->chan.hdrlen = val ? PPPOL2TP_L2TP_HDR_SIZE_SEQ :
13061342
PPPOL2TP_L2TP_HDR_SIZE_NOSEQ;
13071343
}
@@ -1640,8 +1676,9 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
16401676
{
16411677
struct l2tp_session *session = v;
16421678
struct l2tp_tunnel *tunnel = session->tunnel;
1643-
struct pppol2tp_session *ps = l2tp_session_priv(session);
1644-
struct pppox_sock *po = pppox_sk(ps->sock);
1679+
unsigned char state;
1680+
char user_data_ok;
1681+
struct sock *sk;
16451682
u32 ip = 0;
16461683
u16 port = 0;
16471684

@@ -1651,16 +1688,23 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
16511688
port = ntohs(inet->inet_sport);
16521689
}
16531690

1691+
sk = pppol2tp_session_get_sock(session);
1692+
if (sk) {
1693+
state = sk->sk_state;
1694+
user_data_ok = (session == sk->sk_user_data) ? 'Y' : 'N';
1695+
} else {
1696+
state = 0;
1697+
user_data_ok = 'N';
1698+
}
1699+
16541700
seq_printf(m, " SESSION '%s' %08X/%d %04X/%04X -> "
16551701
"%04X/%04X %d %c\n",
16561702
session->name, ip, port,
16571703
tunnel->tunnel_id,
16581704
session->session_id,
16591705
tunnel->peer_tunnel_id,
16601706
session->peer_session_id,
1661-
ps->sock->sk_state,
1662-
(session == ps->sock->sk_user_data) ?
1663-
'Y' : 'N');
1707+
state, user_data_ok);
16641708
seq_printf(m, " %d/%d/%c/%c/%s %08x %u\n",
16651709
session->mtu, session->mru,
16661710
session->recv_seq ? 'R' : '-',
@@ -1677,8 +1721,12 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
16771721
atomic_long_read(&session->stats.rx_bytes),
16781722
atomic_long_read(&session->stats.rx_errors));
16791723

1680-
if (po)
1724+
if (sk) {
1725+
struct pppox_sock *po = pppox_sk(sk);
1726+
16811727
seq_printf(m, " interface %s\n", ppp_dev_name(&po->chan));
1728+
sock_put(sk);
1729+
}
16821730
}
16831731

16841732
static int pppol2tp_seq_show(struct seq_file *m, void *v)

0 commit comments

Comments
 (0)