]> git.kernelconcepts.de Git - karo-tx-linux.git/blobdiff - net/ipv4/udp.c
Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[karo-tx-linux.git] / net / ipv4 / udp.c
index 8aab7d78d25bc6eaa42dcc960cdbd5086f614cad..ea6e4cff9fafe99af23fd8ea666cd979d5af9104 100644 (file)
@@ -134,14 +134,21 @@ EXPORT_SYMBOL(udp_memory_allocated);
 #define MAX_UDP_PORTS 65536
 #define PORTS_PER_CHAIN (MAX_UDP_PORTS / UDP_HTABLE_SIZE_MIN)
 
+/* IPCB reference means this can not be used from early demux */
+static bool udp_lib_exact_dif_match(struct net *net, struct sk_buff *skb)
+{
+#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
+       if (!net->ipv4.sysctl_udp_l3mdev_accept &&
+           skb && ipv4_l3mdev_skb(IPCB(skb)->flags))
+               return true;
+#endif
+       return false;
+}
+
 static int udp_lib_lport_inuse(struct net *net, __u16 num,
                               const struct udp_hslot *hslot,
                               unsigned long *bitmap,
-                              struct sock *sk,
-                              int (*saddr_comp)(const struct sock *sk1,
-                                                const struct sock *sk2,
-                                                bool match_wildcard),
-                              unsigned int log)
+                              struct sock *sk, unsigned int log)
 {
        struct sock *sk2;
        kuid_t uid = sock_i_uid(sk);
@@ -153,13 +160,18 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,
                    (!sk2->sk_reuse || !sk->sk_reuse) &&
                    (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||
                     sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
-                   (!sk2->sk_reuseport || !sk->sk_reuseport ||
-                    rcu_access_pointer(sk->sk_reuseport_cb) ||
-                    !uid_eq(uid, sock_i_uid(sk2))) &&
-                   saddr_comp(sk, sk2, true)) {
-                       if (!bitmap)
-                               return 1;
-                       __set_bit(udp_sk(sk2)->udp_port_hash >> log, bitmap);
+                   inet_rcv_saddr_equal(sk, sk2, true)) {
+                       if (sk2->sk_reuseport && sk->sk_reuseport &&
+                           !rcu_access_pointer(sk->sk_reuseport_cb) &&
+                           uid_eq(uid, sock_i_uid(sk2))) {
+                               if (!bitmap)
+                                       return 0;
+                       } else {
+                               if (!bitmap)
+                                       return 1;
+                               __set_bit(udp_sk(sk2)->udp_port_hash >> log,
+                                         bitmap);
+                       }
                }
        }
        return 0;
@@ -171,10 +183,7 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,
  */
 static int udp_lib_lport_inuse2(struct net *net, __u16 num,
                                struct udp_hslot *hslot2,
-                               struct sock *sk,
-                               int (*saddr_comp)(const struct sock *sk1,
-                                                 const struct sock *sk2,
-                                                 bool match_wildcard))
+                               struct sock *sk)
 {
        struct sock *sk2;
        kuid_t uid = sock_i_uid(sk);
@@ -188,11 +197,14 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
                    (!sk2->sk_reuse || !sk->sk_reuse) &&
                    (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||
                     sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
-                   (!sk2->sk_reuseport || !sk->sk_reuseport ||
-                    rcu_access_pointer(sk->sk_reuseport_cb) ||
-                    !uid_eq(uid, sock_i_uid(sk2))) &&
-                   saddr_comp(sk, sk2, true)) {
-                       res = 1;
+                   inet_rcv_saddr_equal(sk, sk2, true)) {
+                       if (sk2->sk_reuseport && sk->sk_reuseport &&
+                           !rcu_access_pointer(sk->sk_reuseport_cb) &&
+                           uid_eq(uid, sock_i_uid(sk2))) {
+                               res = 0;
+                       } else {
+                               res = 1;
+                       }
                        break;
                }
        }
@@ -200,10 +212,7 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
        return res;
 }
 
-static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot,
-                                 int (*saddr_same)(const struct sock *sk1,
-                                                   const struct sock *sk2,
-                                                   bool match_wildcard))
+static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot)
 {
        struct net *net = sock_net(sk);
        kuid_t uid = sock_i_uid(sk);
@@ -217,7 +226,7 @@ static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot,
                    (udp_sk(sk2)->udp_port_hash == udp_sk(sk)->udp_port_hash) &&
                    (sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
                    sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) &&
-                   (*saddr_same)(sk, sk2, false)) {
+                   inet_rcv_saddr_equal(sk, sk2, false)) {
                        return reuseport_add_sock(sk, sk2);
                }
        }
@@ -233,14 +242,10 @@ static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot,
  *
  *  @sk:          socket struct in question
  *  @snum:        port number to look up
- *  @saddr_comp:  AF-dependent comparison of bound local IP addresses
  *  @hash2_nulladdr: AF-dependent hash value in secondary hash chains,
  *                   with NULL address
  */
 int udp_lib_get_port(struct sock *sk, unsigned short snum,
-                    int (*saddr_comp)(const struct sock *sk1,
-                                      const struct sock *sk2,
-                                      bool match_wildcard),
                     unsigned int hash2_nulladdr)
 {
        struct udp_hslot *hslot, *hslot2;
@@ -269,7 +274,7 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
                        bitmap_zero(bitmap, PORTS_PER_CHAIN);
                        spin_lock_bh(&hslot->lock);
                        udp_lib_lport_inuse(net, snum, hslot, bitmap, sk,
-                                           saddr_comp, udptable->log);
+                                           udptable->log);
 
                        snum = first;
                        /*
@@ -285,6 +290,7 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
                                snum += rand;
                        } while (snum != first);
                        spin_unlock_bh(&hslot->lock);
+                       cond_resched();
                } while (++first != last);
                goto fail;
        } else {
@@ -301,12 +307,11 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
                        if (hslot->count < hslot2->count)
                                goto scan_primary_hash;
 
-                       exist = udp_lib_lport_inuse2(net, snum, hslot2,
-                                                    sk, saddr_comp);
+                       exist = udp_lib_lport_inuse2(net, snum, hslot2, sk);
                        if (!exist && (hash2_nulladdr != slot2)) {
                                hslot2 = udp_hashslot2(udptable, hash2_nulladdr);
                                exist = udp_lib_lport_inuse2(net, snum, hslot2,
-                                                            sk, saddr_comp);
+                                                            sk);
                        }
                        if (exist)
                                goto fail_unlock;
@@ -314,8 +319,7 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
                                goto found;
                }
 scan_primary_hash:
-               if (udp_lib_lport_inuse(net, snum, hslot, NULL, sk,
-                                       saddr_comp, 0))
+               if (udp_lib_lport_inuse(net, snum, hslot, NULL, sk, 0))
                        goto fail_unlock;
        }
 found:
@@ -324,7 +328,7 @@ found:
        udp_sk(sk)->udp_portaddr_hash ^= snum;
        if (sk_unhashed(sk)) {
                if (sk->sk_reuseport &&
-                   udp_reuseport_add_sock(sk, hslot, saddr_comp)) {
+                   udp_reuseport_add_sock(sk, hslot)) {
                        inet_sk(sk)->inet_num = 0;
                        udp_sk(sk)->udp_port_hash = 0;
                        udp_sk(sk)->udp_portaddr_hash ^= snum;
@@ -356,24 +360,6 @@ fail:
 }
 EXPORT_SYMBOL(udp_lib_get_port);
 
-/* match_wildcard == true:  0.0.0.0 equals to any IPv4 addresses
- * match_wildcard == false: addresses must be exactly the same, i.e.
- *                          0.0.0.0 only equals to 0.0.0.0
- */
-int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2,
-                        bool match_wildcard)
-{
-       struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2);
-
-       if (!ipv6_only_sock(sk2)) {
-               if (inet1->inet_rcv_saddr == inet2->inet_rcv_saddr)
-                       return 1;
-               if (!inet1->inet_rcv_saddr || !inet2->inet_rcv_saddr)
-                       return match_wildcard;
-       }
-       return 0;
-}
-
 static u32 udp4_portaddr_hash(const struct net *net, __be32 saddr,
                              unsigned int port)
 {
@@ -389,12 +375,13 @@ int udp_v4_get_port(struct sock *sk, unsigned short snum)
 
        /* precompute partial secondary hash */
        udp_sk(sk)->udp_portaddr_hash = hash2_partial;
-       return udp_lib_get_port(sk, snum, ipv4_rcv_saddr_equal, hash2_nulladdr);
+       return udp_lib_get_port(sk, snum, hash2_nulladdr);
 }
 
 static int compute_score(struct sock *sk, struct net *net,
                         __be32 saddr, __be16 sport,
-                        __be32 daddr, unsigned short hnum, int dif)
+                        __be32 daddr, unsigned short hnum, int dif,
+                        bool exact_dif)
 {
        int score;
        struct inet_sock *inet;
@@ -425,7 +412,7 @@ static int compute_score(struct sock *sk, struct net *net,
                score += 4;
        }
 
-       if (sk->sk_bound_dev_if) {
+       if (sk->sk_bound_dev_if || exact_dif) {
                if (sk->sk_bound_dev_if != dif)
                        return -1;
                score += 4;
@@ -450,7 +437,7 @@ static u32 udp_ehashfn(const struct net *net, const __be32 laddr,
 /* called with rcu_read_lock() */
 static struct sock *udp4_lib_lookup2(struct net *net,
                __be32 saddr, __be16 sport,
-               __be32 daddr, unsigned int hnum, int dif,
+               __be32 daddr, unsigned int hnum, int dif, bool exact_dif,
                struct udp_hslot *hslot2,
                struct sk_buff *skb)
 {
@@ -462,7 +449,7 @@ static struct sock *udp4_lib_lookup2(struct net *net,
        badness = 0;
        udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
                score = compute_score(sk, net, saddr, sport,
-                                     daddr, hnum, dif);
+                                     daddr, hnum, dif, exact_dif);
                if (score > badness) {
                        reuseport = sk->sk_reuseport;
                        if (reuseport) {
@@ -497,6 +484,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
        unsigned short hnum = ntohs(dport);
        unsigned int hash2, slot2, slot = udp_hashfn(net, hnum, udptable->mask);
        struct udp_hslot *hslot2, *hslot = &udptable->hash[slot];
+       bool exact_dif = udp_lib_exact_dif_match(net, skb);
        int score, badness, matches = 0, reuseport = 0;
        u32 hash = 0;
 
@@ -509,7 +497,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
 
                result = udp4_lib_lookup2(net, saddr, sport,
                                          daddr, hnum, dif,
-                                         hslot2, skb);
+                                         exact_dif, hslot2, skb);
                if (!result) {
                        unsigned int old_slot2 = slot2;
                        hash2 = udp4_portaddr_hash(net, htonl(INADDR_ANY), hnum);
@@ -524,7 +512,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
 
                        result = udp4_lib_lookup2(net, saddr, sport,
                                                  daddr, hnum, dif,
-                                                 hslot2, skb);
+                                                 exact_dif, hslot2, skb);
                }
                return result;
        }
@@ -533,7 +521,7 @@ begin:
        badness = 0;
        sk_for_each_rcu(sk, &hslot->head) {
                score = compute_score(sk, net, saddr, sport,
-                                     daddr, hnum, dif);
+                                     daddr, hnum, dif, exact_dif);
                if (score > badness) {
                        reuseport = sk->sk_reuseport;
                        if (reuseport) {
@@ -1113,7 +1101,8 @@ out:
        return err;
 
 do_confirm:
-       dst_confirm(&rt->dst);
+       if (msg->msg_flags & MSG_PROBE)
+               dst_confirm_neigh(&rt->dst, &fl4->daddr);
        if (!(msg->msg_flags&MSG_PROBE) || len)
                goto back_from_confirm;
        err = 0;