-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathPostgresClientImpl.scala
322 lines (271 loc) · 10.1 KB
/
PostgresClientImpl.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
package com.twitter.finagle.postgres
import java.nio.charset.{Charset, StandardCharsets}
import java.util.concurrent.atomic.AtomicInteger
import com.twitter.cache.Refresh
import com.twitter.concurrent.AsyncStream
import com.twitter.conversions.DurationOps._
import com.twitter.finagle.postgres.messages._
import com.twitter.finagle.postgres.values._
import com.twitter.finagle.{Service, ServiceFactory, Status}
import com.twitter.logging.Logger
import com.twitter.util._
import io.netty.buffer.ByteBuf
import scala.language.{existentials, implicitConversions}
import scala.util.Random
/*
* A Finagle client for communicating with Postgres.
*/
class PostgresClientImpl(
factory: ServiceFactory[PgRequest, PgResponse],
id:String,
types: Option[Map[Int, PostgresClient.TypeSpecifier]] = None,
receiveFunctions: PartialFunction[String, ValueDecoder[T] forSome {type T}],
binaryResults: Boolean = false,
binaryParams: Boolean = false
) extends PostgresClient {
private[this] val counter = new AtomicInteger(0)
private[this] val logger = Logger(getClass.getName)
private val resultFormats = if(binaryResults) Seq(1) else Seq(0)
private val paramFormats = if(binaryParams) Seq(1) else Seq(0)
val charset = StandardCharsets.UTF_8
private def retrieveTypeMap() = {
//get a mapping of OIDs to the name of the receive function for all types in the remote DB.
//typreceive is the most reliable way to determine how a type should be decoded
val customTypesQuery =
"""
|SELECT DISTINCT
| CAST(t.typname AS text) AS type,
| CAST(t.oid AS integer) AS oid,
| CAST(t.typreceive AS text) AS typreceive,
| CAST(t.typelem AS integer) AS typelem
|FROM pg_type t
|WHERE CAST(t.typreceive AS text) <> '-'
""".stripMargin
val serviceF = factory.apply
def extractTypes(response: PgResponse): Future[Map[Int, PostgresClient.TypeSpecifier]] =
response match {
case SelectResult(fields, rows) =>
val rowValues = ResultSet(fields, charset, rows, PostgresClient.defaultTypes, receiveFunctions).rows
rowValues.map {
row =>
row.get[Int]("oid") -> PostgresClient.TypeSpecifier(
row.get[String]("typreceive"),
row.get[String]("type"),
row.get[Int]("typelem"))
}.toSeq().map(_.toMap)
}
val customTypesResult = for {
service <- serviceF
response <- service.apply(PgRequest(Query(customTypesQuery)))
types <- extractTypes(response)
} yield types
customTypesResult.ensure {
serviceF.foreach(_.close())
}
customTypesResult
}
private[postgres] val typeMap = Refresh.every(1.hour) {
types.map(Future(_)).getOrElse(retrieveTypeMap())
}
// The OIDs to be used when sending parameters
private[postgres] val encodeOids =
typeMap().map {
tm =>
tm.toIndexedSeq.map {
case (oid, PostgresClient.TypeSpecifier(receiveFn, typeName, elemOid)) => typeName -> oid
}.groupBy(_._1).mapValues(_.map(_._2).min)
}
/*
* Execute some actions inside of a transaction using a single connection
*/
override def inTransaction[T](fn: PostgresClient => Future[T]): Future[T] = for {
types <- typeMap()
service <- factory()
constFactory = ServiceFactory.const(service)
id = Random.alphanumeric.take(28).mkString
transactionalClient = new PostgresClientImpl(constFactory, id, Some(types), receiveFunctions, binaryResults, binaryParams)
closeTransaction = () => transactionalClient.close().ensure(constFactory.close().ensure(service.close()))
completeTransactionQuery = (sql: String) => transactionalClient.query(sql).ensure(closeTransaction())
_ <- transactionalClient.query("BEGIN").onFailure(_ => closeTransaction())
result <- fn(transactionalClient).rescue {
case err => for {
_ <- completeTransactionQuery("ROLLBACK")
_ <- Future.exception(err)
} yield null.asInstanceOf[T]
}
_ <- completeTransactionQuery("COMMIT")
} yield result
/*
* Issue an arbitrary SQL query and get the response.
*/
override def query(sql: String): Future[QueryResponse] = sendQuery(sql) {
case SelectResult(fields, rows) => typeMap().map {
types => ResultSet(fields, charset, rows, types, receiveFunctions)
}
case CommandCompleteResponse(affected) => Future(OK(affected))
}
/*
* Issue a single SELECT query and get the response.
*/
override def fetch(sql: String): Future[SelectResult] = sendQuery(sql) {
case rs: SelectResult => Future(rs)
}
/*
* Execute an update command (e.g., INSERT, DELETE) and get the response.
*/
override def executeUpdate(sql: String): Future[OK] = sendQuery(sql) {
case CommandCompleteResponse(rows) => Future(OK(rows))
}
override def execute(sql: String): Future[OK] = executeUpdate(sql)
/*
* Run a single SELECT query and wrap the results with the provided function.
*/
override def selectToStream[T](sql: String)(f: Row => T): AsyncStream[T] =
AsyncStream.fromFuture {
for {
types <- typeMap()
SelectResult(fields, rows) <- fetch(sql)
} yield ResultSet(fields, charset, rows, types, receiveFunctions).rows.map(f)
}.flatten
/*
* Issue a single, prepared SELECT query and wrap the response rows with the provided function.
*/
override def prepareAndQueryToStream[T](sql: String, params: Param[_]*)(f: Row => T): AsyncStream[T] =
AsyncStream.fromFuture {
typeMap().flatMap { _ =>
for {
service <- factory()
statement = new PreparedStatementImpl(sql, service)
result <- statement.selectToStream(params: _*)(f)
} yield result
}
}.flatten
/*
* Issue a single, prepared arbitrary query without an expected result set, and provide the affected row count
*/
override def prepareAndExecute(sql: String, params: Param[_]*): Future[Int] = {
typeMap().flatMap { _ =>
for {
service <- factory()
statement = new PreparedStatementImpl(sql, service)
OK(count) <- statement.exec(params: _*)
} yield count
}
}
/**
* Close the underlying connection pool and make this Client eternally down
* @return
*/
override def close(): Future[Unit] = {
factory.close()
}
/**
* The current availability [[Status]] of this client.
*/
override def status: Status = factory.status
/**
* Determines whether this client is available (can accept requests
* with a reasonable likelihood of success).
*/
override def isAvailable: Boolean = status == Status.Open
private[this] def sendQuery[T](sql: String)(handler: PartialFunction[PgResponse, Future[T]]) = {
send(PgRequest(Query(sql)))(handler)
}
private[this] def send[T](
r: PgRequest,
optionalService: Option[Service[PgRequest, PgResponse]] = None)(
handler: PartialFunction[PgResponse, Future[T]]
) = {
val service = optionalService.getOrElse(factory.toService)
service(r).flatMap (handler orElse {
case unexpected => Future.exception(new IllegalStateException(s"Unexpected response $unexpected"))
})
}
private[this] class PreparedStatementImpl(
sql: String,
service: Service[PgRequest, PgResponse]
) extends PreparedStatement {
private[this] val name = s"fin-pg-$id-" + counter.incrementAndGet
def closeService = service.close()
private[this] def parse(params: Param[_]*): Future[Unit] = {
val paramTypes = encodeOids.map {
oidMap => params.map {
param => oidMap.getOrElse(param.encoder.typeName, 0)
}
}
paramTypes.flatMap {
types =>
val req = Parse(name, sql, types)
send(PgRequest(req, flush = true), Some(service)) {
case ParseCompletedResponse => Future.value(())
}
}
}
private[this] def bind(params: Seq[ByteBuf]): Future[Unit] = {
val req = Bind(
portal = name,
name = name,
formats = paramFormats,
params = params,
resultFormats = resultFormats
)
send(PgRequest(req, flush = true), Some(service)) {
case BindCompletedResponse => Future.value(())
}
}
private[this] def describe(): Future[Array[Field]] = {
val req = PgRequest(Describe(portal = true, name = name), flush = true)
send(req, Some(service)) {
case RowDescriptions(fields) => Future.value(fields)
}
}
private[this] def execute(
maxRows: Int = 0
) = {
val req = PgRequest(Execute(name, maxRows), flush = true)
send(req, Some(service)) {
case rep => Future.value(rep)
}
}
private[this] def sync(
optionalService: Option[Service[PgRequest, PgResponse]] = None
): Future[Unit] = send(PgRequest(Sync), optionalService) {
case ReadyForQueryResponse => Future.value(())
}
override def fire(params: Param[_]*): Future[QueryResponse] = {
val paramBuffers = if(binaryParams) {
params.map {
p => p.encodeBinary(StandardCharsets.UTF_8)
}
} else {
params.map {
p => p.encodeText(StandardCharsets.UTF_8)
}
}
val f = for {
types <- typeMap()
pname <- parse(params: _*)
_ <- bind(paramBuffers)
fields <- describe()
exec <- execute()
} yield exec match {
case CommandCompleteResponse(rows) => OK(rows)
case Rows(rows) =>
ResultSet(fields, charset, rows, types, receiveFunctions)
}
f.transform {
result =>
sync(Some(service)).flatMap {
_ => Future.const(result)
}
}.ensure(service.close())
}
}
}
case class Param[T](value: T)(implicit val encoder: ValueEncoder[T]) {
def encodeText(charset: Charset = StandardCharsets.UTF_8) = ValueEncoder.encodeText(value, encoder, charset)
def encodeBinary(charset: Charset = StandardCharsets.UTF_8) = ValueEncoder.encodeBinary(value, encoder, charset)
}
object Param {
implicit def convert[T : ValueEncoder](t: T): Param[T] = Param(t)
}