有没有办法限制任务并行库使用的线程?

时间:2010-08-23 14:20:55

标签: multithreading unit-testing concurrency .net-4.0 task-parallel-library

我正在使用TPL,但我发现使用它的棘手的单元测试代码。

我试图不introduce a wrapper因为我觉得它可能会引发问题。

我知道你可以在TPL中设置处理器关联,但真正好的是设置一个线程最大值(可能是每个app-domain)。因此,当将线程最大值设置为1时,TPL将被强制使用它所使用的任何线程。

你怎么看?这是可能的(我很确定它不是),应该可能吗?

编辑:这是一个例子

public class Foo
{
    public Foo( )
    {
        Task.Factory.StartNew( () => somethingLong( ) )
            .ContinueWith( a => Bar = 1 ) ;
    }
}

[Test] public void Foo_should_set_Bar_to_1( )
{
    Assert.Equal(1, new Foo( ).Bar ) ;
}

除非我引入延迟,否则测试可能不会通过。我希望有Task.MaximumThreads=1这样的东西,以便TPL能够连续运行。

3 个答案:

答案 0 :(得分:4)

您可以创建自己的TaskScheduler类派生自TaskScheduler,并将其传递到TaskFactory。现在,您可以创建针对 调度程序创建的任何Task个对象。

无需将其设置为使用一个线程。

然后,在断言之前,只需在其上调用Dispose()即可。在内部,如果你按照那里的样本来编写TaskScheduler: -

,它会做这样的事情
public void Dispose()
{
    if (tasks != null)
    {
        tasks.CompleteAdding();

        foreach (var thread in threads) thread.Join();

        tasks.Dispose();
        tasks = null;
    }
}

这将保证所有任务都已运行。现在,您可以继续使用Asserts。

如果你想在事情发生时检查进度,你也可以在任务运行后使用ContinueWith(...)添加断言。

答案 1 :(得分:2)

对于lambda重码的可测试性而言,这与TPL相关的问题更多。 Hightechrider的建议很好,但基本上你的测试仍在测试TPL,就像你的代码一样。当第一个任务结束并且ContinueWith开始下一个任务时,你真的不需要测试它。

如果lambdas中的代码非常大,那么将其拉出到一个具有明确定义参数的更可测试的方法可能会导致更容易阅读和更可测试的代码。你可以围绕它编写单元测试。在可能的情况下,我尝试限制或删除单元测试中的并行性。

话虽如此,我想看看调度程序方法是否有效。这是使用http://code.msdn.microsoft.com/ParExtSamples

中修改过的StaTaskScheduler的实现
    using System;
    using System.Collections.Concurrent;
    using System.Collections.Generic;
    using System.Linq;
    using System.Threading;
    using System.Threading.Tasks;
    using Xunit;

    namespace Example
    {
      public class Foo
      {
        private TaskScheduler _scheduler;

    public int Bar { get; set; }

    private void SomethingLong()
    {
      Thread.SpinWait(10000);
    }

    public Foo()
      : this(TaskScheduler.Default)
    {
    }

    public Foo(TaskScheduler scheduler)
    {
      _scheduler = scheduler;
    }

    public void DoWork()
    {
      var factory = new TaskFactory(_scheduler);

      factory.StartNew(() => SomethingLong())
      .ContinueWith(a => Bar = 1, _scheduler);
    }
  }

  public class FooTests
  {
    [Fact]
    public void Foo_should_set_Bar_to_1()
    {
      var sch = new StaTaskScheduler(3);
      var target = new Foo(sch);
      target.DoWork();

      sch.Dispose();
      Assert.Equal(1, target.Bar);
    }
  }

  public sealed class StaTaskScheduler : TaskScheduler, IDisposable
  {
    /// <summary>Stores the queued tasks to be executed by our pool of STA threads.</summary>
    private BlockingCollection<Task> _tasks;
    /// <summary>The STA threads used by the scheduler.</summary>
    private readonly List<Thread> _threads;

    /// <summary>Initializes a new instance of the StaTaskScheduler class with the specified concurrency level.</summary>
    /// <param name="numberOfThreads">The number of threads that should be created and used by this scheduler.</param>
    public StaTaskScheduler(int numberOfThreads)
    {
      // Validate arguments
      if (numberOfThreads < 1) throw new ArgumentOutOfRangeException("concurrencyLevel");

      // Initialize the tasks collection
      _tasks = new BlockingCollection<Task>();

      // Create the threads to be used by this scheduler
      _threads = Enumerable.Range(0, numberOfThreads).Select(i =>
      {
        var thread = new Thread(() =>
        {
          // Continually get the next task and try to execute it.
          // This will continue until the scheduler is disposed and no more tasks remain.
          foreach (var t in _tasks.GetConsumingEnumerable())
          {
            TryExecuteTask(t);
          }
        });
        thread.IsBackground = true;
        // NO STA REQUIREMENT!
        // thread.SetApartmentState(ApartmentState.STA);
        return thread;
      }).ToList();

      // Start all of the threads
      _threads.ForEach(t => t.Start());
    }

    /// <summary>Queues a Task to be executed by this scheduler.</summary>
    /// <param name="task">The task to be executed.</param>
    protected override void QueueTask(Task task)
    {
      // Push it into the blocking collection of tasks
      _tasks.Add(task);
    }

    /// <summary>Provides a list of the scheduled tasks for the debugger to consume.</summary>
    /// <returns>An enumerable of all tasks currently scheduled.</returns>
    protected override IEnumerable<Task> GetScheduledTasks()
    {
      // Serialize the contents of the blocking collection of tasks for the debugger
      return _tasks.ToArray();
    }

    /// <summary>Determines whether a Task may be inlined.</summary>
    /// <param name="task">The task to be executed.</param>
    /// <param name="taskWasPreviouslyQueued">Whether the task was previously queued.</param>
    /// <returns>true if the task was successfully inlined; otherwise, false.</returns>
    protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
    {
      // Try to inline if the current thread is STA
      return
      Thread.CurrentThread.GetApartmentState() == ApartmentState.STA &&
      TryExecuteTask(task);
    }

    /// <summary>Gets the maximum concurrency level supported by this scheduler.</summary>
    public override int MaximumConcurrencyLevel
    {
      get { return _threads.Count; }
    }

    /// <summary>
    /// Cleans up the scheduler by indicating that no more tasks will be queued.
    /// This method blocks until all threads successfully shutdown.
    /// </summary>
    public void Dispose()
    {
      if (_tasks != null)
      {
        // Indicate that no new tasks will be coming in
        _tasks.CompleteAdding();

        // Wait for all threads to finish processing tasks
        foreach (var thread in _threads) thread.Join();

        // Cleanup
        _tasks.Dispose();
        _tasks = null;
      }
    }
  }
}

答案 2 :(得分:1)

如果您想摆脱重载构造函数的需要,可以将单元测试代码包装在Task.Factory.ContinueWhenAll(...)中。

public class Foo
{
    public Foo( )
    {
        Task.Factory.StartNew( () => somethingLong( ) )
            .ContinueWith( a => Bar = 1 ) ;
    }
}

[Test] public void Foo_should_set_Bar_to_1( )
{
    Foo foo;
    Task.Factory.ContinueWhenAll(
        new [] {
            new Task(() => {
                foo = new Foo();
            })
        },
        asserts => { 
            Assert.Equal(1, foo.Bar ) ;
        }
    ).Wait;
}

希望听到一些关于这种方法的反馈。