Skip to main content

Process and ProcessStartInfo extensions methods.

//
// Mannex - Extension methods for .NET
// Copyright (c) 2009 Atif Aziz. All rights reserved.
//
//  Author(s):
//
//      Atif Aziz, http://www.raboof.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

namespace ProcExtensions
{
    using System;
    using System.ComponentModel;
    using System.Diagnostics;
    using System.IO;
    using System.Threading;
    using System.Threading.Tasks;

    #region - Internal Extensions -

    /// <summary>
    /// Extension methods for <see cref="WaitHandle"/>.
    /// </summary>
    internal static class WaitHandleExtensions
    {
        /// <summary>
        /// Asynchronously and indefinitely waits for a
        /// <see cref="WaitHandle"/> to become signaled.
        /// </summary>
        public static Task<bool> WaitOneAsync(this WaitHandle handle)
        {
            return WaitOneAsync(handle, null, CancellationToken.None);
        }

        /// <summary>
        /// Asynchronously waits for a <see cref="WaitHandle"/> to become
        /// signaled. An additional parameter specifies a time-out for the
        /// wait to be satisfied.
        /// </summary>
        public static Task<bool> WaitOneAsync(this WaitHandle handle, TimeSpan? timeout)
        {
            return WaitOneAsync(handle, timeout, CancellationToken.None);
        }

        /// <summary>
        /// Asynchronously and indefinitely waits for a
        /// <see cref="WaitHandle"/> to become signaled. An additional
        /// parameter specifies a <see cref="CancellationToken"/> to be used
        /// for cancelling the wait.
        /// </summary>
        public static Task<bool> WaitOneAsync(this WaitHandle handle, CancellationToken cancellationToken)
        {
            return WaitOneAsync(handle, null, cancellationToken);
        }

        /// <summary>
        /// Asynchronously waits for a <see cref="WaitHandle"/> to become
        /// signaled. Additional parameters specify a time-out for the wait as
        /// well as a <see cref="CancellationToken"/> to be used for cancelling
        /// the wait.
        /// </summary>
        public static Task<bool> WaitOneAsync(this WaitHandle handle, TimeSpan? timeout, CancellationToken cancellationToken)
        {
            if (handle == null) throw new ArgumentNullException(nameof(handle));

            cancellationToken.ThrowIfCancellationRequested();

            var tcs = new TaskCompletionSource<bool>();

            var registeredWaitHandles = new[] {default(RegisteredWaitHandle)};
            var registeredWaitHandle = registeredWaitHandles[0] = ThreadPool.RegisterWaitForSingleObject(handle,
                (_, timedOut) =>
                {
                    if (tcs.TrySetResult(!timedOut))
                        registeredWaitHandles[0].Unregister(null);
                },
                null, timeout.ToTimeout(), executeOnlyOnce: true);

            try
            {
                if (cancellationToken.CanBeCanceled)
                {
                    var cancellationTokenRegistrations = new[] {default(CancellationTokenRegistration)};
                    cancellationTokenRegistrations[0] = cancellationToken.Register(() =>
                    {
                        if (tcs.TrySetCanceled())
                            registeredWaitHandles[0].Unregister(null);
                        cancellationTokenRegistrations[0].Dispose();
                    });
                }

                registeredWaitHandle = null; // safe to relinquish ownership
                return tcs.Task;
            }
            finally
            {
                registeredWaitHandle?.Unregister(null);
            }
        }
    }

    /// <summary>
    /// Extension methods for <see cref="TimeSpan"/>.
    /// </summary>
    internal static class TimeSpanExtensions
    {
        /// <summary>
        /// Converts <see cref="TimeSpan"/> to milliseconds as expected by
        /// most of the <see cref="System.Threading"/> API.
        /// </summary>
        public static int ToTimeout(this TimeSpan timeout)
        {
            return (int) timeout.TotalMilliseconds;
        }

        /// <summary>
        /// Converts <see cref="TimeSpan"/> to milliseconds as expected by
        /// most of the <see cref="System.Threading"/> API. If the the
        /// <see cref="TimeSpan"/> value is <c>null</c> then the result is
        /// same as <see cref="Timeout.Infinite"/>.
        /// </summary>
        public static int ToTimeout(this TimeSpan? timeout)
        {
            return timeout?.ToTimeout() ?? Timeout.Infinite;
        }
    }

    #endregion

    #region - ProcessExtensions -

    /// <summary>
    /// Extension methods for <see cref="Process"/>.
    /// </summary>
    public static class ProcessExtensions
    {
        /// <summary>
        /// Attempts to kill the process identified by the <see cref="Process"/>
        /// object and returns <c>null</c> on success otherwise the error
        /// that occurred in the attempt.
        /// </summary>
        public static Exception TryKill(this Process process)
        {
            if (process == null) throw new ArgumentNullException(nameof(process));

            try
            {
                process.Kill();
                return null;
            }
            catch (InvalidOperationException e)
            {
                // Occurs when:
                // - process has already exited.
                // - no process is associated with this Process object.
                return e;
            }
            catch (Win32Exception e)
            {
                // Occurs when:
                // - associated process could not be terminated.
                // - process is terminating.
                // - associated process is a Win16 executable.
                return e;
            }
        }

        /// <summary>
        /// Instructs the <see cref="Process"/> component to wait the specified
        /// amount of time for the associated process to exit. If the specified
        /// time-out period is <c>null</c> then the wait is indefinite.
        /// </summary>
        public static bool WaitForExit(this Process process, TimeSpan? timeout)
        {
            if (process == null) throw new ArgumentNullException(nameof(process));

            return timeout.HasValue
                ? process.WaitForExit((int) timeout.Value.TotalMilliseconds)
                : process.WaitForExit(-1);
        }

        /// <summary>
        /// Begins asynchronous read operations on the re-directed <see cref="Process.StandardOutput"/>
        /// and <see cref="Process.StandardError"/> of the application.
        /// Each line on either is written to a respective <see cref="TextWriter"/>.
        /// </summary>
        /// <returns>
        /// Returns an action that can be used to wait on outputs to drain.
        /// </returns>
        public static Func<TimeSpan?, bool> BeginReadLine(this Process process, TextWriter output, TextWriter error = null)
        {
            if (process == null) throw new ArgumentNullException(nameof(process));

            return BeginReadLine(process, (output ?? TextWriter.Null).WriteLine,
                (error ?? TextWriter.Null).WriteLine);
        }

        /// <summary>
        /// Begins asynchronous read operations on the re-directed <see cref="Process.StandardOutput"/>
        /// and <see cref="Process.StandardError"/> of the application. Each line on the standard output
        /// is sent to a callback.
        /// </summary>
        /// <returns>
        /// Returns an action that can be used to wait on outputs to drain.
        /// </returns>
        public static Func<TimeSpan?, bool> BeginReadLine(this Process process, Action<string> output)
        {
            return BeginReadLine(process, output, null);
        }

        /// <summary>
        /// Begins asynchronous read operations on the re-directed
        /// <see cref="Process.StandardOutput"/> and
        /// <see cref="Process.StandardError"/> of the application. Each line
        /// on either is sent to a respective callback.
        /// </summary>
        /// <returns>
        /// Returns an action that can be used to wait on outputs to drain.
        /// </returns>
        public static Func<TimeSpan?, bool> BeginReadLine(this Process process, Action<string> output, Action<string> error)
        {
            if (process == null) throw new ArgumentNullException(nameof(process));

            var e = BeginReadLineImpl(process, output ?? TextWriter.Null.WriteLine,
                error ?? TextWriter.Null.WriteLine);

            return timeout => e.WaitOne(timeout.ToTimeout());
        }

        private static ManualResetEvent BeginReadLineImpl(Process process, Action<string> output, Action<string> error)
        {
            var done = new ManualResetEvent(false);
            var pending = 2;
            var onEof = new Action(() =>
            {
                if (Interlocked.Decrement(ref pending) == 0) done.Set();
            });

            process.OutputDataReceived += OnDataReceived(output, onEof);
            process.BeginOutputReadLine();

            process.ErrorDataReceived += OnDataReceived(error, onEof);
            process.BeginErrorReadLine();

            return done;
        }

        /// <summary>
        /// Begins asynchronous read operations on the re-directed
        /// <see cref="Process.StandardOutput"/> and
        /// <see cref="Process.StandardError"/> of the application. Each line
        /// on either is sent to a respective callback.
        /// </summary>
        /// <returns>
        /// Returns an action that can be used to asynchronously wait on
        /// outputs to drain.
        /// </returns>
        public static Func<TimeSpan?, Task<bool>> BeginReadLineAsync(this Process process, Action<string> output, Action<string> error)
        {
            var e = BeginReadLineImpl(process, output ?? TextWriter.Null.WriteLine,
                error ?? TextWriter.Null.WriteLine);
            return timeout => e.WaitOneAsync(timeout);
        }

        private static DataReceivedEventHandler OnDataReceived(
            Action<string> line, Action eof)
        {
            return (sender, e) =>
            {
                if (e.Data != null)
                    line(e.Data);
                else
                    eof();
            };
        }

        /// <summary>
        /// Creates <see cref="Task"/> that completes when the process exits
        /// with an exit code of zero and throws an <see cref="Exception"/>
        /// otherwise.
        /// </summary>
        public static Task AsTask(this Process process)
        {
            return AsTask(process, p => new Exception(string.Format("Process exited with the non-zero code {0}.", p.ExitCode)));
        }

        /// <summary>
        /// Creates <see cref="Task"/> that completes when the process exits
        /// with an exit code of zero and throws an <see cref="Exception"/>
        /// otherwise. An additional parameter enables a function to
        /// customize the <see cref="Exception"/> object thrown.
        /// </summary>
        public static Task AsTask(this Process process, Func<Process, Exception> errorSelector)
        {
            return process.AsTask(true, p => p.ExitCode != 0 ? errorSelector(p) : null,
                e => e, _ => (object) null);
        }

        /// <summary>
        /// Creates <see cref="Task"/> that completes when the process exits.
        /// Additional parameters specify how to project the results from
        /// the execution of the process as a result or error for the task.
        /// </summary>
        /// <remarks>
        /// If <paramref name="errorSelector"/> return <c>null</c> then the task
        /// is considered to have succeeded and <paramref name="resultSelector"/>
        /// determines its result. If <paramref name="errorSelector"/> returns
        /// an instance of <see cref="Exception"/> then the task is
        /// considered to have failed with that exception.
        /// </remarks>
        public static Task<TResult> AsTask<T, TResult>(this Process process, bool dispose,
            Func<Process, T> selector,
            Func<T, Exception> errorSelector,
            Func<T, TResult> resultSelector)
        {
            if (process == null) throw new ArgumentNullException(nameof(process));
            if (selector == null) throw new ArgumentNullException(nameof(selector));
            if (errorSelector == null) throw new ArgumentNullException(nameof(errorSelector));
            if (resultSelector == null) throw new ArgumentNullException(nameof(resultSelector));

            var tcs = new TaskCompletionSource<TResult>();

            if (process.HasExited)
            {
                OnExit(process, dispose, selector, errorSelector, resultSelector, tcs);
            }
            else
            {
                process.EnableRaisingEvents = true;
                process.Exited += delegate { OnExit(process, dispose, selector, errorSelector, resultSelector, tcs); };
            }

            return tcs.Task;
        }

        private static void OnExit<T, TResult>(Process process, bool dispose,
            Func<Process, T> selector,
            Func<T, Exception> errorSelector,
            Func<T, TResult> resultSelector,
            TaskCompletionSource<TResult> tcs)
        {
            var capture = selector(process);
            if (dispose)
                process.Dispose();
            var e = errorSelector(capture);
            if (e != null)
                tcs.SetException(e);
            else
                tcs.TrySetResult(resultSelector(capture));
        }
    }

    #endregion

    #region ProcessStartInfoExtensions

    /// <summary>
    /// Extension methods for the <see cref="ProcessStartInfo"/>.
    /// </summary>
    public static class ProcessStartInfoExtensions
    {
        /// <summary>
        /// Starts the process and waits for it to complete asynchronously.
        /// </summary>
        public static Task<Process> StartAsync(this ProcessStartInfo startInfo, CancellationToken cancellationToken = new CancellationToken())
        {
            return StartAsync(startInfo, null, null, cancellationToken);
        }

        /// <summary>
        /// Starts the process and waits for it to complete asynchronously.
        /// </summary>
        public static Task<T> StartAsync<T>(this ProcessStartInfo startInfo, Func<Process, string, T> selector)
        {
            return StartAsync(startInfo, CancellationToken.None, selector);
        }

        /// <summary>
        /// Starts the process and waits for it to complete asynchronously.
        /// </summary>
        public static Task<T> StartAsync<T>(this ProcessStartInfo startInfo, CancellationToken cancellationToken, Func<Process, string, T> selector)
        {
            return StartAsync(startInfo, false, cancellationToken, (p, stdout, _) => selector(p, stdout));
        }

        /// <summary>
        /// Starts the process and waits for it to complete asynchronously.
        /// </summary>
        public static Task<T> StartAsync<T>(this ProcessStartInfo startInfo, Func<Process, string, string, T> selector)
        {
            return StartAsync(startInfo, CancellationToken.None, selector);
        }

        /// <summary>
        /// Starts the process and waits for it to complete asynchronously.
        /// </summary>
        public static Task<T> StartAsync<T>(this ProcessStartInfo startInfo, CancellationToken cancellationToken, Func<Process, string, string, T> selector)
        {
            return StartAsync(startInfo, true, cancellationToken, selector);
        }

        private static Task<T> StartAsync<T>(ProcessStartInfo startInfo, bool captureStandardError, CancellationToken cancellationToken, Func<Process, string, string, T> selector)
        {
            if (selector == null) throw new ArgumentNullException(nameof(selector));
            var stdout = new StringWriter();
            var stderr = captureStandardError ? new StringWriter() : null;
            var task = StartAsync(startInfo, stdout, stderr, cancellationToken);
            return task.ContinueWith(t => selector(t.Result,
                    stdout.ToString(),
                    stderr?.ToString()),
                cancellationToken,
                TaskContinuationOptions.ExecuteSynchronously,
                TaskScheduler.Current);
        }

        /// <summary>
        /// Starts the process and waits for it to complete asynchronously.
        /// </summary>
        public static Task<Process> StartAsync(this ProcessStartInfo startInfo, TextWriter stdout)
        {
            return StartAsync(startInfo, stdout, CancellationToken.None);
        }

        /// <summary>
        /// Starts the process and waits for it to complete asynchronously.
        /// </summary>
        public static Task<Process> StartAsync(this ProcessStartInfo startInfo,
            TextWriter stdout, CancellationToken cancellationToken)
        {
            return startInfo.StartAsync(stdout, null, cancellationToken);
        }

        /// <summary>
        /// Starts the process and waits for it to complete asynchronously.
        /// </summary>
        public static Task<Process> StartAsync(this ProcessStartInfo startInfo,
            TextWriter stdout, TextWriter stderr)
        {
            return StartAsync(startInfo, stdout, stderr, CancellationToken.None);
        }

        /// <summary>
        /// Starts the process and waits for it to complete asynchronously.
        /// </summary>
        public static Task<Process> StartAsync(this ProcessStartInfo startInfo,
            TextWriter stdout, TextWriter stderr,
            CancellationToken cancellationToken)
        {
            if (startInfo == null) throw new ArgumentNullException(nameof(startInfo));

            cancellationToken.ThrowIfCancellationRequested();

            var tcs = new TaskCompletionSource<Process>();
            Process ownedProcess = null;
            try
            {
                var capturingOutput = stdout != null || stderr != null;
                if (capturingOutput)
                    startInfo.RedirectStandardOutput = startInfo.RedirectStandardError = true;

                var process = ownedProcess = Process.Start(startInfo);
                if (process == null)
                    throw new Exception("No process available for completion.");

                if (cancellationToken.CanBeCanceled)
                    cancellationToken.Register(() =>
                    {
                        if (capturingOutput)
                        {
                            process.CancelOutputRead();
                            process.CancelErrorRead();
                        }

                        process.TryKill();
                    });

                process.EnableRaisingEvents = true;

                var drain = capturingOutput
                    ? process.BeginReadLine(stdout, stderr)
                    : _ => true;

                process.Exited += delegate
                {
                    while (!drain(TimeSpan.FromSeconds(1)))
                    {
                        if (cancellationToken.IsCancellationRequested)
                        {
                            tcs.TrySetCanceled();
                            return;
                        }
                    }

                    tcs.TrySetResult(process);
                };

                ownedProcess = null;
            }
            finally
            {
                ownedProcess?.Dispose();
            }

            return tcs.Task;
        }
    }

    #endregion
}