Motivation
While Entity Framework is a great object-relational mapper for the dotnet framework, the queries it generates can sometimes be slow. For instance, in my previous job, EF Core's query performance struggled with tables that have over 500,000 rows. The solution to that problem turned out to be Dapper, since it allows one to write their own optimized query. In this blog post, we will explore how to do a clean implementation of the Generic Repository pattern in Dapper.
Code
The source code can be found here.
Credit
This article builds on top of the brilliant work by Zuraiz Ahmed Shehzad on Medium. This implementation is asynchronous, and its queries are more robust, especially the SELECT
queries which select the column name as the field name dynamiclly.
Approach
We will have an N-tier ASP.NET Web API with four layers:
API --> Application --> Repository (Infrastructure) --> Database (Domain).
The generic repository implementation will be, you guessed it, in the Repository layer.
Furthermore, we will connect to a Postgres database since those are easy to setup on any operating system.
Step 1: The Database
We will create a simple API that allows us to do CRUD operations on Products and Categories. Following is the Db script to create them:
CREATE DATABASE "GenericRepoDapperDb"
WITH
OWNER = postgres
ENCODING = 'UTF8'
LC_COLLATE = 'English_United States.1252'
LC_CTYPE = 'English_United States.1252'
LOCALE_PROVIDER = 'libc'
TABLESPACE = pg_default
CONNECTION LIMIT = -1
IS_TEMPLATE = False;
CREATE TABLE categories (
id SERIAL PRIMARY KEY,
name VARCHAR(255) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE products (
id SERIAL PRIMARY KEY,
name VARCHAR(255) NOT NULL,
description TEXT,
category_id INTEGER NOT NULL REFERENCES categories(id),
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
Step 2: The Domain Layer
We will create Plain Old C# Objects (POCOs) that map to those tables:
public interface IEntity
{
int Id { get; set; }
DateTime CreatedAt { get; set; }
DateTime UpdatedAt { get; set; }
}
public class Entity : IEntity
{
[Key]
[Column("id")]
public int Id { get; set; }
[Column("created_at")]
public DateTime CreatedAt { get; set; }
[Column("updated_at")]
public DateTime UpdatedAt { get; set; }
}
[Table("categories")]
public class Category : Entity
{
[Column("name")]
public string? Name { get; set; }
}
[Table("products")]
public class Product : Entity
{
[Column("name")]
public string? Name { get; set; }
[Column("description")]
public string? Description { get; set; }
[Column("category_id")]
public int CategoryId { get; set; }
}
Now, we will create the database context, which will be registered as a Singleton in a later step.
public class ApplicationDbContext
{
private readonly IConfiguration _configuration;
public ApplicationDbContext(IConfiguration configuration)
{
_configuration = configuration;
}
public IDbConnection CreateConnection(string connectionString = "DefaultConnection")
{
string? connection = _configuration.GetConnectionString(connectionString);
return new NpgsqlConnection(connection);
}
}
Note the use of inheritance to avoid repeating the common fields, and the Table
, Key
, and Column
annotations. which will be useful for step 3.
Step 3: The Generic Repository implementation
In this step, we will implement a GenericRepository
which takes a type T
and have it linked by a Unit Of Work which we can easily register in the Dependency injection container (in a later step).
public interface IGenericRepository<T>
{
Task<T> GetById(int id);
Task<IEnumerable<T>> GetAll();
Task<int> CountAll();
Task<int> Add(T entity);
Task<int> Update(T entity);
Task<int> Delete(T entity);
}
public class GenericRepository<T> : IGenericRepository<T> where T : class
{
private readonly IDbConnection _connection;
public GenericRepository(ApplicationDbContext context)
{
_connection = context.CreateConnection();
}
public async Task<T> GetById(int id)
{
T result;
try
{
string tableName = GetTableName();
string keyColumn = GetKeyColumnName();
string query = $"SELECT {GetColumnsAsProperties()} FROM {tableName} WHERE {keyColumn} = '{id}'";
result = await _connection.QueryFirstOrDefaultAsync<T>(query);
}
catch (Exception ex)
{
Console.WriteLine($"Error fetching a record from db: ${ex.Message}");
throw new Exception("Unable to fetch data. Please contact the administrator.");
}
finally
{
_connection.Close();
}
return result;
}
public async Task<IEnumerable<T>> GetAll()
{
IEnumerable<T> result;
try
{
string tableName = GetTableName();
string query = $"SELECT {GetColumnsAsProperties()} FROM {tableName}";
result = await _connection.QueryAsync<T>(query);
}
catch (Exception ex)
{
Console.WriteLine($"Error fetching records from db: ${ex.Message}");
throw new Exception("Unable to fetch data. Please contact the administrator.");
}
finally
{
_connection.Close();
}
return result;
}
public async Task<int> CountAll()
{
int result = -1;
try
{
string tableName = GetTableName();
string query = $"SELECT COUNT(*) FROM {tableName}"; // May need exact column names
result = await _connection.QueryFirstOrDefaultAsync<int>(query);
}
catch (Exception ex)
{
Console.WriteLine($"Error counting records in db: ${ex.Message}");
throw new Exception("Unable to count data. Please contact the administrator.");
}
finally
{
_connection.Close();
}
return result;
}
public async Task<int> Add(T entity)
{
int rowsEffected = 0;
try
{
string tableName = GetTableName();
string columns = GetColumns(excludeKey: true);
string properties = GetPropertyNames(excludeKey: true);
string query = $"INSERT INTO {tableName} ({columns}) VALUES ({properties})";
rowsEffected = await _connection.ExecuteAsync(query, entity);
}
catch (Exception ex)
{
Console.WriteLine($"Error adding a record to db: ${ex.Message}");
rowsEffected = -1;
}
finally
{
_connection.Close();
}
return rowsEffected;
}
public async Task<int> Update(T entity)
{
int rowsEffected = 0;
try
{
string? tableName = GetTableName();
string? keyColumn = GetKeyColumnName();
string? keyProperty = GetKeyPropertyName();
StringBuilder query = new StringBuilder();
query.Append($"UPDATE {tableName} SET ");
foreach (var property in GetProperties(true))
{
var columnAttribute = property.GetCustomAttribute<ColumnAttribute>();
string propertyName = property.Name;
string columnName = columnAttribute?.Name ?? "";
query.Append($"{columnName} = @{propertyName},");
}
query.Remove(query.Length - 1, 1);
query.Append($" WHERE {keyColumn} = @{keyProperty}");
rowsEffected = await _connection.ExecuteAsync(query.ToString(), entity);
}
catch (Exception ex)
{
Console.WriteLine($"Error updating a record in db: ${ex.Message}");
rowsEffected = -1;
}
finally
{
_connection.Close();
}
return rowsEffected;
}
public async Task<int> Delete(T entity)
{
int rowsEffected = 0;
try
{
string? tableName = GetTableName();
string? keyColumn = GetKeyColumnName();
string? keyProperty = GetKeyPropertyName();
string query = $"DELETE FROM {tableName} WHERE {keyColumn} = @{keyProperty}";
rowsEffected = await _connection.ExecuteAsync(query, entity);
}
catch (Exception ex)
{
Console.WriteLine($"Error deleting a record in db: ${ex.Message}");
rowsEffected = -1;
}
finally
{
_connection.Close();
}
return rowsEffected;
}
private string GetTableName()
{
var type = typeof(T);
var tableAttribute = type.GetCustomAttribute<TableAttribute>();
if (tableAttribute != null)
return tableAttribute.Name;
return type.Name;
}
private static string? GetKeyColumnName()
{
PropertyInfo[] properties = typeof(T).GetProperties();
foreach (PropertyInfo property in properties)
{
object[] keyAttributes = property.GetCustomAttributes(typeof(KeyAttribute), true);
if (keyAttributes != null && keyAttributes.Length > 0)
{
object[] columnAttributes = property.GetCustomAttributes(typeof(ColumnAttribute), true);
if (columnAttributes != null && columnAttributes.Length > 0)
{
ColumnAttribute columnAttribute = (ColumnAttribute)columnAttributes[0];
return columnAttribute?.Name ?? "";
}
else
{
return property.Name;
}
}
}
return null;
}
private string GetColumns(bool excludeKey = false)
{
var type = typeof(T);
var columns = string.Join(", ", type.GetProperties()
.Where(p => !excludeKey || !p.IsDefined(typeof(KeyAttribute)))
.Select(p =>
{
var columnAttribute = p.GetCustomAttribute<ColumnAttribute>();
return columnAttribute != null ? columnAttribute.Name : p.Name;
}));
return columns;
}
private string GetColumnsAsProperties(bool excludeKey = false)
{
var type = typeof(T);
var columnsAsProperties = string.Join(", ", type.GetProperties()
.Where(p => !excludeKey || !p.IsDefined(typeof(KeyAttribute)))
.Select(p =>
{
var columnAttribute = p.GetCustomAttribute<ColumnAttribute>();
return columnAttribute != null ? $"{columnAttribute.Name} AS {p.Name}" : p.Name;
}));
return columnsAsProperties;
}
private string GetPropertyNames(bool excludeKey = false)
{
var properties = typeof(T).GetProperties()
.Where(p => !excludeKey || p.GetCustomAttribute<KeyAttribute>() == null);
var values = string.Join(", ", properties.Select(p => $"@{p.Name}"));
return values;
}
private IEnumerable<PropertyInfo> GetProperties(bool excludeKey = false)
{
var properties = typeof(T).GetProperties()
.Where(p => !excludeKey || p.GetCustomAttribute<KeyAttribute>() == null);
return properties;
}
private string? GetKeyPropertyName()
{
var properties = typeof(T).GetProperties()
.Where(p => p.GetCustomAttribute<KeyAttribute>() != null).ToList();
if (properties.Any())
return properties?.FirstOrDefault()?.Name ?? null;
return null;
}
}
public interface IUnit
{
GenericRepository<T> GetRepository<T>() where T : class, IEntity;
}
public class Unit : IUnit
{
private readonly ApplicationDbContext _context;
public Unit(ApplicationDbContext context)
{
_context = context;
}
public GenericRepository<T> GetRepository<T>() where T : class, IEntity
{
return new GenericRepository<T>(_context);
}
}
Step 4: The Service (Application) layer
In this layer, will simply make use of the repository to access the database and execute our CRUD operations for both Products and Categories (will only show the code for Products
below for breivity).
public interface IProductService
{
Task<int> Create(ProductDto productDto);
Task<int> Update(int id, ProductDto productDto);
Task<IEnumerable<Product>> GetAll();
Task<int> CountAll();
Task<Product> GetById(int id);
Task<bool> Delete(int id);
}
public class ProductService : IProductService
{
private readonly IUnit _unit;
private readonly IGenericRepository<Product> _repository;
private readonly IMapper _mapper;
public ProductService(IUnit unit, IMapper mapper)
{
_unit = unit;
_repository = _unit.GetRepository<Product>();
_mapper = mapper;
}
public async Task<int> Create(ProductDto productDto)
{
var product = _mapper.Map<Product>(productDto);
product.CreatedAt = DateTime.SpecifyKind(DateTime.Now, DateTimeKind.Utc);
int result = await _repository.Add(product);
return result;
}
public async Task<int> Update(int id, ProductDto productDto)
{
var product = await GetById(id);
product.Name = productDto.Name;
product.Description = productDto.Description;
product.CategoryId = productDto.CategoryId;
product.UpdatedAt = DateTime.SpecifyKind(DateTime.Now, DateTimeKind.Utc);
int productsUpdated = await _repository.Update(product);
return productsUpdated;
}
public async Task<IEnumerable<Product>> GetAll()
{
var products = await _repository.GetAll();
return products;
}
public async Task<int> CountAll()
{
int count = await _repository.CountAll();
return count;
}
public async Task<Product> GetById(int id)
{
var product = await _repository.GetById(id);
if (product == null)
throw new Exception("Product record does not exist.");
return product;
}
public async Task<bool> Delete(int id)
{
var product = await GetById(id);
int result = await _repository.Delete(product); // Could also be done with "isDeleted = true;"
return (result > 0);
}
}
Step 5: The Web API
We will create a Web API with two controllers, and give it a Swagger UI to be used for testing.
Program.cs
using Application.Logic.CategoryService;
using Application.Logic.ProductService;
using Domain.Database;
using GenericRepo_Dapper.Configuration;
using Infrastructure.UnitOfWork;
using Microsoft.OpenApi.Models;
var builder = WebApplication.CreateBuilder(args);
string corsName = "CorsName";
builder.Services.AddCors(options =>
{
options.AddPolicy(corsName, policyBuilder => policyBuilder
.WithOrigins("http://localhost", "https://localhost")
.AllowAnyMethod()
.AllowAnyHeader());
});
builder.Services.AddSingleton<ApplicationDbContext>();
builder.Services.AddRouting(context => context.LowercaseUrls = true);
builder.Services.AddControllersWithViews().AddJsonOptions(options =>
{
options.JsonSerializerOptions.ReferenceHandler = System.Text.Json.Serialization.ReferenceHandler.IgnoreCycles;
options.JsonSerializerOptions.WriteIndented = true;
});
builder.Services.AddAutoMapper(AppDomain.CurrentDomain.GetAssemblies());
// Service
builder.Services.AddScoped<ICategoryService, CategoryService>();
builder.Services.AddScoped<IProductService, ProductService>();
// Repository
builder.Services.AddScoped<IUnit, Unit>();
// Swagger
builder.Services.AddEndpointsApiExplorer();
builder.Services.AddSwaggerGen(option =>
{
option.SwaggerDoc("v1", info: new OpenApiInfo { Title = "Generic Repository and Dapper API", Version = "v1" });
option.OperationFilter<HeaderFilter>();
});
var app = builder.Build();
var swaggerConfig = new SwaggerConfig();
builder.Configuration.GetSection(nameof(SwaggerConfig)).Bind(swaggerConfig);
app.UseSwagger(option => { option.RouteTemplate = swaggerConfig.JsonRoute; });
app.UseSwaggerUI(option => { option.SwaggerEndpoint(swaggerConfig.UIEndpoint, swaggerConfig.Description); });
app.UseCors(corsName);
app.UseExceptionHandler("/Error");
app.MapControllers();
app.Run();
Result
Room for imrpovement
- Add Bulk.
- Pagination.
- Filters for the
Get
functions. - Returning a view model instead of the object itself.
End Note
Feel free to comment with questions or inquiries, or to add any notes in the section below. Happy developing!