]> git.kernelconcepts.de Git - karo-tx-linux.git/blobdiff - net/ipv4/udp.c
udp: ipv4: Add udp early demux
[karo-tx-linux.git] / net / ipv4 / udp.c
index 5950e12bd3abb33f6fefd06f3c196bdb1e02810d..262ea3929da65a7963c2b3bbbb53ef6dc5eb4e6d 100644 (file)
 #include <linux/seq_file.h>
 #include <net/net_namespace.h>
 #include <net/icmp.h>
+#include <net/inet_hashtables.h>
 #include <net/route.h>
 #include <net/checksum.h>
 #include <net/xfrm.h>
@@ -565,6 +566,26 @@ struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
 }
 EXPORT_SYMBOL_GPL(udp4_lib_lookup);
 
+static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk,
+                                      __be16 loc_port, __be32 loc_addr,
+                                      __be16 rmt_port, __be32 rmt_addr,
+                                      int dif, unsigned short hnum)
+{
+       struct inet_sock *inet = inet_sk(sk);
+
+       if (!net_eq(sock_net(sk), net) ||
+           udp_sk(sk)->udp_port_hash != hnum ||
+           (inet->inet_daddr && inet->inet_daddr != rmt_addr) ||
+           (inet->inet_dport != rmt_port && inet->inet_dport) ||
+           (inet->inet_rcv_saddr && inet->inet_rcv_saddr != loc_addr) ||
+           ipv6_only_sock(sk) ||
+           (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif))
+               return false;
+       if (!ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif))
+               return false;
+       return true;
+}
+
 static inline struct sock *udp_v4_mcast_next(struct net *net, struct sock *sk,
                                             __be16 loc_port, __be32 loc_addr,
                                             __be16 rmt_port, __be32 rmt_addr,
@@ -575,20 +596,11 @@ static inline struct sock *udp_v4_mcast_next(struct net *net, struct sock *sk,
        unsigned short hnum = ntohs(loc_port);
 
        sk_nulls_for_each_from(s, node) {
-               struct inet_sock *inet = inet_sk(s);
-
-               if (!net_eq(sock_net(s), net) ||
-                   udp_sk(s)->udp_port_hash != hnum ||
-                   (inet->inet_daddr && inet->inet_daddr != rmt_addr) ||
-                   (inet->inet_dport != rmt_port && inet->inet_dport) ||
-                   (inet->inet_rcv_saddr &&
-                    inet->inet_rcv_saddr != loc_addr) ||
-                   ipv6_only_sock(s) ||
-                   (s->sk_bound_dev_if && s->sk_bound_dev_if != dif))
-                       continue;
-               if (!ip_mc_sf_allow(s, loc_addr, rmt_addr, dif))
-                       continue;
-               goto found;
+               if (__udp_is_mcast_sock(net, s,
+                                       loc_port, loc_addr,
+                                       rmt_port, rmt_addr,
+                                       dif, hnum))
+                       goto found;
        }
        s = NULL;
 found:
@@ -1581,6 +1593,14 @@ static void flush_stack(struct sock **stack, unsigned int count,
                kfree_skb(skb1);
 }
 
+static void udp_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb)
+{
+       struct dst_entry *dst = skb_dst(skb);
+
+       dst_hold(dst);
+       sk->sk_rx_dst = dst;
+}
+
 /*
  *     Multicasts and broadcasts go to each listener.
  *
@@ -1709,11 +1729,28 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
        if (udp4_csum_init(skb, uh, proto))
                goto csum_error;
 
-       if (rt->rt_flags & (RTCF_BROADCAST|RTCF_MULTICAST))
-               return __udp4_lib_mcast_deliver(net, skb, uh,
-                               saddr, daddr, udptable);
+       if (skb->sk) {
+               int ret;
+               sk = skb->sk;
 
-       sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable);
+               if (unlikely(sk->sk_rx_dst == NULL))
+                       udp_sk_rx_dst_set(sk, skb);
+
+               ret = udp_queue_rcv_skb(sk, skb);
+
+               /* a return value > 0 means to resubmit the input, but
+                * it wants the return to be -protocol, or 0
+                */
+               if (ret > 0)
+                       return -ret;
+               return 0;
+       } else {
+               if (rt->rt_flags & (RTCF_BROADCAST|RTCF_MULTICAST))
+                       return __udp4_lib_mcast_deliver(net, skb, uh,
+                                       saddr, daddr, udptable);
+
+               sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable);
+       }
 
        if (sk != NULL) {
                int ret;
@@ -1771,6 +1808,135 @@ drop:
        return 0;
 }
 
+/* We can only early demux multicast if there is a single matching socket.
+ * If more than one socket found returns NULL
+ */
+static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net,
+                                                 __be16 loc_port, __be32 loc_addr,
+                                                 __be16 rmt_port, __be32 rmt_addr,
+                                                 int dif)
+{
+       struct sock *sk, *result;
+       struct hlist_nulls_node *node;
+       unsigned short hnum = ntohs(loc_port);
+       unsigned int count, slot = udp_hashfn(net, hnum, udp_table.mask);
+       struct udp_hslot *hslot = &udp_table.hash[slot];
+
+       rcu_read_lock();
+begin:
+       count = 0;
+       result = NULL;
+       sk_nulls_for_each_rcu(sk, node, &hslot->head) {
+               if (__udp_is_mcast_sock(net, sk,
+                                       loc_port, loc_addr,
+                                       rmt_port, rmt_addr,
+                                       dif, hnum)) {
+                       result = sk;
+                       ++count;
+               }
+       }
+       /*
+        * if the nulls value we got at the end of this lookup is
+        * not the expected one, we must restart lookup.
+        * We probably met an item that was moved to another chain.
+        */
+       if (get_nulls_value(node) != slot)
+               goto begin;
+
+       if (result) {
+               if (count != 1 ||
+                   unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))
+                       result = NULL;
+               else if (unlikely(!__udp_is_mcast_sock(net, sk,
+                                                      loc_port, loc_addr,
+                                                      rmt_port, rmt_addr,
+                                                      dif, hnum))) {
+                       sock_put(result);
+                       result = NULL;
+               }
+       }
+       rcu_read_unlock();
+       return result;
+}
+
+/* For unicast we should only early demux connected sockets or we can
+ * break forwarding setups.  The chains here can be long so only check
+ * if the first socket is an exact match and if not move on.
+ */
+static struct sock *__udp4_lib_demux_lookup(struct net *net,
+                                           __be16 loc_port, __be32 loc_addr,
+                                           __be16 rmt_port, __be32 rmt_addr,
+                                           int dif)
+{
+       struct sock *sk, *result;
+       struct hlist_nulls_node *node;
+       unsigned short hnum = ntohs(loc_port);
+       unsigned int hash2 = udp4_portaddr_hash(net, loc_addr, hnum);
+       unsigned int slot2 = hash2 & udp_table.mask;
+       struct udp_hslot *hslot2 = &udp_table.hash2[slot2];
+       INET_ADDR_COOKIE(acookie, rmt_addr, loc_addr)
+       const __portpair ports = INET_COMBINED_PORTS(rmt_port, hnum);
+
+       rcu_read_lock();
+       result = NULL;
+       udp_portaddr_for_each_entry_rcu(sk, node, &hslot2->head) {
+               if (INET_MATCH(sk, net, acookie,
+                              rmt_addr, loc_addr, ports, dif))
+                       result = sk;
+               /* Only check first socket in chain */
+               break;
+       }
+
+       if (result) {
+               if (unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))
+                       result = NULL;
+               else if (unlikely(!INET_MATCH(sk, net, acookie,
+                                             rmt_addr, loc_addr,
+                                             ports, dif))) {
+                       sock_put(result);
+                       result = NULL;
+               }
+       }
+       rcu_read_unlock();
+       return result;
+}
+
+void udp_v4_early_demux(struct sk_buff *skb)
+{
+       const struct iphdr *iph = ip_hdr(skb);
+       const struct udphdr *uh = udp_hdr(skb);
+       struct sock *sk;
+       struct dst_entry *dst;
+       struct net *net = dev_net(skb->dev);
+       int dif = skb->dev->ifindex;
+
+       /* validate the packet */
+       if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr)))
+               return;
+
+       if (skb->pkt_type == PACKET_BROADCAST ||
+           skb->pkt_type == PACKET_MULTICAST)
+               sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr,
+                                                  uh->source, iph->saddr, dif);
+       else if (skb->pkt_type == PACKET_HOST)
+               sk = __udp4_lib_demux_lookup(net, uh->dest, iph->daddr,
+                                            uh->source, iph->saddr, dif);
+       else
+               return;
+
+       if (!sk)
+               return;
+
+       skb->sk = sk;
+       skb->destructor = sock_edemux;
+       dst = sk->sk_rx_dst;
+
+       if (dst)
+               dst = dst_check(dst, 0);
+       if (dst)
+               skb_dst_set_noref(skb, dst);
+}
+
 int udp_rcv(struct sk_buff *skb)
 {
        return __udp4_lib_rcv(skb, &udp_table, IPPROTO_UDP);