Skip to content

Commit e9c238d

Browse files
committed
Introduce Async::Redis::Endpoint.
- Handles authentication and database selection.
1 parent 95e8a8c commit e9c238d

File tree

6 files changed

+316
-15
lines changed

6 files changed

+316
-15
lines changed

lib/async/redis/client.rb

+2-6
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
require_relative 'context/pipeline'
1111
require_relative 'context/transaction'
1212
require_relative 'context/subscribe'
13-
require_relative 'protocol/resp2'
13+
require_relative 'endpoint'
1414

1515
require 'io/endpoint/host_endpoint'
1616
require 'async/pool/controller'
@@ -23,14 +23,10 @@ module Redis
2323
# Legacy.
2424
ServerError = ::Protocol::Redis::ServerError
2525

26-
def self.local_endpoint(port: 6379)
27-
::IO::Endpoint.tcp('localhost', port)
28-
end
29-
3026
class Client
3127
include ::Protocol::Redis::Methods
3228

33-
def initialize(endpoint = Redis.local_endpoint, protocol: Protocol::RESP2, **options)
29+
def initialize(endpoint = Endpoint.local, protocol: endpoint.protocol, **options)
3430
@endpoint = endpoint
3531
@protocol = protocol
3632

lib/async/redis/endpoint.rb

+250
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
# frozen_string_literal: true
2+
3+
# Released under the MIT License.
4+
# Copyright, 2024, by Samuel Williams.
5+
6+
require 'io/endpoint'
7+
require 'io/endpoint/host_endpoint'
8+
require 'io/endpoint/ssl_endpoint'
9+
10+
require_relative 'protocol/resp2'
11+
require_relative 'protocol/authenticated'
12+
require_relative 'protocol/selected'
13+
14+
module Async
15+
module Redis
16+
def self.local_endpoint(**options)
17+
Endpoint.local(**options)
18+
end
19+
20+
# Represents a way to connect to a remote Redis server.
21+
class Endpoint < ::IO::Endpoint::Generic
22+
LOCALHOST = URI.parse("redis://localhost").freeze
23+
24+
def self.local(**options)
25+
self.new(LOCALHOST, **options)
26+
end
27+
28+
SCHEMES = {
29+
'redis' => URI::Generic,
30+
'rediss' => URI::Generic,
31+
}
32+
33+
def self.parse(string, endpoint = nil, **options)
34+
url = URI.parse(string).normalize
35+
36+
return self.new(url, endpoint, **options)
37+
end
38+
39+
# Construct an endpoint with a specified scheme, hostname, optional path, and options.
40+
#
41+
# @parameter scheme [String] The scheme to use, e.g. "redis" or "rediss".
42+
# @parameter hostname [String] The hostname to connect to (or bind to).
43+
# @parameter *options [Hash] Additional options, passed to {#initialize}.
44+
def self.for(scheme, hostname, credentials: nil, port: nil, database: nil, **options)
45+
uri_klass = SCHEMES.fetch(scheme.downcase) do
46+
raise ArgumentError, "Unsupported scheme: #{scheme.inspect}"
47+
end
48+
49+
if database
50+
path = "/#{database}"
51+
end
52+
53+
self.new(
54+
uri_klass.new(scheme, credentials&.join(":"), hostname, port, nil, path, nil, nil, nil).normalize,
55+
**options
56+
)
57+
end
58+
59+
# Coerce the given object into an endpoint.
60+
# @parameter url [String | Endpoint] The URL or endpoint to convert.
61+
def self.[](object)
62+
if object.is_a?(self)
63+
return object
64+
else
65+
self.parse(object.to_s)
66+
end
67+
end
68+
69+
# @option scheme [String] the scheme to use, overrides the URL scheme.
70+
# @option hostname [String] the hostname to connect to (or bind to), overrides the URL hostname (used for SNI).
71+
# @option port [Integer] the port to bind to, overrides the URL port.
72+
# @option ssl_context [OpenSSL::SSL::SSLContext] the context to use for TLS.
73+
# @option alpn_protocols [Array<String>] the alpn protocols to negotiate.
74+
def initialize(url, endpoint = nil, **options)
75+
super(**options)
76+
77+
raise ArgumentError, "URL must be absolute (include scheme, host): #{url}" unless url.absolute?
78+
79+
@url = url
80+
81+
if endpoint
82+
@endpoint = self.build_endpoint(endpoint)
83+
else
84+
@endpoint = nil
85+
end
86+
end
87+
88+
def to_url
89+
url = @url.dup
90+
91+
unless default_port?
92+
url.port = self.port
93+
end
94+
95+
return url
96+
end
97+
98+
def to_s
99+
"\#<#{self.class} #{self.to_url} #{@options}>"
100+
end
101+
102+
def inspect
103+
"\#<#{self.class} #{self.to_url} #{@options.inspect}>"
104+
end
105+
106+
attr :url
107+
108+
def address
109+
endpoint.address
110+
end
111+
112+
def secure?
113+
['rediss'].include?(self.scheme)
114+
end
115+
116+
def protocol
117+
protocol = @options.fetch(:protocol, Protocol::RESP2)
118+
119+
if database = self.database
120+
protocol = Protocol::Selected.new(database, protocol)
121+
end
122+
123+
if credentials = self.credentials
124+
protocol = Protocol::Authenticated.new(credentials, protocol)
125+
end
126+
127+
return protocol
128+
end
129+
130+
def default_port
131+
6379
132+
end
133+
134+
def default_port?
135+
port == default_port
136+
end
137+
138+
def port
139+
@options[:port] || @url.port || default_port
140+
end
141+
142+
# The hostname is the server we are connecting to:
143+
def hostname
144+
@options[:hostname] || @url.hostname
145+
end
146+
147+
def scheme
148+
@options[:scheme] || @url.scheme
149+
end
150+
151+
def database
152+
@options[:database] || @url.path[1..-1].to_i
153+
end
154+
155+
def credentials
156+
@options[:credentials] || @url.userinfo&.split(":")
157+
end
158+
159+
def localhost?
160+
@url.hostname =~ /^(.*?\.)?localhost\.?$/
161+
end
162+
163+
# We don't try to validate peer certificates when talking to localhost because they would always be self-signed.
164+
def ssl_verify_mode
165+
if self.localhost?
166+
OpenSSL::SSL::VERIFY_NONE
167+
else
168+
OpenSSL::SSL::VERIFY_PEER
169+
end
170+
end
171+
172+
def ssl_context
173+
@options[:ssl_context] || OpenSSL::SSL::SSLContext.new.tap do |context|
174+
context.set_params(
175+
verify_mode: self.ssl_verify_mode
176+
)
177+
end
178+
end
179+
180+
def build_endpoint(endpoint = nil)
181+
endpoint ||= tcp_endpoint
182+
183+
if secure?
184+
# Wrap it in SSL:
185+
return ::IO::Endpoint::SSLEndpoint.new(endpoint,
186+
ssl_context: self.ssl_context,
187+
hostname: @url.hostname,
188+
timeout: self.timeout,
189+
)
190+
end
191+
192+
return endpoint
193+
end
194+
195+
def endpoint
196+
@endpoint ||= build_endpoint
197+
end
198+
199+
def endpoint=(endpoint)
200+
@endpoint = build_endpoint(endpoint)
201+
end
202+
203+
def bind(*arguments, &block)
204+
endpoint.bind(*arguments, &block)
205+
end
206+
207+
def connect(&block)
208+
endpoint.connect(&block)
209+
end
210+
211+
def each
212+
return to_enum unless block_given?
213+
214+
self.tcp_endpoint.each do |endpoint|
215+
yield self.class.new(@url, endpoint, **@options)
216+
end
217+
end
218+
219+
def key
220+
[@url, @options]
221+
end
222+
223+
def eql? other
224+
self.key.eql? other.key
225+
end
226+
227+
def hash
228+
self.key.hash
229+
end
230+
231+
protected
232+
233+
def tcp_options
234+
options = @options.dup
235+
236+
options.delete(:scheme)
237+
options.delete(:port)
238+
options.delete(:hostname)
239+
options.delete(:ssl_context)
240+
options.delete(:protocol)
241+
242+
return options
243+
end
244+
245+
def tcp_endpoint
246+
::IO::Endpoint.tcp(self.hostname, port, **tcp_options)
247+
end
248+
end
249+
end
250+
end

lib/async/redis/protocol/authenticated.rb

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class AuthenticationError < StandardError
1818
#
1919
# @parameter credentials [Array] The credentials to use for authentication.
2020
# @parameter protocol [Object] The delegated protocol for connecting.
21-
def initialize(credentials, protocol: Async::Redis::Protocol::RESP2)
21+
def initialize(credentials, protocol = Async::Redis::Protocol::RESP2)
2222
@credentials = credentials
2323
@protocol = protocol
2424
end

lib/async/redis/protocol/selected.rb

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class SelectionError < StandardError
1818
#
1919
# @parameter index [Integer] The database index to select.
2020
# @parameter protocol [Object] The delegated protocol for connecting.
21-
def initialize(index, protocol: Async::Redis::Protocol::RESP2)
21+
def initialize(index, protocol = Async::Redis::Protocol::RESP2)
2222
@index = index
2323
@protocol = protocol
2424
end

test/async/redis/disconnect.rb

+19-7
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,30 @@
1010

1111
describe Async::Redis::Client do
1212
include Sus::Fixtures::Async::ReactorContext
13-
14-
let(:endpoint) {::IO::Endpoint.tcp('localhost', 5555)}
15-
13+
14+
# Intended to not be connected:
15+
let(:endpoint) {Async::Redis::Endpoint.local(port: 5555)}
16+
17+
before do
18+
@server_endpoint = ::IO::Endpoint.tcp("localhost").bound
19+
end
20+
21+
after do
22+
@server_endpoint&.close
23+
end
24+
1625
it "should raise error on unexpected disconnect" do
17-
server_task = reactor.async do
18-
endpoint.accept do |connection|
26+
server_task = Async do
27+
@server_endpoint.accept do |connection|
1928
connection.read(8)
2029
connection.close
2130
end
2231
end
23-
24-
client = Async::Redis::Client.new(endpoint)
32+
33+
client = Async::Redis::Client.new(
34+
@server_endpoint.local_address_endpoint,
35+
protocol: Async::Redis::Protocol::RESP2,
36+
)
2537

2638
expect do
2739
client.call("GET", "test")

test/async/redis/endpoint.rb

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# frozen_string_literal: true
2+
3+
# Released under the MIT License.
4+
# Copyright, 2024, by Samuel Williams.
5+
6+
require 'async/redis/client'
7+
require 'async/redis/protocol/authenticated'
8+
require 'sus/fixtures/async'
9+
10+
describe Async::Redis::Protocol::Authenticated do
11+
include Sus::Fixtures::Async::ReactorContext
12+
13+
let(:endpoint) {Async::Redis.local_endpoint}
14+
let(:credentials) {["testuser", "testpassword"]}
15+
let(:protocol) {subject.new(credentials)}
16+
let(:client) {Async::Redis::Client.new(endpoint, protocol: protocol)}
17+
18+
before do
19+
# Setup ACL user with limited permissions for testing.
20+
admin_client = Async::Redis::Client.new(endpoint)
21+
admin_client.call("ACL", "SETUSER", "testuser", "on", ">" + credentials[1], "+ping", "+auth")
22+
ensure
23+
admin_client.close
24+
end
25+
26+
after do
27+
# Cleanup ACL user after tests.
28+
admin_client = Async::Redis::Client.new(endpoint)
29+
admin_client.call("ACL", "DELUSER", "testuser")
30+
admin_client.close
31+
end
32+
33+
it "can authenticate and send allowed commands" do
34+
response = client.call("PING")
35+
expect(response).to be == "PONG"
36+
end
37+
38+
it "rejects commands not allowed by ACL" do
39+
expect do
40+
client.call("SET", "key", "value")
41+
end.to raise_exception(Protocol::Redis::ServerError, message: be =~ /NOPERM/)
42+
end
43+
end

0 commit comments

Comments
 (0)