1
1
# -*- coding: utf-8 -*-
2
2
from __future__ import absolute_import
3
- import inspect
3
+ import mongoengine , inspect
4
4
5
5
from flask import abort , current_app
6
+ from mongoengine .base .fields import BaseField
7
+ from mongoengine .queryset import (MultipleObjectsReturned ,
8
+ DoesNotExist , QuerySet )
6
9
7
- import mongoengine
8
-
9
- if mongoengine .__version__ == '0.7.10' :
10
- from mongoengine .base import BaseField
11
- else :
12
- from mongoengine .base .fields import BaseField
13
-
14
-
15
- from mongoengine .queryset import MultipleObjectsReturned , DoesNotExist , QuerySet
16
10
from mongoengine .base import ValidationError
17
-
18
11
from pymongo import uri_parser
19
-
20
12
from .sessions import *
21
13
from .pagination import *
22
14
from .metadata import *
23
- from .json import overide_json_encoder
15
+ from .json import override_json_encoder
24
16
from .wtf import WtfBaseField
17
+ from .connection import *
18
+ import flask_mongoengine
25
19
26
- def _patch_base_field (object , name ):
20
+ def redirect_connection_calls (cls ):
21
+ """
22
+ Redirect mongonengine.connection
23
+ calls via flask_mongoengine.connection
24
+ """
25
+
26
+ # Proxy all 'mongoengine.connection'
27
+ # specific attr via 'flask_mongoengine'
28
+ connection_methods = {
29
+ 'get_db' : get_db ,
30
+ 'DEFAULT_CONNECTION_NAME' : DEFAULT_CONNECTION_NAME ,
31
+ 'get_connection' : get_connection
32
+ }
33
+
34
+ cls_module = inspect .getmodule (cls )
35
+ if cls_module != mongoengine .connection :
36
+ for attr in inspect .getmembers (cls_module ):
37
+ n = attr [0 ]
38
+ if connection_methods .get (n , None ):
39
+ setattr (cls_module , n , connection_methods .get (n , None ))
40
+
41
+ def _patch_base_field (obj , name ):
27
42
"""
28
43
If the object submitted has a class whose base class is
29
44
mongoengine.base.fields.BaseField, then monkey patch to
@@ -36,12 +51,12 @@ def _patch_base_field(object, name):
36
51
@see: flask_mongoengine.wtf.base.WtfBaseField.
37
52
@see: model_form in flask_mongoengine.wtf.orm
38
53
39
- @param object: The object whose footprint to locate the class.
54
+ @param obj: The object whose footprint to locate the class.
40
55
@param name: Name of the class to locate.
41
56
42
57
"""
43
58
# locate class
44
- cls = getattr (object , name )
59
+ cls = getattr (obj , name )
45
60
if not inspect .isclass (cls ):
46
61
return
47
62
@@ -57,9 +72,9 @@ def _patch_base_field(object, name):
57
72
58
73
# re-assign class back to
59
74
# object footprint
60
- delattr (object , name )
61
- setattr (object , name , cls )
62
-
75
+ delattr (obj , name )
76
+ setattr (obj , name , cls )
77
+ redirect_connection_calls ( cls )
63
78
64
79
def _include_mongoengine (obj ):
65
80
for module in mongoengine , mongoengine .fields :
@@ -70,30 +85,19 @@ def _include_mongoengine(obj):
70
85
# patch BaseField if available
71
86
_patch_base_field (obj , key )
72
87
73
-
74
- def _create_connection (conn_settings ):
75
-
76
- # Handle multiple connections recursively
77
- if isinstance (conn_settings , list ):
78
- connections = {}
79
- for conn in conn_settings :
80
- connections [conn .get ('alias' )] = _create_connection (conn )
81
- return connections
82
-
83
- # Ugly dict comprehention in order to support python 2.6
84
- conn = dict ((k .lower (), v ) for k , v in conn_settings .items () if v is not None )
85
-
86
- if 'replicaset' in conn :
87
- conn ['replicaSet' ] = conn .pop ('replicaset' )
88
-
89
- # Handle uri style connections
90
- if "://" in conn .get ('host' , '' ):
91
- uri_dict = uri_parser .parse_uri (conn ['host' ])
92
- conn ['db' ] = uri_dict ['database' ]
93
-
94
- return mongoengine .connect (conn .pop ('db' , 'test' ), ** conn )
95
-
96
-
88
+ def current_mongoengine_instance ():
89
+ """
90
+ Obtain instance of MongoEngine in the
91
+ current working app instance.
92
+ """
93
+ me = current_app .extensions .get ('mongoengine' , None )
94
+ if current_app and me :
95
+ instance_dict = me .items ()\
96
+ if (sys .version_info >= (3 , 0 )) else me .iteritems ()
97
+ for k , v in instance_dict :
98
+ if isinstance (k , MongoEngine ):
99
+ return k
100
+ return None
97
101
98
102
class MongoEngine (object ):
99
103
@@ -107,11 +111,10 @@ def __init__(self, app=None, config=None):
107
111
self .init_app (app , config )
108
112
109
113
def init_app (self , app , config = None ):
110
-
111
114
app .extensions = getattr (app , 'extensions' , {})
112
115
113
116
# Make documents JSON serializable
114
- overide_json_encoder (app )
117
+ override_json_encoder (app )
115
118
116
119
if not 'mongoengine' in app .extensions :
117
120
app .extensions ['mongoengine' ] = {}
@@ -122,27 +125,30 @@ def init_app(self, app, config=None):
122
125
raise Exception ('Extension already initialized' )
123
126
124
127
if not config :
125
- # If not passed a config then we read the connection settings
126
- # from the app config.
128
+ # If not passed a config then we
129
+ # read the connection settings from
130
+ # the app config.
127
131
config = app .config
128
132
129
- if 'MONGODB_SETTINGS' in config :
130
- # Connection settings provided as a dictionary.
131
- connection = _create_connection (config ['MONGODB_SETTINGS' ])
133
+ # Obtain db connection
134
+ connection = create_connection (config )
135
+
136
+ # Store objects in application instance
137
+ # so that multiple apps do not end up
138
+ # accessing the same objects.
139
+ s = {'app' : app , 'conn' : connection }
140
+ app .extensions ['mongoengine' ][self ] = s
141
+
142
+ def disconnect (self ):
143
+ conn_settings = fetch_connection_settings (current_app .config )
144
+ if isinstance (conn_settings , list ):
145
+ for setting in conn_settings :
146
+ alias = setting .get ('alias' , DEFAULT_CONNECTION_NAME )
147
+ disconnect (alias , setting .get ('preserve_temp_db' , False ))
132
148
else :
133
- # Connection settings provided in standard format.
134
- settings = {'alias' : config .get ('MONGODB_ALIAS' , None ),
135
- 'db' : config .get ('MONGODB_DB' , None ),
136
- 'host' : config .get ('MONGODB_HOST' , None ),
137
- 'password' : config .get ('MONGODB_PASSWORD' , None ),
138
- 'port' : config .get ('MONGODB_PORT' , None ),
139
- 'username' : config .get ('MONGODB_USERNAME' , None )}
140
- connection = _create_connection (settings )
141
-
142
- # Store objects in application instance so that multiple apps do
143
- # not end up accessing the same objects.
144
- app .extensions ['mongoengine' ][self ] = {'app' : app ,
145
- 'conn' : connection }
149
+ alias = conn_settings .get ('alias' , DEFAULT_CONNECTION_NAME )
150
+ disconnect (alias , conn_settings .get ('preserve_temp_db' , False ))
151
+ return True
146
152
147
153
@property
148
154
def connection (self ):
@@ -179,7 +185,6 @@ def paginate_field(self, field_name, doc_id, page, per_page,
179
185
return ListFieldPagination (self , doc_id , field_name , page , per_page ,
180
186
total = total )
181
187
182
-
183
188
class Document (mongoengine .Document ):
184
189
"""Abstract document with extra helpers in the queryset class"""
185
190
0 commit comments