实现扩展方法WebRequest.GetResponseAsync并支持CancellationToken

时间:2013-07-05 17:54:19

标签: .net asynchronous httpwebrequest task-parallel-library cancellation

这里的想法很简单,但实施有一些有趣的细微差别。这是我想在 .NET 4 中实现的扩展方法的签名。

public static Task<WebResponse> GetResponseAsync(this WebRequest request, CancellationToken token);

这是我的初步实施。根据我的阅读,网页请求可能需要cancelled due to a timeout。除了该页面上描述的支持外,如果通过request.Abort()请求取消,我还需要正确致电CancellationToken

public static Task<WebResponse> GetResponseAsync(this WebRequest request, CancellationToken token)
{
    if (request == null)
        throw new ArgumentNullException("request");

    return Task.Factory.FromAsync<WebRequest, CancellationToken, WebResponse>(BeginGetResponse, request.EndGetResponse, request, token, null);
}

private static IAsyncResult BeginGetResponse(WebRequest request, CancellationToken token, AsyncCallback callback, object state)
{
    IAsyncResult asyncResult = request.BeginGetResponse(callback, state);
    if (!asyncResult.IsCompleted)
    {
        if (request.Timeout != Timeout.Infinite)
            ThreadPool.RegisterWaitForSingleObject(asyncResult.AsyncWaitHandle, WebRequestTimeoutCallback, request, request.Timeout, true);
        if (token != CancellationToken.None)
            ThreadPool.RegisterWaitForSingleObject(token.WaitHandle, WebRequestCancelledCallback, Tuple.Create(request, token), Timeout.Infinite, true);
    }

    return asyncResult;
}

private static void WebRequestTimeoutCallback(object state, bool timedOut)
{
    if (timedOut)
    {
        WebRequest request = state as WebRequest;
        if (request != null)
            request.Abort();
    }
}

private static void WebRequestCancelledCallback(object state, bool timedOut)
{
    Tuple<WebRequest, CancellationToken> data = state as Tuple<WebRequest, CancellationToken>;
    if (data != null && data.Item2.IsCancellationRequested)
    {
        data.Item1.Abort();
    }
}

我的问题既简单又具有挑战性。当与TPL一起使用时,此实现是否会按预期运行?

1 个答案:

答案 0 :(得分:6)

  

当与TPL一起使用时,此实现是否会按预期运行?

没有。

  1. 它不会将Task<T>结果标记为已取消,因此行为将不会完全符合预期。
  2. 如果超时,WebException报告的AggregateException中包含的Task.Exception状态为WebExceptionStatus.RequestCanceled。它应该是WebExceptionStatus.Timeout
  3. 我实际上建议使用TaskCompletionSource<T>来实现这一点。这允许您编写代码而无需使用自己的APM样式方法:

    public static Task<WebResponse> GetResponseAsync(this WebRequest request, CancellationToken token)
    {
        if (request == null)
            throw new ArgumentNullException("request");
    
        bool timeout = false;
        TaskCompletionSource<WebResponse> completionSource = new TaskCompletionSource<WebResponse>();
    
        AsyncCallback completedCallback =
            result =>
            {
                try
                {
                    completionSource.TrySetResult(request.EndGetResponse(result));
                }
                catch (WebException ex)
                {
                    if (timeout)
                        completionSource.TrySetException(new WebException("No response was received during the time-out period for a request.", WebExceptionStatus.Timeout));
                    else if (token.IsCancellationRequested)
                        completionSource.TrySetCanceled();
                    else
                        completionSource.TrySetException(ex);
                }
                catch (Exception ex)
                {
                    completionSource.TrySetException(ex);
                }
            };
    
        IAsyncResult asyncResult = request.BeginGetResponse(completedCallback, null);
        if (!asyncResult.IsCompleted)
        {
            if (request.Timeout != Timeout.Infinite)
            {
                WaitOrTimerCallback timedOutCallback =
                    (object state, bool timedOut) =>
                    {
                        if (timedOut)
                        {
                            timeout = true;
                            request.Abort();
                        }
                    };
    
                ThreadPool.RegisterWaitForSingleObject(asyncResult.AsyncWaitHandle, timedOutCallback, null, request.Timeout, true);
            }
    
            if (token != CancellationToken.None)
            {
                WaitOrTimerCallback cancelledCallback =
                    (object state, bool timedOut) =>
                    {
                        if (token.IsCancellationRequested)
                            request.Abort();
                    };
    
                ThreadPool.RegisterWaitForSingleObject(token.WaitHandle, cancelledCallback, null, Timeout.Infinite, true);
            }
        }
    
        return completionSource.Task;
    }
    

    这里的优点是您的Task<T>结果将按预期完全运行(将被标记为已取消,或者使用超时信息引发与同步版本相同的异常等)。这也避免了使用Task.Factory.FromAsync的开销,因为您已经处理了大部分相关的困难工作。


    280Z28附录

    这是一个单元测试,显示了上述方法的正确操作。

    [TestClass]
    public class AsyncWebRequestTests
    {
        [TestMethod]
        public void TestAsyncWebRequest()
        {
            Uri uri = new Uri("http://google.com");
            WebRequest request = HttpWebRequest.Create(uri);
            Task<WebResponse> response = request.GetResponseAsync();
            response.Wait();
        }
    
        [TestMethod]
        public void TestAsyncWebRequestTimeout()
        {
            Uri uri = new Uri("http://google.com");
            WebRequest request = HttpWebRequest.Create(uri);
            request.Timeout = 0;
            Task<WebResponse> response = request.GetResponseAsync();
            try
            {
                response.Wait();
                Assert.Fail("Expected an exception");
            }
            catch (AggregateException exception)
            {
                Assert.AreEqual(TaskStatus.Faulted, response.Status);
    
                ReadOnlyCollection<Exception> exceptions = exception.InnerExceptions;
                Assert.AreEqual(1, exceptions.Count);
                Assert.IsInstanceOfType(exceptions[0], typeof(WebException));
    
                WebException webException = (WebException)exceptions[0];
                Assert.AreEqual(WebExceptionStatus.Timeout, webException.Status);
            }
        }
    
        [TestMethod]
        public void TestAsyncWebRequestCancellation()
        {
            Uri uri = new Uri("http://google.com");
            WebRequest request = HttpWebRequest.Create(uri);
            CancellationTokenSource cancellationTokenSource = new CancellationTokenSource();
            Task<WebResponse> response = request.GetResponseAsync(cancellationTokenSource.Token);
            cancellationTokenSource.Cancel();
            try
            {
                response.Wait();
                Assert.Fail("Expected an exception");
            }
            catch (AggregateException exception)
            {
                Assert.AreEqual(TaskStatus.Canceled, response.Status);
    
                ReadOnlyCollection<Exception> exceptions = exception.InnerExceptions;
                Assert.AreEqual(1, exceptions.Count);
                Assert.IsInstanceOfType(exceptions[0], typeof(OperationCanceledException));
            }
        }
    
        [TestMethod]
        public void TestAsyncWebRequestError()
        {
            Uri uri = new Uri("http://google.com/fail");
            WebRequest request = HttpWebRequest.Create(uri);
            Task<WebResponse> response = request.GetResponseAsync();
            try
            {
                response.Wait();
                Assert.Fail("Expected an exception");
            }
            catch (AggregateException exception)
            {
                Assert.AreEqual(TaskStatus.Faulted, response.Status);
    
                ReadOnlyCollection<Exception> exceptions = exception.InnerExceptions;
                Assert.AreEqual(1, exceptions.Count);
                Assert.IsInstanceOfType(exceptions[0], typeof(WebException));
    
                WebException webException = (WebException)exceptions[0];
                Assert.AreEqual(HttpStatusCode.NotFound, ((HttpWebResponse)webException.Response).StatusCode);
            }
        }
    }