]> git.kernelconcepts.de Git - karo-tx-linux.git/blob - net/vmw_vsock/virtio_transport_common.c
Merge tag 'gvt-fixes-2017-08-23' of https://github.com/01org/gvt-linux into drm-intel...
[karo-tx-linux.git] / net / vmw_vsock / virtio_transport_common.c
1 /*
2  * common code for virtio vsock
3  *
4  * Copyright (C) 2013-2015 Red Hat, Inc.
5  * Author: Asias He <asias@redhat.com>
6  *         Stefan Hajnoczi <stefanha@redhat.com>
7  *
8  * This work is licensed under the terms of the GNU GPL, version 2.
9  */
10 #include <linux/spinlock.h>
11 #include <linux/module.h>
12 #include <linux/sched/signal.h>
13 #include <linux/ctype.h>
14 #include <linux/list.h>
15 #include <linux/virtio.h>
16 #include <linux/virtio_ids.h>
17 #include <linux/virtio_config.h>
18 #include <linux/virtio_vsock.h>
19 #include <uapi/linux/vsockmon.h>
20
21 #include <net/sock.h>
22 #include <net/af_vsock.h>
23
24 #define CREATE_TRACE_POINTS
25 #include <trace/events/vsock_virtio_transport_common.h>
26
27 /* How long to wait for graceful shutdown of a connection */
28 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
29
30 static const struct virtio_transport *virtio_transport_get_ops(void)
31 {
32         const struct vsock_transport *t = vsock_core_get_transport();
33
34         return container_of(t, struct virtio_transport, transport);
35 }
36
37 static struct virtio_vsock_pkt *
38 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
39                            size_t len,
40                            u32 src_cid,
41                            u32 src_port,
42                            u32 dst_cid,
43                            u32 dst_port)
44 {
45         struct virtio_vsock_pkt *pkt;
46         int err;
47
48         pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
49         if (!pkt)
50                 return NULL;
51
52         pkt->hdr.type           = cpu_to_le16(info->type);
53         pkt->hdr.op             = cpu_to_le16(info->op);
54         pkt->hdr.src_cid        = cpu_to_le64(src_cid);
55         pkt->hdr.dst_cid        = cpu_to_le64(dst_cid);
56         pkt->hdr.src_port       = cpu_to_le32(src_port);
57         pkt->hdr.dst_port       = cpu_to_le32(dst_port);
58         pkt->hdr.flags          = cpu_to_le32(info->flags);
59         pkt->len                = len;
60         pkt->hdr.len            = cpu_to_le32(len);
61         pkt->reply              = info->reply;
62         pkt->vsk                = info->vsk;
63
64         if (info->msg && len > 0) {
65                 pkt->buf = kmalloc(len, GFP_KERNEL);
66                 if (!pkt->buf)
67                         goto out_pkt;
68                 err = memcpy_from_msg(pkt->buf, info->msg, len);
69                 if (err)
70                         goto out;
71         }
72
73         trace_virtio_transport_alloc_pkt(src_cid, src_port,
74                                          dst_cid, dst_port,
75                                          len,
76                                          info->type,
77                                          info->op,
78                                          info->flags);
79
80         return pkt;
81
82 out:
83         kfree(pkt->buf);
84 out_pkt:
85         kfree(pkt);
86         return NULL;
87 }
88
89 /* Packet capture */
90 static struct sk_buff *virtio_transport_build_skb(void *opaque)
91 {
92         struct virtio_vsock_pkt *pkt = opaque;
93         struct af_vsockmon_hdr *hdr;
94         struct sk_buff *skb;
95
96         skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + pkt->len,
97                         GFP_ATOMIC);
98         if (!skb)
99                 return NULL;
100
101         hdr = skb_put(skb, sizeof(*hdr));
102
103         /* pkt->hdr is little-endian so no need to byteswap here */
104         hdr->src_cid = pkt->hdr.src_cid;
105         hdr->src_port = pkt->hdr.src_port;
106         hdr->dst_cid = pkt->hdr.dst_cid;
107         hdr->dst_port = pkt->hdr.dst_port;
108
109         hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
110         hdr->len = cpu_to_le16(sizeof(pkt->hdr));
111         memset(hdr->reserved, 0, sizeof(hdr->reserved));
112
113         switch (le16_to_cpu(pkt->hdr.op)) {
114         case VIRTIO_VSOCK_OP_REQUEST:
115         case VIRTIO_VSOCK_OP_RESPONSE:
116                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
117                 break;
118         case VIRTIO_VSOCK_OP_RST:
119         case VIRTIO_VSOCK_OP_SHUTDOWN:
120                 hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
121                 break;
122         case VIRTIO_VSOCK_OP_RW:
123                 hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
124                 break;
125         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
126         case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
127                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
128                 break;
129         default:
130                 hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
131                 break;
132         }
133
134         skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
135
136         if (pkt->len) {
137                 skb_put_data(skb, pkt->buf, pkt->len);
138         }
139
140         return skb;
141 }
142
143 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
144 {
145         vsock_deliver_tap(virtio_transport_build_skb, pkt);
146 }
147 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
148
149 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
150                                           struct virtio_vsock_pkt_info *info)
151 {
152         u32 src_cid, src_port, dst_cid, dst_port;
153         struct virtio_vsock_sock *vvs;
154         struct virtio_vsock_pkt *pkt;
155         u32 pkt_len = info->pkt_len;
156
157         src_cid = vm_sockets_get_local_cid();
158         src_port = vsk->local_addr.svm_port;
159         if (!info->remote_cid) {
160                 dst_cid = vsk->remote_addr.svm_cid;
161                 dst_port = vsk->remote_addr.svm_port;
162         } else {
163                 dst_cid = info->remote_cid;
164                 dst_port = info->remote_port;
165         }
166
167         vvs = vsk->trans;
168
169         /* we can send less than pkt_len bytes */
170         if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
171                 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
172
173         /* virtio_transport_get_credit might return less than pkt_len credit */
174         pkt_len = virtio_transport_get_credit(vvs, pkt_len);
175
176         /* Do not send zero length OP_RW pkt */
177         if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
178                 return pkt_len;
179
180         pkt = virtio_transport_alloc_pkt(info, pkt_len,
181                                          src_cid, src_port,
182                                          dst_cid, dst_port);
183         if (!pkt) {
184                 virtio_transport_put_credit(vvs, pkt_len);
185                 return -ENOMEM;
186         }
187
188         virtio_transport_inc_tx_pkt(vvs, pkt);
189
190         return virtio_transport_get_ops()->send_pkt(pkt);
191 }
192
193 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
194                                         struct virtio_vsock_pkt *pkt)
195 {
196         vvs->rx_bytes += pkt->len;
197 }
198
199 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
200                                         struct virtio_vsock_pkt *pkt)
201 {
202         vvs->rx_bytes -= pkt->len;
203         vvs->fwd_cnt += pkt->len;
204 }
205
206 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
207 {
208         spin_lock_bh(&vvs->tx_lock);
209         pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
210         pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
211         spin_unlock_bh(&vvs->tx_lock);
212 }
213 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
214
215 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
216 {
217         u32 ret;
218
219         spin_lock_bh(&vvs->tx_lock);
220         ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
221         if (ret > credit)
222                 ret = credit;
223         vvs->tx_cnt += ret;
224         spin_unlock_bh(&vvs->tx_lock);
225
226         return ret;
227 }
228 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
229
230 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
231 {
232         spin_lock_bh(&vvs->tx_lock);
233         vvs->tx_cnt -= credit;
234         spin_unlock_bh(&vvs->tx_lock);
235 }
236 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
237
238 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
239                                                int type,
240                                                struct virtio_vsock_hdr *hdr)
241 {
242         struct virtio_vsock_pkt_info info = {
243                 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
244                 .type = type,
245                 .vsk = vsk,
246         };
247
248         return virtio_transport_send_pkt_info(vsk, &info);
249 }
250
251 static ssize_t
252 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
253                                    struct msghdr *msg,
254                                    size_t len)
255 {
256         struct virtio_vsock_sock *vvs = vsk->trans;
257         struct virtio_vsock_pkt *pkt;
258         size_t bytes, total = 0;
259         int err = -EFAULT;
260
261         spin_lock_bh(&vvs->rx_lock);
262         while (total < len && !list_empty(&vvs->rx_queue)) {
263                 pkt = list_first_entry(&vvs->rx_queue,
264                                        struct virtio_vsock_pkt, list);
265
266                 bytes = len - total;
267                 if (bytes > pkt->len - pkt->off)
268                         bytes = pkt->len - pkt->off;
269
270                 /* sk_lock is held by caller so no one else can dequeue.
271                  * Unlock rx_lock since memcpy_to_msg() may sleep.
272                  */
273                 spin_unlock_bh(&vvs->rx_lock);
274
275                 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
276                 if (err)
277                         goto out;
278
279                 spin_lock_bh(&vvs->rx_lock);
280
281                 total += bytes;
282                 pkt->off += bytes;
283                 if (pkt->off == pkt->len) {
284                         virtio_transport_dec_rx_pkt(vvs, pkt);
285                         list_del(&pkt->list);
286                         virtio_transport_free_pkt(pkt);
287                 }
288         }
289         spin_unlock_bh(&vvs->rx_lock);
290
291         /* Send a credit pkt to peer */
292         virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
293                                             NULL);
294
295         return total;
296
297 out:
298         if (total)
299                 err = total;
300         return err;
301 }
302
303 ssize_t
304 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
305                                 struct msghdr *msg,
306                                 size_t len, int flags)
307 {
308         if (flags & MSG_PEEK)
309                 return -EOPNOTSUPP;
310
311         return virtio_transport_stream_do_dequeue(vsk, msg, len);
312 }
313 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
314
315 int
316 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
317                                struct msghdr *msg,
318                                size_t len, int flags)
319 {
320         return -EOPNOTSUPP;
321 }
322 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
323
324 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
325 {
326         struct virtio_vsock_sock *vvs = vsk->trans;
327         s64 bytes;
328
329         spin_lock_bh(&vvs->rx_lock);
330         bytes = vvs->rx_bytes;
331         spin_unlock_bh(&vvs->rx_lock);
332
333         return bytes;
334 }
335 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
336
337 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
338 {
339         struct virtio_vsock_sock *vvs = vsk->trans;
340         s64 bytes;
341
342         bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
343         if (bytes < 0)
344                 bytes = 0;
345
346         return bytes;
347 }
348
349 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
350 {
351         struct virtio_vsock_sock *vvs = vsk->trans;
352         s64 bytes;
353
354         spin_lock_bh(&vvs->tx_lock);
355         bytes = virtio_transport_has_space(vsk);
356         spin_unlock_bh(&vvs->tx_lock);
357
358         return bytes;
359 }
360 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
361
362 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
363                                     struct vsock_sock *psk)
364 {
365         struct virtio_vsock_sock *vvs;
366
367         vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
368         if (!vvs)
369                 return -ENOMEM;
370
371         vsk->trans = vvs;
372         vvs->vsk = vsk;
373         if (psk) {
374                 struct virtio_vsock_sock *ptrans = psk->trans;
375
376                 vvs->buf_size   = ptrans->buf_size;
377                 vvs->buf_size_min = ptrans->buf_size_min;
378                 vvs->buf_size_max = ptrans->buf_size_max;
379                 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
380         } else {
381                 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
382                 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
383                 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
384         }
385
386         vvs->buf_alloc = vvs->buf_size;
387
388         spin_lock_init(&vvs->rx_lock);
389         spin_lock_init(&vvs->tx_lock);
390         INIT_LIST_HEAD(&vvs->rx_queue);
391
392         return 0;
393 }
394 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
395
396 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
397 {
398         struct virtio_vsock_sock *vvs = vsk->trans;
399
400         return vvs->buf_size;
401 }
402 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
403
404 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
405 {
406         struct virtio_vsock_sock *vvs = vsk->trans;
407
408         return vvs->buf_size_min;
409 }
410 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
411
412 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
413 {
414         struct virtio_vsock_sock *vvs = vsk->trans;
415
416         return vvs->buf_size_max;
417 }
418 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
419
420 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
421 {
422         struct virtio_vsock_sock *vvs = vsk->trans;
423
424         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
425                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
426         if (val < vvs->buf_size_min)
427                 vvs->buf_size_min = val;
428         if (val > vvs->buf_size_max)
429                 vvs->buf_size_max = val;
430         vvs->buf_size = val;
431         vvs->buf_alloc = val;
432 }
433 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
434
435 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
436 {
437         struct virtio_vsock_sock *vvs = vsk->trans;
438
439         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
440                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
441         if (val > vvs->buf_size)
442                 vvs->buf_size = val;
443         vvs->buf_size_min = val;
444 }
445 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
446
447 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
448 {
449         struct virtio_vsock_sock *vvs = vsk->trans;
450
451         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
452                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
453         if (val < vvs->buf_size)
454                 vvs->buf_size = val;
455         vvs->buf_size_max = val;
456 }
457 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
458
459 int
460 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
461                                 size_t target,
462                                 bool *data_ready_now)
463 {
464         if (vsock_stream_has_data(vsk))
465                 *data_ready_now = true;
466         else
467                 *data_ready_now = false;
468
469         return 0;
470 }
471 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
472
473 int
474 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
475                                  size_t target,
476                                  bool *space_avail_now)
477 {
478         s64 free_space;
479
480         free_space = vsock_stream_has_space(vsk);
481         if (free_space > 0)
482                 *space_avail_now = true;
483         else if (free_space == 0)
484                 *space_avail_now = false;
485
486         return 0;
487 }
488 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
489
490 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
491         size_t target, struct vsock_transport_recv_notify_data *data)
492 {
493         return 0;
494 }
495 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
496
497 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
498         size_t target, struct vsock_transport_recv_notify_data *data)
499 {
500         return 0;
501 }
502 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
503
504 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
505         size_t target, struct vsock_transport_recv_notify_data *data)
506 {
507         return 0;
508 }
509 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
510
511 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
512         size_t target, ssize_t copied, bool data_read,
513         struct vsock_transport_recv_notify_data *data)
514 {
515         return 0;
516 }
517 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
518
519 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
520         struct vsock_transport_send_notify_data *data)
521 {
522         return 0;
523 }
524 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
525
526 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
527         struct vsock_transport_send_notify_data *data)
528 {
529         return 0;
530 }
531 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
532
533 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
534         struct vsock_transport_send_notify_data *data)
535 {
536         return 0;
537 }
538 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
539
540 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
541         ssize_t written, struct vsock_transport_send_notify_data *data)
542 {
543         return 0;
544 }
545 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
546
547 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
548 {
549         struct virtio_vsock_sock *vvs = vsk->trans;
550
551         return vvs->buf_size;
552 }
553 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
554
555 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
556 {
557         return true;
558 }
559 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
560
561 bool virtio_transport_stream_allow(u32 cid, u32 port)
562 {
563         return true;
564 }
565 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
566
567 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
568                                 struct sockaddr_vm *addr)
569 {
570         return -EOPNOTSUPP;
571 }
572 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
573
574 bool virtio_transport_dgram_allow(u32 cid, u32 port)
575 {
576         return false;
577 }
578 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
579
580 int virtio_transport_connect(struct vsock_sock *vsk)
581 {
582         struct virtio_vsock_pkt_info info = {
583                 .op = VIRTIO_VSOCK_OP_REQUEST,
584                 .type = VIRTIO_VSOCK_TYPE_STREAM,
585                 .vsk = vsk,
586         };
587
588         return virtio_transport_send_pkt_info(vsk, &info);
589 }
590 EXPORT_SYMBOL_GPL(virtio_transport_connect);
591
592 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
593 {
594         struct virtio_vsock_pkt_info info = {
595                 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
596                 .type = VIRTIO_VSOCK_TYPE_STREAM,
597                 .flags = (mode & RCV_SHUTDOWN ?
598                           VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
599                          (mode & SEND_SHUTDOWN ?
600                           VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
601                 .vsk = vsk,
602         };
603
604         return virtio_transport_send_pkt_info(vsk, &info);
605 }
606 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
607
608 int
609 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
610                                struct sockaddr_vm *remote_addr,
611                                struct msghdr *msg,
612                                size_t dgram_len)
613 {
614         return -EOPNOTSUPP;
615 }
616 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
617
618 ssize_t
619 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
620                                 struct msghdr *msg,
621                                 size_t len)
622 {
623         struct virtio_vsock_pkt_info info = {
624                 .op = VIRTIO_VSOCK_OP_RW,
625                 .type = VIRTIO_VSOCK_TYPE_STREAM,
626                 .msg = msg,
627                 .pkt_len = len,
628                 .vsk = vsk,
629         };
630
631         return virtio_transport_send_pkt_info(vsk, &info);
632 }
633 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
634
635 void virtio_transport_destruct(struct vsock_sock *vsk)
636 {
637         struct virtio_vsock_sock *vvs = vsk->trans;
638
639         kfree(vvs);
640 }
641 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
642
643 static int virtio_transport_reset(struct vsock_sock *vsk,
644                                   struct virtio_vsock_pkt *pkt)
645 {
646         struct virtio_vsock_pkt_info info = {
647                 .op = VIRTIO_VSOCK_OP_RST,
648                 .type = VIRTIO_VSOCK_TYPE_STREAM,
649                 .reply = !!pkt,
650                 .vsk = vsk,
651         };
652
653         /* Send RST only if the original pkt is not a RST pkt */
654         if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
655                 return 0;
656
657         return virtio_transport_send_pkt_info(vsk, &info);
658 }
659
660 /* Normally packets are associated with a socket.  There may be no socket if an
661  * attempt was made to connect to a socket that does not exist.
662  */
663 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
664 {
665         struct virtio_vsock_pkt_info info = {
666                 .op = VIRTIO_VSOCK_OP_RST,
667                 .type = le16_to_cpu(pkt->hdr.type),
668                 .reply = true,
669         };
670
671         /* Send RST only if the original pkt is not a RST pkt */
672         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
673                 return 0;
674
675         pkt = virtio_transport_alloc_pkt(&info, 0,
676                                          le64_to_cpu(pkt->hdr.dst_cid),
677                                          le32_to_cpu(pkt->hdr.dst_port),
678                                          le64_to_cpu(pkt->hdr.src_cid),
679                                          le32_to_cpu(pkt->hdr.src_port));
680         if (!pkt)
681                 return -ENOMEM;
682
683         return virtio_transport_get_ops()->send_pkt(pkt);
684 }
685
686 static void virtio_transport_wait_close(struct sock *sk, long timeout)
687 {
688         if (timeout) {
689                 DEFINE_WAIT_FUNC(wait, woken_wake_function);
690
691                 add_wait_queue(sk_sleep(sk), &wait);
692
693                 do {
694                         if (sk_wait_event(sk, &timeout,
695                                           sock_flag(sk, SOCK_DONE), &wait))
696                                 break;
697                 } while (!signal_pending(current) && timeout);
698
699                 remove_wait_queue(sk_sleep(sk), &wait);
700         }
701 }
702
703 static void virtio_transport_do_close(struct vsock_sock *vsk,
704                                       bool cancel_timeout)
705 {
706         struct sock *sk = sk_vsock(vsk);
707
708         sock_set_flag(sk, SOCK_DONE);
709         vsk->peer_shutdown = SHUTDOWN_MASK;
710         if (vsock_stream_has_data(vsk) <= 0)
711                 sk->sk_state = SS_DISCONNECTING;
712         sk->sk_state_change(sk);
713
714         if (vsk->close_work_scheduled &&
715             (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
716                 vsk->close_work_scheduled = false;
717
718                 vsock_remove_sock(vsk);
719
720                 /* Release refcnt obtained when we scheduled the timeout */
721                 sock_put(sk);
722         }
723 }
724
725 static void virtio_transport_close_timeout(struct work_struct *work)
726 {
727         struct vsock_sock *vsk =
728                 container_of(work, struct vsock_sock, close_work.work);
729         struct sock *sk = sk_vsock(vsk);
730
731         sock_hold(sk);
732         lock_sock(sk);
733
734         if (!sock_flag(sk, SOCK_DONE)) {
735                 (void)virtio_transport_reset(vsk, NULL);
736
737                 virtio_transport_do_close(vsk, false);
738         }
739
740         vsk->close_work_scheduled = false;
741
742         release_sock(sk);
743         sock_put(sk);
744 }
745
746 /* User context, vsk->sk is locked */
747 static bool virtio_transport_close(struct vsock_sock *vsk)
748 {
749         struct sock *sk = &vsk->sk;
750
751         if (!(sk->sk_state == SS_CONNECTED ||
752               sk->sk_state == SS_DISCONNECTING))
753                 return true;
754
755         /* Already received SHUTDOWN from peer, reply with RST */
756         if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
757                 (void)virtio_transport_reset(vsk, NULL);
758                 return true;
759         }
760
761         if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
762                 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
763
764         if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
765                 virtio_transport_wait_close(sk, sk->sk_lingertime);
766
767         if (sock_flag(sk, SOCK_DONE)) {
768                 return true;
769         }
770
771         sock_hold(sk);
772         INIT_DELAYED_WORK(&vsk->close_work,
773                           virtio_transport_close_timeout);
774         vsk->close_work_scheduled = true;
775         schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
776         return false;
777 }
778
779 void virtio_transport_release(struct vsock_sock *vsk)
780 {
781         struct sock *sk = &vsk->sk;
782         bool remove_sock = true;
783
784         lock_sock(sk);
785         if (sk->sk_type == SOCK_STREAM)
786                 remove_sock = virtio_transport_close(vsk);
787         release_sock(sk);
788
789         if (remove_sock)
790                 vsock_remove_sock(vsk);
791 }
792 EXPORT_SYMBOL_GPL(virtio_transport_release);
793
794 static int
795 virtio_transport_recv_connecting(struct sock *sk,
796                                  struct virtio_vsock_pkt *pkt)
797 {
798         struct vsock_sock *vsk = vsock_sk(sk);
799         int err;
800         int skerr;
801
802         switch (le16_to_cpu(pkt->hdr.op)) {
803         case VIRTIO_VSOCK_OP_RESPONSE:
804                 sk->sk_state = SS_CONNECTED;
805                 sk->sk_socket->state = SS_CONNECTED;
806                 vsock_insert_connected(vsk);
807                 sk->sk_state_change(sk);
808                 break;
809         case VIRTIO_VSOCK_OP_INVALID:
810                 break;
811         case VIRTIO_VSOCK_OP_RST:
812                 skerr = ECONNRESET;
813                 err = 0;
814                 goto destroy;
815         default:
816                 skerr = EPROTO;
817                 err = -EINVAL;
818                 goto destroy;
819         }
820         return 0;
821
822 destroy:
823         virtio_transport_reset(vsk, pkt);
824         sk->sk_state = SS_UNCONNECTED;
825         sk->sk_err = skerr;
826         sk->sk_error_report(sk);
827         return err;
828 }
829
830 static int
831 virtio_transport_recv_connected(struct sock *sk,
832                                 struct virtio_vsock_pkt *pkt)
833 {
834         struct vsock_sock *vsk = vsock_sk(sk);
835         struct virtio_vsock_sock *vvs = vsk->trans;
836         int err = 0;
837
838         switch (le16_to_cpu(pkt->hdr.op)) {
839         case VIRTIO_VSOCK_OP_RW:
840                 pkt->len = le32_to_cpu(pkt->hdr.len);
841                 pkt->off = 0;
842
843                 spin_lock_bh(&vvs->rx_lock);
844                 virtio_transport_inc_rx_pkt(vvs, pkt);
845                 list_add_tail(&pkt->list, &vvs->rx_queue);
846                 spin_unlock_bh(&vvs->rx_lock);
847
848                 sk->sk_data_ready(sk);
849                 return err;
850         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
851                 sk->sk_write_space(sk);
852                 break;
853         case VIRTIO_VSOCK_OP_SHUTDOWN:
854                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
855                         vsk->peer_shutdown |= RCV_SHUTDOWN;
856                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
857                         vsk->peer_shutdown |= SEND_SHUTDOWN;
858                 if (vsk->peer_shutdown == SHUTDOWN_MASK &&
859                     vsock_stream_has_data(vsk) <= 0)
860                         sk->sk_state = SS_DISCONNECTING;
861                 if (le32_to_cpu(pkt->hdr.flags))
862                         sk->sk_state_change(sk);
863                 break;
864         case VIRTIO_VSOCK_OP_RST:
865                 virtio_transport_do_close(vsk, true);
866                 break;
867         default:
868                 err = -EINVAL;
869                 break;
870         }
871
872         virtio_transport_free_pkt(pkt);
873         return err;
874 }
875
876 static void
877 virtio_transport_recv_disconnecting(struct sock *sk,
878                                     struct virtio_vsock_pkt *pkt)
879 {
880         struct vsock_sock *vsk = vsock_sk(sk);
881
882         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
883                 virtio_transport_do_close(vsk, true);
884 }
885
886 static int
887 virtio_transport_send_response(struct vsock_sock *vsk,
888                                struct virtio_vsock_pkt *pkt)
889 {
890         struct virtio_vsock_pkt_info info = {
891                 .op = VIRTIO_VSOCK_OP_RESPONSE,
892                 .type = VIRTIO_VSOCK_TYPE_STREAM,
893                 .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
894                 .remote_port = le32_to_cpu(pkt->hdr.src_port),
895                 .reply = true,
896                 .vsk = vsk,
897         };
898
899         return virtio_transport_send_pkt_info(vsk, &info);
900 }
901
902 /* Handle server socket */
903 static int
904 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
905 {
906         struct vsock_sock *vsk = vsock_sk(sk);
907         struct vsock_sock *vchild;
908         struct sock *child;
909
910         if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
911                 virtio_transport_reset(vsk, pkt);
912                 return -EINVAL;
913         }
914
915         if (sk_acceptq_is_full(sk)) {
916                 virtio_transport_reset(vsk, pkt);
917                 return -ENOMEM;
918         }
919
920         child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
921                                sk->sk_type, 0);
922         if (!child) {
923                 virtio_transport_reset(vsk, pkt);
924                 return -ENOMEM;
925         }
926
927         sk->sk_ack_backlog++;
928
929         lock_sock_nested(child, SINGLE_DEPTH_NESTING);
930
931         child->sk_state = SS_CONNECTED;
932
933         vchild = vsock_sk(child);
934         vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
935                         le32_to_cpu(pkt->hdr.dst_port));
936         vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
937                         le32_to_cpu(pkt->hdr.src_port));
938
939         vsock_insert_connected(vchild);
940         vsock_enqueue_accept(sk, child);
941         virtio_transport_send_response(vchild, pkt);
942
943         release_sock(child);
944
945         sk->sk_data_ready(sk);
946         return 0;
947 }
948
949 static bool virtio_transport_space_update(struct sock *sk,
950                                           struct virtio_vsock_pkt *pkt)
951 {
952         struct vsock_sock *vsk = vsock_sk(sk);
953         struct virtio_vsock_sock *vvs = vsk->trans;
954         bool space_available;
955
956         /* buf_alloc and fwd_cnt is always included in the hdr */
957         spin_lock_bh(&vvs->tx_lock);
958         vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
959         vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
960         space_available = virtio_transport_has_space(vsk);
961         spin_unlock_bh(&vvs->tx_lock);
962         return space_available;
963 }
964
965 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
966  * lock.
967  */
968 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
969 {
970         struct sockaddr_vm src, dst;
971         struct vsock_sock *vsk;
972         struct sock *sk;
973         bool space_available;
974
975         vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
976                         le32_to_cpu(pkt->hdr.src_port));
977         vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
978                         le32_to_cpu(pkt->hdr.dst_port));
979
980         trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
981                                         dst.svm_cid, dst.svm_port,
982                                         le32_to_cpu(pkt->hdr.len),
983                                         le16_to_cpu(pkt->hdr.type),
984                                         le16_to_cpu(pkt->hdr.op),
985                                         le32_to_cpu(pkt->hdr.flags),
986                                         le32_to_cpu(pkt->hdr.buf_alloc),
987                                         le32_to_cpu(pkt->hdr.fwd_cnt));
988
989         if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
990                 (void)virtio_transport_reset_no_sock(pkt);
991                 goto free_pkt;
992         }
993
994         /* The socket must be in connected or bound table
995          * otherwise send reset back
996          */
997         sk = vsock_find_connected_socket(&src, &dst);
998         if (!sk) {
999                 sk = vsock_find_bound_socket(&dst);
1000                 if (!sk) {
1001                         (void)virtio_transport_reset_no_sock(pkt);
1002                         goto free_pkt;
1003                 }
1004         }
1005
1006         vsk = vsock_sk(sk);
1007
1008         space_available = virtio_transport_space_update(sk, pkt);
1009
1010         lock_sock(sk);
1011
1012         /* Update CID in case it has changed after a transport reset event */
1013         vsk->local_addr.svm_cid = dst.svm_cid;
1014
1015         if (space_available)
1016                 sk->sk_write_space(sk);
1017
1018         switch (sk->sk_state) {
1019         case VSOCK_SS_LISTEN:
1020                 virtio_transport_recv_listen(sk, pkt);
1021                 virtio_transport_free_pkt(pkt);
1022                 break;
1023         case SS_CONNECTING:
1024                 virtio_transport_recv_connecting(sk, pkt);
1025                 virtio_transport_free_pkt(pkt);
1026                 break;
1027         case SS_CONNECTED:
1028                 virtio_transport_recv_connected(sk, pkt);
1029                 break;
1030         case SS_DISCONNECTING:
1031                 virtio_transport_recv_disconnecting(sk, pkt);
1032                 virtio_transport_free_pkt(pkt);
1033                 break;
1034         default:
1035                 virtio_transport_free_pkt(pkt);
1036                 break;
1037         }
1038         release_sock(sk);
1039
1040         /* Release refcnt obtained when we fetched this socket out of the
1041          * bound or connected list.
1042          */
1043         sock_put(sk);
1044         return;
1045
1046 free_pkt:
1047         virtio_transport_free_pkt(pkt);
1048 }
1049 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1050
1051 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
1052 {
1053         kfree(pkt->buf);
1054         kfree(pkt);
1055 }
1056 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
1057
1058 MODULE_LICENSE("GPL v2");
1059 MODULE_AUTHOR("Asias He");
1060 MODULE_DESCRIPTION("common code for virtio vsock");