@@ -90,7 +90,7 @@ def _reset():
90
90
ctr .clear ()
91
91
92
92
93
- def _default_name (obj_class ):
93
+ def default_name (obj_class ):
94
94
"""Convert a class name to the registry's default name for the class.
95
95
96
96
Args:
@@ -99,7 +99,6 @@ def _default_name(obj_class):
99
99
Returns:
100
100
The registry's default name for the class.
101
101
"""
102
-
103
102
return _convert_camel_to_snake (obj_class .__name__ )
104
103
105
104
@@ -112,25 +111,25 @@ def default_object_name(obj):
112
111
Returns:
113
112
The registry's default name for the class of the object.
114
113
"""
115
-
116
- return _default_name (obj .__class__ )
114
+ return default_name (obj .__class__ )
117
115
118
116
119
117
def register_model (name = None ):
120
118
"""Register a model. name defaults to class name snake-cased."""
121
119
122
120
def decorator (model_cls , registration_name = None ):
123
121
"""Registers & returns model_cls with registration_name or default name."""
124
- model_name = registration_name or _default_name (model_cls )
122
+ model_name = registration_name or default_name (model_cls )
125
123
if model_name in _MODELS :
126
124
raise LookupError ("Model %s already registered." % model_name )
125
+ model_cls .REGISTERED_NAME = property (lambda _ : model_name )
127
126
_MODELS [model_name ] = model_cls
128
127
return model_cls
129
128
130
129
# Handle if decorator was used without parens
131
130
if callable (name ):
132
131
model_cls = name
133
- return decorator (model_cls , registration_name = _default_name (model_cls ))
132
+ return decorator (model_cls , registration_name = default_name (model_cls ))
134
133
135
134
return lambda model_cls : decorator (model_cls , name )
136
135
@@ -150,7 +149,7 @@ def register_hparams(name=None):
150
149
151
150
def decorator (hp_fn , registration_name = None ):
152
151
"""Registers & returns hp_fn with registration_name or default name."""
153
- hp_name = registration_name or _default_name (hp_fn )
152
+ hp_name = registration_name or default_name (hp_fn )
154
153
if hp_name in _HPARAMS :
155
154
raise LookupError ("HParams set %s already registered." % hp_name )
156
155
_HPARAMS [hp_name ] = hp_fn
@@ -159,7 +158,7 @@ def decorator(hp_fn, registration_name=None):
159
158
# Handle if decorator was used without parens
160
159
if callable (name ):
161
160
hp_fn = name
162
- return decorator (hp_fn , registration_name = _default_name (hp_fn ))
161
+ return decorator (hp_fn , registration_name = default_name (hp_fn ))
163
162
164
163
return lambda hp_fn : decorator (hp_fn , name )
165
164
@@ -182,7 +181,7 @@ def register_ranged_hparams(name=None):
182
181
183
182
def decorator (rhp_fn , registration_name = None ):
184
183
"""Registers & returns hp_fn with registration_name or default name."""
185
- rhp_name = registration_name or _default_name (rhp_fn )
184
+ rhp_name = registration_name or default_name (rhp_fn )
186
185
if rhp_name in _RANGED_HPARAMS :
187
186
raise LookupError ("RangedHParams set %s already registered." % rhp_name )
188
187
# Check that the fn takes a single argument
@@ -197,7 +196,7 @@ def decorator(rhp_fn, registration_name=None):
197
196
# Handle if decorator was used without parens
198
197
if callable (name ):
199
198
rhp_fn = name
200
- return decorator (rhp_fn , registration_name = _default_name (rhp_fn ))
199
+ return decorator (rhp_fn , registration_name = default_name (rhp_fn ))
201
200
202
201
return lambda rhp_fn : decorator (rhp_fn , name )
203
202
@@ -217,7 +216,7 @@ def register_problem(name=None):
217
216
218
217
def decorator (p_cls , registration_name = None ):
219
218
"""Registers & returns p_cls with registration_name or default name."""
220
- p_name = registration_name or _default_name (p_cls )
219
+ p_name = registration_name or default_name (p_cls )
221
220
if p_name in _PROBLEMS :
222
221
raise LookupError ("Problem %s already registered." % p_name )
223
222
@@ -228,7 +227,7 @@ def decorator(p_cls, registration_name=None):
228
227
# Handle if decorator was used without parens
229
228
if callable (name ):
230
229
p_cls = name
231
- return decorator (p_cls , registration_name = _default_name (p_cls ))
230
+ return decorator (p_cls , registration_name = default_name (p_cls ))
232
231
233
232
return lambda p_cls : decorator (p_cls , name )
234
233
@@ -313,7 +312,7 @@ def _internal_register_modality(name, mod_collection, collection_str):
313
312
314
313
def decorator (mod_cls , registration_name = None ):
315
314
"""Registers & returns mod_cls with registration_name or default name."""
316
- mod_name = registration_name or _default_name (mod_cls )
315
+ mod_name = registration_name or default_name (mod_cls )
317
316
if mod_name in mod_collection :
318
317
raise LookupError ("%s modality %s already registered." % (collection_str ,
319
318
mod_name ))
@@ -323,7 +322,7 @@ def decorator(mod_cls, registration_name=None):
323
322
# Handle if decorator was used without parens
324
323
if callable (name ):
325
324
mod_cls = name
326
- return decorator (mod_cls , registration_name = _default_name (mod_cls ))
325
+ return decorator (mod_cls , registration_name = default_name (mod_cls ))
327
326
328
327
return lambda mod_cls : decorator (mod_cls , name )
329
328
0 commit comments