-
Notifications
You must be signed in to change notification settings - Fork 52
/
Copy pathmenstrual_predictor_utils.py
119 lines (93 loc) · 3.05 KB
/
menstrual_predictor_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import datetime
from loguru import logger
from typing import List
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from app.database.models import Event, UserMenstrualPeriodLength
from app.routers.event import create_event
MENSTRUAL_PERIOD_CATEGORY_ID = 111
def get_avg_period_gap(db: Session, user_id: int):
period_days = get_all_period_days(db, user_id)
gaps_list = []
for i in range(len(period_days) - 1):
gap = get_date_diff(period_days[i].start, period_days[i + 1].start)
gaps_list.append(gap.days)
return get_list_avg(gaps_list)
def get_date_diff(date_1: datetime, date_2: datetime):
return date_2 - date_1
def get_list_avg(received_list: List):
return sum(received_list) // len(received_list)
def remove_existing_period_dates(db: Session, user_id: int):
(
db.query(Event)
.filter(Event.owner_id == user_id)
.filter(Event.category_id == MENSTRUAL_PERIOD_CATEGORY_ID)
.filter(Event.start > datetime.datetime.now())
.delete()
)
db.commit()
logger.info("Removed all period predictions to create new ones")
def generate_predicted_period_dates(
db: Session,
period_length: str,
period_start_date: datetime,
user_id: int,
):
delta = datetime.timedelta(int(period_length))
period_end_date = period_start_date + delta
period_event = create_event(
db,
"period",
period_start_date,
period_end_date,
user_id,
category_id=MENSTRUAL_PERIOD_CATEGORY_ID,
)
return period_event
def add_3_month_predictions(
db: Session,
period_length: str,
period_start_date: datetime,
user_id: int,
):
avg_gap = get_avg_period_gap(db, user_id)
avg_gap_delta = datetime.timedelta(avg_gap)
generated_3_months = []
for _ in range(4):
generated_period = generate_predicted_period_dates(
db,
period_length,
period_start_date,
user_id,
)
generated_3_months.append(generated_period)
period_start_date += avg_gap_delta
logger.info(f"Generated predictions: {generated_3_months}")
return generated_3_months
def get_all_period_days(session: Session, user_id: int) -> List[Event]:
"""Returns all period days filtered by user id."""
try:
period_days = sorted(
(
session.query(Event)
.filter(Event.owner_id == user_id)
.filter(Event.category_id == MENSTRUAL_PERIOD_CATEGORY_ID)
.all()
),
key=lambda d: d.start,
)
except SQLAlchemyError as err:
logger.exception(err)
return []
else:
return period_days
def is_user_signed_up_to_menstrual_predictor(session: Session, user_id: int):
user_menstrual_period_length = (
session.query(UserMenstrualPeriodLength)
.filter(user_id == user_id)
.first()
)
if user_menstrual_period_length:
return user_menstrual_period_length.period_length
else:
return False