diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index d207044..0000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,67 +0,0 @@ -version: 2 -jobs: - MacOS: - macos: - xcode: "10.0.0" - steps: - - checkout - - restore_cache: - keys: - - v1-spm-deps-{{ checksum "Package.swift" }} - - run: - name: Install dependencies - command: | - brew tap vapor/homebrew-tap - brew install cmysql - brew install ctls - brew install libressl - brew install cstack - - run: - name: Build and Run Tests - no_output_timeout: 1800 - command: | - swift package generate-xcodeproj --enable-code-coverage - xcodebuild -scheme Gatekeeper-Package -enableCodeCoverage YES test | xcpretty - - run: - name: Report coverage to Codecov - command: | - bash <(curl -s https://codecov.io/bash) - - save_cache: - key: v1-spm-deps-{{ checksum "Package.swift" }} - paths: - - .build - Linux: - docker: - - image: nodesvapor/vapor-ci:swift-4.2 - steps: - - checkout - - restore_cache: - keys: - - v1-spm-deps-{{ checksum "Package.swift" }} - - run: - name: Copy Package File - command: cp Package.swift res - - run: - name: Build and Run Tests - no_output_timeout: 1800 - command: | - swift test -Xswiftc -DNOJSON - - run: - name: Restoring Package File - command: mv res Package.swift - - save_cache: - key: v1-spm-deps-{{ checksum "Package.swift" }} - paths: - - .build -workflows: - version: 2 - build-and-test: - jobs: - - MacOS - - Linux -experimental: - notify: - branches: - only: - - master - - develop diff --git a/.codebeatignore b/.codebeatignore deleted file mode 100644 index 2d9084a..0000000 --- a/.codebeatignore +++ /dev/null @@ -1,2 +0,0 @@ -Public/** -Resources/Assets/** diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..874b1a7 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,26 @@ +name: test +on: + pull_request: + push: + branches: + - master +jobs: + linux: + runs-on: ubuntu-latest + container: swift:5.3-focal + steps: + - name: Check out code + uses: actions/checkout@v2 + - name: Run tests with Thread Sanitizer + run: swift test --enable-test-discovery --sanitize=thread + macOS: + runs-on: macos-latest + steps: + - name: Select latest available Xcode + uses: maxim-lobanov/setup-xcode@v1 + with: + xcode-version: latest + - name: Check out code + uses: actions/checkout@v2 + - name: Run tests with Thread Sanitizer + run: swift test --enable-test-discovery --sanitize=thread \ No newline at end of file diff --git a/.gitignore b/.gitignore index 460f13c..b2915af 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ Config/secrets/ .DS_Store .swift-version CMakeLists.txt -Package.resolved \ No newline at end of file +Package.resolved +.swiftpm \ No newline at end of file diff --git a/LICENSE b/LICENSE index 8de12da..0db1d3d 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2017-2019 Nodes +Copyright (c) 2017-2021 Nodes Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Package.swift b/Package.swift index 1d0257e..cab3258 100644 --- a/Package.swift +++ b/Package.swift @@ -1,24 +1,30 @@ -// swift-tools-version:4.2 +// swift-tools-version:5.3 import PackageDescription let package = Package( - name: "Gatekeeper", + name: "gatekeeper", + platforms: [ + .macOS(.v10_15), + ], products: [ .library( name: "Gatekeeper", targets: ["Gatekeeper"]), ], dependencies: [ - .package(url: "https://github.com/vapor/vapor.git", from: "3.0.0"), + .package(url: "https://github.com/vapor/vapor.git", from: "4.38.0"), ], targets: [ .target( name: "Gatekeeper", dependencies: [ - "Vapor" + .product(name: "Vapor", package: "vapor") ]), .testTarget( name: "GatekeeperTests", - dependencies: ["Gatekeeper"]), + dependencies: [ + "Gatekeeper", + .product(name: "XCTVapor", package: "vapor") + ]), ] ) diff --git a/README.md b/README.md index f7f876e..6be2584 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,10 @@ # Gatekeeper 👮 -[![Swift Version](https://img.shields.io/badge/Swift-4.2-brightgreen.svg)](http://swift.org) -[![Vapor Version](https://img.shields.io/badge/Vapor-3-30B6FC.svg)](http://vapor.codes) -[![Circle CI](https://circleci.com/gh/nodes-vapor/gatekeeper/tree/master.svg?style=shield)](https://circleci.com/gh/nodes-vapor/gatekeeper) -[![codebeat badge](https://codebeat.co/badges/35c7b0bb-1662-44ae-b953-ab1d4aaf231f)](https://codebeat.co/projects/github-com-nodes-vapor-gatekeeper-master) -[![codecov](https://codecov.io/gh/nodes-vapor/gatekeeper/branch/master/graph/badge.svg)](https://codecov.io/gh/nodes-vapor/gatekeeper) -[![Readme Score](http://readme-score-api.herokuapp.com/score.svg?url=https://github.com/nodes-vapor/gatekeeper)](http://clayallsopp.github.io/readme-score?url=https://github.com/nodes-vapor/gatekeeper) +[![Swift Version](https://img.shields.io/badge/Swift-5.3-brightgreen.svg)](http://swift.org) +[![Vapor Version](https://img.shields.io/badge/Vapor-4-30B6FC.svg)](http://vapor.codes) [![GitHub license](https://img.shields.io/badge/license-MIT-blue.svg)](https://raw.githubusercontent.com/nodes-vapor/gatekeeper/master/LICENSE) -Gatekeeper is a middleware that restricts the number of requests from clients, based on their IP address. -It works by adding the clients IP address to the cache and count how many requests the clients can make during the Gatekeeper's defined lifespan and give back an HTTP 429(Too Many Requests) if the limit has been reached. The number of requests left will be reset when the defined timespan has been reached. +Gatekeeper is a middleware that restricts the number of requests from clients, based on their IP address **(can be customized)**. +It works by adding the clients identifier to the cache and count how many requests the clients can make during the Gatekeeper's defined lifespan and give back an HTTP 429(Too Many Requests) if the limit has been reached. The number of requests left will be reset when the defined timespan has been reached. **Please take into consideration that multiple clients can be using the same IP address. eg. public wifi** @@ -18,7 +14,7 @@ It works by adding the clients IP address to the cache and count how many reques Update your `Package.swift` dependencies: ```swift -.package(url: "https://github.com/nodes-vapor/gatekeeper.git", from: "3.0.0"), +.package(url: "https://github.com/nodes-vapor/gatekeeper.git", from: "4.0.0"), ``` as well as to your target (e.g. "App"): @@ -26,7 +22,7 @@ as well as to your target (e.g. "App"): ```swift targets: [ .target(name: "App", dependencies: [..., "Gatekeeper", ...]), -// ... + // ... ] ``` @@ -40,15 +36,8 @@ import Gatekeeper // [...] -// Register providers first -try services.register( - GatekeeperProvider( - config: GatekeeperConfig(maxRequests: 10, per: .second), - cacheFactory: { container -> KeyedCache in - return try container.make() - } - ) -) +app.caches.use(.memory) +app.gatekeeper.config = .init(maxRequests: 10, per: .second) ``` ### Add to routes @@ -58,7 +47,7 @@ You can add the `GatekeeperMiddleware` to specific routes or to all. **Specific routes** in routes.swift: ```swift -let protectedRoutes = router.grouped(GatekeeperMiddleware.self) +let protectedRoutes = router.grouped(GatekeeperMiddleware()) protectedRoutes.get("protected/hello") { req in return "Protected Hello, World!" } @@ -68,15 +57,67 @@ protectedRoutes.get("protected/hello") { req in in configure.swift: ```swift // Register middleware -var middlewares = MiddlewareConfig() // Create _empty_ middleware config -middlewares.use(GatekeeperMiddleware.self) -services.register(middlewares) +app.middlewares.use(GatekeeperMiddleware()) +``` + +#### Customizing config +By default `GatekeeperMiddleware` uses `app.gatekeeper.config` as its configuration. +However, you can pass a custom configuration to each `GatekeeperMiddleware` type via the initializer +`GatekeeperMiddleware(config:)`. This allows you to set configuration on a per-route basis. + +## Key Makers 🔑 +By default Gatekeeper uses the client's hostname (IP address) to identify them. This can cause issues where multiple clients are connected from the same network. Therefore, you can customize how Gatekeeper should identify the client by using the `GatekeeperKeyMaker` protocol. + +`GatekeeperHostnameKeyMaker` is used by default. + +You can configure which key maker Gatekeeper should use in `configure.swift`: +```swift +app.gatekeeper.keyMakers.use(.hostname) // default +``` + +### Custom key maker +This is an example of a key maker that uses the user's ID to identify them. +```swift +struct UserIDKeyMaker: GatekeeperKeyMaker { + public func make(for req: Request) -> EventLoopFuture { + let userID = try req.auth.require(User.self).requireID() + return req.eventLoop.future("gatekeeper_" + userID.uuidString) + } +} +``` + +```swift +extension Application.Gatekeeper.KeyMakers.Provider { + public static var userID: Self { + .init { app in + app.gatekeeper.keyMakers.use { _ in UserIDKeyMaker() } + } + } +} +``` +**configure.swift:** +```swift +app.gatekeeper.keyMakers.use(.userID) ``` +## Cache 🗄 +Gatekeeper uses the same cache as configured by `app.caches.use()` from Vapor, by default. +Therefore it is **important** to set up Vapor's cache if you're using this default behaviour. You can use an in-memory cache for Vapor like so: + +**configure.swift**: +```swift +app.cache.use(.memory) +``` + +### Custom cache +You can override which cache to use by creating your own type that conforms to the `Cache` protocol from Vapor. Use `app.gatekeeper.caches.use()` to configure which cache to use. + + ## Credits 🏆 This package is developed and maintained by the Vapor team at [Nodes](https://www.nodesagency.com). The package owner for this project is [Christian](https://github.com/cweinberger). +Special thanks goes to [madsodgaard](https://github.com/madsodgaard) for his work on the Vapor 4 version! ## License 📄 diff --git a/Sources/Gatekeeper/Gatekeeper+Vapor/Application+Gatekeeper.swift b/Sources/Gatekeeper/Gatekeeper+Vapor/Application+Gatekeeper.swift new file mode 100644 index 0000000..d11daea --- /dev/null +++ b/Sources/Gatekeeper/Gatekeeper+Vapor/Application+Gatekeeper.swift @@ -0,0 +1,144 @@ +import Vapor + +// MARK: - Application+Gatekeeper +extension Application { + public struct Gatekeeper { + private let app: Application + + init(app: Application) { + self.app = app + } + + private final class Storage { + var config: GatekeeperConfig? = nil + var makeCache: ((Application) -> Cache)? = nil + var makeKeyMaker: ((Application) -> GatekeeperKeyMaker)? = nil + } + + private struct Key: StorageKey { + typealias Value = Storage + } + + private var storage: Storage { + if app.storage[Key.self] == nil { + initialize() + } + + return app.storage[Key.self]! + } + + private func initialize() { + app.storage[Key.self] = Storage() + app.gatekeeper.caches.use(.default) + app.gatekeeper.keyMakers.use(.hostname) + } + + /// The default config used for middlewares. + public var config: GatekeeperConfig { + get { + guard let config = storage.config else { + fatalError("Gatekeeper not configured, use: app.gatekeeper.config = ...") + } + + return config + } + nonmutating set { storage.config = newValue } + } + } + + public var gatekeeper: Gatekeeper { + .init(app: self) + } +} + +// MARK: - Gatekeeper+Caches +extension Application.Gatekeeper { + public struct Caches { + private let gatekeeper: Application.Gatekeeper + + public init(_ gatekeeper: Application.Gatekeeper) { + self.gatekeeper = gatekeeper + } + + public struct Provider { + public let run: (Application) -> Void + + public init(_ run: @escaping (Application) -> Void) { + self.run = run + } + + /// A provider that uses the default Vapor cache. + public static var `default`: Self { + .init { app in + app.gatekeeper.caches.use { $0.cache } + } + } + } + + public func use(_ makeCache: @escaping (Application) -> Cache) { + gatekeeper.storage.makeCache = makeCache + } + + public func use(_ provider: Provider) { + provider.run(gatekeeper.app) + } + + public var cache: Cache { + guard let factory = gatekeeper.storage.makeCache else { + fatalError("Gatekeeper not configured, use: app.gatekeeper.caches.use(...)") + } + + return factory(gatekeeper.app) + } + } + + public var caches: Caches { + .init(self) + } +} + +// MARK: - Gatekeeper+Keymakers +extension Application.Gatekeeper { + public struct KeyMakers { + private let gatekeeper: Application.Gatekeeper + + public init(_ gatekeeper: Application.Gatekeeper) { + self.gatekeeper = gatekeeper + } + + public struct Provider { + public let run: (Application) -> Void + + public init(_ run: @escaping (Application) -> Void) { + self.run = run + } + + /// A provider that the request hostname to generate a cache key. + public static var hostname: Self { + .init { app in + app.gatekeeper.keyMakers.use { _ in GatekeeperHostnameKeyMaker() } + } + } + } + + public func use(_ makeKeyMaker: @escaping (Application) -> GatekeeperKeyMaker) { + gatekeeper.storage.makeKeyMaker = makeKeyMaker + } + + public func use(_ provider: Provider) { + provider.run(gatekeeper.app) + } + + public var keyMaker: GatekeeperKeyMaker { + guard let factory = gatekeeper.storage.makeKeyMaker else { + fatalError("Gatekeeper not configured, use: app.gatekeeper.keyMakers.use(...)") + } + + return factory(gatekeeper.app) + } + } + + public var keyMakers: KeyMakers { + .init(self) + } +} diff --git a/Sources/Gatekeeper/Gatekeeper+Vapor/Request+Gatekeeper.swift b/Sources/Gatekeeper/Gatekeeper+Vapor/Request+Gatekeeper.swift new file mode 100644 index 0000000..b432bf0 --- /dev/null +++ b/Sources/Gatekeeper/Gatekeeper+Vapor/Request+Gatekeeper.swift @@ -0,0 +1,15 @@ +import Vapor + +extension Request { + func gatekeeper( + config: GatekeeperConfig? = nil, + cache: Cache? = nil, + keyMaker: GatekeeperKeyMaker? = nil + ) -> Gatekeeper { + .init( + cache: cache ?? application.gatekeeper.caches.cache.for(self), + config: config ?? application.gatekeeper.config, + identifier: keyMaker ?? application.gatekeeper.keyMakers.keyMaker + ) + } +} diff --git a/Sources/Gatekeeper/Gatekeeper.swift b/Sources/Gatekeeper/Gatekeeper.swift index 202d0e5..5e67b17 100644 --- a/Sources/Gatekeeper/Gatekeeper.swift +++ b/Sources/Gatekeeper/Gatekeeper.swift @@ -1,77 +1,50 @@ import Vapor -public struct Gatekeeper: Service { - - internal let config: GatekeeperConfig - internal let cacheFactory: ((Container) throws -> KeyedCache) - - public init( - config: GatekeeperConfig, - cacheFactory: @escaping ((Container) throws -> KeyedCache) = { container in try container.make() } - ) { +public struct Gatekeeper { + private let cache: Cache + private let config: GatekeeperConfig + private let keyMaker: GatekeeperKeyMaker + + public init(cache: Cache, config: GatekeeperConfig, identifier: GatekeeperKeyMaker) { + self.cache = cache self.config = config - self.cacheFactory = cacheFactory + self.keyMaker = identifier } - - public func accessEndpoint( - on request: Request - ) throws -> Future { - - guard let peerHostName = request.http.remotePeer.hostname else { - throw Abort( - .forbidden, - reason: "Unable to verify peer" - ) - } - - let peerCacheKey = cacheKey(for: peerHostName) - let cache = try cacheFactory(request) - - return cache.get(peerCacheKey, as: Entry.self) - .map(to: Entry.self) { entry in - if let entry = entry { - return entry - } else { - return Entry( - peerHostname: peerHostName, - createdAt: Date(), - requestsLeft: self.config.limit - ) - } - } - .map(to: Entry.self) { entry in - - let now = Date() - var mutableEntry = entry - if now.timeIntervalSince1970 - entry.createdAt.timeIntervalSince1970 >= self.config.refreshInterval { - mutableEntry.createdAt = now - mutableEntry.requestsLeft = self.config.limit - } - mutableEntry.requestsLeft -= 1 - return mutableEntry - }.then { entry in - return cache.set(peerCacheKey, to: entry).transform(to: entry) - }.map(to: Entry.self) { entry in - - if entry.requestsLeft < 0 { - throw Abort( - .tooManyRequests, - reason: "Slow down. You sent too many requests." - ) - } - return entry + + public func gatekeep(on req: Request) -> EventLoopFuture { + keyMaker + .make(for: req) + .flatMap { cacheKey in + fetchOrCreateEntry(for: cacheKey, on: req) + .map(updateEntry) + .flatMap { entry in + cache + .set(cacheKey, to: entry) + .transform(to: entry) + } } + .guard( + { $0.requestsLeft > 0 }, + else: Abort(.tooManyRequests, reason: "Slow down. You sent too many requests.")) + .transform(to: ()) } - - private func cacheKey(for hostname: String) -> String { - return "gatekeeper_\(hostname)" + + private func updateEntry(_ entry: Entry) -> Entry { + var newEntry = entry + if newEntry.hasExpired(within: config.refreshInterval) { + newEntry.reset(remainingRequests: config.limit) + } + newEntry.touch() + return newEntry } -} - -extension Gatekeeper { - public struct Entry: Codable { - let peerHostname: String - var createdAt: Date - var requestsLeft: Int + + private func fetchOrCreateEntry(for key: String, on req: Request) -> EventLoopFuture { + guard let hostname = req.hostname else { + return req.eventLoop.future(error: Abort(.forbidden, reason: "Unable to verify peer")) + } + + return cache + .get(key, as: Entry.self) + .unwrap(orReplace: Entry(hostname: hostname, createdAt: Date(), requestsLeft: config.limit)) } } diff --git a/Sources/Gatekeeper/GatekeeperConfig.swift b/Sources/Gatekeeper/GatekeeperConfig.swift index 79597d6..53c7483 100644 --- a/Sources/Gatekeeper/GatekeeperConfig.swift +++ b/Sources/Gatekeeper/GatekeeperConfig.swift @@ -1,7 +1,6 @@ import Vapor -public struct GatekeeperConfig: Service { - +public struct GatekeeperConfig { public enum Interval { case second case minute @@ -17,7 +16,7 @@ public struct GatekeeperConfig: Service { self.interval = interval } - internal var refreshInterval: Double { + var refreshInterval: Double { switch interval { case .second: return 1 diff --git a/Sources/Gatekeeper/GatekeeperEntry.swift b/Sources/Gatekeeper/GatekeeperEntry.swift new file mode 100644 index 0000000..a1bf908 --- /dev/null +++ b/Sources/Gatekeeper/GatekeeperEntry.swift @@ -0,0 +1,25 @@ +import Vapor + +extension Gatekeeper { + /// A model representing a entry in the cache for a specific client + public struct Entry: Codable { + let hostname: String + var createdAt: Date + var requestsLeft: Int + } +} + +extension Gatekeeper.Entry { + func hasExpired(within interval: Double) -> Bool { + Date().timeIntervalSince1970 - createdAt.timeIntervalSince1970 >= interval + } + + mutating func reset(remainingRequests: Int) { + createdAt = Date() + requestsLeft = remainingRequests + } + + mutating func touch() { + requestsLeft -= 1 + } +} diff --git a/Sources/Gatekeeper/GatekeeperMiddleware.swift b/Sources/Gatekeeper/GatekeeperMiddleware.swift index 3f91575..e6f6f01 100644 --- a/Sources/Gatekeeper/GatekeeperMiddleware.swift +++ b/Sources/Gatekeeper/GatekeeperMiddleware.swift @@ -1,23 +1,18 @@ import Vapor -public struct GatekeeperMiddleware { - let gatekeeper: Gatekeeper -} - -extension GatekeeperMiddleware: Middleware { - public func respond( - to request: Request, - chainingTo next: Responder - ) throws -> Future { - - return try gatekeeper.accessEndpoint(on: request).flatMap { _ in - return try next.respond(to: request) - } +/// Middleware used to rate-limit a single route or a group of routes. +public struct GatekeeperMiddleware: Middleware { + private let config: GatekeeperConfig? + + /// Initialize with a custom `GatekeeperConfig` instead of using the default `app.gatekeeper.config` + public init(config: GatekeeperConfig? = nil) { + self.config = config } -} - -extension GatekeeperMiddleware: ServiceType { - public static func makeService(for container: Container) throws -> GatekeeperMiddleware { - return try .init(gatekeeper: container.make()) + + public func respond(to request: Request, chainingTo next: Responder) -> EventLoopFuture { + request + .gatekeeper(config: config) + .gatekeep(on: request) + .flatMap { next.respond(to: request) } } } diff --git a/Sources/Gatekeeper/GatekeeperProvider.swift b/Sources/Gatekeeper/GatekeeperProvider.swift deleted file mode 100644 index 551f7d6..0000000 --- a/Sources/Gatekeeper/GatekeeperProvider.swift +++ /dev/null @@ -1,33 +0,0 @@ -import Vapor - -public final class GatekeeperProvider { - - internal let config: GatekeeperConfig - internal let cacheFactory: ((Container) throws -> KeyedCache) - - public init( - config: GatekeeperConfig, - cacheFactory: @escaping ((Container) throws -> KeyedCache) = { container in try container.make() } - ) { - self.config = config - self.cacheFactory = cacheFactory - } -} - -extension GatekeeperProvider: Provider { - public func register(_ services: inout Services) throws { - services.register(config) - services.register( - Gatekeeper( - config: config, - cacheFactory: cacheFactory - ), - as: Gatekeeper.self - ) - services.register(GatekeeperMiddleware.self) - } - - public func didBoot(_ container: Container) throws -> EventLoopFuture { - return .done(on: container) - } -} diff --git a/Sources/Gatekeeper/KeyMaker/GatekeeperHostnameKeyMaker.swift b/Sources/Gatekeeper/KeyMaker/GatekeeperHostnameKeyMaker.swift new file mode 100644 index 0000000..fab0b12 --- /dev/null +++ b/Sources/Gatekeeper/KeyMaker/GatekeeperHostnameKeyMaker.swift @@ -0,0 +1,12 @@ +import Vapor + +/// Uses the hostname of the client to create a cache key. +public struct GatekeeperHostnameKeyMaker: GatekeeperKeyMaker { + public func make(for req: Request) -> EventLoopFuture { + guard let hostname = req.hostname else { + return req.eventLoop.future(error: Abort(.forbidden, reason: "Unable to verify peer")) + } + + return req.eventLoop.future("gatekeeper_" + hostname) + } +} diff --git a/Sources/Gatekeeper/KeyMaker/GatekeeperKeyMaker.swift b/Sources/Gatekeeper/KeyMaker/GatekeeperKeyMaker.swift new file mode 100644 index 0000000..54903f0 --- /dev/null +++ b/Sources/Gatekeeper/KeyMaker/GatekeeperKeyMaker.swift @@ -0,0 +1,6 @@ +import Vapor + +/// Reponsible for generating a cache key for a specific `Request` +public protocol GatekeeperKeyMaker { + func make(for req: Request) -> EventLoopFuture +} diff --git a/Sources/Gatekeeper/Request+Hostname.swift b/Sources/Gatekeeper/Request+Hostname.swift new file mode 100644 index 0000000..491d853 --- /dev/null +++ b/Sources/Gatekeeper/Request+Hostname.swift @@ -0,0 +1,7 @@ +import Vapor + +extension Request { + var hostname: String? { + headers.first(name: .xForwardedFor) ?? remoteAddress?.hostname + } +} diff --git a/Tests/GatekeeperTests/GatekeeperTests.swift b/Tests/GatekeeperTests/GatekeeperTests.swift index 60b1bf3..bca7120 100644 --- a/Tests/GatekeeperTests/GatekeeperTests.swift +++ b/Tests/GatekeeperTests/GatekeeperTests.swift @@ -1,122 +1,67 @@ import XCTest -import Vapor +import XCTVapor @testable import Gatekeeper class GatekeeperTests: XCTestCase { - func testGateKeeper() throws { - - let request = try Request.test( - gatekeeperConfig: GatekeeperConfig(maxRequests: 10, per: .minute), - peerName: "::1" - ) - - let gatekeeperMiddleware = try request.make(GatekeeperMiddleware.self) - - for i in 1...11 { - do { - _ = try gatekeeperMiddleware.respond(to: request, chainingTo: TestResponder()).wait() - XCTAssertTrue(i <= 10, "ran \(i) times.") - } catch let error as Abort { - switch error.status { - case .tooManyRequests: - //success - XCTAssertEqual(i, 11, "Should've failed after the 11th attempt.") - break - default: - XCTFail("Expected too many request: \(error)") + let app = Application(.testing) + defer { app.shutdown() } + app.gatekeeper.config = .init(maxRequests: 10, per: .second) + + app.grouped(GatekeeperMiddleware()).get("test") { req -> HTTPStatus in + return .ok + } + + for i in 1...10 { + try app.test(.GET, "test", headers: ["X-Forwarded-For": "::1"], afterResponse: { res in + if i == 10 { + XCTAssertEqual(res.status, .tooManyRequests) + } else { + XCTAssertEqual(res.status, .ok, "failed for request \(i) with status: \(res.status)") } - } catch { - XCTFail("Caught wrong error: \(error)") - } + }) } } - func testGateKeeperNoPeer() throws { - - let request = try Request.test( - gatekeeperConfig: GatekeeperConfig(maxRequests: 10, per: .minute), - peerName: nil - ) - - let gatekeeperMiddleware = try request.make(GatekeeperMiddleware.self) - - do { - _ = try gatekeeperMiddleware.respond(to: request, chainingTo: TestResponder()).wait() - XCTFail("Gatekeeper should throw") - } catch let error as Abort { - switch error.status { - case .forbidden: - //success - break - default: - XCTFail("Expected forbidden") - } - } catch { - XCTFail("Rate limiter failed: \(error)") + func testGateKeeperNoPeerReturnsForbidden() throws { + let app = Application(.testing) + defer { app.shutdown() } + app.gatekeeper.config = .init(maxRequests: 10, per: .second) + + app.grouped(GatekeeperMiddleware()).get("test") { req -> HTTPStatus in + return .ok } - } + try app.test(.GET, "test", afterResponse: { res in + XCTAssertEqual(res.status, .forbidden) + }) + } + func testGateKeeperCountRefresh() throws { - - let request = try Request.test( - gatekeeperConfig: GatekeeperConfig(maxRequests: 100, per: .second), - peerName: "192.168.1.2" - ) - - let gatekeeperMiddleware = try request.make(GatekeeperMiddleware.self) - + let app = Application(.testing) + defer { app.shutdown() } + app.gatekeeper.config = .init(maxRequests: 100, per: .second) + app.grouped(GatekeeperMiddleware()).get("test") { req -> HTTPStatus in + return .ok + } + for _ in 0..<50 { - do { - _ = try gatekeeperMiddleware.respond(to: request, chainingTo: TestResponder()).wait() - } catch { - XCTFail("Rate limiter failed: \(error)") - break - } + try app.test(.GET, "test", headers: ["X-Forwarded-For": "::1"], afterResponse: { res in + XCTAssertEqual(res.status, .ok) + }) } - let cache = try request.make(KeyedCache.self) - var entry = try cache.get("gatekeeper_192.168.1.2", as: Gatekeeper.Entry.self).wait() - XCTAssertEqual(entry!.requestsLeft, 50) + let entryBefore = try app.gatekeeper.caches.cache .get("gatekeeper_::1", as: Gatekeeper.Entry.self).wait() + XCTAssertEqual(entryBefore!.requestsLeft, 50) Thread.sleep(forTimeInterval: 1) - do { - _ = try gatekeeperMiddleware.respond(to: request, chainingTo: TestResponder()).wait() - } catch { - XCTFail("Rate limiter failed: \(error)") - } - - entry = try! cache.get("gatekeeper_192.168.1.2", as: Gatekeeper.Entry.self).wait() - XCTAssertEqual(entry!.requestsLeft, 99, "Requests left should've reset") - } - - func testGateKeeperWithCacheFactory() throws { - - let request = try Request.test( - gatekeeperConfig: GatekeeperConfig(maxRequests: 10, per: .minute), - peerName: "::1", - cacheFactory: { try $0.make(KeyedCache.self) } - ) + + try app.test(.GET, "test", headers: ["X-Forwarded-For": "::1"], afterResponse: { res in + XCTAssertEqual(res.status, .ok) + }) - let gatekeeperMiddleware = try request.make(GatekeeperMiddleware.self) - - for i in 1...11 { - do { - _ = try gatekeeperMiddleware.respond(to: request, chainingTo: TestResponder()).wait() - XCTAssertTrue(i <= 10, "ran \(i) times.") - } catch let error as Abort { - switch error.status { - case .tooManyRequests: - //success - XCTAssertEqual(i, 11, "Should've failed after the 11th attempt.") - break - default: - XCTFail("Expected too many request: \(error)") - } - } catch { - XCTFail("Caught wrong error: \(error)") - } - } + let entryAfter = try app.gatekeeper.caches.cache .get("gatekeeper_::1", as: Gatekeeper.Entry.self).wait() + XCTAssertEqual(entryAfter!.requestsLeft, 99, "Requests left should've reset") } func testRefreshIntervalValues() { @@ -132,4 +77,35 @@ class GatekeeperTests: XCTestCase { XCTAssertEqual(rate.refreshInterval, expected) } } + + func testGatekeeperUsesKeyMaker() throws { + struct DummyKeyMaker: GatekeeperKeyMaker { + func make(for req: Request) -> EventLoopFuture { + req.eventLoop.future("dummy") + } + } + + let app = Application(.testing) + defer { app.shutdown() } + app.gatekeeper.config = .init(maxRequests: 10, per: .second) + app.gatekeeper.keyMakers.use { _ in + DummyKeyMaker() + } + + app.grouped(GatekeeperMiddleware()).get("test") { req -> HTTPStatus in + return .ok + } + + try app.test(.GET, "test", headers: ["X-Forwarded-For": "::1"], afterResponse: { _ in }) + + let entry = try app.gatekeeper.caches.cache.get("dummy", as: Gatekeeper.Entry.self).wait() + XCTAssertNotNil(entry) + } + + func testGatekeeperDefaultProviders() throws { + let app = Application(.testing) + defer { app.shutdown() } + + XCTAssertTrue(app.gatekeeper.keyMakers.keyMaker is GatekeeperHostnameKeyMaker) + } } diff --git a/Tests/GatekeeperTests/Utilities/Request+test.swift b/Tests/GatekeeperTests/Utilities/Request+test.swift deleted file mode 100644 index 04d8088..0000000 --- a/Tests/GatekeeperTests/Utilities/Request+test.swift +++ /dev/null @@ -1,57 +0,0 @@ -import Gatekeeper -import HTTP -import Vapor - -extension Request { - static func test( - gatekeeperConfig: GatekeeperConfig, - url: URLRepresentable = "http://localhost:8080/test", - peerName: String? = "::1", - cacheFactory: ((Container) throws -> KeyedCache)? = nil - ) throws -> Request { - let config = Config() - - var services = Services() - services.register(KeyedCache.self) { container in - return MemoryKeyedCache() - } - - if let cacheFactory = cacheFactory { - try services.register( - GatekeeperProvider( - config: gatekeeperConfig, - cacheFactory: cacheFactory - ) - ) - } else { - try services.register( - GatekeeperProvider( - config: gatekeeperConfig - ) - ) - } - - services.register(GatekeeperMiddleware.self) - - let sharedThreadPool = BlockingIOThreadPool(numberOfThreads: 2) - sharedThreadPool.start() - services.register(sharedThreadPool) - - let app = try Application(config: config, environment: .testing, services: services) - let request = Request( - http: HTTPRequest( - method: .GET, - url: url - ), - using: app - ) - - var http = request.http - if let peerName = peerName { - http.headers.add(name: .init("X-Forwarded-For"), value: peerName) - } - request.http = http - - return request - } -} diff --git a/Tests/GatekeeperTests/Utilities/TestResponder.swift b/Tests/GatekeeperTests/Utilities/TestResponder.swift deleted file mode 100644 index 1c53793..0000000 --- a/Tests/GatekeeperTests/Utilities/TestResponder.swift +++ /dev/null @@ -1,7 +0,0 @@ -import Vapor - -public struct TestResponder: Responder { - public func respond(to req: Request) throws -> EventLoopFuture { - return req.future(req.response()) - } -} diff --git a/Tests/GatekeeperTests/XCTestManifests.swift b/Tests/GatekeeperTests/XCTestManifests.swift deleted file mode 100644 index eba17e3..0000000 --- a/Tests/GatekeeperTests/XCTestManifests.swift +++ /dev/null @@ -1,18 +0,0 @@ -import XCTest - -extension GatekeeperTests { - static let __allTests = [ - ("testGateKeeper", testGateKeeper), - ("testGateKeeperNoPeer", testGateKeeperNoPeer), - ("testGateKeeperCountRefresh", testGateKeeperCountRefresh), - ("testRefreshIntervalValues", testRefreshIntervalValues), - ] -} - -#if !os(macOS) -public func __allTests() -> [XCTestCaseEntry] { - return [ - testCase(GatekeeperTests.__allTests), - ] -} -#endif diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift deleted file mode 100644 index 4a838d0..0000000 --- a/Tests/LinuxMain.swift +++ /dev/null @@ -1,8 +0,0 @@ -import XCTest - -import GatekeeperTests - -var tests = [XCTestCaseEntry]() -tests += GatekeeperTests.__allTests() - -XCTMain(tests)