1
1
import datetime
2
+ import json
2
3
4
+ import sqlparse
5
+ import websocket # Using websocket-client library for synchronous operations
3
6
from django .db import IntegrityError , DatabaseError
4
7
from django .db .backends .sqlite3 .base import DatabaseWrapper as SQLiteDatabaseWrapper
5
- from django .db .backends .sqlite3 .features import DatabaseFeatures as SQLiteDatabaseFeatures
6
- from django .db .backends .sqlite3 .operations import DatabaseOperations as SQLiteDatabaseOperations
7
- from django .db .backends .sqlite3 .schema import DatabaseSchemaEditor as SQLiteDatabaseSchemaEditor
8
8
from django .db .backends .sqlite3 .client import DatabaseClient as SQLiteDatabaseClient
9
9
from django .db .backends .sqlite3 .creation import DatabaseCreation as SQLiteDatabaseCreation
10
+ from django .db .backends .sqlite3 .features import DatabaseFeatures as SQLiteDatabaseFeatures
10
11
from django .db .backends .sqlite3 .introspection import DatabaseIntrospection as SQLiteDatabaseIntrospection
11
-
12
- import websocket # Using websocket-client library for synchronous operations
13
- import json
14
-
12
+ from django .db .backends .sqlite3 .operations import DatabaseOperations as SQLiteDatabaseOperations
13
+ from django .db .backends .sqlite3 .schema import DatabaseSchemaEditor as SQLiteDatabaseSchemaEditor
15
14
from django .utils import timezone
15
+ from sqlparse .sql import IdentifierList , Identifier
16
+ from sqlparse .tokens import DML
16
17
17
18
18
19
class DatabaseFeatures (SQLiteDatabaseFeatures ):
@@ -27,33 +28,63 @@ def _quote_columns(self, sql):
27
28
"""
28
29
Ensure column names are properly quoted and aliased to avoid collisions.
29
30
"""
30
- # Split the SQL to find the SELECT and FROM clauses
31
- select_start = sql .lower ().find ('select' )
32
- from_start = sql .lower ().find ('from' )
33
-
34
- if select_start == - 1 or from_start == - 1 :
35
- return sql # Not a SELECT query, skip processing
36
-
37
- # Extract the columns between SELECT and FROM
38
- columns_section = sql [select_start + len ('select' ):from_start ].strip ()
39
- columns = columns_section .split (',' )
40
-
41
- # Quote and alias columns
42
- aliased_columns = []
43
- for column in columns :
44
- column = column .strip ()
45
-
46
- if '.' in column : # It's a "table.column" format
47
- table , col = column .split ('.' )
48
- aliased_columns .append ((f'{ self .quote_name (table )} .{ self .quote_name (col )} AS { table } _{ col } ' ).replace ('"' , '' ))
31
+ parsed = sqlparse .parse (sql )
32
+ if not parsed :
33
+ return sql # Unable to parse, return original SQL
34
+
35
+ stmt = parsed [0 ]
36
+ new_tokens = []
37
+ select_seen = False
38
+
39
+ for token in stmt .tokens :
40
+ if token .ttype is DML and token .value .upper () == 'SELECT' :
41
+ select_seen = True
42
+ new_tokens .append (token )
43
+ elif select_seen :
44
+ if isinstance (token , IdentifierList ):
45
+ # Process each identifier in the SELECT clause
46
+ new_identifiers = []
47
+ for identifier in token .get_identifiers ():
48
+ new_identifiers .append (self ._process_identifier (identifier ))
49
+ # Rebuild the IdentifierList
50
+ new_token = IdentifierList (new_identifiers )
51
+ new_tokens .append (new_token )
52
+ elif isinstance (token , Identifier ):
53
+ # Single column without commas
54
+ new_token = self ._process_identifier (token )
55
+ new_tokens .append (new_token )
56
+ else :
57
+ new_tokens .append (token )
58
+ select_seen = False # Assuming SELECT clause is only once
49
59
else :
50
- aliased_columns .append (column )
60
+ new_tokens .append (token )
51
61
52
- # Rebuild the SQL with quoted and aliased columns
53
- new_columns_section = ', ' .join (aliased_columns )
54
- new_sql = f'SELECT { new_columns_section } { sql [from_start :]} '
62
+ # Reconstruct the SQL statement
63
+ new_sql = '' .join (str (token ) for token in new_tokens )
55
64
return new_sql
56
65
66
+ def _process_identifier (self , identifier ):
67
+ # Get the real name and alias if present
68
+ real_name = identifier .get_real_name ()
69
+ alias = identifier .get_alias ()
70
+ parent_name = identifier .get_parent_name ()
71
+
72
+ if real_name :
73
+ if parent_name :
74
+ # Format: table.column
75
+ table = self .quote_name (parent_name )
76
+ column = self .quote_name (real_name )
77
+ new_alias = f"{ parent_name } _{ real_name } "
78
+ return f"{ table } .{ column } AS { new_alias } "
79
+ else :
80
+ # Simple column
81
+ column = self .quote_name (real_name )
82
+ new_alias = real_name
83
+ return f"{ column } AS { new_alias } "
84
+ else :
85
+ # Complex expression (e.g., functions), return as is
86
+ return str (identifier )
87
+
57
88
def _format_params (self , sql , params ):
58
89
def quote_param (param ):
59
90
if isinstance (param , str ):
@@ -95,7 +126,6 @@ def _parse_datetime(self, value):
95
126
96
127
return value # If it's not a datetime string, return the original value
97
128
98
-
99
129
def _convert_results (self , results ):
100
130
"""
101
131
Convert any datetime strings in the result set to actual timezone-aware datetime objects.
@@ -244,7 +274,6 @@ def bulk_insert_sql(self, fields, placeholder_rows):
244
274
return "VALUES " + values_sql
245
275
246
276
247
-
248
277
class DatabaseWrapper (SQLiteDatabaseWrapper ):
249
278
vendor = 'websocket'
250
279
@@ -281,7 +310,8 @@ def get_new_connection(self, conn_params):
281
310
header = headers
282
311
)
283
312
except Exception as e :
284
- raise ValueError ("Unable to connect to websocket, please check your credentials and make sure you have websockets enabled in your domain" )
313
+ raise ValueError (
314
+ "Unable to connect to websocket, please check your credentials and make sure you have websockets enabled in your domain" )
285
315
286
316
return self ._websocket
287
317
0 commit comments