]> git.kernelconcepts.de Git - karo-tx-linux.git/blob - net/mpls/af_mpls.c
mpls: Properly validate RTA_VIA payload length
[karo-tx-linux.git] / net / mpls / af_mpls.c
1 #include <linux/types.h>
2 #include <linux/skbuff.h>
3 #include <linux/socket.h>
4 #include <linux/sysctl.h>
5 #include <linux/net.h>
6 #include <linux/module.h>
7 #include <linux/if_arp.h>
8 #include <linux/ipv6.h>
9 #include <linux/mpls.h>
10 #include <linux/vmalloc.h>
11 #include <net/ip.h>
12 #include <net/dst.h>
13 #include <net/sock.h>
14 #include <net/arp.h>
15 #include <net/ip_fib.h>
16 #include <net/netevent.h>
17 #include <net/netns/generic.h>
18 #include "internal.h"
19
20 #define LABEL_NOT_SPECIFIED (1<<20)
21 #define MAX_NEW_LABELS 2
22
23 /* This maximum ha length copied from the definition of struct neighbour */
24 #define MAX_VIA_ALEN (ALIGN(MAX_ADDR_LEN, sizeof(unsigned long)))
25
26 struct mpls_route { /* next hop label forwarding entry */
27         struct net_device       *rt_dev;
28         struct rcu_head         rt_rcu;
29         u32                     rt_label[MAX_NEW_LABELS];
30         u8                      rt_protocol; /* routing protocol that set this entry */
31         u8                      rt_labels:2,
32                                 rt_via_alen:6;
33         unsigned short          rt_via_family;
34         u8                      rt_via[0];
35 };
36
37 static int zero = 0;
38 static int label_limit = (1 << 20) - 1;
39
40 static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
41                        struct nlmsghdr *nlh, struct net *net, u32 portid,
42                        unsigned int nlm_flags);
43
44 static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index)
45 {
46         struct mpls_route *rt = NULL;
47
48         if (index < net->mpls.platform_labels) {
49                 struct mpls_route __rcu **platform_label =
50                         rcu_dereference(net->mpls.platform_label);
51                 rt = rcu_dereference(platform_label[index]);
52         }
53         return rt;
54 }
55
56 static bool mpls_output_possible(const struct net_device *dev)
57 {
58         return dev && (dev->flags & IFF_UP) && netif_carrier_ok(dev);
59 }
60
61 static unsigned int mpls_rt_header_size(const struct mpls_route *rt)
62 {
63         /* The size of the layer 2.5 labels to be added for this route */
64         return rt->rt_labels * sizeof(struct mpls_shim_hdr);
65 }
66
67 static unsigned int mpls_dev_mtu(const struct net_device *dev)
68 {
69         /* The amount of data the layer 2 frame can hold */
70         return dev->mtu;
71 }
72
73 static bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
74 {
75         if (skb->len <= mtu)
76                 return false;
77
78         if (skb_is_gso(skb) && skb_gso_network_seglen(skb) <= mtu)
79                 return false;
80
81         return true;
82 }
83
84 static bool mpls_egress(struct mpls_route *rt, struct sk_buff *skb,
85                         struct mpls_entry_decoded dec)
86 {
87         /* RFC4385 and RFC5586 encode other packets in mpls such that
88          * they don't conflict with the ip version number, making
89          * decoding by examining the ip version correct in everything
90          * except for the strangest cases.
91          *
92          * The strange cases if we choose to support them will require
93          * manual configuration.
94          */
95         struct iphdr *hdr4 = ip_hdr(skb);
96         bool success = true;
97
98         if (hdr4->version == 4) {
99                 skb->protocol = htons(ETH_P_IP);
100                 csum_replace2(&hdr4->check,
101                               htons(hdr4->ttl << 8),
102                               htons(dec.ttl << 8));
103                 hdr4->ttl = dec.ttl;
104         }
105         else if (hdr4->version == 6) {
106                 struct ipv6hdr *hdr6 = ipv6_hdr(skb);
107                 skb->protocol = htons(ETH_P_IPV6);
108                 hdr6->hop_limit = dec.ttl;
109         }
110         else
111                 /* version 0 and version 1 are used by pseudo wires */
112                 success = false;
113         return success;
114 }
115
116 static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
117                         struct packet_type *pt, struct net_device *orig_dev)
118 {
119         struct net *net = dev_net(dev);
120         struct mpls_shim_hdr *hdr;
121         struct mpls_route *rt;
122         struct mpls_entry_decoded dec;
123         struct net_device *out_dev;
124         unsigned int hh_len;
125         unsigned int new_header_size;
126         unsigned int mtu;
127         int err;
128
129         /* Careful this entire function runs inside of an rcu critical section */
130
131         if (skb->pkt_type != PACKET_HOST)
132                 goto drop;
133
134         if ((skb = skb_share_check(skb, GFP_ATOMIC)) == NULL)
135                 goto drop;
136
137         if (!pskb_may_pull(skb, sizeof(*hdr)))
138                 goto drop;
139
140         /* Read and decode the label */
141         hdr = mpls_hdr(skb);
142         dec = mpls_entry_decode(hdr);
143
144         /* Pop the label */
145         skb_pull(skb, sizeof(*hdr));
146         skb_reset_network_header(skb);
147
148         skb_orphan(skb);
149
150         rt = mpls_route_input_rcu(net, dec.label);
151         if (!rt)
152                 goto drop;
153
154         /* Find the output device */
155         out_dev = rt->rt_dev;
156         if (!mpls_output_possible(out_dev))
157                 goto drop;
158
159         if (skb_warn_if_lro(skb))
160                 goto drop;
161
162         skb_forward_csum(skb);
163
164         /* Verify ttl is valid */
165         if (dec.ttl <= 2)
166                 goto drop;
167         dec.ttl -= 1;
168
169         /* Verify the destination can hold the packet */
170         new_header_size = mpls_rt_header_size(rt);
171         mtu = mpls_dev_mtu(out_dev);
172         if (mpls_pkt_too_big(skb, mtu - new_header_size))
173                 goto drop;
174
175         hh_len = LL_RESERVED_SPACE(out_dev);
176         if (!out_dev->header_ops)
177                 hh_len = 0;
178
179         /* Ensure there is enough space for the headers in the skb */
180         if (skb_cow(skb, hh_len + new_header_size))
181                 goto drop;
182
183         skb->dev = out_dev;
184         skb->protocol = htons(ETH_P_MPLS_UC);
185
186         if (unlikely(!new_header_size && dec.bos)) {
187                 /* Penultimate hop popping */
188                 if (!mpls_egress(rt, skb, dec))
189                         goto drop;
190         } else {
191                 bool bos;
192                 int i;
193                 skb_push(skb, new_header_size);
194                 skb_reset_network_header(skb);
195                 /* Push the new labels */
196                 hdr = mpls_hdr(skb);
197                 bos = dec.bos;
198                 for (i = rt->rt_labels - 1; i >= 0; i--) {
199                         hdr[i] = mpls_entry_encode(rt->rt_label[i], dec.ttl, 0, bos);
200                         bos = false;
201                 }
202         }
203
204         err = neigh_xmit(rt->rt_via_family, out_dev, rt->rt_via, skb);
205         if (err)
206                 net_dbg_ratelimited("%s: packet transmission failed: %d\n",
207                                     __func__, err);
208         return 0;
209
210 drop:
211         kfree_skb(skb);
212         return NET_RX_DROP;
213 }
214
215 static struct packet_type mpls_packet_type __read_mostly = {
216         .type = cpu_to_be16(ETH_P_MPLS_UC),
217         .func = mpls_forward,
218 };
219
220 static const struct nla_policy rtm_mpls_policy[RTA_MAX+1] = {
221         [RTA_DST]               = { .type = NLA_U32 },
222         [RTA_OIF]               = { .type = NLA_U32 },
223 };
224
225 struct mpls_route_config {
226         u32             rc_protocol;
227         u32             rc_ifindex;
228         u16             rc_via_family;
229         u16             rc_via_alen;
230         u8              rc_via[MAX_VIA_ALEN];
231         u32             rc_label;
232         u32             rc_output_labels;
233         u32             rc_output_label[MAX_NEW_LABELS];
234         u32             rc_nlflags;
235         struct nl_info  rc_nlinfo;
236 };
237
238 static struct mpls_route *mpls_rt_alloc(size_t alen)
239 {
240         struct mpls_route *rt;
241
242         rt = kzalloc(GFP_KERNEL, sizeof(*rt) + alen);
243         if (rt)
244                 rt->rt_via_alen = alen;
245         return rt;
246 }
247
248 static void mpls_rt_free(struct mpls_route *rt)
249 {
250         if (rt)
251                 kfree_rcu(rt, rt_rcu);
252 }
253
254 static void mpls_notify_route(struct net *net, unsigned index,
255                               struct mpls_route *old, struct mpls_route *new,
256                               const struct nl_info *info)
257 {
258         struct nlmsghdr *nlh = info ? info->nlh : NULL;
259         unsigned portid = info ? info->portid : 0;
260         int event = new ? RTM_NEWROUTE : RTM_DELROUTE;
261         struct mpls_route *rt = new ? new : old;
262         unsigned nlm_flags = (old && new) ? NLM_F_REPLACE : 0;
263         /* Ignore reserved labels for now */
264         if (rt && (index >= 16))
265                 rtmsg_lfib(event, index, rt, nlh, net, portid, nlm_flags);
266 }
267
268 static void mpls_route_update(struct net *net, unsigned index,
269                               struct net_device *dev, struct mpls_route *new,
270                               const struct nl_info *info)
271 {
272         struct mpls_route *rt, *old = NULL;
273
274         ASSERT_RTNL();
275
276         rt = net->mpls.platform_label[index];
277         if (!dev || (rt && (rt->rt_dev == dev))) {
278                 rcu_assign_pointer(net->mpls.platform_label[index], new);
279                 old = rt;
280         }
281
282         mpls_notify_route(net, index, old, new, info);
283
284         /* If we removed a route free it now */
285         mpls_rt_free(old);
286 }
287
288 static unsigned find_free_label(struct net *net)
289 {
290         unsigned index;
291         for (index = 16; index < net->mpls.platform_labels; index++) {
292                 if (!net->mpls.platform_label[index])
293                         return index;
294         }
295         return LABEL_NOT_SPECIFIED;
296 }
297
298 static int mpls_route_add(struct mpls_route_config *cfg)
299 {
300         struct net *net = cfg->rc_nlinfo.nl_net;
301         struct net_device *dev = NULL;
302         struct mpls_route *rt, *old;
303         unsigned index;
304         int i;
305         int err = -EINVAL;
306
307         index = cfg->rc_label;
308
309         /* If a label was not specified during insert pick one */
310         if ((index == LABEL_NOT_SPECIFIED) &&
311             (cfg->rc_nlflags & NLM_F_CREATE)) {
312                 index = find_free_label(net);
313         }
314
315         /* The first 16 labels are reserved, and may not be set */
316         if (index < 16)
317                 goto errout;
318
319         /* The full 20 bit range may not be supported. */
320         if (index >= net->mpls.platform_labels)
321                 goto errout;
322
323         /* Ensure only a supported number of labels are present */
324         if (cfg->rc_output_labels > MAX_NEW_LABELS)
325                 goto errout;
326
327         err = -ENODEV;
328         dev = dev_get_by_index(net, cfg->rc_ifindex);
329         if (!dev)
330                 goto errout;
331
332         /* For now just support ethernet devices */
333         err = -EINVAL;
334         if ((dev->type != ARPHRD_ETHER) && (dev->type != ARPHRD_LOOPBACK))
335                 goto errout;
336
337         err = -EINVAL;
338         if ((cfg->rc_via_family == AF_PACKET) &&
339             (dev->addr_len != cfg->rc_via_alen))
340                 goto errout;
341
342         /* Append makes no sense with mpls */
343         err = -EINVAL;
344         if (cfg->rc_nlflags & NLM_F_APPEND)
345                 goto errout;
346
347         err = -EEXIST;
348         old = net->mpls.platform_label[index];
349         if ((cfg->rc_nlflags & NLM_F_EXCL) && old)
350                 goto errout;
351
352         err = -EEXIST;
353         if (!(cfg->rc_nlflags & NLM_F_REPLACE) && old)
354                 goto errout;
355
356         err = -ENOENT;
357         if (!(cfg->rc_nlflags & NLM_F_CREATE) && !old)
358                 goto errout;
359
360         err = -ENOMEM;
361         rt = mpls_rt_alloc(cfg->rc_via_alen);
362         if (!rt)
363                 goto errout;
364
365         rt->rt_labels = cfg->rc_output_labels;
366         for (i = 0; i < rt->rt_labels; i++)
367                 rt->rt_label[i] = cfg->rc_output_label[i];
368         rt->rt_protocol = cfg->rc_protocol;
369         rt->rt_dev = dev;
370         rt->rt_via_family = cfg->rc_via_family;
371         memcpy(rt->rt_via, cfg->rc_via, cfg->rc_via_alen);
372
373         mpls_route_update(net, index, NULL, rt, &cfg->rc_nlinfo);
374
375         dev_put(dev);
376         return 0;
377
378 errout:
379         if (dev)
380                 dev_put(dev);
381         return err;
382 }
383
384 static int mpls_route_del(struct mpls_route_config *cfg)
385 {
386         struct net *net = cfg->rc_nlinfo.nl_net;
387         unsigned index;
388         int err = -EINVAL;
389
390         index = cfg->rc_label;
391
392         /* The first 16 labels are reserved, and may not be removed */
393         if (index < 16)
394                 goto errout;
395
396         /* The full 20 bit range may not be supported */
397         if (index >= net->mpls.platform_labels)
398                 goto errout;
399
400         mpls_route_update(net, index, NULL, NULL, &cfg->rc_nlinfo);
401
402         err = 0;
403 errout:
404         return err;
405 }
406
407 static void mpls_ifdown(struct net_device *dev)
408 {
409         struct net *net = dev_net(dev);
410         unsigned index;
411
412         for (index = 0; index < net->mpls.platform_labels; index++) {
413                 struct mpls_route *rt = net->mpls.platform_label[index];
414                 if (!rt)
415                         continue;
416                 if (rt->rt_dev != dev)
417                         continue;
418                 rt->rt_dev = NULL;
419         }
420 }
421
422 static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
423                            void *ptr)
424 {
425         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
426
427         switch(event) {
428         case NETDEV_UNREGISTER:
429                 mpls_ifdown(dev);
430                 break;
431         }
432         return NOTIFY_OK;
433 }
434
435 static struct notifier_block mpls_dev_notifier = {
436         .notifier_call = mpls_dev_notify,
437 };
438
439 static int nla_put_via(struct sk_buff *skb,
440                        u16 family, const void *addr, int alen)
441 {
442         struct nlattr *nla;
443         struct rtvia *via;
444
445         nla = nla_reserve(skb, RTA_VIA, alen + 2);
446         if (!nla)
447                 return -EMSGSIZE;
448
449         via = nla_data(nla);
450         via->rtvia_family = family;
451         memcpy(via->rtvia_addr, addr, alen);
452         return 0;
453 }
454
455 int nla_put_labels(struct sk_buff *skb, int attrtype,
456                    u8 labels, const u32 label[])
457 {
458         struct nlattr *nla;
459         struct mpls_shim_hdr *nla_label;
460         bool bos;
461         int i;
462         nla = nla_reserve(skb, attrtype, labels*4);
463         if (!nla)
464                 return -EMSGSIZE;
465
466         nla_label = nla_data(nla);
467         bos = true;
468         for (i = labels - 1; i >= 0; i--) {
469                 nla_label[i] = mpls_entry_encode(label[i], 0, 0, bos);
470                 bos = false;
471         }
472
473         return 0;
474 }
475
476 int nla_get_labels(const struct nlattr *nla,
477                    u32 max_labels, u32 *labels, u32 label[])
478 {
479         unsigned len = nla_len(nla);
480         unsigned nla_labels;
481         struct mpls_shim_hdr *nla_label;
482         bool bos;
483         int i;
484
485         /* len needs to be an even multiple of 4 (the label size) */
486         if (len & 3)
487                 return -EINVAL;
488
489         /* Limit the number of new labels allowed */
490         nla_labels = len/4;
491         if (nla_labels > max_labels)
492                 return -EINVAL;
493
494         nla_label = nla_data(nla);
495         bos = true;
496         for (i = nla_labels - 1; i >= 0; i--, bos = false) {
497                 struct mpls_entry_decoded dec;
498                 dec = mpls_entry_decode(nla_label + i);
499
500                 /* Ensure the bottom of stack flag is properly set
501                  * and ttl and tc are both clear.
502                  */
503                 if ((dec.bos != bos) || dec.ttl || dec.tc)
504                         return -EINVAL;
505
506                 label[i] = dec.label;
507         }
508         *labels = nla_labels;
509         return 0;
510 }
511
512 static int rtm_to_route_config(struct sk_buff *skb,  struct nlmsghdr *nlh,
513                                struct mpls_route_config *cfg)
514 {
515         struct rtmsg *rtm;
516         struct nlattr *tb[RTA_MAX+1];
517         int index;
518         int err;
519
520         err = nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, rtm_mpls_policy);
521         if (err < 0)
522                 goto errout;
523
524         err = -EINVAL;
525         rtm = nlmsg_data(nlh);
526         memset(cfg, 0, sizeof(*cfg));
527
528         if (rtm->rtm_family != AF_MPLS)
529                 goto errout;
530         if (rtm->rtm_dst_len != 20)
531                 goto errout;
532         if (rtm->rtm_src_len != 0)
533                 goto errout;
534         if (rtm->rtm_tos != 0)
535                 goto errout;
536         if (rtm->rtm_table != RT_TABLE_MAIN)
537                 goto errout;
538         /* Any value is acceptable for rtm_protocol */
539
540         /* As mpls uses destination specific addresses
541          * (or source specific address in the case of multicast)
542          * all addresses have universal scope.
543          */
544         if (rtm->rtm_scope != RT_SCOPE_UNIVERSE)
545                 goto errout;
546         if (rtm->rtm_type != RTN_UNICAST)
547                 goto errout;
548         if (rtm->rtm_flags != 0)
549                 goto errout;
550
551         cfg->rc_label           = LABEL_NOT_SPECIFIED;
552         cfg->rc_protocol        = rtm->rtm_protocol;
553         cfg->rc_nlflags         = nlh->nlmsg_flags;
554         cfg->rc_nlinfo.portid   = NETLINK_CB(skb).portid;
555         cfg->rc_nlinfo.nlh      = nlh;
556         cfg->rc_nlinfo.nl_net   = sock_net(skb->sk);
557
558         for (index = 0; index <= RTA_MAX; index++) {
559                 struct nlattr *nla = tb[index];
560                 if (!nla)
561                         continue;
562
563                 switch(index) {
564                 case RTA_OIF:
565                         cfg->rc_ifindex = nla_get_u32(nla);
566                         break;
567                 case RTA_NEWDST:
568                         if (nla_get_labels(nla, MAX_NEW_LABELS,
569                                            &cfg->rc_output_labels,
570                                            cfg->rc_output_label))
571                                 goto errout;
572                         break;
573                 case RTA_DST:
574                 {
575                         u32 label_count;
576                         if (nla_get_labels(nla, 1, &label_count,
577                                            &cfg->rc_label))
578                                 goto errout;
579
580                         /* The first 16 labels are reserved, and may not be set */
581                         if (cfg->rc_label < 16)
582                                 goto errout;
583
584                         break;
585                 }
586                 case RTA_VIA:
587                 {
588                         struct rtvia *via = nla_data(nla);
589                         if (nla_len(nla) < offsetof(struct rtvia, rtvia_addr))
590                                 goto errout;
591                         cfg->rc_via_family = via->rtvia_family;
592                         cfg->rc_via_alen   = nla_len(nla) -
593                                 offsetof(struct rtvia, rtvia_addr);
594                         if (cfg->rc_via_alen > MAX_VIA_ALEN)
595                                 goto errout;
596
597                         /* Validate the address family */
598                         switch(cfg->rc_via_family) {
599                         case AF_PACKET:
600                                 break;
601                         case AF_INET:
602                                 if (cfg->rc_via_alen != 4)
603                                         goto errout;
604                                 break;
605                         case AF_INET6:
606                                 if (cfg->rc_via_alen != 16)
607                                         goto errout;
608                                 break;
609                         default:
610                                 /* Unsupported address family */
611                                 goto errout;
612                         }
613
614                         memcpy(cfg->rc_via, via->rtvia_addr, cfg->rc_via_alen);
615                         break;
616                 }
617                 default:
618                         /* Unsupported attribute */
619                         goto errout;
620                 }
621         }
622
623         err = 0;
624 errout:
625         return err;
626 }
627
628 static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh)
629 {
630         struct mpls_route_config cfg;
631         int err;
632
633         err = rtm_to_route_config(skb, nlh, &cfg);
634         if (err < 0)
635                 return err;
636
637         return mpls_route_del(&cfg);
638 }
639
640
641 static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh)
642 {
643         struct mpls_route_config cfg;
644         int err;
645
646         err = rtm_to_route_config(skb, nlh, &cfg);
647         if (err < 0)
648                 return err;
649
650         return mpls_route_add(&cfg);
651 }
652
653 static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event,
654                            u32 label, struct mpls_route *rt, int flags)
655 {
656         struct nlmsghdr *nlh;
657         struct rtmsg *rtm;
658
659         nlh = nlmsg_put(skb, portid, seq, event, sizeof(*rtm), flags);
660         if (nlh == NULL)
661                 return -EMSGSIZE;
662
663         rtm = nlmsg_data(nlh);
664         rtm->rtm_family = AF_MPLS;
665         rtm->rtm_dst_len = 20;
666         rtm->rtm_src_len = 0;
667         rtm->rtm_tos = 0;
668         rtm->rtm_table = RT_TABLE_MAIN;
669         rtm->rtm_protocol = rt->rt_protocol;
670         rtm->rtm_scope = RT_SCOPE_UNIVERSE;
671         rtm->rtm_type = RTN_UNICAST;
672         rtm->rtm_flags = 0;
673
674         if (rt->rt_labels &&
675             nla_put_labels(skb, RTA_NEWDST, rt->rt_labels, rt->rt_label))
676                 goto nla_put_failure;
677         if (nla_put_via(skb, rt->rt_via_family, rt->rt_via, rt->rt_via_alen))
678                 goto nla_put_failure;
679         if (rt->rt_dev && nla_put_u32(skb, RTA_OIF, rt->rt_dev->ifindex))
680                 goto nla_put_failure;
681         if (nla_put_labels(skb, RTA_DST, 1, &label))
682                 goto nla_put_failure;
683
684         nlmsg_end(skb, nlh);
685         return 0;
686
687 nla_put_failure:
688         nlmsg_cancel(skb, nlh);
689         return -EMSGSIZE;
690 }
691
692 static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
693 {
694         struct net *net = sock_net(skb->sk);
695         unsigned int index;
696
697         ASSERT_RTNL();
698
699         index = cb->args[0];
700         if (index < 16)
701                 index = 16;
702
703         for (; index < net->mpls.platform_labels; index++) {
704                 struct mpls_route *rt;
705                 rt = net->mpls.platform_label[index];
706                 if (!rt)
707                         continue;
708
709                 if (mpls_dump_route(skb, NETLINK_CB(cb->skb).portid,
710                                     cb->nlh->nlmsg_seq, RTM_NEWROUTE,
711                                     index, rt, NLM_F_MULTI) < 0)
712                         break;
713         }
714         cb->args[0] = index;
715
716         return skb->len;
717 }
718
719 static inline size_t lfib_nlmsg_size(struct mpls_route *rt)
720 {
721         size_t payload =
722                 NLMSG_ALIGN(sizeof(struct rtmsg))
723                 + nla_total_size(2 + rt->rt_via_alen)   /* RTA_VIA */
724                 + nla_total_size(4);                    /* RTA_DST */
725         if (rt->rt_labels)                              /* RTA_NEWDST */
726                 payload += nla_total_size(rt->rt_labels * 4);
727         if (rt->rt_dev)                                 /* RTA_OIF */
728                 payload += nla_total_size(4);
729         return payload;
730 }
731
732 static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
733                        struct nlmsghdr *nlh, struct net *net, u32 portid,
734                        unsigned int nlm_flags)
735 {
736         struct sk_buff *skb;
737         u32 seq = nlh ? nlh->nlmsg_seq : 0;
738         int err = -ENOBUFS;
739
740         skb = nlmsg_new(lfib_nlmsg_size(rt), GFP_KERNEL);
741         if (skb == NULL)
742                 goto errout;
743
744         err = mpls_dump_route(skb, portid, seq, event, label, rt, nlm_flags);
745         if (err < 0) {
746                 /* -EMSGSIZE implies BUG in lfib_nlmsg_size */
747                 WARN_ON(err == -EMSGSIZE);
748                 kfree_skb(skb);
749                 goto errout;
750         }
751         rtnl_notify(skb, net, portid, RTNLGRP_MPLS_ROUTE, nlh, GFP_KERNEL);
752
753         return;
754 errout:
755         if (err < 0)
756                 rtnl_set_sk_err(net, RTNLGRP_MPLS_ROUTE, err);
757 }
758
759 static int resize_platform_label_table(struct net *net, size_t limit)
760 {
761         size_t size = sizeof(struct mpls_route *) * limit;
762         size_t old_limit;
763         size_t cp_size;
764         struct mpls_route __rcu **labels = NULL, **old;
765         struct mpls_route *rt0 = NULL, *rt2 = NULL;
766         unsigned index;
767
768         if (size) {
769                 labels = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
770                 if (!labels)
771                         labels = vzalloc(size);
772
773                 if (!labels)
774                         goto nolabels;
775         }
776
777         /* In case the predefined labels need to be populated */
778         if (limit > LABEL_IPV4_EXPLICIT_NULL) {
779                 struct net_device *lo = net->loopback_dev;
780                 rt0 = mpls_rt_alloc(lo->addr_len);
781                 if (!rt0)
782                         goto nort0;
783                 rt0->rt_dev = lo;
784                 rt0->rt_protocol = RTPROT_KERNEL;
785                 rt0->rt_via_family = AF_PACKET;
786                 memcpy(rt0->rt_via, lo->dev_addr, lo->addr_len);
787         }
788         if (limit > LABEL_IPV6_EXPLICIT_NULL) {
789                 struct net_device *lo = net->loopback_dev;
790                 rt2 = mpls_rt_alloc(lo->addr_len);
791                 if (!rt2)
792                         goto nort2;
793                 rt2->rt_dev = lo;
794                 rt2->rt_protocol = RTPROT_KERNEL;
795                 rt2->rt_via_family = AF_PACKET;
796                 memcpy(rt2->rt_via, lo->dev_addr, lo->addr_len);
797         }
798
799         rtnl_lock();
800         /* Remember the original table */
801         old = net->mpls.platform_label;
802         old_limit = net->mpls.platform_labels;
803
804         /* Free any labels beyond the new table */
805         for (index = limit; index < old_limit; index++)
806                 mpls_route_update(net, index, NULL, NULL, NULL);
807
808         /* Copy over the old labels */
809         cp_size = size;
810         if (old_limit < limit)
811                 cp_size = old_limit * sizeof(struct mpls_route *);
812
813         memcpy(labels, old, cp_size);
814
815         /* If needed set the predefined labels */
816         if ((old_limit <= LABEL_IPV6_EXPLICIT_NULL) &&
817             (limit > LABEL_IPV6_EXPLICIT_NULL)) {
818                 labels[LABEL_IPV6_EXPLICIT_NULL] = rt2;
819                 rt2 = NULL;
820         }
821
822         if ((old_limit <= LABEL_IPV4_EXPLICIT_NULL) &&
823             (limit > LABEL_IPV4_EXPLICIT_NULL)) {
824                 labels[LABEL_IPV4_EXPLICIT_NULL] = rt0;
825                 rt0 = NULL;
826         }
827
828         /* Update the global pointers */
829         net->mpls.platform_labels = limit;
830         net->mpls.platform_label = labels;
831
832         rtnl_unlock();
833
834         mpls_rt_free(rt2);
835         mpls_rt_free(rt0);
836
837         if (old) {
838                 synchronize_rcu();
839                 kvfree(old);
840         }
841         return 0;
842
843 nort2:
844         mpls_rt_free(rt0);
845 nort0:
846         kvfree(labels);
847 nolabels:
848         return -ENOMEM;
849 }
850
851 static int mpls_platform_labels(struct ctl_table *table, int write,
852                                 void __user *buffer, size_t *lenp, loff_t *ppos)
853 {
854         struct net *net = table->data;
855         int platform_labels = net->mpls.platform_labels;
856         int ret;
857         struct ctl_table tmp = {
858                 .procname       = table->procname,
859                 .data           = &platform_labels,
860                 .maxlen         = sizeof(int),
861                 .mode           = table->mode,
862                 .extra1         = &zero,
863                 .extra2         = &label_limit,
864         };
865
866         ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos);
867
868         if (write && ret == 0)
869                 ret = resize_platform_label_table(net, platform_labels);
870
871         return ret;
872 }
873
874 static struct ctl_table mpls_table[] = {
875         {
876                 .procname       = "platform_labels",
877                 .data           = NULL,
878                 .maxlen         = sizeof(int),
879                 .mode           = 0644,
880                 .proc_handler   = mpls_platform_labels,
881         },
882         { }
883 };
884
885 static int mpls_net_init(struct net *net)
886 {
887         struct ctl_table *table;
888
889         net->mpls.platform_labels = 0;
890         net->mpls.platform_label = NULL;
891
892         table = kmemdup(mpls_table, sizeof(mpls_table), GFP_KERNEL);
893         if (table == NULL)
894                 return -ENOMEM;
895
896         table[0].data = net;
897         net->mpls.ctl = register_net_sysctl(net, "net/mpls", table);
898         if (net->mpls.ctl == NULL)
899                 return -ENOMEM;
900
901         return 0;
902 }
903
904 static void mpls_net_exit(struct net *net)
905 {
906         struct ctl_table *table;
907         unsigned int index;
908
909         table = net->mpls.ctl->ctl_table_arg;
910         unregister_net_sysctl_table(net->mpls.ctl);
911         kfree(table);
912
913         /* An rcu grace period haselapsed since there was a device in
914          * the network namespace (and thus the last in fqlight packet)
915          * left this network namespace.  This is because
916          * unregister_netdevice_many and netdev_run_todo has completed
917          * for each network device that was in this network namespace.
918          *
919          * As such no additional rcu synchronization is necessary when
920          * freeing the platform_label table.
921          */
922         rtnl_lock();
923         for (index = 0; index < net->mpls.platform_labels; index++) {
924                 struct mpls_route *rt = net->mpls.platform_label[index];
925                 rcu_assign_pointer(net->mpls.platform_label[index], NULL);
926                 mpls_rt_free(rt);
927         }
928         rtnl_unlock();
929
930         kvfree(net->mpls.platform_label);
931 }
932
933 static struct pernet_operations mpls_net_ops = {
934         .init = mpls_net_init,
935         .exit = mpls_net_exit,
936 };
937
938 static int __init mpls_init(void)
939 {
940         int err;
941
942         BUILD_BUG_ON(sizeof(struct mpls_shim_hdr) != 4);
943
944         err = register_pernet_subsys(&mpls_net_ops);
945         if (err)
946                 goto out;
947
948         err = register_netdevice_notifier(&mpls_dev_notifier);
949         if (err)
950                 goto out_unregister_pernet;
951
952         dev_add_pack(&mpls_packet_type);
953
954         rtnl_register(PF_MPLS, RTM_NEWROUTE, mpls_rtm_newroute, NULL, NULL);
955         rtnl_register(PF_MPLS, RTM_DELROUTE, mpls_rtm_delroute, NULL, NULL);
956         rtnl_register(PF_MPLS, RTM_GETROUTE, NULL, mpls_dump_routes, NULL);
957         err = 0;
958 out:
959         return err;
960
961 out_unregister_pernet:
962         unregister_pernet_subsys(&mpls_net_ops);
963         goto out;
964 }
965 module_init(mpls_init);
966
967 static void __exit mpls_exit(void)
968 {
969         rtnl_unregister_all(PF_MPLS);
970         dev_remove_pack(&mpls_packet_type);
971         unregister_netdevice_notifier(&mpls_dev_notifier);
972         unregister_pernet_subsys(&mpls_net_ops);
973 }
974 module_exit(mpls_exit);
975
976 MODULE_DESCRIPTION("MultiProtocol Label Switching");
977 MODULE_LICENSE("GPL v2");
978 MODULE_ALIAS_NETPROTO(PF_MPLS);