Skip to main content

A fluent retry policy builder supporting sync/async execution, selective exception handling, exponential backoff with jitter, cancellation, and per-attempt logging.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;

/// <summary>
/// A fluent retry policy builder supporting sync/async execution, selective exception handling,
/// exponential backoff with jitter, cancellation, and structured logging via ILogger.
/// </summary>
public class RetryPolicyBuilder
{
    private readonly HashSet<Type> _handledExceptions = new();
    private int _retryCount = 3;
    private int _baseDelayMs = 500;
    private ILogger? _logger;

    private static readonly Random _random = new();

    private RetryPolicyBuilder() { }

    /// <summary>
    /// Initializes a new retry policy builder instance.
    /// </summary>
    /// <returns>A new <see cref="RetryPolicyBuilder"/>.</returns>
    public static RetryPolicyBuilder Create() => new();

    /// <summary>
    /// Specifies the first exception type to handle.
    /// </summary>
    public RetryPolicyBuilder Handle<TException>() where TException : Exception
    {
        _handledExceptions.Add(typeof(TException));
        return this;
    }

    /// <summary>
    /// Specifies additional exception types to handle.
    /// </summary>
    public RetryPolicyBuilder And<TException>() where TException : Exception => Handle<TException>();

    /// <summary>
    /// Sets the number of retry attempts.
    /// </summary>
    public RetryPolicyBuilder WithRetryCount(int count)
    {
        _retryCount = count;
        return this;
    }

    /// <summary>
    /// Sets the base delay in milliseconds for exponential backoff.
    /// </summary>
    public RetryPolicyBuilder WithBaseDelay(int ms)
    {
        _baseDelayMs = ms;
        return this;
    }

    /// <summary>
    /// Sets a structured logger for retry diagnostics.
    /// </summary>
    /// <param name="logger">An instance of <see cref="ILogger"/>.</param>
    public RetryPolicyBuilder WithLogger(ILogger logger)
    {
        _logger = logger;
        return this;
    }

    /// <summary>
    /// Executes a synchronous action with retry logic.
    /// </summary>
    public void Execute(Action action, CancellationToken cancellationToken = default)
    {
        for (int attempt = 0; attempt < _retryCount; attempt++)
        {
            cancellationToken.ThrowIfCancellationRequested();

            try
            {
                action();
                return;
            }
            catch (Exception ex) when (ShouldHandle(ex))
            {
                _logger?.LogWarning(ex, "Attempt {Attempt} failed.", attempt + 1);

                if (attempt == _retryCount - 1) throw;

                DelayWithJitter(attempt, cancellationToken);
            }
        }
    }

    /// <summary>
    /// Executes a synchronous function with retry logic and returns a result.
    /// </summary>
    public T Execute<T>(Func<T> func, CancellationToken cancellationToken = default)
    {
        for (int attempt = 0; attempt < _retryCount; attempt++)
        {
            cancellationToken.ThrowIfCancellationRequested();

            try
            {
                return func();
            }
            catch (Exception ex) when (ShouldHandle(ex))
            {
                _logger?.LogWarning(ex, "Attempt {Attempt} failed.", attempt + 1);

                if (attempt == _retryCount - 1) throw;

                DelayWithJitter(attempt, cancellationToken);
            }
        }

        throw new Exception("Retry failed.");
    }

    /// <summary>
    /// Executes an asynchronous action with retry logic.
    /// </summary>
    public async Task ExecuteAsync(Func<Task> action, CancellationToken cancellationToken = default)
    {
        for (int attempt = 0; attempt < _retryCount; attempt++)
        {
            cancellationToken.ThrowIfCancellationRequested();

            try
            {
                await action();
                return;
            }
            catch (Exception ex) when (ShouldHandle(ex))
            {
                _logger?.LogWarning(ex, "Attempt {Attempt} failed.", attempt + 1);

                if (attempt == _retryCount - 1) throw;

                await DelayWithJitterAsync(attempt, cancellationToken);
            }
        }
    }

    /// <summary>
    /// Executes an asynchronous function with retry logic and returns a result.
    /// </summary>
    public async Task<T> ExecuteAsync<T>(Func<Task<T>> func, CancellationToken cancellationToken = default)
    {
        for (int attempt = 0; attempt < _retryCount; attempt++)
        {
            cancellationToken.ThrowIfCancellationRequested();

            try
            {
                return await func();
            }
            catch (Exception ex) when (ShouldHandle(ex))
            {
                _logger?.LogWarning(ex, "Attempt {Attempt} failed.", attempt + 1);

                if (attempt == _retryCount - 1) throw;

                await DelayWithJitterAsync(attempt, cancellationToken);
            }
        }

        throw new Exception("Retry failed.");
    }

    // PRIVATE HELPERS

    private void DelayWithJitter(int attempt, CancellationToken cancellationToken)
    {
        int maxDelay = _baseDelayMs * (int)Math.Pow(2, attempt);
        int jitter = _random.Next(0, maxDelay + 1);
        Task.Delay(jitter, cancellationToken).GetAwaiter().GetResult();
    }

    private async Task DelayWithJitterAsync(int attempt, CancellationToken cancellationToken)
    {
        int maxDelay = _baseDelayMs * (int)Math.Pow(2, attempt);
        int jitter = _random.Next(0, maxDelay + 1);
        await Task.Delay(jitter, cancellationToken);
    }

    private bool ShouldHandle(Exception ex)
    {
        return !_handledExceptions.Any() || _handledExceptions.Contains(ex.GetType());
    }
}

// -------------------------------------------------------------------------
// Example usage
// -------------------------------------------------------------------------

//
// Setup logging
// -------------

using Microsoft.Extensions.Logging;

var loggerFactory = LoggerFactory.Create(builder =>
{
    builder.AddConsole();
});

var logger = loggerFactory.CreateLogger("RetryPolicy");


//
// 1. Execute(Action action, CancellationToken)
// --------------------------------------------

var cts = new CancellationTokenSource();
cts.CancelAfter(TimeSpan.FromSeconds(5));

RetryPolicyBuilder.Create()
    .Handle<InvalidOperationException>()
    .WithRetryCount(3)
    .WithBaseDelay(200)
    .WithLogger(logger)
    .Execute(() =>
    {
        Console.WriteLine("Trying...");
        if (new Random().Next(2) == 0)
            throw new InvalidOperationException("Random fail!");
    }, cts.Token);

//
// 2. Execute<T>(Func<T> func, CancellationToken)
// ----------------------------------------------

var result = RetryPolicyBuilder.Create()
    .Handle<Exception>()
    .WithRetryCount(4)
    .WithLogger(logger)
    .Execute(() =>
    {
        if (new Random().Next(3) != 0)
            throw new Exception("Try again!");

        return 99;
    });

Console.WriteLine($"Final result: {result}");

//
// 3. ExecuteAsync(Func<Task> action, CancellationToken)
// -----------------------------------------------------

var cts = new CancellationTokenSource();
cts.CancelAfter(4000); // Cancel if it takes too long

await RetryPolicyBuilder.Create()
    .Handle<TimeoutException>()
    .WithRetryCount(5)
    .WithBaseDelay(300)
    .WithLogger(logger)
    .ExecuteAsync(async () =>
    {
        Console.WriteLine("Calling remote service...");
        await Task.Delay(100); // Simulate latency
        if (new Random().Next(2) == 0)
            throw new TimeoutException("Flaky API!");
    }, cts.Token);

//
// 4. ExecuteAsync<T>(Func<Task<T>> func, CancellationToken)
// ---------------------------------------------------------

var result = await RetryPolicyBuilder.Create()
    .Handle<Exception>()
    .WithRetryCount(3)
    .WithBaseDelay(500)
    .WithLogger(logger)
    .ExecuteAsync(async () =>
    {
        await Task.Delay(100);
        if (new Random().Next(3) != 0)
            throw new Exception("Still failing...");
        return "Fetched from server!";
    });

Console.WriteLine($"Got result: {result}");

//
// 5. Combine With External Cancellation
// --------------------------------------------

using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));

try
{
    int value = await RetryPolicyBuilder.Create()
        .Handle<Exception>()
        .WithRetryCount(10)
        .WithLogger(logger)
        .ExecuteAsync(async () =>
        {
            await Task.Delay(300);
            throw new Exception("Fail hard!");
        }, cts.Token);
}
catch (OperationCanceledException)
{
    Console.WriteLine("Retry cancelled due to timeout.");
}