Skip to main content

A C# base class for a disconnected Generic Repository.

using System;
using System.Collections.Generic;
using System.Data.Entity;
using System.Linq;
using System.Linq.Expressions;
using SharedKernel.Data;

namespace DisconnectedGenericRepository
{
    public class GenericRepository<TEntity> where TEntity : class
    {
        internal DbContext _context;
        internal DbSet<TEntity> _dbSet;

        public GenericRepository(DbContext context)
        {
            _context = context;
            _dbSet = context.Set<TEntity>();
        }

        public IEnumerable<TEntity> All()
        {
            return _dbSet.AsNoTracking().ToList();
        }

        public IEnumerable<TEntity> AllInclude (params Expression<Func<TEntity, object>>[] includeProperties)
        {
            return GetAllIncluding(includeProperties).ToList();
        }

        public IEnumerable<TEntity> FindByInclude (Expression<Func<TEntity, bool>> predicate, params Expression<Func<TEntity, object>>[] includeProperties)
        {
            var query = GetAllIncluding(includeProperties);
            IEnumerable<TEntity> results = query.Where(predicate).ToList();
            return results;
        }

        private IQueryable<TEntity> GetAllIncluding (params Expression<Func<TEntity, object>>[] includeProperties)
        {
            IQueryable<TEntity> queryable = _dbSet.AsNoTracking();

            return includeProperties.Aggregate (queryable, (current, includeProperty) => current.Include(includeProperty));
        }

        public IEnumerable<TEntity> FindBy(Expression<Func<TEntity, bool>> predicate)
        {
            IEnumerable<TEntity> results = _dbSet.AsNoTracking().Where(predicate).ToList();
            return results;
        }

        public TEntity FindByKey(int id)
        {
            Expression<Func<TEntity, bool>> lambda = Utilities.BuildLambdaForFindByKey<TEntity>(id);
            return _dbSet.AsNoTracking().SingleOrDefault(lambda);
        }

        public void Insert(TEntity entity)
        {
            _dbSet.Add(entity);
            _context.SaveChanges();
        }

        public void Update(TEntity entity)
        {
            _dbSet.Attach(entity);
            _context.Entry(entity).State = EntityState.Modified;
            _context.SaveChanges();
        }

        public void Delete(int id)
        {
            var entity = FindByKey(id);
            _dbSet.Remove(entity);
            _context.SaveChanges();
        }
    }
}


//
// Example Usage (via MVC Controller)
//

using MvcSalesApp.Domain;
using DisconnectedGenericRepository;
using System.Data.Entity;
using System.Linq;
using System.Net;
using System.Web.Mvc;

namespace MvcSalesApp.Controllers
{
    public class CustomersController : Controller
    {
        private GenericRepository<Customer> repo;

        public CustomersController(GenericRepository<Customer> _repo)
        {
            repo = _repo;
        }

        public ActionResult Index()
        {
            return View(repo.All());
        }

        public ActionResult Details(int? id)
        {
            if (id == null)
            {
                return new HttpStatusCodeResult((int)HttpStatusCode.BadRequest);
            }
            Customer customer = repo.FindByKey(id.Value);
            if (customer == null)
            {
                return HttpNotFound();
            }
            return View(customer);
        }

        public ActionResult Create()
        {
            return View();
        }

        [HttpPost]
        [ValidateAntiForgeryToken]
        public ActionResult Create([Bind(Include = "CustomerId,FirstName,LastName,DateOfBirth")] Customer customer)
        {
            if (ModelState.IsValid)
            {
                repo.Insert(customer);
                return RedirectToAction("Index");
            }
            return View(customer);
        }

        public ActionResult Edit(int? id)
        {
            if (id == null)
            {
                return new HttpStatusCodeResult((int)HttpStatusCode.BadRequest);
            }
            Customer customer = repo.FindByKey(id.Value);
            if (customer == null)
            {
                return HttpNotFound();
            }
            return View(customer);
        }

        [HttpPost]
        [ValidateAntiForgeryToken]
        public ActionResult Edit([Bind(Include = "CustomerId,FirstName,LastName,DateOfBirth")] Customer customer)
        {
            if (ModelState.IsValid)
            {
                repo.Update(customer);
                return RedirectToAction("Index");
            }
            return View(customer);
        }

        public ActionResult Delete(int? id)
        {
            if (id == null)
            {
                return new HttpStatusCodeResult((int)HttpStatusCode.BadRequest);
            }
            Customer customer = repo.FindByKey(id.Value);
            if (customer == null)
            {
                return HttpNotFound();
            }
            return View(customer);
        }

        [HttpPost, ActionName("Delete")]
        [ValidateAntiForgeryToken]
        public ActionResult DeleteConfirmed(int id)
        {
            repo.Delete(id);

            return RedirectToAction("Index");
        }
    }
}

//
// Generic Repository Integration Tests
//

using DisconnectedGenericRepository;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MvcSalesApp.Data;
using MvcSalesApp.Domain;
using System;
using System.Data.Entity;
using System.Diagnostics;
using System.Linq;
using System.Text;

namespace MvcSalesApp.Tests.Data
{
    [TestClass]
    public class GenericRepositoryIntegrationTests
    {
        private StringBuilder _logBuilder = new StringBuilder();
        private string _log;
        private OrderSystemContext _context;
        private GenericRepository<Customer> _customerRepository;

        public GenericRepositoryIntegrationTests()
        {
            Database.SetInitializer(new NullDatabaseInitializer<OrderSystemContext>());
            _context = new OrderSystemContext();
            _customerRepository = new GenericRepository<Customer>(_context);
            SetupLogging();
        }

        [TestMethod]
        public void CanFindByCustomerByKeyWithDynamicLambda()
        {
            var results = _customerRepository.FindByKey(1);
            WriteLog();
            Assert.IsTrue(_log.Contains("FROM [dbo].[Customers"));
        }

        [TestMethod]
        public void CanFindByProductByKeyWithDynamicLambda()
        {
            var results = new GenericRepository<Product>(_context).FindByKey(1);
            WriteLog();
            Assert.IsTrue(_log.Contains("FROM [dbo].[Products"));
        }

        [TestMethod]
        public void NoTrackingQueriesDoNotCacheObjects()
        {
            var results = _customerRepository.All();
            Assert.AreEqual(0, _context.ChangeTracker.Entries().Count());
        }

        [TestMethod]
        public void CanQueryWithSinglePredicate()
        {
            var results = _customerRepository.FindBy(c => c.LastName.StartsWith("L"));
            WriteLog();
            Assert.IsTrue(_log.Contains("'L%'"));
        }

        [TestMethod]
        public void CanQueryWithDualPredicate()
        {
            var date = new DateTime(2001, 1, 1);
            var results = _customerRepository.FindBy(c => c.LastName.StartsWith("L") && c.DateOfBirth >= date);
            WriteLog();
            Assert.IsTrue(_log.Contains("'L%'") && _log.Contains("1/1/2001"));
        }

        [TestMethod]
        public void CanQueryWithComplexRelatedPredicate()
        {
            var date = new DateTime(2001, 1, 1);
            var results = _customerRepository.FindBy(c => c.LastName.StartsWith("L") && c.DateOfBirth >= date
                                  && c.Orders.Any());
            WriteLog();
            Assert.IsTrue(_log.Contains("'L%'") && _log.Contains("1/1/2001") && _log.Contains("Orders"));
        }

        //[TestMethod]
        //public void GetAllIncludingComprehendsSingleNavigation() {
        //  var results = _customerRepository.GetAllIncluding(c => c.Orders);
        //  Assert.IsTrue(results.Any(c => c.Orders.Any()));
        //}

        //[TestMethod]
        //public void GetAllIncludingComprehendsTwoChildNavigation() {
        //  var results = _customerRepository
        //    .GetAllIncluding(c => c.Orders, c => c.ContactDetail);
        //  WriteLog();
        //  Assert.IsTrue(_log.Contains("ContactDetails"));
        //  Assert.IsTrue(results.Any(c => c.Orders.Any()));
        //}

        //[TestMethod]
        //public void GetAllIncludingComprehendsTwoLevelNavigation() {
        //  var results = _customerRepository
        //    .GetAllIncluding(c => c.Orders, c => c.Orders.Select(o => o.LineItems));
        //  WriteLog();
        //  Assert.IsTrue(_log.Contains("LineItems"));
        //  Assert.IsTrue(results.Any(c => c.Orders.Any()));
        //}

        //[TestMethod]
        //public void CanIncludeNavigationProperties() {
        //  var results = _customerRepository.GetAllIncluding(c => c.Orders);
        //  WriteLog();
        //  Assert.IsTrue(_log.Contains("Orders"));
        //  Assert.IsTrue(results.Any(c => c.Orders.Any()));
        //}

        [TestMethod]
        public void ComposedOnAllListExecutedInMemory()
        {
            _customerRepository.All().Where(c => c.FirstName == "Julie").ToList();
            WriteLog();
            Assert.IsFalse(_log.Contains("Julie"));
        }

        private void WriteLog()
        {
            Debug.WriteLine(_log);
        }

        private void SetupLogging()
        {
            _context.Database.Log = BuildLogString;
        }

        private void BuildLogString(string message)
        {
            _logBuilder.Append(message);
            _log = _logBuilder.ToString();
        }
    }
}