Skip to content

Commit

Permalink
shift to sqlalchemy 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
selgamal committed Feb 10, 2024
1 parent 28b3450 commit fc00dbd
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 220 deletions.
2 changes: 1 addition & 1 deletion arelle-stubs/arelle/Cntlr.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ from arelle.Locale import getLanguageCodes as getLanguageCodes, setDisableRTL as
from arelle.ModelXbrl import ModelXbrl as ModelXbrl
from arelle.WebCache import WebCache as WebCache
from arelle.typing import TypeGetText as TypeGetText
from typing import Any, TextIO
from typing import Any, TextIO, Union

osPrcs: Incomplete
LOG_TEXT_MAX_LENGTH: int
Expand Down
4 changes: 2 additions & 2 deletions src/xbrlreportsindexes/cmd_line_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,8 @@ def do_search_tasks(
),
title="Query result",
)
assert isinstance(file, tuple)
return file[0]
assert isinstance(file, tuple)
return file[0]


def list_industry_tree(
Expand Down
102 changes: 54 additions & 48 deletions src/xbrlreportsindexes/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import Union

from sqlalchemy import BOOLEAN
from sqlalchemy import CHAR
from sqlalchemy import Column
from sqlalchemy import false
from sqlalchemy import ForeignKey
Expand All @@ -21,12 +20,19 @@
from sqlalchemy import text
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import declarative_mixin
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import declared_attr
from sqlalchemy.orm import Mapped
from sqlalchemy.sql import TableClause
from sqlalchemy.sql.base import ImmutableColumnCollection
from sqlalchemy.types import NullType
from xbrlreportsindexes.model import types_mapping
from sqlalchemy.types import (
BigInteger,
DateTime,
Integer,
Numeric,
SmallInteger,
String,
)

find_caps: Pattern[str] = re.compile(r"[a-zA-Z][^A-Z]*")

Expand Down Expand Up @@ -55,8 +61,8 @@ def __tablename__(self) -> str:
@classmethod
def cols_names(cls) -> list[str]:
"""Get column names"""
table: Table = getattr(cls, "__table__", Table())
columns: ImmutableColumnCollection[Any] = table.columns
table: Table = getattr(cls, "__table__", Table(cls.__tablename__, cls.__table__.metadata))
columns: Column[Any] = table.columns
return columns.keys()

def to_dict(self) -> dict[str, Any]:
Expand All @@ -83,16 +89,16 @@ def __tablename__(self) -> str:
def feed_id(self) -> Mapped[int]:
"""Creates feed_id column"""
if self.__tablename__ != "sec_feed":
return Column(
return mapped_column(
BigInteger(),
ForeignKey(
"sec_feed.feed_id", onupdate="RESTRICT", ondelete="CASCADE"
),
types_mapping.Bigint_type,
autoincrement=False,
nullable=False,
)
return Column(
types_mapping.Bigint_type,
BigInteger(),
primary_key=True,
autoincrement=False,
nullable=False,
Expand All @@ -111,18 +117,18 @@ def __tablename__(self) -> str:
def filing_id(self) -> Mapped[int]:
"""Create filing id column"""
if self.__tablename__ != "sec_filing":
return Column(
return mapped_column(
BigInteger(),
ForeignKey(
"sec_filing.filing_id",
onupdate="RESTRICT",
ondelete="CASCADE",
),
types_mapping.Bigint_type,
autoincrement=False,
nullable=False,
)
return Column(
types_mapping.Bigint_type,
BigInteger(),
primary_key=True,
autoincrement=False,
nullable=False,
Expand All @@ -137,7 +143,7 @@ class CreatedUpdatedAtColMixin:
def created_updated_at(self) -> Mapped[datetime.datetime]:
"""Creates created at col"""
return Column(
types_mapping.Timestamptz_type,
DateTime(timezone=True),
server_default=func.CURRENT_TIMESTAMP(),
onupdate=datetime.datetime.now,
)
Expand All @@ -151,7 +157,7 @@ class LogTablesMixin:
def log_id(self) -> Mapped[int]:
"""log id column"""
return Column(
types_mapping.Bigint_type,
BigInteger().with_variant(Integer, "sqlite"),
nullable=False,
primary_key=True,
autoincrement=True,
Expand All @@ -160,72 +166,72 @@ def log_id(self) -> Mapped[int]:
@declared_attr
def timestamp_at(self) -> Mapped[datetime.datetime]:
"""timestamp column"""
return Column(types_mapping.Timestamptz_type, nullable=False)
return Column(DateTime(timezone=True), nullable=False)

@declared_attr
def task(self) -> Mapped[str]:
"""Task column"""
return Column(types_mapping.Text_type, nullable=True)
return Column(String(), nullable=True)

@declared_attr
def time_taken(self) -> Mapped[float]:
"""time taken column"""
return Column(types_mapping.Float_type, nullable=True)
return Column(Numeric(), nullable=True)


class Location(Base, CreatedUpdatedAtColMixin):
"""Locations based on SEC/EDGAR coding, in addition to alpha 2 and 3."""

__table_args__ = {"comment": "all"}
# columns
code = Column(types_mapping.Text_type, nullable=False, primary_key=True)
latitude = Column(types_mapping.Numeric_type, nullable=True)
longitude = Column(types_mapping.Numeric_type, nullable=True)
country = Column(types_mapping.Text_type, nullable=True)
alpha_2 = Column(CHAR(2), nullable=True)
alpha_3 = Column(CHAR(3), nullable=True)
numeric = Column(CHAR(3), nullable=True)
state_province = Column(types_mapping.Text_type, nullable=True)
location_fix = Column(types_mapping.Text_type, nullable=True)
code = Column(String(), nullable=False, primary_key=True)
latitude = Column(Numeric(), nullable=True)
longitude = Column(Numeric(), nullable=True)
country = Column(String(), nullable=True)
alpha_2 = Column(String(2), nullable=True)
alpha_3 = Column(String(3), nullable=True)
numeric = Column(String(3), nullable=True)
state_province = Column(String(), nullable=True)
location_fix = Column(String(), nullable=True)


class ProcessingLog(Base, LogTablesMixin):
"""Stores log messages produced during processing"""

__table_args__ = {"comment": "all"}
message = Column(types_mapping.Text_type, nullable=True)
subject = Column(types_mapping.Text_type, nullable=True)
task_id = Column(types_mapping.Bigint_type, nullable=True)
message = Column(String(), nullable=True)
subject = Column(String(), nullable=True)
task_id = Column(BigInteger(), nullable=True)


class ActionLog(Base, LogTablesMixin):
"""Tracks tables bulk inserts and deletes"""

__table_args__ = {"comment": "all"}
feed_id = Column(
types_mapping.Bigint_type,
BigInteger(),
nullable=False,
server_default=text("999999"),
)
action = Column(types_mapping.Text_type, nullable=True)
table_name = Column(types_mapping.Text_type, nullable=True)
rowcount = Column(types_mapping.Integer_type, nullable=True)
action = Column(String(), nullable=True)
table_name = Column(String(), nullable=True)
rowcount = Column(Integer(), nullable=True)
is_committed = Column(BOOLEAN(), nullable=True)
task_id = Column(types_mapping.Bigint_type, nullable=True)
task_id = Column(BigInteger(), nullable=True)


class LastUpdate(Base, CreatedUpdatedAtColMixin):
"""Last database update date time"""

__table_args__ = {"comment": "all"}
id = Column(
types_mapping.Bigint_type,
BigInteger().with_variant(Integer, "sqlite"),
primary_key=True,
nullable=False,
autoincrement=True,
)
task = Column(types_mapping.Text_type, nullable=False)
last_updated = Column(types_mapping.Timestamptz_type, nullable=False)
task = Column(String(), nullable=False)
last_updated = Column(DateTime(timezone=True), nullable=False)


class TaskTracker(Base, CreatedUpdatedAtColMixin):
Expand All @@ -235,38 +241,38 @@ class TaskTracker(Base, CreatedUpdatedAtColMixin):

__table_args__ = {"comment": "all"}
task_id = Column(
types_mapping.Bigint_type,
BigInteger().with_variant(Integer, "sqlite"),
primary_key=True,
nullable=False,
autoincrement=True,
)
process_id = Column(types_mapping.Bigint_type)
process_id = Column(BigInteger())
# one of initialize-db, update-feeds, update-filers, refresh-tables
task_name = Column(types_mapping.Text_type, nullable=False)
task_name = Column(String(), nullable=False)
# is properly closed and everything handled
is_closed = Column(BOOLEAN, nullable=False, server_default=false())
is_completed = Column(BOOLEAN, nullable=False, server_default=false())
is_interrupted = Column(BOOLEAN, nullable=False, server_default=false())
task_parameters = Column(types_mapping.Text_type, nullable=False)
task_parameters = Column(String(), nullable=False)
started_at = Column(
types_mapping.Timestamptz_type,
DateTime(timezone=True),
nullable=True,
default=datetime.datetime.now,
)
ended_at: datetime.datetime | Column[NullType] = Column(
types_mapping.Timestamptz_type
DateTime(timezone=True)
)
time_taken = Column(types_mapping.Float_type)
time_taken = Column(Numeric())
total_items: Mapped[int] = Column(
types_mapping.Integer_type, default=text("0")
Integer(), default=text("0")
)
completed_items: Mapped[int] = Column(
types_mapping.Integer_type, default=text("0")
Integer(), default=text("0")
)
successful_items: Mapped[int] = Column(
types_mapping.Integer_type, default=text("0")
Integer(), default=text("0")
)
failed_items: Mapped[int] = Column(
types_mapping.Integer_type, default=text("0")
Integer(), default=text("0")
)
task_notes: str | Column[NullType] = Column(types_mapping.Text_type)
task_notes: Column[NullType] = Column(String())
Loading

0 comments on commit fc00dbd

Please sign in to comment.