米国株売買シミュレーションで学ぶPythonプログラミング - 第5回 SQLAlchemyの使い方

米国株売買シミュレーションで学ぶPythonプログラミング」の5回目です。この回ではSQLAlchemyの使い方を紹介します。SQLAlchemyはPythonで使えるO/Rマッパーで、オブジェクトを操作するようにしてSQLを実行できます。

また、withを使ってSQLAlchemyのトランザクションを管理する方法や、作成するテーブル名を動的に決定する方法も紹介します。

今回の例ではDBにSQLite3を使い、保存先はメモリとします。

※当連載で扱うアプリのソースはGitHub上に公開してあります。 https://github.com/yusukemurayama/ppytrading

SQLAlchemyの基本的な使い方

SQLAlchemyの基本的な使い方を説明します。モデルのフィールドがDBのカラムに対応します。

モデルの作成

SQLAlchemyを使うためにモデルを定義します。このモデルとDBのテーブルが一対一の関係になります。また、モデルのフィールドに対応した、カラムがテーブルに追加されます。

モデルを定義するために、まずは継承元のクラスを取得します。

Base = declarative_base()

そして、このクラスを継承してモデルを定義していきます。まずは、Stockモデルを定義します。

class Stock(Base):
    __tablename__ = 'stock'
    id = Column(Integer, primary_key=True)
    name = Column(String(length=64), nullable=False)
    symbol = Column(String(length=32), nullable=False, unique=True)
    created_at = Column(DateTime, default=datetime.now())
    updated_at = Column(DateTime, default=datetime.now(), onupdate=datetime.now())
    histories = relationship("History", backref='stock', cascade='delete')

    @reconstructor
    def initialize(self):
        pass

上記クラスで、「stock」という名前のテーブルとマッピングします。また、そのテーブルには「id」、「name」、「symbol」、「created_at」、「updated_at」カラムが追加されます。

「id」はInteger型のPrimary Keyで、自動採番されます。(AUTO_INCREMENT Behavior

「created_at」と「updated_at」はレコード作成時に現在日時が自動的に設定されます。また、「updated_at」はレコード更新時にも自動的に上書きされます。

「histories」は次に定義する「History」との関係を表しているだけで、カラムは作成されません。「Stock」クラスのインスタンスからstock.hitoriesで関連付いている「History」クラスのインスタンスのリストを取得することができます。また、「History」からも「stock」で、「Stock」クラスのインスタンスを取得できるようにしています。

「initialize」メソッドはコンストラクタです。何かしら初期化が必要な場合は__init__.pyではなく、reconstructorデコレータをつけたメソッドに記述します。

続けて「History」クラスを定義します。

class History(Base):
    __tablename__ = 'history'
    stock_id = Column(Integer, ForeignKey(Stock.id), primary_key=True)
    date = Column(Date, primary_key=True)
    price = Column(Float, nullable=False)
    volume = Column(Integer, nullable=False)

上記クラスで、「history」というテーブルで、「stock_id」、「date」、「price」、「volume」というカラムがあるテーブルとマッピングするようになります。

「stock_id」は「Stock」クラスと「History」クラスとが、一対多の関係になるようにForeign Keyを設定しています。

モデルを作成したら、create_allメソッドでテーブルを作成することができます。テーブル作成時には、DSNを設定したengineオブジェクトを渡します。下記の例ではSqlite3を使い、データはメモリに保存されます。

engine = create_engine('sqlite://', echo=True)
Base.metadata.create_all(engine, checkfirst=True)

このコードを実行すると、stock・historyテーブルが作成されます。

sqlite> .table
sqlite> .schema stock
CREATE TABLE stock (
    id INTEGER NOT NULL,
    name VARCHAR(64) NOT NULL,
    symbol VARCHAR(32) NOT NULL,
    created_at DATETIME,
    updated_at DATETIME,
    PRIMARY KEY (id),
    UNIQUE (symbol)
);
sqlite> .schema history
CREATE TABLE history (
    stock_id INTEGER NOT NULL,
    date DATE NOT NULL,
    price FLOAT NOT NULL,
    volume INTEGER NOT NULL,
    PRIMARY KEY (stock_id, date),
    FOREIGN KEY(stock_id) REFERENCES stock (id)
);

sessionインスタンスの生成

モデルを定義したら、次はSessionインスタンスの生成です。このインスタンスを通じてSELECT文などを実行できるようになります。

Session = sessionmaker(bind=engine, autocommit=False)
session = Session()

SQLを実行

sessionインスタンスを生成できるようになったので、実際にSQLを実行できるようになりました。

※モデルを定義したファイルを「sqlalchemy_samples」とします。

# coding: utf-8
from datetime import date
from sqlalchemy_samples import Session, Stock, History

session = Session()
# autocommit=Falseなので、自動的にトランザクションが開始されます。
# session.begin()

# stockテーブルにレコードを追加します。
stock = Stock()
stock.name = 'Test'
stock.symbol = 'test'
session.add(stock)

# SQLを実行してstock.idを確定させます。
session.flush()
session.refresh(stock)
stock_id = stock.id

# historyテーブルにレコードを追加します。
hist = History()
hist.stock_id = stock_id
hist.date = date.today()
hist.price = 100.12
hist.volume = 1000
session.add(hist)
session.commit()  # COMMITします。
session.close()

# SELECT文を実行します。
session = Session()

# 全件取得します。
stocks = session.query(Stock).all()
print('stocks: {}'.format(stocks))

# filterで絞り込めます。
stocks = session.query(Stock).filter(Stock.name == 'Test')
print('stocks(filter): {}'.format(list(stocks)))

# filter_byでも絞り込めます。
stocks = session.query(Stock).filter_by(name='Test')
print('stocks(filter_by): {}'.format(list(stocks)))

# 「.」で繋いでプロパティにアクセスすることができます。
print('name: {}'.format(stocks[0].name))  # name: Test
print('symbol: {}'.format(stocks[0].symbol))  # symbol: test

# historiesで紐づくhistoryを取得できます。
histories = stocks[0].histories
print('histories: {}'.format(histories))

# COUNTやEXISTSも実行できます。
print('count: {}'.format(session.query(Stock).filter_by(name='Test').count()))  # count: 1
query = session.query(Stock).filter_by(name='Test')
print('exists: {}'.format(session.query(query.exists()).scalar()))  # exists: True

session.close()

# stockテーブルのレコードを更新します。
session = Session()
stock = session.query(Stock).get(stock_id)
stock.name = 'Mod Test'
session.commit()
session.close()

# 更新されたことを確認します。
session = Session()
print('count(Test): {}'.format(session.query(Stock).filter_by(name='Test').count()))  # count: 0
print('count(Mod Test): {}'.format(session.query(Stock).filter_by(name='Mod Test').count()))  # count: 1
session.close()

# stockテーブルのレコードを削除します。
session = Session()
stock = session.query(Stock).get(stock_id)
session.delete(stock)  # レコードを削除します。
print('count(stocks): {}'.format(session.query(Stock).count()))  # count(stocks): 0

# historyテーブルのレコードも削除されていることを確認します。
# ※ relationshipに「cascade='delete'」が指定されているから削除されます。
print('count(histories) {}'.format(session.query(History).count()))  # count(histories): 0
session.commit()  # COMMITします。

SQLAlchemyの便利な使い方

トランザクションを管理

SQLAlchemyを使っていくと、トランザクションの開始(Session()session.close()など)が面倒になってきます。また、closeの書き忘れのような、トラブルとなりかねないことも起こりえます。これらはwithを使ってトランザクションを開始できるようにすると解決します。

まずは、モデルを定義したファイルにstart_session関数を追加します。

...
from contextlib import contextmanager

...

@contextmanager
def start_session(commit=False):
    session = None
    try:
        # トランザクションを開始します。
        # ※autocommit=Falseなので、自動的にトランザクションが開始されます。
        session = Session()
        try:
            yield session
            if commit:
                session.commit()
        except:
            # 例外発生時はトランザクションをロールバックして、その例外をそのまま投げます。
            session.rollback()
            raise
    finally:
        if session is not None:
            session.close()

後は、このstart_sessionとwithを使って以下のようにSQLを実行できます。

# coding: utf-8
from datetime import date
from sqlalchemy_samples import start_session, Stock


# commit=Trueでトランザクションを開始します。
with start_session(commit=True) as session:
    stock = Stock()
    stock.name = 'Test'
    stock.symbol = 'test'
    stock.date = date.today()
    session.add(stock)

# commit=Falseでトランザクションを開始します。
with start_session() as session:
    stock = session.query(Stock).all()[0]
    print('name: {}'.format(stock.name))  # name: Test

テーブル名を動的に決定

テーブル名を動的に決定する方法を紹介します。まずは、カラム(に対応したフィールド)を定義したMixinを作成します。

class DynamicTableMixin(object):
    id = Column(Integer, primary_key=True)

    @declared_attr
    def stock_id(cls):
        return Column(Integer, ForeignKey('stock.id'))

このクラスで作成されるテーブルは、「id」と「stock_id」という2つのカラムを持つようにしてあります。

「stock_id」はForeign Keyで、stockと動的なテーブルで一対多になるようにしてあります。

declared_attrは、Foreign KeyなどをMixinで定義するときに付ける必要があるデコレータです。もし、「stock_id」をHistoryクラスのようにフィールドとしてを定義すると、以下のようなエラーが出力されます。

sqlalchemy.exc.InvalidRequestError: Columns with foreign keys to other columns must be declared as @declared_attr callables on declarative mixin classes.

Mixinを定義したら、それを継承してモデルを作成します。

from random import randint
from sqlalchemy_samples import engine, Base, DynamicTableMixin

tablename = 'new_table_{}'.format(randint(10, 99))  # テーブル名を決定します。


class NewTable(DynamicTableMixin, Base):
    __tablename__ = tablename

Base.metadata.create_all(engine)  # new_tableを作成します。

これで「new_table_xx」という名前が作成されます。このテーブル名の「xx」の部分は10から99までの数字がランダムで入ります。

また、type関数を使って新しいクラスを定義しても、テーブルを作成することができます。

# coding: utf-8
from random import randint
from sqlalchemy_samples import engine, Base, DynamicTableMixin

tablename = 'new_table_{}'.format(randint(10, 99))  # テーブル名を決定します。

klass = type('NewTable', (DynamicTableMixin, Base), {'__tablename__': tablename})

Base.metadata.create_all(engine)  # new_tableを作成します。

モデルのソースとテストケース

モデルの定義周りと、基本的なクエリ実行のテストケースを掲載します。

この記事が役に立った場合、シェアしていただけると励みになります!!

この記事に関する質問は@ysk_murayamaでご連絡ください。可能な内容であれば回答します!