「米国株売買シミュレーションで学ぶPythonプログラミング」の5回目です。この回ではSQLAlchemyの使い方を紹介します。SQLAlchemyはPythonで使えるO/Rマッパーで、オブジェクトを操作するようにしてSQLを実行できます。
また、withを使ってSQLAlchemyのトランザクションを管理する方法や、作成するテーブル名を動的に決定する方法も紹介します。
今回の例ではDBにSQLite3を使い、保存先はメモリとします。
※当連載で扱うアプリのソースはGitHub上に公開してあります。 https://github.com/yusukemurayama/ppytrading
SQLAlchemyの基本的な使い方
SQLAlchemyの基本的な使い方を説明します。モデルのフィールドがDBのカラムに対応します。
モデルの作成
SQLAlchemyを使うためにモデルを定義します。このモデルとDBのテーブルが一対一の関係になります。また、モデルのフィールドに対応した、カラムがテーブルに追加されます。
モデルを定義するために、まずは継承元のクラスを取得します。
Base = declarative_base()
そして、このクラスを継承してモデルを定義していきます。まずは、Stockモデルを定義します。
class Stock(Base):
__tablename__ = 'stock'
id = Column(Integer, primary_key=True)
name = Column(String(length=64), nullable=False)
symbol = Column(String(length=32), nullable=False, unique=True)
created_at = Column(DateTime, default=datetime.now())
updated_at = Column(DateTime, default=datetime.now(), onupdate=datetime.now())
histories = relationship("History", backref='stock', cascade='delete')
@reconstructor
def initialize(self):
pass
上記クラスで、「stock」という名前のテーブルとマッピングします。また、そのテーブルには「id」、「name」、「symbol」、「created_at」、「updated_at」カラムが追加されます。
「id」はInteger型のPrimary Keyで、自動採番されます。(AUTO_INCREMENT Behavior)
「created_at」と「updated_at」はレコード作成時に現在日時が自動的に設定されます。また、「updated_at」はレコード更新時にも自動的に上書きされます。
「histories」は次に定義する「History」との関係を表しているだけで、カラムは作成されません。「Stock」クラスのインスタンスからstock.hitories
で関連付いている「History」クラスのインスタンスのリストを取得することができます。また、「History」からも「stock」で、「Stock」クラスのインスタンスを取得できるようにしています。
「initialize」メソッドはコンストラクタです。何かしら初期化が必要な場合は__init__.pyではなく、reconstructorデコレータをつけたメソッドに記述します。
続けて「History」クラスを定義します。
class History(Base):
__tablename__ = 'history'
stock_id = Column(Integer, ForeignKey(Stock.id), primary_key=True)
date = Column(Date, primary_key=True)
price = Column(Float, nullable=False)
volume = Column(Integer, nullable=False)
上記クラスで、「history」というテーブルで、「stock_id」、「date」、「price」、「volume」というカラムがあるテーブルとマッピングするようになります。
「stock_id」は「Stock」クラスと「History」クラスとが、一対多の関係になるようにForeign Keyを設定しています。
モデルを作成したら、create_all
メソッドでテーブルを作成することができます。テーブル作成時には、DSNを設定したengineオブジェクトを渡します。下記の例ではSqlite3を使い、データはメモリに保存されます。
engine = create_engine('sqlite://', echo=True)
Base.metadata.create_all(engine, checkfirst=True)
このコードを実行すると、stock・historyテーブルが作成されます。
sqlite> .table
sqlite> .schema stock
CREATE TABLE stock (
id INTEGER NOT NULL,
name VARCHAR(64) NOT NULL,
symbol VARCHAR(32) NOT NULL,
created_at DATETIME,
updated_at DATETIME,
PRIMARY KEY (id),
UNIQUE (symbol)
);
sqlite> .schema history
CREATE TABLE history (
stock_id INTEGER NOT NULL,
date DATE NOT NULL,
price FLOAT NOT NULL,
volume INTEGER NOT NULL,
PRIMARY KEY (stock_id, date),
FOREIGN KEY(stock_id) REFERENCES stock (id)
);
sessionインスタンスの生成
モデルを定義したら、次はSessionインスタンスの生成です。このインスタンスを通じてSELECT文などを実行できるようになります。
Session = sessionmaker(bind=engine, autocommit=False)
session = Session()
SQLを実行
sessionインスタンスを生成できるようになったので、実際にSQLを実行できるようになりました。
※モデルを定義したファイルを「sqlalchemy_samples」とします。
# coding: utf-8
from datetime import date
from sqlalchemy_samples import Session, Stock, History
session = Session()
# autocommit=Falseなので、自動的にトランザクションが開始されます。
# session.begin()
# stockテーブルにレコードを追加します。
stock = Stock()
stock.name = 'Test'
stock.symbol = 'test'
session.add(stock)
# SQLを実行してstock.idを確定させます。
session.flush()
session.refresh(stock)
stock_id = stock.id
# historyテーブルにレコードを追加します。
hist = History()
hist.stock_id = stock_id
hist.date = date.today()
hist.price = 100.12
hist.volume = 1000
session.add(hist)
session.commit() # COMMITします。
session.close()
# SELECT文を実行します。
session = Session()
# 全件取得します。
stocks = session.query(Stock).all()
print('stocks: {}'.format(stocks))
# filterで絞り込めます。
stocks = session.query(Stock).filter(Stock.name == 'Test')
print('stocks(filter): {}'.format(list(stocks)))
# filter_byでも絞り込めます。
stocks = session.query(Stock).filter_by(name='Test')
print('stocks(filter_by): {}'.format(list(stocks)))
# 「.」で繋いでプロパティにアクセスすることができます。
print('name: {}'.format(stocks[0].name)) # name: Test
print('symbol: {}'.format(stocks[0].symbol)) # symbol: test
# historiesで紐づくhistoryを取得できます。
histories = stocks[0].histories
print('histories: {}'.format(histories))
# COUNTやEXISTSも実行できます。
print('count: {}'.format(session.query(Stock).filter_by(name='Test').count())) # count: 1
query = session.query(Stock).filter_by(name='Test')
print('exists: {}'.format(session.query(query.exists()).scalar())) # exists: True
session.close()
# stockテーブルのレコードを更新します。
session = Session()
stock = session.query(Stock).get(stock_id)
stock.name = 'Mod Test'
session.commit()
session.close()
# 更新されたことを確認します。
session = Session()
print('count(Test): {}'.format(session.query(Stock).filter_by(name='Test').count())) # count: 0
print('count(Mod Test): {}'.format(session.query(Stock).filter_by(name='Mod Test').count())) # count: 1
session.close()
# stockテーブルのレコードを削除します。
session = Session()
stock = session.query(Stock).get(stock_id)
session.delete(stock) # レコードを削除します。
print('count(stocks): {}'.format(session.query(Stock).count())) # count(stocks): 0
# historyテーブルのレコードも削除されていることを確認します。
# ※ relationshipに「cascade='delete'」が指定されているから削除されます。
print('count(histories) {}'.format(session.query(History).count())) # count(histories): 0
session.commit() # COMMITします。
SQLAlchemyの便利な使い方
トランザクションを管理
SQLAlchemyを使っていくと、トランザクションの開始(Session()
やsession.close()
など)が面倒になってきます。また、closeの書き忘れのような、トラブルとなりかねないことも起こりえます。これらはwithを使ってトランザクションを開始できるようにすると解決します。
まずは、モデルを定義したファイルにstart_session
関数を追加します。
...
from contextlib import contextmanager
...
@contextmanager
def start_session(commit=False):
session = None
try:
# トランザクションを開始します。
# ※autocommit=Falseなので、自動的にトランザクションが開始されます。
session = Session()
try:
yield session
if commit:
session.commit()
except:
# 例外発生時はトランザクションをロールバックして、その例外をそのまま投げます。
session.rollback()
raise
finally:
if session is not None:
session.close()
後は、このstart_sessionとwithを使って以下のようにSQLを実行できます。
# coding: utf-8
from datetime import date
from sqlalchemy_samples import start_session, Stock
# commit=Trueでトランザクションを開始します。
with start_session(commit=True) as session:
stock = Stock()
stock.name = 'Test'
stock.symbol = 'test'
stock.date = date.today()
session.add(stock)
# commit=Falseでトランザクションを開始します。
with start_session() as session:
stock = session.query(Stock).all()[0]
print('name: {}'.format(stock.name)) # name: Test
テーブル名を動的に決定
テーブル名を動的に決定する方法を紹介します。まずは、カラム(に対応したフィールド)を定義したMixinを作成します。
class DynamicTableMixin(object):
id = Column(Integer, primary_key=True)
@declared_attr
def stock_id(cls):
return Column(Integer, ForeignKey('stock.id'))
このクラスで作成されるテーブルは、「id」と「stock_id」という2つのカラムを持つようにしてあります。
「stock_id」はForeign Keyで、stockと動的なテーブルで一対多になるようにしてあります。
declared_attrは、Foreign KeyなどをMixinで定義するときに付ける必要があるデコレータです。もし、「stock_id」をHistoryクラスのようにフィールドとしてを定義すると、以下のようなエラーが出力されます。
sqlalchemy.exc.InvalidRequestError: Columns with foreign keys to other columns must be declared as @declared_attr callables on declarative mixin classes.
Mixinを定義したら、それを継承してモデルを作成します。
from random import randint
from sqlalchemy_samples import engine, Base, DynamicTableMixin
tablename = 'new_table_{}'.format(randint(10, 99)) # テーブル名を決定します。
class NewTable(DynamicTableMixin, Base):
__tablename__ = tablename
Base.metadata.create_all(engine) # new_tableを作成します。
これで「new_table_xx」という名前が作成されます。このテーブル名の「xx」の部分は10から99までの数字がランダムで入ります。
また、type関数を使って新しいクラスを定義しても、テーブルを作成することができます。
# coding: utf-8
from random import randint
from sqlalchemy_samples import engine, Base, DynamicTableMixin
tablename = 'new_table_{}'.format(randint(10, 99)) # テーブル名を決定します。
klass = type('NewTable', (DynamicTableMixin, Base), {'__tablename__': tablename})
Base.metadata.create_all(engine) # new_tableを作成します。
モデルのソースとテストケース
モデルの定義周りと、基本的なクエリ実行のテストケースを掲載します。