]> git.kernelconcepts.de Git - karo-tx-linux.git/blob - net/bridge/br_multicast.c
10e6fce1bb62abb0987a98c167bf653356e088ac
[karo-tx-linux.git] / net / bridge / br_multicast.c
1 /*
2  * Bridge multicast support.
3  *
4  * Copyright (c) 2010 Herbert Xu <herbert@gondor.apana.org.au>
5  *
6  * This program is free software; you can redistribute it and/or modify it
7  * under the terms of the GNU General Public License as published by the Free
8  * Software Foundation; either version 2 of the License, or (at your option)
9  * any later version.
10  *
11  */
12
13 #include <linux/err.h>
14 #include <linux/if_ether.h>
15 #include <linux/igmp.h>
16 #include <linux/jhash.h>
17 #include <linux/kernel.h>
18 #include <linux/log2.h>
19 #include <linux/netdevice.h>
20 #include <linux/netfilter_bridge.h>
21 #include <linux/random.h>
22 #include <linux/rculist.h>
23 #include <linux/skbuff.h>
24 #include <linux/slab.h>
25 #include <linux/timer.h>
26 #include <net/ip.h>
27 #if IS_ENABLED(CONFIG_IPV6)
28 #include <net/ipv6.h>
29 #include <net/mld.h>
30 #include <net/ip6_checksum.h>
31 #endif
32
33 #include "br_private.h"
34
35 static void br_multicast_start_querier(struct net_bridge *br);
36 unsigned int br_mdb_rehash_seq;
37
38 static inline int br_ip_equal(const struct br_ip *a, const struct br_ip *b)
39 {
40         if (a->proto != b->proto)
41                 return 0;
42         if (a->vid != b->vid)
43                 return 0;
44         switch (a->proto) {
45         case htons(ETH_P_IP):
46                 return a->u.ip4 == b->u.ip4;
47 #if IS_ENABLED(CONFIG_IPV6)
48         case htons(ETH_P_IPV6):
49                 return ipv6_addr_equal(&a->u.ip6, &b->u.ip6);
50 #endif
51         }
52         return 0;
53 }
54
55 static inline int __br_ip4_hash(struct net_bridge_mdb_htable *mdb, __be32 ip,
56                                 __u16 vid)
57 {
58         return jhash_2words((__force u32)ip, vid, mdb->secret) & (mdb->max - 1);
59 }
60
61 #if IS_ENABLED(CONFIG_IPV6)
62 static inline int __br_ip6_hash(struct net_bridge_mdb_htable *mdb,
63                                 const struct in6_addr *ip,
64                                 __u16 vid)
65 {
66         return jhash_2words(ipv6_addr_hash(ip), vid,
67                             mdb->secret) & (mdb->max - 1);
68 }
69 #endif
70
71 static inline int br_ip_hash(struct net_bridge_mdb_htable *mdb,
72                              struct br_ip *ip)
73 {
74         switch (ip->proto) {
75         case htons(ETH_P_IP):
76                 return __br_ip4_hash(mdb, ip->u.ip4, ip->vid);
77 #if IS_ENABLED(CONFIG_IPV6)
78         case htons(ETH_P_IPV6):
79                 return __br_ip6_hash(mdb, &ip->u.ip6, ip->vid);
80 #endif
81         }
82         return 0;
83 }
84
85 static struct net_bridge_mdb_entry *__br_mdb_ip_get(
86         struct net_bridge_mdb_htable *mdb, struct br_ip *dst, int hash)
87 {
88         struct net_bridge_mdb_entry *mp;
89
90         hlist_for_each_entry_rcu(mp, &mdb->mhash[hash], hlist[mdb->ver]) {
91                 if (br_ip_equal(&mp->addr, dst))
92                         return mp;
93         }
94
95         return NULL;
96 }
97
98 struct net_bridge_mdb_entry *br_mdb_ip_get(struct net_bridge_mdb_htable *mdb,
99                                            struct br_ip *dst)
100 {
101         if (!mdb)
102                 return NULL;
103
104         return __br_mdb_ip_get(mdb, dst, br_ip_hash(mdb, dst));
105 }
106
107 static struct net_bridge_mdb_entry *br_mdb_ip4_get(
108         struct net_bridge_mdb_htable *mdb, __be32 dst, __u16 vid)
109 {
110         struct br_ip br_dst;
111
112         br_dst.u.ip4 = dst;
113         br_dst.proto = htons(ETH_P_IP);
114         br_dst.vid = vid;
115
116         return br_mdb_ip_get(mdb, &br_dst);
117 }
118
119 #if IS_ENABLED(CONFIG_IPV6)
120 static struct net_bridge_mdb_entry *br_mdb_ip6_get(
121         struct net_bridge_mdb_htable *mdb, const struct in6_addr *dst,
122         __u16 vid)
123 {
124         struct br_ip br_dst;
125
126         br_dst.u.ip6 = *dst;
127         br_dst.proto = htons(ETH_P_IPV6);
128         br_dst.vid = vid;
129
130         return br_mdb_ip_get(mdb, &br_dst);
131 }
132 #endif
133
134 struct net_bridge_mdb_entry *br_mdb_get(struct net_bridge *br,
135                                         struct sk_buff *skb)
136 {
137         struct net_bridge_mdb_htable *mdb = rcu_dereference(br->mdb);
138         struct br_ip ip;
139
140         if (br->multicast_disabled)
141                 return NULL;
142
143         if (BR_INPUT_SKB_CB(skb)->igmp)
144                 return NULL;
145
146         ip.proto = skb->protocol;
147
148         switch (skb->protocol) {
149         case htons(ETH_P_IP):
150                 ip.u.ip4 = ip_hdr(skb)->daddr;
151                 break;
152 #if IS_ENABLED(CONFIG_IPV6)
153         case htons(ETH_P_IPV6):
154                 ip.u.ip6 = ipv6_hdr(skb)->daddr;
155                 break;
156 #endif
157         default:
158                 return NULL;
159         }
160
161         return br_mdb_ip_get(mdb, &ip);
162 }
163
164 static void br_mdb_free(struct rcu_head *head)
165 {
166         struct net_bridge_mdb_htable *mdb =
167                 container_of(head, struct net_bridge_mdb_htable, rcu);
168         struct net_bridge_mdb_htable *old = mdb->old;
169
170         mdb->old = NULL;
171         kfree(old->mhash);
172         kfree(old);
173 }
174
175 static int br_mdb_copy(struct net_bridge_mdb_htable *new,
176                        struct net_bridge_mdb_htable *old,
177                        int elasticity)
178 {
179         struct net_bridge_mdb_entry *mp;
180         int maxlen;
181         int len;
182         int i;
183
184         for (i = 0; i < old->max; i++)
185                 hlist_for_each_entry(mp, &old->mhash[i], hlist[old->ver])
186                         hlist_add_head(&mp->hlist[new->ver],
187                                        &new->mhash[br_ip_hash(new, &mp->addr)]);
188
189         if (!elasticity)
190                 return 0;
191
192         maxlen = 0;
193         for (i = 0; i < new->max; i++) {
194                 len = 0;
195                 hlist_for_each_entry(mp, &new->mhash[i], hlist[new->ver])
196                         len++;
197                 if (len > maxlen)
198                         maxlen = len;
199         }
200
201         return maxlen > elasticity ? -EINVAL : 0;
202 }
203
204 void br_multicast_free_pg(struct rcu_head *head)
205 {
206         struct net_bridge_port_group *p =
207                 container_of(head, struct net_bridge_port_group, rcu);
208
209         kfree(p);
210 }
211
212 static void br_multicast_free_group(struct rcu_head *head)
213 {
214         struct net_bridge_mdb_entry *mp =
215                 container_of(head, struct net_bridge_mdb_entry, rcu);
216
217         kfree(mp);
218 }
219
220 static void br_multicast_group_expired(unsigned long data)
221 {
222         struct net_bridge_mdb_entry *mp = (void *)data;
223         struct net_bridge *br = mp->br;
224         struct net_bridge_mdb_htable *mdb;
225
226         spin_lock(&br->multicast_lock);
227         if (!netif_running(br->dev) || timer_pending(&mp->timer))
228                 goto out;
229
230         mp->mglist = false;
231
232         if (mp->ports)
233                 goto out;
234
235         mdb = mlock_dereference(br->mdb, br);
236
237         hlist_del_rcu(&mp->hlist[mdb->ver]);
238         mdb->size--;
239
240         call_rcu_bh(&mp->rcu, br_multicast_free_group);
241
242 out:
243         spin_unlock(&br->multicast_lock);
244 }
245
246 static void br_multicast_del_pg(struct net_bridge *br,
247                                 struct net_bridge_port_group *pg)
248 {
249         struct net_bridge_mdb_htable *mdb;
250         struct net_bridge_mdb_entry *mp;
251         struct net_bridge_port_group *p;
252         struct net_bridge_port_group __rcu **pp;
253
254         mdb = mlock_dereference(br->mdb, br);
255
256         mp = br_mdb_ip_get(mdb, &pg->addr);
257         if (WARN_ON(!mp))
258                 return;
259
260         for (pp = &mp->ports;
261              (p = mlock_dereference(*pp, br)) != NULL;
262              pp = &p->next) {
263                 if (p != pg)
264                         continue;
265
266                 rcu_assign_pointer(*pp, p->next);
267                 hlist_del_init(&p->mglist);
268                 del_timer(&p->timer);
269                 call_rcu_bh(&p->rcu, br_multicast_free_pg);
270
271                 if (!mp->ports && !mp->mglist &&
272                     netif_running(br->dev))
273                         mod_timer(&mp->timer, jiffies);
274
275                 return;
276         }
277
278         WARN_ON(1);
279 }
280
281 static void br_multicast_port_group_expired(unsigned long data)
282 {
283         struct net_bridge_port_group *pg = (void *)data;
284         struct net_bridge *br = pg->port->br;
285
286         spin_lock(&br->multicast_lock);
287         if (!netif_running(br->dev) || timer_pending(&pg->timer) ||
288             hlist_unhashed(&pg->mglist) || pg->state & MDB_PERMANENT)
289                 goto out;
290
291         br_multicast_del_pg(br, pg);
292
293 out:
294         spin_unlock(&br->multicast_lock);
295 }
296
297 static int br_mdb_rehash(struct net_bridge_mdb_htable __rcu **mdbp, int max,
298                          int elasticity)
299 {
300         struct net_bridge_mdb_htable *old = rcu_dereference_protected(*mdbp, 1);
301         struct net_bridge_mdb_htable *mdb;
302         int err;
303
304         mdb = kmalloc(sizeof(*mdb), GFP_ATOMIC);
305         if (!mdb)
306                 return -ENOMEM;
307
308         mdb->max = max;
309         mdb->old = old;
310
311         mdb->mhash = kzalloc(max * sizeof(*mdb->mhash), GFP_ATOMIC);
312         if (!mdb->mhash) {
313                 kfree(mdb);
314                 return -ENOMEM;
315         }
316
317         mdb->size = old ? old->size : 0;
318         mdb->ver = old ? old->ver ^ 1 : 0;
319
320         if (!old || elasticity)
321                 get_random_bytes(&mdb->secret, sizeof(mdb->secret));
322         else
323                 mdb->secret = old->secret;
324
325         if (!old)
326                 goto out;
327
328         err = br_mdb_copy(mdb, old, elasticity);
329         if (err) {
330                 kfree(mdb->mhash);
331                 kfree(mdb);
332                 return err;
333         }
334
335         br_mdb_rehash_seq++;
336         call_rcu_bh(&mdb->rcu, br_mdb_free);
337
338 out:
339         rcu_assign_pointer(*mdbp, mdb);
340
341         return 0;
342 }
343
344 static struct sk_buff *br_ip4_multicast_alloc_query(struct net_bridge *br,
345                                                     __be32 group)
346 {
347         struct sk_buff *skb;
348         struct igmphdr *ih;
349         struct ethhdr *eth;
350         struct iphdr *iph;
351
352         skb = netdev_alloc_skb_ip_align(br->dev, sizeof(*eth) + sizeof(*iph) +
353                                                  sizeof(*ih) + 4);
354         if (!skb)
355                 goto out;
356
357         skb->protocol = htons(ETH_P_IP);
358
359         skb_reset_mac_header(skb);
360         eth = eth_hdr(skb);
361
362         memcpy(eth->h_source, br->dev->dev_addr, 6);
363         eth->h_dest[0] = 1;
364         eth->h_dest[1] = 0;
365         eth->h_dest[2] = 0x5e;
366         eth->h_dest[3] = 0;
367         eth->h_dest[4] = 0;
368         eth->h_dest[5] = 1;
369         eth->h_proto = htons(ETH_P_IP);
370         skb_put(skb, sizeof(*eth));
371
372         skb_set_network_header(skb, skb->len);
373         iph = ip_hdr(skb);
374
375         iph->version = 4;
376         iph->ihl = 6;
377         iph->tos = 0xc0;
378         iph->tot_len = htons(sizeof(*iph) + sizeof(*ih) + 4);
379         iph->id = 0;
380         iph->frag_off = htons(IP_DF);
381         iph->ttl = 1;
382         iph->protocol = IPPROTO_IGMP;
383         iph->saddr = 0;
384         iph->daddr = htonl(INADDR_ALLHOSTS_GROUP);
385         ((u8 *)&iph[1])[0] = IPOPT_RA;
386         ((u8 *)&iph[1])[1] = 4;
387         ((u8 *)&iph[1])[2] = 0;
388         ((u8 *)&iph[1])[3] = 0;
389         ip_send_check(iph);
390         skb_put(skb, 24);
391
392         skb_set_transport_header(skb, skb->len);
393         ih = igmp_hdr(skb);
394         ih->type = IGMP_HOST_MEMBERSHIP_QUERY;
395         ih->code = (group ? br->multicast_last_member_interval :
396                             br->multicast_query_response_interval) /
397                    (HZ / IGMP_TIMER_SCALE);
398         ih->group = group;
399         ih->csum = 0;
400         ih->csum = ip_compute_csum((void *)ih, sizeof(struct igmphdr));
401         skb_put(skb, sizeof(*ih));
402
403         __skb_pull(skb, sizeof(*eth));
404
405 out:
406         return skb;
407 }
408
409 #if IS_ENABLED(CONFIG_IPV6)
410 static struct sk_buff *br_ip6_multicast_alloc_query(struct net_bridge *br,
411                                                     const struct in6_addr *group)
412 {
413         struct sk_buff *skb;
414         struct ipv6hdr *ip6h;
415         struct mld_msg *mldq;
416         struct ethhdr *eth;
417         u8 *hopopt;
418         unsigned long interval;
419
420         skb = netdev_alloc_skb_ip_align(br->dev, sizeof(*eth) + sizeof(*ip6h) +
421                                                  8 + sizeof(*mldq));
422         if (!skb)
423                 goto out;
424
425         skb->protocol = htons(ETH_P_IPV6);
426
427         /* Ethernet header */
428         skb_reset_mac_header(skb);
429         eth = eth_hdr(skb);
430
431         memcpy(eth->h_source, br->dev->dev_addr, 6);
432         eth->h_proto = htons(ETH_P_IPV6);
433         skb_put(skb, sizeof(*eth));
434
435         /* IPv6 header + HbH option */
436         skb_set_network_header(skb, skb->len);
437         ip6h = ipv6_hdr(skb);
438
439         *(__force __be32 *)ip6h = htonl(0x60000000);
440         ip6h->payload_len = htons(8 + sizeof(*mldq));
441         ip6h->nexthdr = IPPROTO_HOPOPTS;
442         ip6h->hop_limit = 1;
443         ipv6_addr_set(&ip6h->daddr, htonl(0xff020000), 0, 0, htonl(1));
444         if (ipv6_dev_get_saddr(dev_net(br->dev), br->dev, &ip6h->daddr, 0,
445                                &ip6h->saddr)) {
446                 kfree_skb(skb);
447                 return NULL;
448         }
449         ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
450
451         hopopt = (u8 *)(ip6h + 1);
452         hopopt[0] = IPPROTO_ICMPV6;             /* next hdr */
453         hopopt[1] = 0;                          /* length of HbH */
454         hopopt[2] = IPV6_TLV_ROUTERALERT;       /* Router Alert */
455         hopopt[3] = 2;                          /* Length of RA Option */
456         hopopt[4] = 0;                          /* Type = 0x0000 (MLD) */
457         hopopt[5] = 0;
458         hopopt[6] = IPV6_TLV_PAD1;              /* Pad1 */
459         hopopt[7] = IPV6_TLV_PAD1;              /* Pad1 */
460
461         skb_put(skb, sizeof(*ip6h) + 8);
462
463         /* ICMPv6 */
464         skb_set_transport_header(skb, skb->len);
465         mldq = (struct mld_msg *) icmp6_hdr(skb);
466
467         interval = ipv6_addr_any(group) ? br->multicast_last_member_interval :
468                                           br->multicast_query_response_interval;
469
470         mldq->mld_type = ICMPV6_MGM_QUERY;
471         mldq->mld_code = 0;
472         mldq->mld_cksum = 0;
473         mldq->mld_maxdelay = htons((u16)jiffies_to_msecs(interval));
474         mldq->mld_reserved = 0;
475         mldq->mld_mca = *group;
476
477         /* checksum */
478         mldq->mld_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
479                                           sizeof(*mldq), IPPROTO_ICMPV6,
480                                           csum_partial(mldq,
481                                                        sizeof(*mldq), 0));
482         skb_put(skb, sizeof(*mldq));
483
484         __skb_pull(skb, sizeof(*eth));
485
486 out:
487         return skb;
488 }
489 #endif
490
491 static struct sk_buff *br_multicast_alloc_query(struct net_bridge *br,
492                                                 struct br_ip *addr)
493 {
494         switch (addr->proto) {
495         case htons(ETH_P_IP):
496                 return br_ip4_multicast_alloc_query(br, addr->u.ip4);
497 #if IS_ENABLED(CONFIG_IPV6)
498         case htons(ETH_P_IPV6):
499                 return br_ip6_multicast_alloc_query(br, &addr->u.ip6);
500 #endif
501         }
502         return NULL;
503 }
504
505 static struct net_bridge_mdb_entry *br_multicast_get_group(
506         struct net_bridge *br, struct net_bridge_port *port,
507         struct br_ip *group, int hash)
508 {
509         struct net_bridge_mdb_htable *mdb;
510         struct net_bridge_mdb_entry *mp;
511         unsigned int count = 0;
512         unsigned int max;
513         int elasticity;
514         int err;
515
516         mdb = rcu_dereference_protected(br->mdb, 1);
517         hlist_for_each_entry(mp, &mdb->mhash[hash], hlist[mdb->ver]) {
518                 count++;
519                 if (unlikely(br_ip_equal(group, &mp->addr)))
520                         return mp;
521         }
522
523         elasticity = 0;
524         max = mdb->max;
525
526         if (unlikely(count > br->hash_elasticity && count)) {
527                 if (net_ratelimit())
528                         br_info(br, "Multicast hash table "
529                                 "chain limit reached: %s\n",
530                                 port ? port->dev->name : br->dev->name);
531
532                 elasticity = br->hash_elasticity;
533         }
534
535         if (mdb->size >= max) {
536                 max *= 2;
537                 if (unlikely(max > br->hash_max)) {
538                         br_warn(br, "Multicast hash table maximum of %d "
539                                 "reached, disabling snooping: %s\n",
540                                 br->hash_max,
541                                 port ? port->dev->name : br->dev->name);
542                         err = -E2BIG;
543 disable:
544                         br->multicast_disabled = 1;
545                         goto err;
546                 }
547         }
548
549         if (max > mdb->max || elasticity) {
550                 if (mdb->old) {
551                         if (net_ratelimit())
552                                 br_info(br, "Multicast hash table "
553                                         "on fire: %s\n",
554                                         port ? port->dev->name : br->dev->name);
555                         err = -EEXIST;
556                         goto err;
557                 }
558
559                 err = br_mdb_rehash(&br->mdb, max, elasticity);
560                 if (err) {
561                         br_warn(br, "Cannot rehash multicast "
562                                 "hash table, disabling snooping: %s, %d, %d\n",
563                                 port ? port->dev->name : br->dev->name,
564                                 mdb->size, err);
565                         goto disable;
566                 }
567
568                 err = -EAGAIN;
569                 goto err;
570         }
571
572         return NULL;
573
574 err:
575         mp = ERR_PTR(err);
576         return mp;
577 }
578
579 struct net_bridge_mdb_entry *br_multicast_new_group(struct net_bridge *br,
580         struct net_bridge_port *port, struct br_ip *group)
581 {
582         struct net_bridge_mdb_htable *mdb;
583         struct net_bridge_mdb_entry *mp;
584         int hash;
585         int err;
586
587         mdb = rcu_dereference_protected(br->mdb, 1);
588         if (!mdb) {
589                 err = br_mdb_rehash(&br->mdb, BR_HASH_SIZE, 0);
590                 if (err)
591                         return ERR_PTR(err);
592                 goto rehash;
593         }
594
595         hash = br_ip_hash(mdb, group);
596         mp = br_multicast_get_group(br, port, group, hash);
597         switch (PTR_ERR(mp)) {
598         case 0:
599                 break;
600
601         case -EAGAIN:
602 rehash:
603                 mdb = rcu_dereference_protected(br->mdb, 1);
604                 hash = br_ip_hash(mdb, group);
605                 break;
606
607         default:
608                 goto out;
609         }
610
611         mp = kzalloc(sizeof(*mp), GFP_ATOMIC);
612         if (unlikely(!mp))
613                 return ERR_PTR(-ENOMEM);
614
615         mp->br = br;
616         mp->addr = *group;
617         setup_timer(&mp->timer, br_multicast_group_expired,
618                     (unsigned long)mp);
619
620         hlist_add_head_rcu(&mp->hlist[mdb->ver], &mdb->mhash[hash]);
621         mdb->size++;
622
623 out:
624         return mp;
625 }
626
627 struct net_bridge_port_group *br_multicast_new_port_group(
628                         struct net_bridge_port *port,
629                         struct br_ip *group,
630                         struct net_bridge_port_group __rcu *next,
631                         unsigned char state)
632 {
633         struct net_bridge_port_group *p;
634
635         p = kzalloc(sizeof(*p), GFP_ATOMIC);
636         if (unlikely(!p))
637                 return NULL;
638
639         p->addr = *group;
640         p->port = port;
641         p->state = state;
642         rcu_assign_pointer(p->next, next);
643         hlist_add_head(&p->mglist, &port->mglist);
644         setup_timer(&p->timer, br_multicast_port_group_expired,
645                     (unsigned long)p);
646         return p;
647 }
648
649 static int br_multicast_add_group(struct net_bridge *br,
650                                   struct net_bridge_port *port,
651                                   struct br_ip *group)
652 {
653         struct net_bridge_mdb_entry *mp;
654         struct net_bridge_port_group *p;
655         struct net_bridge_port_group __rcu **pp;
656         unsigned long now = jiffies;
657         int err;
658
659         spin_lock(&br->multicast_lock);
660         if (!netif_running(br->dev) ||
661             (port && port->state == BR_STATE_DISABLED))
662                 goto out;
663
664         mp = br_multicast_new_group(br, port, group);
665         err = PTR_ERR(mp);
666         if (IS_ERR(mp))
667                 goto err;
668
669         if (!port) {
670                 mp->mglist = true;
671                 mod_timer(&mp->timer, now + br->multicast_membership_interval);
672                 goto out;
673         }
674
675         for (pp = &mp->ports;
676              (p = mlock_dereference(*pp, br)) != NULL;
677              pp = &p->next) {
678                 if (p->port == port)
679                         goto found;
680                 if ((unsigned long)p->port < (unsigned long)port)
681                         break;
682         }
683
684         p = br_multicast_new_port_group(port, group, *pp, MDB_TEMPORARY);
685         if (unlikely(!p))
686                 goto err;
687         rcu_assign_pointer(*pp, p);
688         br_mdb_notify(br->dev, port, group, RTM_NEWMDB);
689
690 found:
691         mod_timer(&p->timer, now + br->multicast_membership_interval);
692 out:
693         err = 0;
694
695 err:
696         spin_unlock(&br->multicast_lock);
697         return err;
698 }
699
700 static int br_ip4_multicast_add_group(struct net_bridge *br,
701                                       struct net_bridge_port *port,
702                                       __be32 group,
703                                       __u16 vid)
704 {
705         struct br_ip br_group;
706
707         if (ipv4_is_local_multicast(group))
708                 return 0;
709
710         br_group.u.ip4 = group;
711         br_group.proto = htons(ETH_P_IP);
712         br_group.vid = vid;
713
714         return br_multicast_add_group(br, port, &br_group);
715 }
716
717 #if IS_ENABLED(CONFIG_IPV6)
718 static int br_ip6_multicast_add_group(struct net_bridge *br,
719                                       struct net_bridge_port *port,
720                                       const struct in6_addr *group,
721                                       __u16 vid)
722 {
723         struct br_ip br_group;
724
725         if (!ipv6_is_transient_multicast(group))
726                 return 0;
727
728         br_group.u.ip6 = *group;
729         br_group.proto = htons(ETH_P_IPV6);
730         br_group.vid = vid;
731
732         return br_multicast_add_group(br, port, &br_group);
733 }
734 #endif
735
736 static void br_multicast_router_expired(unsigned long data)
737 {
738         struct net_bridge_port *port = (void *)data;
739         struct net_bridge *br = port->br;
740
741         spin_lock(&br->multicast_lock);
742         if (port->multicast_router != 1 ||
743             timer_pending(&port->multicast_router_timer) ||
744             hlist_unhashed(&port->rlist))
745                 goto out;
746
747         hlist_del_init_rcu(&port->rlist);
748
749 out:
750         spin_unlock(&br->multicast_lock);
751 }
752
753 static void br_multicast_local_router_expired(unsigned long data)
754 {
755 }
756
757 static void br_multicast_querier_expired(unsigned long data)
758 {
759         struct net_bridge *br = (void *)data;
760
761         spin_lock(&br->multicast_lock);
762         if (!netif_running(br->dev) || br->multicast_disabled)
763                 goto out;
764
765         br_multicast_start_querier(br);
766
767 out:
768         spin_unlock(&br->multicast_lock);
769 }
770
771 static void __br_multicast_send_query(struct net_bridge *br,
772                                       struct net_bridge_port *port,
773                                       struct br_ip *ip)
774 {
775         struct sk_buff *skb;
776
777         skb = br_multicast_alloc_query(br, ip);
778         if (!skb)
779                 return;
780
781         if (port) {
782                 __skb_push(skb, sizeof(struct ethhdr));
783                 skb->dev = port->dev;
784                 NF_HOOK(NFPROTO_BRIDGE, NF_BR_LOCAL_OUT, skb, NULL, skb->dev,
785                         dev_queue_xmit);
786         } else
787                 netif_rx(skb);
788 }
789
790 static void br_multicast_send_query(struct net_bridge *br,
791                                     struct net_bridge_port *port, u32 sent)
792 {
793         unsigned long time;
794         struct br_ip br_group;
795
796         if (!netif_running(br->dev) || br->multicast_disabled ||
797             !br->multicast_querier ||
798             timer_pending(&br->multicast_querier_timer))
799                 return;
800
801         memset(&br_group.u, 0, sizeof(br_group.u));
802
803         br_group.proto = htons(ETH_P_IP);
804         __br_multicast_send_query(br, port, &br_group);
805
806 #if IS_ENABLED(CONFIG_IPV6)
807         br_group.proto = htons(ETH_P_IPV6);
808         __br_multicast_send_query(br, port, &br_group);
809 #endif
810
811         time = jiffies;
812         time += sent < br->multicast_startup_query_count ?
813                 br->multicast_startup_query_interval :
814                 br->multicast_query_interval;
815         mod_timer(port ? &port->multicast_query_timer :
816                          &br->multicast_query_timer, time);
817 }
818
819 static void br_multicast_port_query_expired(unsigned long data)
820 {
821         struct net_bridge_port *port = (void *)data;
822         struct net_bridge *br = port->br;
823
824         spin_lock(&br->multicast_lock);
825         if (port->state == BR_STATE_DISABLED ||
826             port->state == BR_STATE_BLOCKING)
827                 goto out;
828
829         if (port->multicast_startup_queries_sent <
830             br->multicast_startup_query_count)
831                 port->multicast_startup_queries_sent++;
832
833         br_multicast_send_query(port->br, port,
834                                 port->multicast_startup_queries_sent);
835
836 out:
837         spin_unlock(&br->multicast_lock);
838 }
839
840 void br_multicast_add_port(struct net_bridge_port *port)
841 {
842         port->multicast_router = 1;
843
844         setup_timer(&port->multicast_router_timer, br_multicast_router_expired,
845                     (unsigned long)port);
846         setup_timer(&port->multicast_query_timer,
847                     br_multicast_port_query_expired, (unsigned long)port);
848 }
849
850 void br_multicast_del_port(struct net_bridge_port *port)
851 {
852         del_timer_sync(&port->multicast_router_timer);
853 }
854
855 static void __br_multicast_enable_port(struct net_bridge_port *port)
856 {
857         port->multicast_startup_queries_sent = 0;
858
859         if (try_to_del_timer_sync(&port->multicast_query_timer) >= 0 ||
860             del_timer(&port->multicast_query_timer))
861                 mod_timer(&port->multicast_query_timer, jiffies);
862 }
863
864 void br_multicast_enable_port(struct net_bridge_port *port)
865 {
866         struct net_bridge *br = port->br;
867
868         spin_lock(&br->multicast_lock);
869         if (br->multicast_disabled || !netif_running(br->dev))
870                 goto out;
871
872         __br_multicast_enable_port(port);
873
874 out:
875         spin_unlock(&br->multicast_lock);
876 }
877
878 void br_multicast_disable_port(struct net_bridge_port *port)
879 {
880         struct net_bridge *br = port->br;
881         struct net_bridge_port_group *pg;
882         struct hlist_node *n;
883
884         spin_lock(&br->multicast_lock);
885         hlist_for_each_entry_safe(pg, n, &port->mglist, mglist)
886                 br_multicast_del_pg(br, pg);
887
888         if (!hlist_unhashed(&port->rlist))
889                 hlist_del_init_rcu(&port->rlist);
890         del_timer(&port->multicast_router_timer);
891         del_timer(&port->multicast_query_timer);
892         spin_unlock(&br->multicast_lock);
893 }
894
895 static int br_ip4_multicast_igmp3_report(struct net_bridge *br,
896                                          struct net_bridge_port *port,
897                                          struct sk_buff *skb)
898 {
899         struct igmpv3_report *ih;
900         struct igmpv3_grec *grec;
901         int i;
902         int len;
903         int num;
904         int type;
905         int err = 0;
906         __be32 group;
907         u16 vid = 0;
908
909         if (!pskb_may_pull(skb, sizeof(*ih)))
910                 return -EINVAL;
911
912         br_vlan_get_tag(skb, &vid);
913         ih = igmpv3_report_hdr(skb);
914         num = ntohs(ih->ngrec);
915         len = sizeof(*ih);
916
917         for (i = 0; i < num; i++) {
918                 len += sizeof(*grec);
919                 if (!pskb_may_pull(skb, len))
920                         return -EINVAL;
921
922                 grec = (void *)(skb->data + len - sizeof(*grec));
923                 group = grec->grec_mca;
924                 type = grec->grec_type;
925
926                 len += ntohs(grec->grec_nsrcs) * 4;
927                 if (!pskb_may_pull(skb, len))
928                         return -EINVAL;
929
930                 /* We treat this as an IGMPv2 report for now. */
931                 switch (type) {
932                 case IGMPV3_MODE_IS_INCLUDE:
933                 case IGMPV3_MODE_IS_EXCLUDE:
934                 case IGMPV3_CHANGE_TO_INCLUDE:
935                 case IGMPV3_CHANGE_TO_EXCLUDE:
936                 case IGMPV3_ALLOW_NEW_SOURCES:
937                 case IGMPV3_BLOCK_OLD_SOURCES:
938                         break;
939
940                 default:
941                         continue;
942                 }
943
944                 err = br_ip4_multicast_add_group(br, port, group, vid);
945                 if (err)
946                         break;
947         }
948
949         return err;
950 }
951
952 #if IS_ENABLED(CONFIG_IPV6)
953 static int br_ip6_multicast_mld2_report(struct net_bridge *br,
954                                         struct net_bridge_port *port,
955                                         struct sk_buff *skb)
956 {
957         struct icmp6hdr *icmp6h;
958         struct mld2_grec *grec;
959         int i;
960         int len;
961         int num;
962         int err = 0;
963         u16 vid = 0;
964
965         if (!pskb_may_pull(skb, sizeof(*icmp6h)))
966                 return -EINVAL;
967
968         br_vlan_get_tag(skb, &vid);
969         icmp6h = icmp6_hdr(skb);
970         num = ntohs(icmp6h->icmp6_dataun.un_data16[1]);
971         len = sizeof(*icmp6h);
972
973         for (i = 0; i < num; i++) {
974                 __be16 *nsrcs, _nsrcs;
975
976                 nsrcs = skb_header_pointer(skb,
977                                            len + offsetof(struct mld2_grec,
978                                                           grec_nsrcs),
979                                            sizeof(_nsrcs), &_nsrcs);
980                 if (!nsrcs)
981                         return -EINVAL;
982
983                 if (!pskb_may_pull(skb,
984                                    len + sizeof(*grec) +
985                                    sizeof(struct in6_addr) * ntohs(*nsrcs)))
986                         return -EINVAL;
987
988                 grec = (struct mld2_grec *)(skb->data + len);
989                 len += sizeof(*grec) +
990                        sizeof(struct in6_addr) * ntohs(*nsrcs);
991
992                 /* We treat these as MLDv1 reports for now. */
993                 switch (grec->grec_type) {
994                 case MLD2_MODE_IS_INCLUDE:
995                 case MLD2_MODE_IS_EXCLUDE:
996                 case MLD2_CHANGE_TO_INCLUDE:
997                 case MLD2_CHANGE_TO_EXCLUDE:
998                 case MLD2_ALLOW_NEW_SOURCES:
999                 case MLD2_BLOCK_OLD_SOURCES:
1000                         break;
1001
1002                 default:
1003                         continue;
1004                 }
1005
1006                 err = br_ip6_multicast_add_group(br, port, &grec->grec_mca,
1007                                                  vid);
1008                 if (!err)
1009                         break;
1010         }
1011
1012         return err;
1013 }
1014 #endif
1015
1016 /*
1017  * Add port to rotuer_list
1018  *  list is maintained ordered by pointer value
1019  *  and locked by br->multicast_lock and RCU
1020  */
1021 static void br_multicast_add_router(struct net_bridge *br,
1022                                     struct net_bridge_port *port)
1023 {
1024         struct net_bridge_port *p;
1025         struct hlist_node *slot = NULL;
1026
1027         hlist_for_each_entry(p, &br->router_list, rlist) {
1028                 if ((unsigned long) port >= (unsigned long) p)
1029                         break;
1030                 slot = &p->rlist;
1031         }
1032
1033         if (slot)
1034                 hlist_add_after_rcu(slot, &port->rlist);
1035         else
1036                 hlist_add_head_rcu(&port->rlist, &br->router_list);
1037 }
1038
1039 static void br_multicast_mark_router(struct net_bridge *br,
1040                                      struct net_bridge_port *port)
1041 {
1042         unsigned long now = jiffies;
1043
1044         if (!port) {
1045                 if (br->multicast_router == 1)
1046                         mod_timer(&br->multicast_router_timer,
1047                                   now + br->multicast_querier_interval);
1048                 return;
1049         }
1050
1051         if (port->multicast_router != 1)
1052                 return;
1053
1054         if (!hlist_unhashed(&port->rlist))
1055                 goto timer;
1056
1057         br_multicast_add_router(br, port);
1058
1059 timer:
1060         mod_timer(&port->multicast_router_timer,
1061                   now + br->multicast_querier_interval);
1062 }
1063
1064 static void br_multicast_query_received(struct net_bridge *br,
1065                                         struct net_bridge_port *port,
1066                                         int saddr)
1067 {
1068         if (saddr)
1069                 mod_timer(&br->multicast_querier_timer,
1070                           jiffies + br->multicast_querier_interval);
1071         else if (timer_pending(&br->multicast_querier_timer))
1072                 return;
1073
1074         br_multicast_mark_router(br, port);
1075 }
1076
1077 static int br_ip4_multicast_query(struct net_bridge *br,
1078                                   struct net_bridge_port *port,
1079                                   struct sk_buff *skb)
1080 {
1081         const struct iphdr *iph = ip_hdr(skb);
1082         struct igmphdr *ih = igmp_hdr(skb);
1083         struct net_bridge_mdb_entry *mp;
1084         struct igmpv3_query *ih3;
1085         struct net_bridge_port_group *p;
1086         struct net_bridge_port_group __rcu **pp;
1087         unsigned long max_delay;
1088         unsigned long now = jiffies;
1089         __be32 group;
1090         int err = 0;
1091         u16 vid = 0;
1092
1093         spin_lock(&br->multicast_lock);
1094         if (!netif_running(br->dev) ||
1095             (port && port->state == BR_STATE_DISABLED))
1096                 goto out;
1097
1098         br_multicast_query_received(br, port, !!iph->saddr);
1099
1100         group = ih->group;
1101
1102         if (skb->len == sizeof(*ih)) {
1103                 max_delay = ih->code * (HZ / IGMP_TIMER_SCALE);
1104
1105                 if (!max_delay) {
1106                         max_delay = 10 * HZ;
1107                         group = 0;
1108                 }
1109         } else {
1110                 if (!pskb_may_pull(skb, sizeof(struct igmpv3_query))) {
1111                         err = -EINVAL;
1112                         goto out;
1113                 }
1114
1115                 ih3 = igmpv3_query_hdr(skb);
1116                 if (ih3->nsrcs)
1117                         goto out;
1118
1119                 max_delay = ih3->code ?
1120                             IGMPV3_MRC(ih3->code) * (HZ / IGMP_TIMER_SCALE) : 1;
1121         }
1122
1123         if (!group)
1124                 goto out;
1125
1126         br_vlan_get_tag(skb, &vid);
1127         mp = br_mdb_ip4_get(mlock_dereference(br->mdb, br), group, vid);
1128         if (!mp)
1129                 goto out;
1130
1131         max_delay *= br->multicast_last_member_count;
1132
1133         if (mp->mglist &&
1134             (timer_pending(&mp->timer) ?
1135              time_after(mp->timer.expires, now + max_delay) :
1136              try_to_del_timer_sync(&mp->timer) >= 0))
1137                 mod_timer(&mp->timer, now + max_delay);
1138
1139         for (pp = &mp->ports;
1140              (p = mlock_dereference(*pp, br)) != NULL;
1141              pp = &p->next) {
1142                 if (timer_pending(&p->timer) ?
1143                     time_after(p->timer.expires, now + max_delay) :
1144                     try_to_del_timer_sync(&p->timer) >= 0)
1145                         mod_timer(&p->timer, now + max_delay);
1146         }
1147
1148 out:
1149         spin_unlock(&br->multicast_lock);
1150         return err;
1151 }
1152
1153 #if IS_ENABLED(CONFIG_IPV6)
1154 static int br_ip6_multicast_query(struct net_bridge *br,
1155                                   struct net_bridge_port *port,
1156                                   struct sk_buff *skb)
1157 {
1158         const struct ipv6hdr *ip6h = ipv6_hdr(skb);
1159         struct mld_msg *mld;
1160         struct net_bridge_mdb_entry *mp;
1161         struct mld2_query *mld2q;
1162         struct net_bridge_port_group *p;
1163         struct net_bridge_port_group __rcu **pp;
1164         unsigned long max_delay;
1165         unsigned long now = jiffies;
1166         const struct in6_addr *group = NULL;
1167         int err = 0;
1168         u16 vid = 0;
1169
1170         spin_lock(&br->multicast_lock);
1171         if (!netif_running(br->dev) ||
1172             (port && port->state == BR_STATE_DISABLED))
1173                 goto out;
1174
1175         br_multicast_query_received(br, port, !ipv6_addr_any(&ip6h->saddr));
1176
1177         if (skb->len == sizeof(*mld)) {
1178                 if (!pskb_may_pull(skb, sizeof(*mld))) {
1179                         err = -EINVAL;
1180                         goto out;
1181                 }
1182                 mld = (struct mld_msg *) icmp6_hdr(skb);
1183                 max_delay = msecs_to_jiffies(ntohs(mld->mld_maxdelay));
1184                 if (max_delay)
1185                         group = &mld->mld_mca;
1186         } else if (skb->len >= sizeof(*mld2q)) {
1187                 if (!pskb_may_pull(skb, sizeof(*mld2q))) {
1188                         err = -EINVAL;
1189                         goto out;
1190                 }
1191                 mld2q = (struct mld2_query *)icmp6_hdr(skb);
1192                 if (!mld2q->mld2q_nsrcs)
1193                         group = &mld2q->mld2q_mca;
1194                 max_delay = mld2q->mld2q_mrc ? MLDV2_MRC(ntohs(mld2q->mld2q_mrc)) : 1;
1195         }
1196
1197         if (!group)
1198                 goto out;
1199
1200         br_vlan_get_tag(skb, &vid);
1201         mp = br_mdb_ip6_get(mlock_dereference(br->mdb, br), group, vid);
1202         if (!mp)
1203                 goto out;
1204
1205         max_delay *= br->multicast_last_member_count;
1206         if (mp->mglist &&
1207             (timer_pending(&mp->timer) ?
1208              time_after(mp->timer.expires, now + max_delay) :
1209              try_to_del_timer_sync(&mp->timer) >= 0))
1210                 mod_timer(&mp->timer, now + max_delay);
1211
1212         for (pp = &mp->ports;
1213              (p = mlock_dereference(*pp, br)) != NULL;
1214              pp = &p->next) {
1215                 if (timer_pending(&p->timer) ?
1216                     time_after(p->timer.expires, now + max_delay) :
1217                     try_to_del_timer_sync(&p->timer) >= 0)
1218                         mod_timer(&p->timer, now + max_delay);
1219         }
1220
1221 out:
1222         spin_unlock(&br->multicast_lock);
1223         return err;
1224 }
1225 #endif
1226
1227 static void br_multicast_leave_group(struct net_bridge *br,
1228                                      struct net_bridge_port *port,
1229                                      struct br_ip *group)
1230 {
1231         struct net_bridge_mdb_htable *mdb;
1232         struct net_bridge_mdb_entry *mp;
1233         struct net_bridge_port_group *p;
1234         unsigned long now;
1235         unsigned long time;
1236
1237         spin_lock(&br->multicast_lock);
1238         if (!netif_running(br->dev) ||
1239             (port && port->state == BR_STATE_DISABLED) ||
1240             timer_pending(&br->multicast_querier_timer))
1241                 goto out;
1242
1243         mdb = mlock_dereference(br->mdb, br);
1244         mp = br_mdb_ip_get(mdb, group);
1245         if (!mp)
1246                 goto out;
1247
1248         if (port && (port->flags & BR_MULTICAST_FAST_LEAVE)) {
1249                 struct net_bridge_port_group __rcu **pp;
1250
1251                 for (pp = &mp->ports;
1252                      (p = mlock_dereference(*pp, br)) != NULL;
1253                      pp = &p->next) {
1254                         if (p->port != port)
1255                                 continue;
1256
1257                         rcu_assign_pointer(*pp, p->next);
1258                         hlist_del_init(&p->mglist);
1259                         del_timer(&p->timer);
1260                         call_rcu_bh(&p->rcu, br_multicast_free_pg);
1261                         br_mdb_notify(br->dev, port, group, RTM_DELMDB);
1262
1263                         if (!mp->ports && !mp->mglist &&
1264                             netif_running(br->dev))
1265                                 mod_timer(&mp->timer, jiffies);
1266                 }
1267                 goto out;
1268         }
1269
1270         now = jiffies;
1271         time = now + br->multicast_last_member_count *
1272                      br->multicast_last_member_interval;
1273
1274         if (!port) {
1275                 if (mp->mglist &&
1276                     (timer_pending(&mp->timer) ?
1277                      time_after(mp->timer.expires, time) :
1278                      try_to_del_timer_sync(&mp->timer) >= 0)) {
1279                         mod_timer(&mp->timer, time);
1280                 }
1281
1282                 goto out;
1283         }
1284
1285         for (p = mlock_dereference(mp->ports, br);
1286              p != NULL;
1287              p = mlock_dereference(p->next, br)) {
1288                 if (p->port != port)
1289                         continue;
1290
1291                 if (!hlist_unhashed(&p->mglist) &&
1292                     (timer_pending(&p->timer) ?
1293                      time_after(p->timer.expires, time) :
1294                      try_to_del_timer_sync(&p->timer) >= 0)) {
1295                         mod_timer(&p->timer, time);
1296                 }
1297
1298                 break;
1299         }
1300
1301 out:
1302         spin_unlock(&br->multicast_lock);
1303 }
1304
1305 static void br_ip4_multicast_leave_group(struct net_bridge *br,
1306                                          struct net_bridge_port *port,
1307                                          __be32 group,
1308                                          __u16 vid)
1309 {
1310         struct br_ip br_group;
1311
1312         if (ipv4_is_local_multicast(group))
1313                 return;
1314
1315         br_group.u.ip4 = group;
1316         br_group.proto = htons(ETH_P_IP);
1317         br_group.vid = vid;
1318
1319         br_multicast_leave_group(br, port, &br_group);
1320 }
1321
1322 #if IS_ENABLED(CONFIG_IPV6)
1323 static void br_ip6_multicast_leave_group(struct net_bridge *br,
1324                                          struct net_bridge_port *port,
1325                                          const struct in6_addr *group,
1326                                          __u16 vid)
1327 {
1328         struct br_ip br_group;
1329
1330         if (!ipv6_is_transient_multicast(group))
1331                 return;
1332
1333         br_group.u.ip6 = *group;
1334         br_group.proto = htons(ETH_P_IPV6);
1335         br_group.vid = vid;
1336
1337         br_multicast_leave_group(br, port, &br_group);
1338 }
1339 #endif
1340
1341 static int br_multicast_ipv4_rcv(struct net_bridge *br,
1342                                  struct net_bridge_port *port,
1343                                  struct sk_buff *skb)
1344 {
1345         struct sk_buff *skb2 = skb;
1346         const struct iphdr *iph;
1347         struct igmphdr *ih;
1348         unsigned int len;
1349         unsigned int offset;
1350         int err;
1351         u16 vid = 0;
1352
1353         /* We treat OOM as packet loss for now. */
1354         if (!pskb_may_pull(skb, sizeof(*iph)))
1355                 return -EINVAL;
1356
1357         iph = ip_hdr(skb);
1358
1359         if (iph->ihl < 5 || iph->version != 4)
1360                 return -EINVAL;
1361
1362         if (!pskb_may_pull(skb, ip_hdrlen(skb)))
1363                 return -EINVAL;
1364
1365         iph = ip_hdr(skb);
1366
1367         if (unlikely(ip_fast_csum((u8 *)iph, iph->ihl)))
1368                 return -EINVAL;
1369
1370         if (iph->protocol != IPPROTO_IGMP) {
1371                 if ((iph->daddr & IGMP_LOCAL_GROUP_MASK) != IGMP_LOCAL_GROUP)
1372                         BR_INPUT_SKB_CB(skb)->mrouters_only = 1;
1373                 return 0;
1374         }
1375
1376         len = ntohs(iph->tot_len);
1377         if (skb->len < len || len < ip_hdrlen(skb))
1378                 return -EINVAL;
1379
1380         if (skb->len > len) {
1381                 skb2 = skb_clone(skb, GFP_ATOMIC);
1382                 if (!skb2)
1383                         return -ENOMEM;
1384
1385                 err = pskb_trim_rcsum(skb2, len);
1386                 if (err)
1387                         goto err_out;
1388         }
1389
1390         len -= ip_hdrlen(skb2);
1391         offset = skb_network_offset(skb2) + ip_hdrlen(skb2);
1392         __skb_pull(skb2, offset);
1393         skb_reset_transport_header(skb2);
1394
1395         err = -EINVAL;
1396         if (!pskb_may_pull(skb2, sizeof(*ih)))
1397                 goto out;
1398
1399         switch (skb2->ip_summed) {
1400         case CHECKSUM_COMPLETE:
1401                 if (!csum_fold(skb2->csum))
1402                         break;
1403                 /* fall through */
1404         case CHECKSUM_NONE:
1405                 skb2->csum = 0;
1406                 if (skb_checksum_complete(skb2))
1407                         goto out;
1408         }
1409
1410         err = 0;
1411
1412         br_vlan_get_tag(skb2, &vid);
1413         BR_INPUT_SKB_CB(skb)->igmp = 1;
1414         ih = igmp_hdr(skb2);
1415
1416         switch (ih->type) {
1417         case IGMP_HOST_MEMBERSHIP_REPORT:
1418         case IGMPV2_HOST_MEMBERSHIP_REPORT:
1419                 BR_INPUT_SKB_CB(skb)->mrouters_only = 1;
1420                 err = br_ip4_multicast_add_group(br, port, ih->group, vid);
1421                 break;
1422         case IGMPV3_HOST_MEMBERSHIP_REPORT:
1423                 err = br_ip4_multicast_igmp3_report(br, port, skb2);
1424                 break;
1425         case IGMP_HOST_MEMBERSHIP_QUERY:
1426                 err = br_ip4_multicast_query(br, port, skb2);
1427                 break;
1428         case IGMP_HOST_LEAVE_MESSAGE:
1429                 br_ip4_multicast_leave_group(br, port, ih->group, vid);
1430                 break;
1431         }
1432
1433 out:
1434         __skb_push(skb2, offset);
1435 err_out:
1436         if (skb2 != skb)
1437                 kfree_skb(skb2);
1438         return err;
1439 }
1440
1441 #if IS_ENABLED(CONFIG_IPV6)
1442 static int br_multicast_ipv6_rcv(struct net_bridge *br,
1443                                  struct net_bridge_port *port,
1444                                  struct sk_buff *skb)
1445 {
1446         struct sk_buff *skb2;
1447         const struct ipv6hdr *ip6h;
1448         u8 icmp6_type;
1449         u8 nexthdr;
1450         __be16 frag_off;
1451         unsigned int len;
1452         int offset;
1453         int err;
1454         u16 vid = 0;
1455
1456         if (!pskb_may_pull(skb, sizeof(*ip6h)))
1457                 return -EINVAL;
1458
1459         ip6h = ipv6_hdr(skb);
1460
1461         /*
1462          * We're interested in MLD messages only.
1463          *  - Version is 6
1464          *  - MLD has always Router Alert hop-by-hop option
1465          *  - But we do not support jumbrograms.
1466          */
1467         if (ip6h->version != 6 ||
1468             ip6h->nexthdr != IPPROTO_HOPOPTS ||
1469             ip6h->payload_len == 0)
1470                 return 0;
1471
1472         len = ntohs(ip6h->payload_len) + sizeof(*ip6h);
1473         if (skb->len < len)
1474                 return -EINVAL;
1475
1476         nexthdr = ip6h->nexthdr;
1477         offset = ipv6_skip_exthdr(skb, sizeof(*ip6h), &nexthdr, &frag_off);
1478
1479         if (offset < 0 || nexthdr != IPPROTO_ICMPV6)
1480                 return 0;
1481
1482         /* Okay, we found ICMPv6 header */
1483         skb2 = skb_clone(skb, GFP_ATOMIC);
1484         if (!skb2)
1485                 return -ENOMEM;
1486
1487         err = -EINVAL;
1488         if (!pskb_may_pull(skb2, offset + sizeof(struct icmp6hdr)))
1489                 goto out;
1490
1491         len -= offset - skb_network_offset(skb2);
1492
1493         __skb_pull(skb2, offset);
1494         skb_reset_transport_header(skb2);
1495         skb_postpull_rcsum(skb2, skb_network_header(skb2),
1496                            skb_network_header_len(skb2));
1497
1498         icmp6_type = icmp6_hdr(skb2)->icmp6_type;
1499
1500         switch (icmp6_type) {
1501         case ICMPV6_MGM_QUERY:
1502         case ICMPV6_MGM_REPORT:
1503         case ICMPV6_MGM_REDUCTION:
1504         case ICMPV6_MLD2_REPORT:
1505                 break;
1506         default:
1507                 err = 0;
1508                 goto out;
1509         }
1510
1511         /* Okay, we found MLD message. Check further. */
1512         if (skb2->len > len) {
1513                 err = pskb_trim_rcsum(skb2, len);
1514                 if (err)
1515                         goto out;
1516                 err = -EINVAL;
1517         }
1518
1519         ip6h = ipv6_hdr(skb2);
1520
1521         switch (skb2->ip_summed) {
1522         case CHECKSUM_COMPLETE:
1523                 if (!csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, skb2->len,
1524                                         IPPROTO_ICMPV6, skb2->csum))
1525                         break;
1526                 /*FALLTHROUGH*/
1527         case CHECKSUM_NONE:
1528                 skb2->csum = ~csum_unfold(csum_ipv6_magic(&ip6h->saddr,
1529                                                         &ip6h->daddr,
1530                                                         skb2->len,
1531                                                         IPPROTO_ICMPV6, 0));
1532                 if (__skb_checksum_complete(skb2))
1533                         goto out;
1534         }
1535
1536         err = 0;
1537
1538         br_vlan_get_tag(skb, &vid);
1539         BR_INPUT_SKB_CB(skb)->igmp = 1;
1540
1541         switch (icmp6_type) {
1542         case ICMPV6_MGM_REPORT:
1543             {
1544                 struct mld_msg *mld;
1545                 if (!pskb_may_pull(skb2, sizeof(*mld))) {
1546                         err = -EINVAL;
1547                         goto out;
1548                 }
1549                 mld = (struct mld_msg *)skb_transport_header(skb2);
1550                 BR_INPUT_SKB_CB(skb)->mrouters_only = 1;
1551                 err = br_ip6_multicast_add_group(br, port, &mld->mld_mca, vid);
1552                 break;
1553             }
1554         case ICMPV6_MLD2_REPORT:
1555                 err = br_ip6_multicast_mld2_report(br, port, skb2);
1556                 break;
1557         case ICMPV6_MGM_QUERY:
1558                 err = br_ip6_multicast_query(br, port, skb2);
1559                 break;
1560         case ICMPV6_MGM_REDUCTION:
1561             {
1562                 struct mld_msg *mld;
1563                 if (!pskb_may_pull(skb2, sizeof(*mld))) {
1564                         err = -EINVAL;
1565                         goto out;
1566                 }
1567                 mld = (struct mld_msg *)skb_transport_header(skb2);
1568                 br_ip6_multicast_leave_group(br, port, &mld->mld_mca, vid);
1569             }
1570         }
1571
1572 out:
1573         kfree_skb(skb2);
1574         return err;
1575 }
1576 #endif
1577
1578 int br_multicast_rcv(struct net_bridge *br, struct net_bridge_port *port,
1579                      struct sk_buff *skb)
1580 {
1581         BR_INPUT_SKB_CB(skb)->igmp = 0;
1582         BR_INPUT_SKB_CB(skb)->mrouters_only = 0;
1583
1584         if (br->multicast_disabled)
1585                 return 0;
1586
1587         switch (skb->protocol) {
1588         case htons(ETH_P_IP):
1589                 return br_multicast_ipv4_rcv(br, port, skb);
1590 #if IS_ENABLED(CONFIG_IPV6)
1591         case htons(ETH_P_IPV6):
1592                 return br_multicast_ipv6_rcv(br, port, skb);
1593 #endif
1594         }
1595
1596         return 0;
1597 }
1598
1599 static void br_multicast_query_expired(unsigned long data)
1600 {
1601         struct net_bridge *br = (void *)data;
1602
1603         spin_lock(&br->multicast_lock);
1604         if (br->multicast_startup_queries_sent <
1605             br->multicast_startup_query_count)
1606                 br->multicast_startup_queries_sent++;
1607
1608         br_multicast_send_query(br, NULL, br->multicast_startup_queries_sent);
1609
1610         spin_unlock(&br->multicast_lock);
1611 }
1612
1613 void br_multicast_init(struct net_bridge *br)
1614 {
1615         br->hash_elasticity = 4;
1616         br->hash_max = 512;
1617
1618         br->multicast_router = 1;
1619         br->multicast_querier = 0;
1620         br->multicast_last_member_count = 2;
1621         br->multicast_startup_query_count = 2;
1622
1623         br->multicast_last_member_interval = HZ;
1624         br->multicast_query_response_interval = 10 * HZ;
1625         br->multicast_startup_query_interval = 125 * HZ / 4;
1626         br->multicast_query_interval = 125 * HZ;
1627         br->multicast_querier_interval = 255 * HZ;
1628         br->multicast_membership_interval = 260 * HZ;
1629
1630         spin_lock_init(&br->multicast_lock);
1631         setup_timer(&br->multicast_router_timer,
1632                     br_multicast_local_router_expired, 0);
1633         setup_timer(&br->multicast_querier_timer,
1634                     br_multicast_querier_expired, (unsigned long)br);
1635         setup_timer(&br->multicast_query_timer, br_multicast_query_expired,
1636                     (unsigned long)br);
1637 }
1638
1639 void br_multicast_open(struct net_bridge *br)
1640 {
1641         br->multicast_startup_queries_sent = 0;
1642
1643         if (br->multicast_disabled)
1644                 return;
1645
1646         mod_timer(&br->multicast_query_timer, jiffies);
1647 }
1648
1649 void br_multicast_stop(struct net_bridge *br)
1650 {
1651         struct net_bridge_mdb_htable *mdb;
1652         struct net_bridge_mdb_entry *mp;
1653         struct hlist_node *n;
1654         u32 ver;
1655         int i;
1656
1657         del_timer_sync(&br->multicast_router_timer);
1658         del_timer_sync(&br->multicast_querier_timer);
1659         del_timer_sync(&br->multicast_query_timer);
1660
1661         spin_lock_bh(&br->multicast_lock);
1662         mdb = mlock_dereference(br->mdb, br);
1663         if (!mdb)
1664                 goto out;
1665
1666         br->mdb = NULL;
1667
1668         ver = mdb->ver;
1669         for (i = 0; i < mdb->max; i++) {
1670                 hlist_for_each_entry_safe(mp, n, &mdb->mhash[i],
1671                                           hlist[ver]) {
1672                         del_timer(&mp->timer);
1673                         call_rcu_bh(&mp->rcu, br_multicast_free_group);
1674                 }
1675         }
1676
1677         if (mdb->old) {
1678                 spin_unlock_bh(&br->multicast_lock);
1679                 rcu_barrier_bh();
1680                 spin_lock_bh(&br->multicast_lock);
1681                 WARN_ON(mdb->old);
1682         }
1683
1684         mdb->old = mdb;
1685         call_rcu_bh(&mdb->rcu, br_mdb_free);
1686
1687 out:
1688         spin_unlock_bh(&br->multicast_lock);
1689 }
1690
1691 int br_multicast_set_router(struct net_bridge *br, unsigned long val)
1692 {
1693         int err = -ENOENT;
1694
1695         spin_lock_bh(&br->multicast_lock);
1696         if (!netif_running(br->dev))
1697                 goto unlock;
1698
1699         switch (val) {
1700         case 0:
1701         case 2:
1702                 del_timer(&br->multicast_router_timer);
1703                 /* fall through */
1704         case 1:
1705                 br->multicast_router = val;
1706                 err = 0;
1707                 break;
1708
1709         default:
1710                 err = -EINVAL;
1711                 break;
1712         }
1713
1714 unlock:
1715         spin_unlock_bh(&br->multicast_lock);
1716
1717         return err;
1718 }
1719
1720 int br_multicast_set_port_router(struct net_bridge_port *p, unsigned long val)
1721 {
1722         struct net_bridge *br = p->br;
1723         int err = -ENOENT;
1724
1725         spin_lock(&br->multicast_lock);
1726         if (!netif_running(br->dev) || p->state == BR_STATE_DISABLED)
1727                 goto unlock;
1728
1729         switch (val) {
1730         case 0:
1731         case 1:
1732         case 2:
1733                 p->multicast_router = val;
1734                 err = 0;
1735
1736                 if (val < 2 && !hlist_unhashed(&p->rlist))
1737                         hlist_del_init_rcu(&p->rlist);
1738
1739                 if (val == 1)
1740                         break;
1741
1742                 del_timer(&p->multicast_router_timer);
1743
1744                 if (val == 0)
1745                         break;
1746
1747                 br_multicast_add_router(br, p);
1748                 break;
1749
1750         default:
1751                 err = -EINVAL;
1752                 break;
1753         }
1754
1755 unlock:
1756         spin_unlock(&br->multicast_lock);
1757
1758         return err;
1759 }
1760
1761 static void br_multicast_start_querier(struct net_bridge *br)
1762 {
1763         struct net_bridge_port *port;
1764
1765         br_multicast_open(br);
1766
1767         list_for_each_entry(port, &br->port_list, list) {
1768                 if (port->state == BR_STATE_DISABLED ||
1769                     port->state == BR_STATE_BLOCKING)
1770                         continue;
1771
1772                 __br_multicast_enable_port(port);
1773         }
1774 }
1775
1776 int br_multicast_toggle(struct net_bridge *br, unsigned long val)
1777 {
1778         int err = 0;
1779         struct net_bridge_mdb_htable *mdb;
1780
1781         spin_lock_bh(&br->multicast_lock);
1782         if (br->multicast_disabled == !val)
1783                 goto unlock;
1784
1785         br->multicast_disabled = !val;
1786         if (br->multicast_disabled)
1787                 goto unlock;
1788
1789         if (!netif_running(br->dev))
1790                 goto unlock;
1791
1792         mdb = mlock_dereference(br->mdb, br);
1793         if (mdb) {
1794                 if (mdb->old) {
1795                         err = -EEXIST;
1796 rollback:
1797                         br->multicast_disabled = !!val;
1798                         goto unlock;
1799                 }
1800
1801                 err = br_mdb_rehash(&br->mdb, mdb->max,
1802                                     br->hash_elasticity);
1803                 if (err)
1804                         goto rollback;
1805         }
1806
1807         br_multicast_start_querier(br);
1808
1809 unlock:
1810         spin_unlock_bh(&br->multicast_lock);
1811
1812         return err;
1813 }
1814
1815 int br_multicast_set_querier(struct net_bridge *br, unsigned long val)
1816 {
1817         val = !!val;
1818
1819         spin_lock_bh(&br->multicast_lock);
1820         if (br->multicast_querier == val)
1821                 goto unlock;
1822
1823         br->multicast_querier = val;
1824         if (val)
1825                 br_multicast_start_querier(br);
1826
1827 unlock:
1828         spin_unlock_bh(&br->multicast_lock);
1829
1830         return 0;
1831 }
1832
1833 int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val)
1834 {
1835         int err = -ENOENT;
1836         u32 old;
1837         struct net_bridge_mdb_htable *mdb;
1838
1839         spin_lock(&br->multicast_lock);
1840         if (!netif_running(br->dev))
1841                 goto unlock;
1842
1843         err = -EINVAL;
1844         if (!is_power_of_2(val))
1845                 goto unlock;
1846
1847         mdb = mlock_dereference(br->mdb, br);
1848         if (mdb && val < mdb->size)
1849                 goto unlock;
1850
1851         err = 0;
1852
1853         old = br->hash_max;
1854         br->hash_max = val;
1855
1856         if (mdb) {
1857                 if (mdb->old) {
1858                         err = -EEXIST;
1859 rollback:
1860                         br->hash_max = old;
1861                         goto unlock;
1862                 }
1863
1864                 err = br_mdb_rehash(&br->mdb, br->hash_max,
1865                                     br->hash_elasticity);
1866                 if (err)
1867                         goto rollback;
1868         }
1869
1870 unlock:
1871         spin_unlock(&br->multicast_lock);
1872
1873         return err;
1874 }