具有EntityFramework的通用存储库

时间:2017-03-16 10:14:53

标签: c# entity-framework repository-pattern

我想使用Entity Framework实现一个通用的存储库模式(我知道有很多关于存储库的争议性意见,但这仍然是我需要的)。 我希望它具有的界面如下:

public interface IRepository
{
    IQueryable<TEntity> Query<TEntity>() 
        where TEntity: Entity;

    void Save<TEntity>(TEntity entity) 
        where TEntity : Entity;

    void Delete<TEntity>(TEntity entity) 
        where TEntity : Entity;
}

Entity是一个只有int ID属性的基类。 并像这样使用它:

        IRepository repository = ... // get repository (connects to DB)
        int userId = GetCurrentUserId();
        if (!repository.Query<User>().Any(u => u.Id == userId)) // performs SELECT query
        {    /*return error*/    }

        var newOrder = new Order { UserId = userId, Status = "New" }
        repository.Save(newOrder); // performs INSERT query
        ...
        newOrder.Status = "Completed";
        repository.Save(newOrder); // performs UPDATE query

我想避免UnitOwWork,只需在调用Save()Delete()后将所有对象更改提交到数据库。我想做的事情看起来很简单,但我没有找到任何使用EntityFramework的例子。我能找到的最接近的例子是this answer,但它使用UnitOwWork和repository-per - 实体,这比我需要做的更复杂。

3 个答案:

答案 0 :(得分:3)

我曾经使用过它,但是正如许多开发人员所说的那样,它将增加代码的复杂性并可能导致问题:

我的interface IRepositoryBase的代码:

public interface IRepositoryBase<TEntity> where TEntity : class
{
    void Add(TEntity objModel);
    void AddRange(IEnumerable<TEntity> objModel);
    TEntity GetId(int id);
    Task<TEntity> GetIdAsync(int id);
    TEntity Get(Expression<Func<TEntity, bool>> predicate);
    Task<TEntity> GetAsync(Expression<Func<TEntity, bool>> predicate);
    IEnumerable<TEntity> GetList(Expression<Func<TEntity, bool>> predicate);
    Task<IEnumerable<TEntity>> GetListAsync(Expression<Func<TEntity, bool>> predicate);
    IEnumerable<TEntity> GetAll();
    Task<IEnumerable<TEntity>> GetAllAsync();
    int Count();
    Task<int> CountAsync();
    void Update(TEntity objModel);
    void Remove(TEntity objModel);
    void Dispose(); 
}

interface上实施我的RepositoryBase的代码:

public class RepositoryBase<TEntity> : IRepositoryBase<TEntity> where TEntity : class
{
    #region Fields

    protected readonly EntityContext _context = new EntityContext();

    #endregion

    #region Methods

    public void Add(TEntity objModel)
    {
        _context.Set<TEntity>().Add(objModel);
        _context.SaveChanges();
    }

    public void AddRange(IEnumerable<TEntity> objModel)
    {
        _context.Set<TEntity>().AddRange(objModel);
        _context.SaveChanges();
    }

    public TEntity GetId(int id)
    {
        return _context.Set<TEntity>().Find(id);
    }

    public async Task<TEntity> GetIdAsync(int id)
    {
        return await _context.Set<TEntity>().FindAsync(id);
    }

    public TEntity Get(Expression<Func<TEntity, bool>> predicate)
    {
        return _context.Set<TEntity>().FirstOrDefault(predicate);
    }

    public async Task<TEntity> GetAsync(Expression<Func<TEntity, bool>> predicate)
    {
        return await _context.Set<TEntity>().FirstOrDefaultAsync(predicate);
    }

    public IEnumerable<TEntity> GetList(Expression<Func<TEntity, bool>> predicate)
    {
        return _context.Set<TEntity>().Where<TEntity>(predicate).ToList();
    }

    public async Task<IEnumerable<TEntity>> GetListAsync(Expression<Func<TEntity, bool>> predicate)
    {
        return await Task.Run(() =>
            _context.Set<TEntity>().Where<TEntity>(predicate));
    }

    public IEnumerable<TEntity> GetAll()
    {
        return _context.Set<TEntity>().ToList();
    }

    public async Task<IEnumerable<TEntity>> GetAllAsync()
    {
        return await Task.Run(() => _context.Set<TEntity>());
    }

    public int Count()
    {
        return _context.Set<TEntity>().Count();
    }

    public async Task<int> CountAsync()
    {
        return await _context.Set<TEntity>().CountAsync();
    }

    public void Update(TEntity objModel)
    {
        _context.Entry(objModel).State = EntityState.Modified;
        _context.SaveChanges();
    }

    public void Remove(TEntity objModel)
    {
        _context.Set<TEntity>().Remove(objModel);
        _context.SaveChanges();
    }

    public void Dispose()
    {
        _context.Dispose();
    }

    #endregion
}

我的实体interface

public interface IMyEntityRepository : IRepositoryBase<MyEntity>
{
     //here you can place other implementations your repository doesn't have
}

public class MyEntityRepository : RepositoryBase<MyEntity>, IMyEntityRepository
{
}

如何调用它(我在使用依赖注入):

public class MyServiceOrController
{
    #region Fields

    private readonly IMyEntityRepository _myEntityRepository;

    #endregion

    #region Constructors

    public MyServiceOrController(IMyEntityRepository myEntityRepository)
    {
        _myEntityRepository = myEntityRepository;
    }

    #endregion

    #region Methods

    public IList<MyEntity> TestGetAll()
    {
        return _myEntityRepository.GetAll();
    }

    #endregion
}

答案 1 :(得分:2)

1-创建一个界面

interface IMain<T> where T : class
    {
        List<T> GetAll();
        T GetById(int id);
        void Add(T entity);
        void Edit(T entity);
        void Del(int id);
        int Savechange();
    }

2-创建一个班级

public class Main<T> : IMain<T> where T : class
    {
        public DataContext db;
        public void Add(T entity)
        {
            db.Set<T>().Add(entity);
        }

        public void Del(int id)
        {
            var q = GetById(id);
            db.Set<T>().Remove(q);
        }

        public void Edit(T entity)
        {
            db.Entry<T>(entity).State = EntityState.Modified;
        }

        public List<T> GetAll()
        {
            return db.Set<T>().Select(a=>a).ToList();
        }

        public T GetById(int id)
        {
            return db.Set<T>().Find(id);
        }

        public int Savechange()
        {
            return db.SaveChanges();
        }
    }

3-创建一个名称为YourTable的存储库,例如学生

 public class Student : Main<Tbl_Student>
    {
        public Student()
        {
            db = new DataContext();
        }
    }

4-执行您的操作

Student student=new Student();
student.Del(3);
int a = student.Savechange();

答案 2 :(得分:0)

您可以使用expression关键字来做到这一点;

set;

然后创建您的存储库;

    public interface IRepository<TEntity> where TEntity : Entity
    {
        IQueryable<TEntity> Query(Expression<Func<TEntity, bool>> predicate);

        void Save(TEntity entity);

        void Delete(TEntity entity);
    }

    public abstract class EfRepository<T> : IRepository<T> where T : Entity
    {
        private readonly DbContext _dbContext;
        protected readonly DbSet<T> _dbSet;
        public EfRepository(YourDbContextContext dbContext)
        {
            _dbContext = dbContext;
            _dbSet = dbContext.Set<T>();
        }

        public void Delete(T entity)
        {
            if (entity == null) return;
            else
            {
                DbEntityEntry dbEntityEntry = _dbContext.Entry(entity);

                if (dbEntityEntry.State != EntityState.Deleted)
                {
                    dbEntityEntry.State = EntityState.Deleted;
                }
                else
                {
                    _dbSet.Attach(entity);
                    _dbSet.Remove(entity);
                    _dbContext.SaveChanges();
                }
            }
        }

        public IQueryable<T> Query(Expression<Func<T, bool>> predicate)
        {
            return _dbSet.Where(predicate);
        }

        public void Save(T entity)
        {
            if (entity.Id > 0)
            {
                _dbSet.Attach(entity);
                _dbContext.Entry(entity).State = EntityState.Modified;
                _dbContext.SaveChanges();
            }
            else
            {
                _dbSet.Add(entity);
                _dbContext.SaveChanges();
            }
        }
    }
    public class Entity
    {
        public int Id { get; set; }
    }

然后使用它;

     public interface IUserRepository : IRepository<User>
     {
       //Also you can add here another methods according to your needs
     }
     public class UserRepository : EfRepository<User>,IUserRepository
     {
            public UserRepository(YourDbContext yourDbContext) : base(yourDbContext)
            {

            }
     }