]> git.kernelconcepts.de Git - karo-tx-linux.git/blob - net/ipv4/fou.c
ASoC: sti-asoc-card: update tdm mode
[karo-tx-linux.git] / net / ipv4 / fou.c
1 #include <linux/module.h>
2 #include <linux/errno.h>
3 #include <linux/socket.h>
4 #include <linux/skbuff.h>
5 #include <linux/ip.h>
6 #include <linux/udp.h>
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <net/genetlink.h>
10 #include <net/gue.h>
11 #include <net/ip.h>
12 #include <net/protocol.h>
13 #include <net/udp.h>
14 #include <net/udp_tunnel.h>
15 #include <net/xfrm.h>
16 #include <uapi/linux/fou.h>
17 #include <uapi/linux/genetlink.h>
18
19 struct fou {
20         struct socket *sock;
21         u8 protocol;
22         u8 flags;
23         __be16 port;
24         u16 type;
25         struct udp_offload udp_offloads;
26         struct list_head list;
27         struct rcu_head rcu;
28 };
29
30 #define FOU_F_REMCSUM_NOPARTIAL BIT(0)
31
32 struct fou_cfg {
33         u16 type;
34         u8 protocol;
35         u8 flags;
36         struct udp_port_cfg udp_config;
37 };
38
39 static unsigned int fou_net_id;
40
41 struct fou_net {
42         struct list_head fou_list;
43         struct mutex fou_lock;
44 };
45
46 static inline struct fou *fou_from_sock(struct sock *sk)
47 {
48         return sk->sk_user_data;
49 }
50
51 static int fou_recv_pull(struct sk_buff *skb, size_t len)
52 {
53         struct iphdr *iph = ip_hdr(skb);
54
55         /* Remove 'len' bytes from the packet (UDP header and
56          * FOU header if present).
57          */
58         iph->tot_len = htons(ntohs(iph->tot_len) - len);
59         __skb_pull(skb, len);
60         skb_postpull_rcsum(skb, udp_hdr(skb), len);
61         skb_reset_transport_header(skb);
62         return iptunnel_pull_offloads(skb);
63 }
64
65 static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
66 {
67         struct fou *fou = fou_from_sock(sk);
68
69         if (!fou)
70                 return 1;
71
72         if (fou_recv_pull(skb, sizeof(struct udphdr)))
73                 goto drop;
74
75         return -fou->protocol;
76
77 drop:
78         kfree_skb(skb);
79         return 0;
80 }
81
82 static struct guehdr *gue_remcsum(struct sk_buff *skb, struct guehdr *guehdr,
83                                   void *data, size_t hdrlen, u8 ipproto,
84                                   bool nopartial)
85 {
86         __be16 *pd = data;
87         size_t start = ntohs(pd[0]);
88         size_t offset = ntohs(pd[1]);
89         size_t plen = sizeof(struct udphdr) + hdrlen +
90             max_t(size_t, offset + sizeof(u16), start);
91
92         if (skb->remcsum_offload)
93                 return guehdr;
94
95         if (!pskb_may_pull(skb, plen))
96                 return NULL;
97         guehdr = (struct guehdr *)&udp_hdr(skb)[1];
98
99         skb_remcsum_process(skb, (void *)guehdr + hdrlen,
100                             start, offset, nopartial);
101
102         return guehdr;
103 }
104
105 static int gue_control_message(struct sk_buff *skb, struct guehdr *guehdr)
106 {
107         /* No support yet */
108         kfree_skb(skb);
109         return 0;
110 }
111
112 static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
113 {
114         struct fou *fou = fou_from_sock(sk);
115         size_t len, optlen, hdrlen;
116         struct guehdr *guehdr;
117         void *data;
118         u16 doffset = 0;
119
120         if (!fou)
121                 return 1;
122
123         len = sizeof(struct udphdr) + sizeof(struct guehdr);
124         if (!pskb_may_pull(skb, len))
125                 goto drop;
126
127         guehdr = (struct guehdr *)&udp_hdr(skb)[1];
128
129         optlen = guehdr->hlen << 2;
130         len += optlen;
131
132         if (!pskb_may_pull(skb, len))
133                 goto drop;
134
135         /* guehdr may change after pull */
136         guehdr = (struct guehdr *)&udp_hdr(skb)[1];
137
138         hdrlen = sizeof(struct guehdr) + optlen;
139
140         if (guehdr->version != 0 || validate_gue_flags(guehdr, optlen))
141                 goto drop;
142
143         hdrlen = sizeof(struct guehdr) + optlen;
144
145         ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len);
146
147         /* Pull csum through the guehdr now . This can be used if
148          * there is a remote checksum offload.
149          */
150         skb_postpull_rcsum(skb, udp_hdr(skb), len);
151
152         data = &guehdr[1];
153
154         if (guehdr->flags & GUE_FLAG_PRIV) {
155                 __be32 flags = *(__be32 *)(data + doffset);
156
157                 doffset += GUE_LEN_PRIV;
158
159                 if (flags & GUE_PFLAG_REMCSUM) {
160                         guehdr = gue_remcsum(skb, guehdr, data + doffset,
161                                              hdrlen, guehdr->proto_ctype,
162                                              !!(fou->flags &
163                                                 FOU_F_REMCSUM_NOPARTIAL));
164                         if (!guehdr)
165                                 goto drop;
166
167                         data = &guehdr[1];
168
169                         doffset += GUE_PLEN_REMCSUM;
170                 }
171         }
172
173         if (unlikely(guehdr->control))
174                 return gue_control_message(skb, guehdr);
175
176         __skb_pull(skb, sizeof(struct udphdr) + hdrlen);
177         skb_reset_transport_header(skb);
178
179         if (iptunnel_pull_offloads(skb))
180                 goto drop;
181
182         return -guehdr->proto_ctype;
183
184 drop:
185         kfree_skb(skb);
186         return 0;
187 }
188
189 static struct sk_buff **fou_gro_receive(struct sk_buff **head,
190                                         struct sk_buff *skb,
191                                         struct udp_offload *uoff)
192 {
193         const struct net_offload *ops;
194         struct sk_buff **pp = NULL;
195         u8 proto = NAPI_GRO_CB(skb)->proto;
196         const struct net_offload **offloads;
197
198         rcu_read_lock();
199         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
200         ops = rcu_dereference(offloads[proto]);
201         if (!ops || !ops->callbacks.gro_receive)
202                 goto out_unlock;
203
204         pp = ops->callbacks.gro_receive(head, skb);
205
206 out_unlock:
207         rcu_read_unlock();
208
209         return pp;
210 }
211
212 static int fou_gro_complete(struct sk_buff *skb, int nhoff,
213                             struct udp_offload *uoff)
214 {
215         const struct net_offload *ops;
216         u8 proto = NAPI_GRO_CB(skb)->proto;
217         int err = -ENOSYS;
218         const struct net_offload **offloads;
219
220         udp_tunnel_gro_complete(skb, nhoff);
221
222         rcu_read_lock();
223         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
224         ops = rcu_dereference(offloads[proto]);
225         if (WARN_ON(!ops || !ops->callbacks.gro_complete))
226                 goto out_unlock;
227
228         err = ops->callbacks.gro_complete(skb, nhoff);
229
230 out_unlock:
231         rcu_read_unlock();
232
233         return err;
234 }
235
236 static struct guehdr *gue_gro_remcsum(struct sk_buff *skb, unsigned int off,
237                                       struct guehdr *guehdr, void *data,
238                                       size_t hdrlen, struct gro_remcsum *grc,
239                                       bool nopartial)
240 {
241         __be16 *pd = data;
242         size_t start = ntohs(pd[0]);
243         size_t offset = ntohs(pd[1]);
244
245         if (skb->remcsum_offload)
246                 return guehdr;
247
248         if (!NAPI_GRO_CB(skb)->csum_valid)
249                 return NULL;
250
251         guehdr = skb_gro_remcsum_process(skb, (void *)guehdr, off, hdrlen,
252                                          start, offset, grc, nopartial);
253
254         skb->remcsum_offload = 1;
255
256         return guehdr;
257 }
258
259 static struct sk_buff **gue_gro_receive(struct sk_buff **head,
260                                         struct sk_buff *skb,
261                                         struct udp_offload *uoff)
262 {
263         const struct net_offload **offloads;
264         const struct net_offload *ops;
265         struct sk_buff **pp = NULL;
266         struct sk_buff *p;
267         struct guehdr *guehdr;
268         size_t len, optlen, hdrlen, off;
269         void *data;
270         u16 doffset = 0;
271         int flush = 1;
272         struct fou *fou = container_of(uoff, struct fou, udp_offloads);
273         struct gro_remcsum grc;
274
275         skb_gro_remcsum_init(&grc);
276
277         off = skb_gro_offset(skb);
278         len = off + sizeof(*guehdr);
279
280         guehdr = skb_gro_header_fast(skb, off);
281         if (skb_gro_header_hard(skb, len)) {
282                 guehdr = skb_gro_header_slow(skb, len, off);
283                 if (unlikely(!guehdr))
284                         goto out;
285         }
286
287         optlen = guehdr->hlen << 2;
288         len += optlen;
289
290         if (skb_gro_header_hard(skb, len)) {
291                 guehdr = skb_gro_header_slow(skb, len, off);
292                 if (unlikely(!guehdr))
293                         goto out;
294         }
295
296         if (unlikely(guehdr->control) || guehdr->version != 0 ||
297             validate_gue_flags(guehdr, optlen))
298                 goto out;
299
300         hdrlen = sizeof(*guehdr) + optlen;
301
302         /* Adjust NAPI_GRO_CB(skb)->csum to account for guehdr,
303          * this is needed if there is a remote checkcsum offload.
304          */
305         skb_gro_postpull_rcsum(skb, guehdr, hdrlen);
306
307         data = &guehdr[1];
308
309         if (guehdr->flags & GUE_FLAG_PRIV) {
310                 __be32 flags = *(__be32 *)(data + doffset);
311
312                 doffset += GUE_LEN_PRIV;
313
314                 if (flags & GUE_PFLAG_REMCSUM) {
315                         guehdr = gue_gro_remcsum(skb, off, guehdr,
316                                                  data + doffset, hdrlen, &grc,
317                                                  !!(fou->flags &
318                                                     FOU_F_REMCSUM_NOPARTIAL));
319
320                         if (!guehdr)
321                                 goto out;
322
323                         data = &guehdr[1];
324
325                         doffset += GUE_PLEN_REMCSUM;
326                 }
327         }
328
329         skb_gro_pull(skb, hdrlen);
330
331         for (p = *head; p; p = p->next) {
332                 const struct guehdr *guehdr2;
333
334                 if (!NAPI_GRO_CB(p)->same_flow)
335                         continue;
336
337                 guehdr2 = (struct guehdr *)(p->data + off);
338
339                 /* Compare base GUE header to be equal (covers
340                  * hlen, version, proto_ctype, and flags.
341                  */
342                 if (guehdr->word != guehdr2->word) {
343                         NAPI_GRO_CB(p)->same_flow = 0;
344                         continue;
345                 }
346
347                 /* Compare optional fields are the same. */
348                 if (guehdr->hlen && memcmp(&guehdr[1], &guehdr2[1],
349                                            guehdr->hlen << 2)) {
350                         NAPI_GRO_CB(p)->same_flow = 0;
351                         continue;
352                 }
353         }
354
355         rcu_read_lock();
356         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
357         ops = rcu_dereference(offloads[guehdr->proto_ctype]);
358         if (WARN_ON_ONCE(!ops || !ops->callbacks.gro_receive))
359                 goto out_unlock;
360
361         pp = ops->callbacks.gro_receive(head, skb);
362         flush = 0;
363
364 out_unlock:
365         rcu_read_unlock();
366 out:
367         NAPI_GRO_CB(skb)->flush |= flush;
368         skb_gro_remcsum_cleanup(skb, &grc);
369
370         return pp;
371 }
372
373 static int gue_gro_complete(struct sk_buff *skb, int nhoff,
374                             struct udp_offload *uoff)
375 {
376         const struct net_offload **offloads;
377         struct guehdr *guehdr = (struct guehdr *)(skb->data + nhoff);
378         const struct net_offload *ops;
379         unsigned int guehlen;
380         u8 proto;
381         int err = -ENOENT;
382
383         proto = guehdr->proto_ctype;
384
385         guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
386
387         rcu_read_lock();
388         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
389         ops = rcu_dereference(offloads[proto]);
390         if (WARN_ON(!ops || !ops->callbacks.gro_complete))
391                 goto out_unlock;
392
393         err = ops->callbacks.gro_complete(skb, nhoff + guehlen);
394
395 out_unlock:
396         rcu_read_unlock();
397         return err;
398 }
399
400 static int fou_add_to_port_list(struct net *net, struct fou *fou)
401 {
402         struct fou_net *fn = net_generic(net, fou_net_id);
403         struct fou *fout;
404
405         mutex_lock(&fn->fou_lock);
406         list_for_each_entry(fout, &fn->fou_list, list) {
407                 if (fou->port == fout->port) {
408                         mutex_unlock(&fn->fou_lock);
409                         return -EALREADY;
410                 }
411         }
412
413         list_add(&fou->list, &fn->fou_list);
414         mutex_unlock(&fn->fou_lock);
415
416         return 0;
417 }
418
419 static void fou_release(struct fou *fou)
420 {
421         struct socket *sock = fou->sock;
422         struct sock *sk = sock->sk;
423
424         if (sk->sk_family == AF_INET)
425                 udp_del_offload(&fou->udp_offloads);
426         list_del(&fou->list);
427         udp_tunnel_sock_release(sock);
428
429         kfree_rcu(fou, rcu);
430 }
431
432 static int fou_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
433 {
434         udp_sk(sk)->encap_rcv = fou_udp_recv;
435         fou->protocol = cfg->protocol;
436         fou->udp_offloads.callbacks.gro_receive = fou_gro_receive;
437         fou->udp_offloads.callbacks.gro_complete = fou_gro_complete;
438         fou->udp_offloads.port = cfg->udp_config.local_udp_port;
439         fou->udp_offloads.ipproto = cfg->protocol;
440
441         return 0;
442 }
443
444 static int gue_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
445 {
446         udp_sk(sk)->encap_rcv = gue_udp_recv;
447         fou->udp_offloads.callbacks.gro_receive = gue_gro_receive;
448         fou->udp_offloads.callbacks.gro_complete = gue_gro_complete;
449         fou->udp_offloads.port = cfg->udp_config.local_udp_port;
450
451         return 0;
452 }
453
454 static int fou_create(struct net *net, struct fou_cfg *cfg,
455                       struct socket **sockp)
456 {
457         struct socket *sock = NULL;
458         struct fou *fou = NULL;
459         struct sock *sk;
460         int err;
461
462         /* Open UDP socket */
463         err = udp_sock_create(net, &cfg->udp_config, &sock);
464         if (err < 0)
465                 goto error;
466
467         /* Allocate FOU port structure */
468         fou = kzalloc(sizeof(*fou), GFP_KERNEL);
469         if (!fou) {
470                 err = -ENOMEM;
471                 goto error;
472         }
473
474         sk = sock->sk;
475
476         fou->flags = cfg->flags;
477         fou->port = cfg->udp_config.local_udp_port;
478
479         /* Initial for fou type */
480         switch (cfg->type) {
481         case FOU_ENCAP_DIRECT:
482                 err = fou_encap_init(sk, fou, cfg);
483                 if (err)
484                         goto error;
485                 break;
486         case FOU_ENCAP_GUE:
487                 err = gue_encap_init(sk, fou, cfg);
488                 if (err)
489                         goto error;
490                 break;
491         default:
492                 err = -EINVAL;
493                 goto error;
494         }
495
496         fou->type = cfg->type;
497
498         udp_sk(sk)->encap_type = 1;
499         udp_encap_enable();
500
501         sk->sk_user_data = fou;
502         fou->sock = sock;
503
504         inet_inc_convert_csum(sk);
505
506         sk->sk_allocation = GFP_ATOMIC;
507
508         if (cfg->udp_config.family == AF_INET) {
509                 err = udp_add_offload(net, &fou->udp_offloads);
510                 if (err)
511                         goto error;
512         }
513
514         err = fou_add_to_port_list(net, fou);
515         if (err)
516                 goto error;
517
518         if (sockp)
519                 *sockp = sock;
520
521         return 0;
522
523 error:
524         kfree(fou);
525         if (sock)
526                 udp_tunnel_sock_release(sock);
527
528         return err;
529 }
530
531 static int fou_destroy(struct net *net, struct fou_cfg *cfg)
532 {
533         struct fou_net *fn = net_generic(net, fou_net_id);
534         __be16 port = cfg->udp_config.local_udp_port;
535         int err = -EINVAL;
536         struct fou *fou;
537
538         mutex_lock(&fn->fou_lock);
539         list_for_each_entry(fou, &fn->fou_list, list) {
540                 if (fou->port == port) {
541                         fou_release(fou);
542                         err = 0;
543                         break;
544                 }
545         }
546         mutex_unlock(&fn->fou_lock);
547
548         return err;
549 }
550
551 static struct genl_family fou_nl_family = {
552         .id             = GENL_ID_GENERATE,
553         .hdrsize        = 0,
554         .name           = FOU_GENL_NAME,
555         .version        = FOU_GENL_VERSION,
556         .maxattr        = FOU_ATTR_MAX,
557         .netnsok        = true,
558 };
559
560 static struct nla_policy fou_nl_policy[FOU_ATTR_MAX + 1] = {
561         [FOU_ATTR_PORT] = { .type = NLA_U16, },
562         [FOU_ATTR_AF] = { .type = NLA_U8, },
563         [FOU_ATTR_IPPROTO] = { .type = NLA_U8, },
564         [FOU_ATTR_TYPE] = { .type = NLA_U8, },
565         [FOU_ATTR_REMCSUM_NOPARTIAL] = { .type = NLA_FLAG, },
566 };
567
568 static int parse_nl_config(struct genl_info *info,
569                            struct fou_cfg *cfg)
570 {
571         memset(cfg, 0, sizeof(*cfg));
572
573         cfg->udp_config.family = AF_INET;
574
575         if (info->attrs[FOU_ATTR_AF]) {
576                 u8 family = nla_get_u8(info->attrs[FOU_ATTR_AF]);
577
578                 if (family != AF_INET)
579                         return -EINVAL;
580
581                 cfg->udp_config.family = family;
582         }
583
584         if (info->attrs[FOU_ATTR_PORT]) {
585                 __be16 port = nla_get_be16(info->attrs[FOU_ATTR_PORT]);
586
587                 cfg->udp_config.local_udp_port = port;
588         }
589
590         if (info->attrs[FOU_ATTR_IPPROTO])
591                 cfg->protocol = nla_get_u8(info->attrs[FOU_ATTR_IPPROTO]);
592
593         if (info->attrs[FOU_ATTR_TYPE])
594                 cfg->type = nla_get_u8(info->attrs[FOU_ATTR_TYPE]);
595
596         if (info->attrs[FOU_ATTR_REMCSUM_NOPARTIAL])
597                 cfg->flags |= FOU_F_REMCSUM_NOPARTIAL;
598
599         return 0;
600 }
601
602 static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info)
603 {
604         struct net *net = genl_info_net(info);
605         struct fou_cfg cfg;
606         int err;
607
608         err = parse_nl_config(info, &cfg);
609         if (err)
610                 return err;
611
612         return fou_create(net, &cfg, NULL);
613 }
614
615 static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info)
616 {
617         struct net *net = genl_info_net(info);
618         struct fou_cfg cfg;
619         int err;
620
621         err = parse_nl_config(info, &cfg);
622         if (err)
623                 return err;
624
625         return fou_destroy(net, &cfg);
626 }
627
628 static int fou_fill_info(struct fou *fou, struct sk_buff *msg)
629 {
630         if (nla_put_u8(msg, FOU_ATTR_AF, fou->sock->sk->sk_family) ||
631             nla_put_be16(msg, FOU_ATTR_PORT, fou->port) ||
632             nla_put_u8(msg, FOU_ATTR_IPPROTO, fou->protocol) ||
633             nla_put_u8(msg, FOU_ATTR_TYPE, fou->type))
634                 return -1;
635
636         if (fou->flags & FOU_F_REMCSUM_NOPARTIAL)
637                 if (nla_put_flag(msg, FOU_ATTR_REMCSUM_NOPARTIAL))
638                         return -1;
639         return 0;
640 }
641
642 static int fou_dump_info(struct fou *fou, u32 portid, u32 seq,
643                          u32 flags, struct sk_buff *skb, u8 cmd)
644 {
645         void *hdr;
646
647         hdr = genlmsg_put(skb, portid, seq, &fou_nl_family, flags, cmd);
648         if (!hdr)
649                 return -ENOMEM;
650
651         if (fou_fill_info(fou, skb) < 0)
652                 goto nla_put_failure;
653
654         genlmsg_end(skb, hdr);
655         return 0;
656
657 nla_put_failure:
658         genlmsg_cancel(skb, hdr);
659         return -EMSGSIZE;
660 }
661
662 static int fou_nl_cmd_get_port(struct sk_buff *skb, struct genl_info *info)
663 {
664         struct net *net = genl_info_net(info);
665         struct fou_net *fn = net_generic(net, fou_net_id);
666         struct sk_buff *msg;
667         struct fou_cfg cfg;
668         struct fou *fout;
669         __be16 port;
670         int ret;
671
672         ret = parse_nl_config(info, &cfg);
673         if (ret)
674                 return ret;
675         port = cfg.udp_config.local_udp_port;
676         if (port == 0)
677                 return -EINVAL;
678
679         msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
680         if (!msg)
681                 return -ENOMEM;
682
683         ret = -ESRCH;
684         mutex_lock(&fn->fou_lock);
685         list_for_each_entry(fout, &fn->fou_list, list) {
686                 if (port == fout->port) {
687                         ret = fou_dump_info(fout, info->snd_portid,
688                                             info->snd_seq, 0, msg,
689                                             info->genlhdr->cmd);
690                         break;
691                 }
692         }
693         mutex_unlock(&fn->fou_lock);
694         if (ret < 0)
695                 goto out_free;
696
697         return genlmsg_reply(msg, info);
698
699 out_free:
700         nlmsg_free(msg);
701         return ret;
702 }
703
704 static int fou_nl_dump(struct sk_buff *skb, struct netlink_callback *cb)
705 {
706         struct net *net = sock_net(skb->sk);
707         struct fou_net *fn = net_generic(net, fou_net_id);
708         struct fou *fout;
709         int idx = 0, ret;
710
711         mutex_lock(&fn->fou_lock);
712         list_for_each_entry(fout, &fn->fou_list, list) {
713                 if (idx++ < cb->args[0])
714                         continue;
715                 ret = fou_dump_info(fout, NETLINK_CB(cb->skb).portid,
716                                     cb->nlh->nlmsg_seq, NLM_F_MULTI,
717                                     skb, FOU_CMD_GET);
718                 if (ret)
719                         break;
720         }
721         mutex_unlock(&fn->fou_lock);
722
723         cb->args[0] = idx;
724         return skb->len;
725 }
726
727 static const struct genl_ops fou_nl_ops[] = {
728         {
729                 .cmd = FOU_CMD_ADD,
730                 .doit = fou_nl_cmd_add_port,
731                 .policy = fou_nl_policy,
732                 .flags = GENL_ADMIN_PERM,
733         },
734         {
735                 .cmd = FOU_CMD_DEL,
736                 .doit = fou_nl_cmd_rm_port,
737                 .policy = fou_nl_policy,
738                 .flags = GENL_ADMIN_PERM,
739         },
740         {
741                 .cmd = FOU_CMD_GET,
742                 .doit = fou_nl_cmd_get_port,
743                 .dumpit = fou_nl_dump,
744                 .policy = fou_nl_policy,
745         },
746 };
747
748 size_t fou_encap_hlen(struct ip_tunnel_encap *e)
749 {
750         return sizeof(struct udphdr);
751 }
752 EXPORT_SYMBOL(fou_encap_hlen);
753
754 size_t gue_encap_hlen(struct ip_tunnel_encap *e)
755 {
756         size_t len;
757         bool need_priv = false;
758
759         len = sizeof(struct udphdr) + sizeof(struct guehdr);
760
761         if (e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) {
762                 len += GUE_PLEN_REMCSUM;
763                 need_priv = true;
764         }
765
766         len += need_priv ? GUE_LEN_PRIV : 0;
767
768         return len;
769 }
770 EXPORT_SYMBOL(gue_encap_hlen);
771
772 static void fou_build_udp(struct sk_buff *skb, struct ip_tunnel_encap *e,
773                           struct flowi4 *fl4, u8 *protocol, __be16 sport)
774 {
775         struct udphdr *uh;
776
777         skb_push(skb, sizeof(struct udphdr));
778         skb_reset_transport_header(skb);
779
780         uh = udp_hdr(skb);
781
782         uh->dest = e->dport;
783         uh->source = sport;
784         uh->len = htons(skb->len);
785         udp_set_csum(!(e->flags & TUNNEL_ENCAP_FLAG_CSUM), skb,
786                      fl4->saddr, fl4->daddr, skb->len);
787
788         *protocol = IPPROTO_UDP;
789 }
790
791 int fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
792                      u8 *protocol, struct flowi4 *fl4)
793 {
794         int type = e->flags & TUNNEL_ENCAP_FLAG_CSUM ? SKB_GSO_UDP_TUNNEL_CSUM :
795                                                        SKB_GSO_UDP_TUNNEL;
796         __be16 sport;
797
798         skb = iptunnel_handle_offloads(skb, type);
799
800         if (IS_ERR(skb))
801                 return PTR_ERR(skb);
802
803         sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
804                                                skb, 0, 0, false);
805         fou_build_udp(skb, e, fl4, protocol, sport);
806
807         return 0;
808 }
809 EXPORT_SYMBOL(fou_build_header);
810
811 int gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
812                      u8 *protocol, struct flowi4 *fl4)
813 {
814         int type = e->flags & TUNNEL_ENCAP_FLAG_CSUM ? SKB_GSO_UDP_TUNNEL_CSUM :
815                                                        SKB_GSO_UDP_TUNNEL;
816         struct guehdr *guehdr;
817         size_t hdrlen, optlen = 0;
818         __be16 sport;
819         void *data;
820         bool need_priv = false;
821
822         if ((e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) &&
823             skb->ip_summed == CHECKSUM_PARTIAL) {
824                 optlen += GUE_PLEN_REMCSUM;
825                 type |= SKB_GSO_TUNNEL_REMCSUM;
826                 need_priv = true;
827         }
828
829         optlen += need_priv ? GUE_LEN_PRIV : 0;
830
831         skb = iptunnel_handle_offloads(skb, type);
832
833         if (IS_ERR(skb))
834                 return PTR_ERR(skb);
835
836         /* Get source port (based on flow hash) before skb_push */
837         sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
838                                                skb, 0, 0, false);
839
840         hdrlen = sizeof(struct guehdr) + optlen;
841
842         skb_push(skb, hdrlen);
843
844         guehdr = (struct guehdr *)skb->data;
845
846         guehdr->control = 0;
847         guehdr->version = 0;
848         guehdr->hlen = optlen >> 2;
849         guehdr->flags = 0;
850         guehdr->proto_ctype = *protocol;
851
852         data = &guehdr[1];
853
854         if (need_priv) {
855                 __be32 *flags = data;
856
857                 guehdr->flags |= GUE_FLAG_PRIV;
858                 *flags = 0;
859                 data += GUE_LEN_PRIV;
860
861                 if (type & SKB_GSO_TUNNEL_REMCSUM) {
862                         u16 csum_start = skb_checksum_start_offset(skb);
863                         __be16 *pd = data;
864
865                         if (csum_start < hdrlen)
866                                 return -EINVAL;
867
868                         csum_start -= hdrlen;
869                         pd[0] = htons(csum_start);
870                         pd[1] = htons(csum_start + skb->csum_offset);
871
872                         if (!skb_is_gso(skb)) {
873                                 skb->ip_summed = CHECKSUM_NONE;
874                                 skb->encapsulation = 0;
875                         }
876
877                         *flags |= GUE_PFLAG_REMCSUM;
878                         data += GUE_PLEN_REMCSUM;
879                 }
880
881         }
882
883         fou_build_udp(skb, e, fl4, protocol, sport);
884
885         return 0;
886 }
887 EXPORT_SYMBOL(gue_build_header);
888
889 #ifdef CONFIG_NET_FOU_IP_TUNNELS
890
891 static const struct ip_tunnel_encap_ops fou_iptun_ops = {
892         .encap_hlen = fou_encap_hlen,
893         .build_header = fou_build_header,
894 };
895
896 static const struct ip_tunnel_encap_ops gue_iptun_ops = {
897         .encap_hlen = gue_encap_hlen,
898         .build_header = gue_build_header,
899 };
900
901 static int ip_tunnel_encap_add_fou_ops(void)
902 {
903         int ret;
904
905         ret = ip_tunnel_encap_add_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
906         if (ret < 0) {
907                 pr_err("can't add fou ops\n");
908                 return ret;
909         }
910
911         ret = ip_tunnel_encap_add_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
912         if (ret < 0) {
913                 pr_err("can't add gue ops\n");
914                 ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
915                 return ret;
916         }
917
918         return 0;
919 }
920
921 static void ip_tunnel_encap_del_fou_ops(void)
922 {
923         ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
924         ip_tunnel_encap_del_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
925 }
926
927 #else
928
929 static int ip_tunnel_encap_add_fou_ops(void)
930 {
931         return 0;
932 }
933
934 static void ip_tunnel_encap_del_fou_ops(void)
935 {
936 }
937
938 #endif
939
940 static __net_init int fou_init_net(struct net *net)
941 {
942         struct fou_net *fn = net_generic(net, fou_net_id);
943
944         INIT_LIST_HEAD(&fn->fou_list);
945         mutex_init(&fn->fou_lock);
946         return 0;
947 }
948
949 static __net_exit void fou_exit_net(struct net *net)
950 {
951         struct fou_net *fn = net_generic(net, fou_net_id);
952         struct fou *fou, *next;
953
954         /* Close all the FOU sockets */
955         mutex_lock(&fn->fou_lock);
956         list_for_each_entry_safe(fou, next, &fn->fou_list, list)
957                 fou_release(fou);
958         mutex_unlock(&fn->fou_lock);
959 }
960
961 static struct pernet_operations fou_net_ops = {
962         .init = fou_init_net,
963         .exit = fou_exit_net,
964         .id   = &fou_net_id,
965         .size = sizeof(struct fou_net),
966 };
967
968 static int __init fou_init(void)
969 {
970         int ret;
971
972         ret = register_pernet_device(&fou_net_ops);
973         if (ret)
974                 goto exit;
975
976         ret = genl_register_family_with_ops(&fou_nl_family,
977                                             fou_nl_ops);
978         if (ret < 0)
979                 goto unregister;
980
981         ret = ip_tunnel_encap_add_fou_ops();
982         if (ret == 0)
983                 return 0;
984
985         genl_unregister_family(&fou_nl_family);
986 unregister:
987         unregister_pernet_device(&fou_net_ops);
988 exit:
989         return ret;
990 }
991
992 static void __exit fou_fini(void)
993 {
994         ip_tunnel_encap_del_fou_ops();
995         genl_unregister_family(&fou_nl_family);
996         unregister_pernet_device(&fou_net_ops);
997 }
998
999 module_init(fou_init);
1000 module_exit(fou_fini);
1001 MODULE_AUTHOR("Tom Herbert <therbert@google.com>");
1002 MODULE_LICENSE("GPL");