Skip to content
This repository has been archived by the owner on Jul 31, 2024. It is now read-only.

Commit

Permalink
validate filter values on db results (#4616)
Browse files Browse the repository at this point in the history
  • Loading branch information
brockallen committed Jul 3, 2020
1 parent 9e0b0cd commit 9558fac
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 24 deletions.
8 changes: 4 additions & 4 deletions src/EntityFramework.Storage/src/Stores/ClientStore.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.


Expand Down Expand Up @@ -52,10 +52,10 @@ public ClientStore(IConfigurationDbContext context, ILogger<ClientStore> logger)
public virtual async Task<Client> FindClientByIdAsync(string clientId)
{
IQueryable<Entities.Client> baseQuery = Context.Clients
.Where(x => x.ClientId == clientId)
.Take(1);
.Where(x => x.ClientId == clientId);

var client = await baseQuery.FirstOrDefaultAsync();
var client = (await baseQuery.ToArrayAsync())
.SingleOrDefault(x => x.ClientId == clientId);
if (client == null) return null;

await baseQuery.Include(x => x.AllowedCorsOrigins).SelectMany(c => c.AllowedCorsOrigins).LoadAsync();
Expand Down
15 changes: 10 additions & 5 deletions src/EntityFramework.Storage/src/Stores/DeviceFlowStore.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.


using System;
using System.Linq;
using System.Threading.Tasks;
using IdentityModel;
using IdentityServer4.EntityFramework.Entities;
Expand Down Expand Up @@ -73,7 +74,8 @@ public virtual async Task StoreDeviceAuthorizationAsync(string deviceCode, strin
/// <returns></returns>
public virtual async Task<DeviceCode> FindByUserCodeAsync(string userCode)
{
var deviceFlowCodes = await Context.DeviceFlowCodes.AsNoTracking().FirstOrDefaultAsync(x => x.UserCode == userCode);
var deviceFlowCodes = (await Context.DeviceFlowCodes.AsNoTracking().Where(x => x.UserCode == userCode).ToArrayAsync())
.SingleOrDefault(x => x.UserCode == userCode);
var model = ToModel(deviceFlowCodes?.Data);

Logger.LogDebug("{userCode} found in database: {userCodeFound}", userCode, model != null);
Expand All @@ -88,7 +90,8 @@ public virtual async Task<DeviceCode> FindByUserCodeAsync(string userCode)
/// <returns></returns>
public virtual async Task<DeviceCode> FindByDeviceCodeAsync(string deviceCode)
{
var deviceFlowCodes = await Context.DeviceFlowCodes.AsNoTracking().FirstOrDefaultAsync(x => x.DeviceCode == deviceCode);
var deviceFlowCodes = (await Context.DeviceFlowCodes.AsNoTracking().Where(x => x.DeviceCode == deviceCode).ToArrayAsync())
.SingleOrDefault(x => x.DeviceCode == deviceCode);
var model = ToModel(deviceFlowCodes?.Data);

Logger.LogDebug("{deviceCode} found in database: {deviceCodeFound}", deviceCode, model != null);
Expand All @@ -104,7 +107,8 @@ public virtual async Task<DeviceCode> FindByDeviceCodeAsync(string deviceCode)
/// <returns></returns>
public virtual async Task UpdateByUserCodeAsync(string userCode, DeviceCode data)
{
var existing = await Context.DeviceFlowCodes.SingleOrDefaultAsync(x => x.UserCode == userCode);
var existing = (await Context.DeviceFlowCodes.Where(x => x.UserCode == userCode).ToArrayAsync())
.SingleOrDefault(x => x.UserCode == userCode);
if (existing == null)
{
Logger.LogError("{userCode} not found in database", userCode);
Expand Down Expand Up @@ -134,7 +138,8 @@ public virtual async Task UpdateByUserCodeAsync(string userCode, DeviceCode data
/// <returns></returns>
public virtual async Task RemoveByDeviceCodeAsync(string deviceCode)
{
var deviceFlowCodes = await Context.DeviceFlowCodes.FirstOrDefaultAsync(x => x.DeviceCode == deviceCode);
var deviceFlowCodes = (await Context.DeviceFlowCodes.Where(x => x.DeviceCode == deviceCode).ToArrayAsync())
.SingleOrDefault(x => x.DeviceCode == deviceCode);

if(deviceFlowCodes != null)
{
Expand Down
22 changes: 13 additions & 9 deletions src/EntityFramework.Storage/src/Stores/PersistedGrantStore.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.


Expand Down Expand Up @@ -46,7 +46,8 @@ public PersistedGrantStore(IPersistedGrantDbContext context, ILogger<PersistedGr
/// <inheritdoc/>
public virtual async Task StoreAsync(PersistedGrant token)
{
var existing = await Context.PersistedGrants.SingleOrDefaultAsync(x => x.Key == token.Key);
var existing = (await Context.PersistedGrants.Where(x => x.Key == token.Key).ToArrayAsync())
.SingleOrDefault(x => x.Key == token.Key);
if (existing == null)
{
Logger.LogDebug("{persistedGrantKey} not found in database", token.Key);
Expand Down Expand Up @@ -74,7 +75,8 @@ public virtual async Task StoreAsync(PersistedGrant token)
/// <inheritdoc/>
public virtual async Task<PersistedGrant> GetAsync(string key)
{
var persistedGrant = await Context.PersistedGrants.AsNoTracking().FirstOrDefaultAsync(x => x.Key == key);
var persistedGrant = (await Context.PersistedGrants.AsNoTracking().Where(x => x.Key == key).ToArrayAsync())
.SingleOrDefault(x => x.Key == key);
var model = persistedGrant?.ToModel();

Logger.LogDebug("{persistedGrantKey} found in database: {persistedGrantKeyFound}", key, model != null);
Expand All @@ -87,7 +89,9 @@ public async Task<IEnumerable<PersistedGrant>> GetAllAsync(PersistedGrantFilter
{
filter.Validate();

var persistedGrants = await Filter(filter).ToArrayAsync();
var persistedGrants = await Filter(Context.PersistedGrants.AsQueryable(), filter).ToArrayAsync();
persistedGrants = Filter(persistedGrants.AsQueryable(), filter).ToArray();

var model = persistedGrants.Select(x => x.ToModel());

Logger.LogDebug("{persistedGrantCount} persisted grants found for {@filter}", persistedGrants.Length, filter);
Expand All @@ -98,7 +102,8 @@ public async Task<IEnumerable<PersistedGrant>> GetAllAsync(PersistedGrantFilter
/// <inheritdoc/>
public virtual async Task RemoveAsync(string key)
{
var persistedGrant = await Context.PersistedGrants.FirstOrDefaultAsync(x => x.Key == key);
var persistedGrant = (await Context.PersistedGrants.Where(x => x.Key == key).ToArrayAsync())
.SingleOrDefault(x => x.Key == key);
if (persistedGrant!= null)
{
Logger.LogDebug("removing {persistedGrantKey} persisted grant from database", key);
Expand All @@ -125,7 +130,8 @@ public async Task RemoveAllAsync(PersistedGrantFilter filter)
{
filter.Validate();

var persistedGrants = await Filter(filter).ToArrayAsync();
var persistedGrants = await Filter(Context.PersistedGrants.AsQueryable(), filter).ToArrayAsync();
persistedGrants = Filter(persistedGrants.AsQueryable(), filter).ToArray();

Logger.LogDebug("removing {persistedGrantCount} persisted grants from database for {@filter}", persistedGrants.Length, filter);

Expand All @@ -142,10 +148,8 @@ public async Task RemoveAllAsync(PersistedGrantFilter filter)
}


private IQueryable<Entities.PersistedGrant> Filter(PersistedGrantFilter filter)
private IQueryable<Entities.PersistedGrant> Filter(IQueryable<Entities.PersistedGrant> query, PersistedGrantFilter filter)
{
var query = Context.PersistedGrants.AsQueryable();

if (!String.IsNullOrWhiteSpace(filter.ClientId))
{
query = query.Where(x => x.ClientId == filter.ClientId);
Expand Down
17 changes: 11 additions & 6 deletions src/EntityFramework.Storage/src/Stores/ResourceStore.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.


Expand Down Expand Up @@ -56,15 +56,17 @@ public virtual async Task<IEnumerable<ApiResource>> FindApiResourcesByNameAsync(
from apiResource in Context.ApiResources
where apiResourceNames.Contains(apiResource.Name)
select apiResource;

var apis = query
.Include(x => x.Secrets)
.Include(x => x.Scopes)
.Include(x => x.UserClaims)
.Include(x => x.Properties)
.AsNoTracking();

var result = (await apis.ToArrayAsync()).Select(x => x.ToModel()).ToArray();
var result = (await apis.ToArrayAsync())
.Where(x => apiResourceNames.Contains(x.Name))
.Select(x => x.ToModel()).ToArray();

if (result.Any())
{
Expand Down Expand Up @@ -99,7 +101,8 @@ where api.Scopes.Where(x => names.Contains(x.Scope)).Any()
.Include(x => x.Properties)
.AsNoTracking();

var results = await apis.ToArrayAsync();
var results = (await apis.ToArrayAsync())
.Where(api => api.Scopes.Any(x => names.Contains(x.Scope)));
var models = results.Select(x => x.ToModel()).ToArray();

Logger.LogDebug("Found {apis} API resources in database", models.Select(x => x.Name));
Expand All @@ -126,7 +129,8 @@ where scopes.Contains(identityResource.Name)
.Include(x => x.Properties)
.AsNoTracking();

var results = await resources.ToArrayAsync();
var results = (await resources.ToArrayAsync())
.Where(x => scopes.Contains(x.Name));

Logger.LogDebug("Found {scopes} identity scopes in database", results.Select(x => x.Name));

Expand All @@ -152,7 +156,8 @@ where scopes.Contains(scope.Name)
.Include(x => x.Properties)
.AsNoTracking();

var results = await resources.ToArrayAsync();
var results = (await resources.ToArrayAsync())
.Where(x => scopes.Contains(x.Name));

Logger.LogDebug("Found {scopes} scopes in database", results.Select(x => x.Name));

Expand Down

0 comments on commit 9558fac

Please sign in to comment.