]> git.kernelconcepts.de Git - karo-tx-linux.git/commitdiff
Merge branch 'vsock-pkt-cancel'
authorDavid S. Miller <davem@davemloft.net>
Tue, 21 Mar 2017 21:41:47 +0000 (14:41 -0700)
committerDavid S. Miller <davem@davemloft.net>
Tue, 21 Mar 2017 21:41:47 +0000 (14:41 -0700)
Peng Tao says:

====================
vsock: cancel connect packets when failing to connect

Currently, if a connect call fails on a signal or timeout (e.g., guest is still
in the process of starting up), we'll just return to caller and leave the connect
packet queued and they are sent even though the connection is considered a failure,
which can confuse applications with unwanted false connect attempt.

The patchset enables vsock (both host and guest) to cancel queued packets when
a connect attempt is considered to fail.

v5 changelog:
  - change virtio_vsock_pkt->cancel_token back to virtio_vsock_pkt->vsk
v4 changelog:
  - drop two unnecessary void * cast
  - update new callback comment
v3 changelog:
  - define cancel_pkt callback in struct vsock_transport rather than struct virtio_transport
  - rename virtio_vsock_pkt->vsk to virtio_vsock_pkt->cancel_token
v2 changelog:
  - fix queued_replies counting and resume tx/rx when necessary
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
drivers/vhost/vsock.c
include/linux/virtio_vsock.h
include/net/af_vsock.h
net/vmw_vsock/af_vsock.c
net/vmw_vsock/virtio_transport.c
net/vmw_vsock/virtio_transport_common.c

index ce5e63d2c66aac7d019c422ec294cab025e94e5e..44eed8eb0725b25e3c9765e19387e7c338ab9bbb 100644 (file)
@@ -223,6 +223,46 @@ vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
        return len;
 }
 
+static int
+vhost_transport_cancel_pkt(struct vsock_sock *vsk)
+{
+       struct vhost_vsock *vsock;
+       struct virtio_vsock_pkt *pkt, *n;
+       int cnt = 0;
+       LIST_HEAD(freeme);
+
+       /* Find the vhost_vsock according to guest context id  */
+       vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
+       if (!vsock)
+               return -ENODEV;
+
+       spin_lock_bh(&vsock->send_pkt_list_lock);
+       list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
+               if (pkt->vsk != vsk)
+                       continue;
+               list_move(&pkt->list, &freeme);
+       }
+       spin_unlock_bh(&vsock->send_pkt_list_lock);
+
+       list_for_each_entry_safe(pkt, n, &freeme, list) {
+               if (pkt->reply)
+                       cnt++;
+               list_del(&pkt->list);
+               virtio_transport_free_pkt(pkt);
+       }
+
+       if (cnt) {
+               struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
+               int new_cnt;
+
+               new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
+               if (new_cnt + cnt >= tx_vq->num && new_cnt < tx_vq->num)
+                       vhost_poll_queue(&tx_vq->poll);
+       }
+
+       return 0;
+}
+
 static struct virtio_vsock_pkt *
 vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
                      unsigned int out, unsigned int in)
@@ -675,6 +715,7 @@ static struct virtio_transport vhost_transport = {
                .release                  = virtio_transport_release,
                .connect                  = virtio_transport_connect,
                .shutdown                 = virtio_transport_shutdown,
+               .cancel_pkt               = vhost_transport_cancel_pkt,
 
                .dgram_enqueue            = virtio_transport_dgram_enqueue,
                .dgram_dequeue            = virtio_transport_dgram_dequeue,
index 9638bfeb0d1f639ae310d1586b4e2fca567ba2f7..584f9a647ad4acca191ff6116a47c14da1385fa3 100644 (file)
@@ -48,6 +48,8 @@ struct virtio_vsock_pkt {
        struct virtio_vsock_hdr hdr;
        struct work_struct work;
        struct list_head list;
+       /* socket refcnt not held, only use for cancellation */
+       struct vsock_sock *vsk;
        void *buf;
        u32 len;
        u32 off;
@@ -56,6 +58,7 @@ struct virtio_vsock_pkt {
 
 struct virtio_vsock_pkt_info {
        u32 remote_cid, remote_port;
+       struct vsock_sock *vsk;
        struct msghdr *msg;
        u32 pkt_len;
        u16 type;
index f2758964ce6f890e3b11df5ba5bf2eefe636abd1..f32ed9ac181a47c00757596fc3b8c5733426c468 100644 (file)
@@ -100,6 +100,9 @@ struct vsock_transport {
        void (*destruct)(struct vsock_sock *);
        void (*release)(struct vsock_sock *);
 
+       /* Cancel all pending packets sent on vsock. */
+       int (*cancel_pkt)(struct vsock_sock *vsk);
+
        /* Connections. */
        int (*connect)(struct vsock_sock *);
 
index 9f770f33c10098fd3fcccfd9c739ab9a28a6b6f5..6f7f6757ceefb500551fafbf40c462835c4baf88 100644 (file)
@@ -1102,10 +1102,19 @@ static const struct proto_ops vsock_dgram_ops = {
        .sendpage = sock_no_sendpage,
 };
 
+static int vsock_transport_cancel_pkt(struct vsock_sock *vsk)
+{
+       if (!transport->cancel_pkt)
+               return -EOPNOTSUPP;
+
+       return transport->cancel_pkt(vsk);
+}
+
 static void vsock_connect_timeout(struct work_struct *work)
 {
        struct sock *sk;
        struct vsock_sock *vsk;
+       int cancel = 0;
 
        vsk = container_of(work, struct vsock_sock, dwork.work);
        sk = sk_vsock(vsk);
@@ -1116,8 +1125,11 @@ static void vsock_connect_timeout(struct work_struct *work)
                sk->sk_state = SS_UNCONNECTED;
                sk->sk_err = ETIMEDOUT;
                sk->sk_error_report(sk);
+               cancel = 1;
        }
        release_sock(sk);
+       if (cancel)
+               vsock_transport_cancel_pkt(vsk);
 
        sock_put(sk);
 }
@@ -1224,11 +1236,13 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
                        err = sock_intr_errno(timeout);
                        sk->sk_state = SS_UNCONNECTED;
                        sock->state = SS_UNCONNECTED;
+                       vsock_transport_cancel_pkt(vsk);
                        goto out_wait;
                } else if (timeout == 0) {
                        err = -ETIMEDOUT;
                        sk->sk_state = SS_UNCONNECTED;
                        sock->state = SS_UNCONNECTED;
+                       vsock_transport_cancel_pkt(vsk);
                        goto out_wait;
                }
 
index 9d24c0e958b18e614e30b24c0fcfbbe2152941f3..68675a151f22b8b63c02b25a67b833d9a6046d84 100644 (file)
@@ -213,6 +213,47 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
        return len;
 }
 
+static int
+virtio_transport_cancel_pkt(struct vsock_sock *vsk)
+{
+       struct virtio_vsock *vsock;
+       struct virtio_vsock_pkt *pkt, *n;
+       int cnt = 0;
+       LIST_HEAD(freeme);
+
+       vsock = virtio_vsock_get();
+       if (!vsock) {
+               return -ENODEV;
+       }
+
+       spin_lock_bh(&vsock->send_pkt_list_lock);
+       list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
+               if (pkt->vsk != vsk)
+                       continue;
+               list_move(&pkt->list, &freeme);
+       }
+       spin_unlock_bh(&vsock->send_pkt_list_lock);
+
+       list_for_each_entry_safe(pkt, n, &freeme, list) {
+               if (pkt->reply)
+                       cnt++;
+               list_del(&pkt->list);
+               virtio_transport_free_pkt(pkt);
+       }
+
+       if (cnt) {
+               struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX];
+               int new_cnt;
+
+               new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
+               if (new_cnt + cnt >= virtqueue_get_vring_size(rx_vq) &&
+                   new_cnt < virtqueue_get_vring_size(rx_vq))
+                       queue_work(virtio_vsock_workqueue, &vsock->rx_work);
+       }
+
+       return 0;
+}
+
 static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
 {
        int buf_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
@@ -462,6 +503,7 @@ static struct virtio_transport virtio_transport = {
                .release                  = virtio_transport_release,
                .connect                  = virtio_transport_connect,
                .shutdown                 = virtio_transport_shutdown,
+               .cancel_pkt               = virtio_transport_cancel_pkt,
 
                .dgram_bind               = virtio_transport_dgram_bind,
                .dgram_dequeue            = virtio_transport_dgram_dequeue,
index 8d592a45b59786746d186e12d0c362d07c30bdac..af087b44ceea2311e53060e2442b4af2024bb037 100644 (file)
@@ -58,6 +58,7 @@ virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
        pkt->len                = len;
        pkt->hdr.len            = cpu_to_le32(len);
        pkt->reply              = info->reply;
+       pkt->vsk                = info->vsk;
 
        if (info->msg && len > 0) {
                pkt->buf = kmalloc(len, GFP_KERNEL);
@@ -180,6 +181,7 @@ static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
        struct virtio_vsock_pkt_info info = {
                .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
                .type = type,
+               .vsk = vsk,
        };
 
        return virtio_transport_send_pkt_info(vsk, &info);
@@ -519,6 +521,7 @@ int virtio_transport_connect(struct vsock_sock *vsk)
        struct virtio_vsock_pkt_info info = {
                .op = VIRTIO_VSOCK_OP_REQUEST,
                .type = VIRTIO_VSOCK_TYPE_STREAM,
+               .vsk = vsk,
        };
 
        return virtio_transport_send_pkt_info(vsk, &info);
@@ -534,6 +537,7 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
                          VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
                         (mode & SEND_SHUTDOWN ?
                          VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
+               .vsk = vsk,
        };
 
        return virtio_transport_send_pkt_info(vsk, &info);
@@ -560,6 +564,7 @@ virtio_transport_stream_enqueue(struct vsock_sock *vsk,
                .type = VIRTIO_VSOCK_TYPE_STREAM,
                .msg = msg,
                .pkt_len = len,
+               .vsk = vsk,
        };
 
        return virtio_transport_send_pkt_info(vsk, &info);
@@ -581,6 +586,7 @@ static int virtio_transport_reset(struct vsock_sock *vsk,
                .op = VIRTIO_VSOCK_OP_RST,
                .type = VIRTIO_VSOCK_TYPE_STREAM,
                .reply = !!pkt,
+               .vsk = vsk,
        };
 
        /* Send RST only if the original pkt is not a RST pkt */
@@ -826,6 +832,7 @@ virtio_transport_send_response(struct vsock_sock *vsk,
                .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
                .remote_port = le32_to_cpu(pkt->hdr.src_port),
                .reply = true,
+               .vsk = vsk,
        };
 
        return virtio_transport_send_pkt_info(vsk, &info);