Skip to content

Commit a24369f

Browse files
committed
🌟 feat: Category and Payment API
1 parent a82ba63 commit a24369f

File tree

10 files changed

+329
-55
lines changed

10 files changed

+329
-55
lines changed

‎k_backend/db.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from alembic import command
44
from alembic.config import Config
55
from loguru import logger
6-
from sqlmodel import create_engine
6+
from sqlmodel import Session, create_engine
77

88
try:
99
POSTGRES_USER = os.environ["POSTGRES_USER"]
@@ -23,3 +23,8 @@ def alembic_upgrade():
2323
alembic_cfg = Config("alembic.ini")
2424
command.upgrade(alembic_cfg, "head")
2525
logger.info("Alembic upgrade completed.")
26+
27+
28+
def get_session():
29+
with Session(engine) as session:
30+
yield session

‎k_backend/routers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,7 @@
99
auth_router,
1010
account_router,
1111
currency_router,
12+
category_router,
13+
payment_router,
1214
invoice_router,
1315
]

‎k_backend/routers/category.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from fastapi import APIRouter, Depends, HTTPException
2+
from sqlmodel import Session, select
3+
4+
from ..auth import get_client
5+
from ..db import engine
6+
from ..schemas.category import Category, CategoryCreate, CategoryRead
7+
8+
TAG_NAME = "Category"
9+
tag = {
10+
"name": TAG_NAME,
11+
"description": "Create and manage payment entry categories",
12+
}
13+
14+
category_router = APIRouter(
15+
prefix="/category",
16+
tags=[TAG_NAME],
17+
dependencies=[Depends(get_client)],
18+
responses={404: {"description": "Not found"}},
19+
)
20+
21+
22+
@category_router.post("", response_model=CategoryRead, tags=[TAG_NAME])
23+
def create_category(category: CategoryCreate):
24+
with Session(engine) as session:
25+
db_category = Category.from_orm(category)
26+
session.add(db_category)
27+
session.commit()
28+
session.refresh(db_category)
29+
return db_category
30+
31+
32+
@category_router.get("", response_model=list[CategoryRead], tags=[TAG_NAME])
33+
def read_categories():
34+
with Session(engine) as session:
35+
categories = session.exec(select(Category)).all()
36+
return categories
37+
38+
39+
@category_router.patch("", response_model=CategoryRead, tags=[TAG_NAME])
40+
def update_category(category: Category):
41+
with Session(engine) as session:
42+
session.merge(category)
43+
session.commit()
44+
session.refresh(category)
45+
return category
46+
47+
48+
# TODO: Think about how this should work
49+
# @category_router.delete("/{id}", tags=[TAG_NAME])
50+
# def delete_category(id: int):
51+
# with Session(engine) as session:
52+
# category = session.query(Category).get(id)
53+
# if category is None:
54+
# raise HTTPException(status_code=404, detail="Category not found")
55+
# session.delete(category)
56+
# session.commit()

‎k_backend/routers/payment.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from fastapi import APIRouter, Depends, HTTPException
2+
from sqlmodel import Session, select
3+
4+
from ..auth import get_client
5+
from ..db import engine, get_session
6+
from ..schemas.payment import (
7+
Payment,
8+
PaymentCreate,
9+
PaymentEntry,
10+
PaymentEntryCreate,
11+
PaymentRead,
12+
PaymentReadWithEntries,
13+
)
14+
15+
TAG_NAME = "Payment"
16+
tag = {
17+
"name": TAG_NAME,
18+
"description": "Create and edit payment records",
19+
}
20+
21+
payment_router = APIRouter(
22+
prefix="/payment",
23+
tags=[TAG_NAME],
24+
dependencies=[Depends(get_client)],
25+
responses={404: {"description": "Not found"}},
26+
)
27+
28+
29+
@payment_router.post(
30+
"",
31+
response_model=PaymentReadWithEntries,
32+
tags=[TAG_NAME],
33+
openapi_extra={
34+
"requestBody": {
35+
"content": {
36+
"application/json": {
37+
"examples": {
38+
"Taipei with 3 entries": {
39+
"summary": "Taipei with 3 entries",
40+
"value": {
41+
"payment": {
42+
"timestamp": "2022-09-08T08:07:08.000",
43+
"timezone": "Asia/Taipei",
44+
"description": "Some payment description",
45+
},
46+
"entries": [
47+
{
48+
"category_id": 1,
49+
"amount": 20,
50+
"quantity": 2,
51+
"description": "First entry",
52+
},
53+
{
54+
"category_id": 2,
55+
"amount": 30,
56+
"quantity": 2,
57+
"description": "Second entry",
58+
},
59+
{
60+
"category_id": 3,
61+
"amount": 10,
62+
"quantity": 1,
63+
"description": "Third entry",
64+
},
65+
],
66+
},
67+
}
68+
}
69+
}
70+
}
71+
}
72+
},
73+
)
74+
def create_payment(
75+
*,
76+
session: Session = Depends(get_session),
77+
payment: PaymentCreate,
78+
entries: list[PaymentEntryCreate]
79+
):
80+
# Calculate total
81+
payment.total = sum([entry.amount * entry.quantity for entry in entries])
82+
83+
# Store payment
84+
db_payment = Payment.from_orm(payment)
85+
session.add(db_payment)
86+
session.commit()
87+
session.refresh(db_payment)
88+
89+
# Store entries
90+
for entry in entries:
91+
entry.payment_id = db_payment.id
92+
db_entry = PaymentEntry.from_orm(entry)
93+
session.add(db_entry)
94+
session.commit()
95+
96+
new_payment = session.get(Payment, db_payment.id)
97+
return new_payment
98+
99+
100+
@payment_router.get("", response_model=list[PaymentReadWithEntries], tags=[TAG_NAME])
101+
def read_payments(*, session: Session = Depends(get_session)):
102+
payments = session.exec(select(Payment)).all()
103+
return payments
104+
105+
106+
@payment_router.patch("", response_model=PaymentRead, tags=[TAG_NAME])
107+
def update_payment(*, session: Session = Depends(get_session), payment: Payment):
108+
session.merge(payment)
109+
session.commit()
110+
session.refresh(payment)
111+
return payment
112+
113+
114+
@payment_router.get("/{id}", response_model=PaymentReadWithEntries, tags=[TAG_NAME])
115+
def read_payment(*, session: Session = Depends(get_session), id: int):
116+
payment = session.query(Payment).get(id)
117+
return payment
118+
119+
120+
@payment_router.delete("/{id}", tags=[TAG_NAME])
121+
def delete_payment(*, session: Session = Depends(get_session), id: int):
122+
payment = session.query(Payment).get(id)
123+
if payment is None:
124+
raise HTTPException(status_code=404, detail="Payment not found")
125+
session.delete(payment)
126+
session.commit()

‎k_backend/schemas/_custom_types.py

+25
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,28 @@ def validate(cls, v):
3636
except ZoneInfoNotFoundError:
3737
raise ValueError(f"{v} is not a valid timezone")
3838
raise TypeError("zoneinfo.ZoneInfo or str required")
39+
40+
41+
def create_timestamp_validator(values):
42+
try:
43+
timestamp = values["timestamp"]
44+
timezone = values["timezone"]
45+
if timestamp.tzinfo is None:
46+
values["timestamp"] = timestamp.replace(tzinfo=timezone)
47+
return values
48+
stamp_offset = timestamp.tzinfo.utcoffset(timestamp)
49+
zone_offset = timezone.utcoffset(timestamp)
50+
if stamp_offset != zone_offset:
51+
raise ValueError(
52+
f"Inconsistent timestamp offset ({stamp_offset}) and timezone offset ({zone_offset})"
53+
)
54+
except KeyError:
55+
raise ValueError("Missing timestamp or timezone")
56+
return values
57+
58+
59+
def tz_timestamp_reader(values):
60+
timezone = values.get("timezone")
61+
if timezone is not None:
62+
values["timestamp"] = values["timestamp"].astimezone(timezone)
63+
return values

‎k_backend/schemas/account.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
class AccountBase(SQLModel):
77
name: str
88
currency_code: str = Field(foreign_key="currency.code", nullable=False)
9-
currency: str = Relationship(back_populates="accounts")
10-
transactions: List["Transaction"] = Relationship(back_populates="account")
119

1210

1311
class Account(AccountBase, table=True):
1412
__tablename__ = "account"
1513
id: Optional[int] = Field(primary_key=True, nullable=False)
14+
currency: "Currency" = Relationship(back_populates="accounts")
15+
transactions: List["Transaction"] = Relationship(back_populates="account")
1616

1717

1818
class AccountCreate(AccountBase):

‎k_backend/schemas/category.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
class CategoryBase(SQLModel):
77
name: str
88
description: str
9-
entries: List["PaymentEntry"] = Relationship(back_populates="category")
109

1110

1211
class Category(CategoryBase, table=True):
1312
__tablename__ = "category"
1413
id: Optional[int] = Field(primary_key=True, nullable=False)
14+
entries: List["PaymentEntry"] = Relationship(back_populates="category")
1515

1616

1717
class CategoryCreate(CategoryBase):

‎k_backend/schemas/payment.py

+71-11
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,65 @@
11
from datetime import datetime
22
from typing import List, Optional
33

4+
from pydantic import root_validator
45
from sqlmodel import Column, DateTime, Field, Relationship, SQLModel
56

6-
from ._custom_types import PydanticTimezone, SATimezone
7+
from ._custom_types import (
8+
EXTENDED_JSON_ENCODERS,
9+
PydanticTimezone,
10+
SATimezone,
11+
create_timestamp_validator,
12+
tz_timestamp_reader,
13+
)
714

15+
#
16+
# Payment
17+
#
818

9-
class Payment(SQLModel, table=True):
10-
__tablename__ = "payment"
11-
id: Optional[int] = Field(primary_key=True, nullable=False)
12-
total: float
19+
20+
class PaymentBase(SQLModel):
1321
timestamp: datetime = Field(
14-
sa_column=Column(DateTime(timezone=True)), nullable=False
22+
sa_column=Column(DateTime(timezone=True), nullable=False),
23+
nullable=False,
24+
title="Local timestamp, or timezone-aware timestamp",
25+
)
26+
timezone: PydanticTimezone = Field(
27+
sa_column=Column(SATimezone(), nullable=False), nullable=False
1528
)
16-
timezone: PydanticTimezone = Field(sa_column=Column(SATimezone()), nullable=False)
1729
description: str
30+
31+
32+
class Payment(PaymentBase, table=True):
33+
__tablename__ = "payment"
34+
id: Optional[int] = Field(primary_key=True, nullable=False)
35+
total: float
1836
transactions: List["Transaction"] = Relationship(back_populates="payment")
37+
entries: List["PaymentEntry"] = Relationship(back_populates="payment")
38+
39+
40+
class PaymentCreate(PaymentBase):
41+
total: Optional[float]
42+
43+
@root_validator
44+
def verify_timezone(cls, values):
45+
return create_timestamp_validator(values)
46+
47+
48+
class PaymentRead(PaymentBase):
49+
id: int
50+
total: float
51+
52+
@root_validator
53+
def convert_timezone(cls, values):
54+
return tz_timestamp_reader(values)
55+
56+
class Config:
57+
json_encoders = EXTENDED_JSON_ENCODERS
58+
59+
60+
#
61+
# PaymentEntry
62+
#
1963

2064

2165
class PaymentEntryBase(SQLModel):
@@ -24,14 +68,14 @@ class PaymentEntryBase(SQLModel):
2468
amount: float
2569
quantity: int
2670
description: str
27-
payment: Payment = Relationship(back_populates="entries")
28-
category: "Category" = Relationship(back_populates="entries")
2971

3072

3173
class PaymentEntry(PaymentEntryBase, table=True):
3274
__tablename__ = "payment_entry"
3375
id: Optional[int] = Field(primary_key=True, nullable=False)
3476
payment_id: int = Field(foreign_key="payment.id", nullable=False)
77+
payment: Payment = Relationship(back_populates="entries")
78+
category: "Category" = Relationship(back_populates="entries")
3579

3680

3781
class PaymentEntryCreate(PaymentEntryBase):
@@ -43,19 +87,35 @@ class PaymentEntryRead(PaymentEntryBase):
4387
payment_id: int
4488

4589

90+
#
91+
# Transaction
92+
#
93+
94+
4695
class Transaction(SQLModel, table=True):
4796

4897
__tablename__ = "transaction"
4998
account_id: int = Field(primary_key=True, foreign_key="account.id", nullable=False)
5099
payment_id: int = Field(primary_key=True, foreign_key="payment.id", nullable=False)
51100
amount: float
52101
timestamp: datetime = Field(
53-
sa_column=Column(DateTime(timezone=True)), nullable=False
102+
sa_column=Column(DateTime(timezone=True), nullable=False), nullable=False
103+
)
104+
timezone: PydanticTimezone = Field(
105+
sa_column=Column(SATimezone(), nullable=False), nullable=False
54106
)
55-
timezone: PydanticTimezone = Field(sa_column=Column(SATimezone()), nullable=False)
56107
account: "Account" = Relationship(back_populates="transactions")
57108
payment: Payment = Relationship(back_populates="transactions")
58109

59110

111+
#
112+
# Relationship Models
113+
#
114+
115+
116+
class PaymentReadWithEntries(PaymentRead):
117+
entries: List[PaymentEntryRead] = []
118+
119+
60120
from .account import Account
61121
from .category import Category

0 commit comments

Comments
 (0)