Skip to main content

C# extension methods for the Microsoft.VisualStudio.TestTools.UnitTesting class.

using System;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace Extensions
{
    public static class TestExtensions
    {
        public static void ShouldEqual<T>(this T val1, T val2)
        {
            Assert.AreEqual(val1, val2);
        }

        public static void ShouldNotEqual<T>(this T val1, T val2)
        {
            Assert.AreNotEqual(val1, val2);
        }

        public static void ShouldBeNull<T>(this T val)
        {
            Assert.IsNull(val);
        }

        public static void ShouldNotBeNull<T>(this T val)
        {
            Assert.IsNotNull(val);
        }

        public static void ShouldBeGreaterThan<T>(this T val1, T val2)
        {
            Assert.IsTrue(val1.GreaterThan(val2));
        }

        public static void ShouldBeGreaterThanOrEqualTo<T>(this T val1, T val2)
        {
            Assert.IsTrue(val1.GreaterThan(val2));
        }

        public static void ShouldBeLessThan<T>(this T val1, T val2)
        {
            Assert.IsTrue(val1.LessThan(val2));
        }

        public static void ShouldBeLessThanOrEqualTo<T>(this T val1, T val2)
        {
            Assert.IsTrue(val1.LessThan(val2) || val1.EqualTo(val2));
        }

        private static bool EqualTo<T>(this T val1, T val2)
        {
            return Compare(val1, val2) == 0;
        }

        private static bool GreaterThan<T>(this T val1, T val2)
        {
            return Compare(val1, val2) == 1;
        }

        private static bool LessThan<T>(this T val1, T val2)
        {
            return Compare(val1, val2) == -1;
        }

        public static int Compare(object tVal, object other)
        {
            var mType = tVal.GetType();

            if (mType == typeof (int))
            {
                return new IntCompare(tVal).CompareTo(other);
            }

            if (mType == typeof (decimal))
            {
                return new DecimalCompare(tVal).CompareTo(other);
            }

            if (mType == typeof (char))
            {
                return new CharCompare(tVal).CompareTo(other);
            }

            if (mType == typeof (string))
            {
                return string.CompareOrdinal((string) tVal, (string) other);
            }

            if (mType == typeof (DateTime))
            {
                return new DateTimeCompare(tVal).CompareTo(other);
            }

            throw new Exception("Unknown type!");
        }
    }

    internal abstract class CompareBase<T>
    {
        protected T tVal;

        protected CompareBase(object val)
        {
            tVal = (T) val;
        }

        public abstract int CompareTo(object other);
    }

    internal class IntCompare : CompareBase<int>
    {
        internal IntCompare(object val) : base(val)
        {
        }

        public override int CompareTo(object other)
        {
            if (tVal == (int) other)
            {
                return 0;
            }

            if (tVal > (int) other)
            {
                return 1;
            }

            return -1;
        }
    }

    internal class DecimalCompare : CompareBase<decimal>
    {
        internal DecimalCompare(object val) : base(val)
        {
        }

        public override int CompareTo(object other)
        {
            if (tVal == (decimal) other)
            {
                return 0;
            }

            if (tVal > (decimal) other)
            {
                return 1;
            }

            return -1;
        }
    }

    internal class DateTimeCompare : CompareBase<DateTime>
    {
        internal DateTimeCompare(object val) : base(val)
        {
        }

        public override int CompareTo(object other)
        {
            if (tVal == (DateTime) other)
            {
                return 0;
            }

            if (tVal > (DateTime) other)
            {
                return 1;
            }

            return -1;
        }
    }

    internal class CharCompare : CompareBase<char>
    {
        internal CharCompare(object val) : base(val)
        {
        }

        public override int CompareTo(object other)
        {
            if (tVal == (char) other)
            {
                return 0;
            }

            if (tVal > (char) other)
            {
                return 1;
            }

            return -1;
        }
    }
}