diff --git a/spanner/django/features.py b/spanner/django/features.py index 9d61761e71..0474310730 100644 --- a/spanner/django/features.py +++ b/spanner/django/features.py @@ -61,9 +61,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): # https://github.com/orijtech/spanner-orm/issues/155 'timezones.tests.LegacyDatabaseTests.test_raw_sql', 'timezones.tests.NewDatabaseTests.test_raw_sql', - # implement DatabaseOperations.date_interval_sql() - # https://github.com/orijtech/spanner-orm/issues/184 - 'timezones.tests.NewDatabaseTests.test_update_with_timedelta', # Unable to infer type for parameter: # https://github.com/orijtech/spanner-orm/issues/185 'timezones.tests.LegacyDatabaseTests.test_cursor_execute_returns_naive_datetime', @@ -92,9 +89,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): 'lookup.test_decimalfield.DecimalFieldLookupTests.test_lt', 'lookup.test_decimalfield.DecimalFieldLookupTests.test_lte', 'model_fields.test_decimalfield.DecimalFieldTests.test_filter_with_strings', - # annotating DateTimeField + DurationField crashes: - # https://github.com/orijtech/spanner-orm/issues/202 - 'annotations.tests.NonAggregateAnnotationTestCase.test_mixed_type_annotation_date_interval', # using NULL with + crashes: https://github.com/orijtech/spanner-orm/issues/201 'annotations.tests.NonAggregateAnnotationTestCase.test_combined_annotation_commutative', # Spanner loses DecimalField precision due to conversion to float: diff --git a/spanner/django/operations.py b/spanner/django/operations.py index 10fd9e4d76..3adcc0a887 100644 --- a/spanner/django/operations.py +++ b/spanner/django/operations.py @@ -6,7 +6,9 @@ from django.conf import settings from django.db.backends.base.operations import BaseDatabaseOperations +from django.db.utils import DatabaseError from django.utils import timezone +from django.utils.duration import duration_microseconds from spanner.dbapi.parse_utils import DateStr, TimestampStr, escape_name @@ -173,6 +175,20 @@ def datetime_trunc_sql(self, lookup_type, field_name, tzname): sql = 'TIMESTAMP_ADD(' + sql + ', INTERVAL 1 DAY)' return sql + def date_interval_sql(self, timedelta): + return 'INTERVAL %s MICROSECOND' % duration_microseconds(timedelta) + + def format_for_duration_arithmetic(self, sql): + return 'INTERVAL %s MICROSECOND' % sql + + def combine_duration_expression(self, connector, sub_expressions): + if connector == '+': + return 'TIMESTAMP_ADD(' + ', '.join(sub_expressions) + ')' + elif connector == '-': + return 'TIMESTAMP_SUB(' + ', '.join(sub_expressions) + ')' + else: + raise DatabaseError('Invalid connector for timedelta: %s.' % connector) + def lookup_cast(self, lookup_type, internal_type=None): # Cast text lookups to string to allow things like filter(x__contains=4) if lookup_type in ('contains', 'icontains', 'startswith', 'istartswith',