KATOエンジニヤリング開発日誌

「アウトプット無きエンジニアにインプットもチャンスも無い」の精神で書いています

KABU+で取得したCSVファイルをデータベースにインポートする

前々回KABU+で株価データのCSVファイルを取得しましたが、CSVファイルのままだとデータの加工や集計に不便なのでデータベースを作成することにしました。

www.kato-eng.info

現在CSVファイルの取得はConoHaのVPSを利用しています。VPS上にデータベースソフトをインストールして使用することも出来ますが、後学のためDBサーバを別途用意してVPS(APサーバ)から接続する構成にしました。

f:id:masayuki_kato:20181229134705p:plain
サーバ構成 ※ConoHaのHPから引用

www.conoha.jp

事前準備

はじめにVPS(AP)サーバからDBサーバに接続できるようにする必要があります。グローバルネットワークを利用してDBサーバに接続することもできますが、後々のパフォーマンスのことも考え、プライベートネットワークで接続できるようにします。

www.conoha.jp

上記の公式サイトを参考に設定を行いました。

その後、作成したDBサーバにテーブルを作成します。DDLは次のようになります。

CREATE TABLE rwsoa_japan_stock.japan_all_stock_prices(
    security_code INTEGER,
    dt DATE,
    company_name VARCHAR(50),
    stock_exchange_code INTEGER,
    industry_type INTEGER,
    opening_price DOUBLE,
    closing_price DOUBLE,
    high_price DOUBLE,
    low_price DOUBLE,
    day_before_ratio DOUBLE,
    day_before_ratio_percentage DOUBLE,
    last_day_closing_price DOUBLE,
    volume INTEGER,
    trading_value INTEGER,
    market_capitalization INTEGER,
    price_range_lower_limit DOUBLE,
    price_range_upper_limit DOUBLE,
    PRIMARY KEY (security_code, dt)
);

ディレクトリ構成

ディレクトリ構成は下記の通りです。

  • スクリプトファイル
    • /[適当なrootディレクトリ]/script/import_japan_all_stock.py
      • 日次で毎日インポートするスクリプト
    • /[適当なrootディレクトリ]/script/import_japan_all_stock_prices_monthly.py
      • 月次でひと月分のCSVファイルを一括でインポートするスクリプト
  • ダウンロードしたCSVファイルが格納されている場所
    • /[適当なrootディレクトリ]/data/japan_all_stock_prices/japan-all-stock-prices_YYYYmmdd.csv

スクリプト

今回もスクリプトはPythonで作成しました。

日次スクリプト

最初に日次で毎日インポートするスクリプトです。これはcronで毎日実行して取得したCSVファイルをインポートする用途で作成しました。

#coding:utf-8

import datetime
import csv
import mysql.connector
import sys

# Constants
DB_USER = '※ここに作成したDB操作用のユーザ名を指定'
DB_PASSWORD = '※DB操作ユーザのパスワードを指定'
DB_HOST = '※ConoHaのプライベートネットワークのホストを指定'
DB_DATABASE = '※DBサーバに作成したデータベーススキーマ名を指定'
TABLE = 'japan_all_stock_prices'
BASE_DIR = '/usr/local/script/'
CSV_FILE_DIR = BASE_DIR + "/../data/japan_all_stock_prices/"

args = sys.argv

def get_stock_exchange_code(stock_exchange_name):
    if stock_exchange_name == '東証一部':
        return 1
    elif stock_exchange_name == '東証二部':
        return 2
    elif stock_exchange_name == 'JQS':
        return 3
    elif stock_exchange_name == 'JQG':
        return 4
    elif stock_exchange_name == '東証マザ':
        return 5
    elif stock_exchange_name == '名証一部':
        return 6
    elif stock_exchange_name == '名証二部':
        return 7
    elif stock_exchange_name == '名証セント':
        return 8
    elif stock_exchange_name == '札証':
        return 9
    elif stock_exchange_name == '札証アンビ':
        return 10
    elif stock_exchange_name == '福証':
        return 11
    elif stock_exchange_name == '福証QB':
        return 12
    else:
        return 'null'


def get_industry_type(industry_name):
    if industry_name == '水産・農林':
        return 1
    elif industry_name == '鉱業':
        return 2
    elif industry_name == '建設':
        return 3
    elif industry_name == '食料品':
        return 4
    elif industry_name == '繊維製品':
        return 5
    elif industry_name == 'パルプ・紙':
        return 6
    elif industry_name == '化学':
        return 7
    elif industry_name == '医薬品':
        return 8
    elif industry_name == '石油・石炭':
        return 9
    elif industry_name == 'ゴム製品':
        return 10
    elif industry_name == 'ガラス土石':
        return 11
    elif industry_name == '鉄鋼':
        return 12
    elif industry_name == '非鉄金属':
        return 13
    elif industry_name == '金属製品':
        return 14
    elif industry_name == '機械':
        return 15
    elif industry_name == '電気機器':
        return 16
    elif industry_name == '輸送用機器':
        return 17
    elif industry_name == '精密機器':
        return 18
    elif industry_name == 'その他製品':
        return 19
    elif industry_name == '電気・ガス':
        return 20
    elif industry_name == '陸運':
        return 21
    elif industry_name =='海運':
        return 22
    elif industry_name == '空運':
        return 23
    elif industry_name == '倉庫・運輸':
        return 24
    elif industry_name == '情報通信':
        return 25
    elif industry_name == '卸売':
        return 26
    elif industry_name == '小売':
        return 27
    elif industry_name == '銀行':
        return 28
    elif industry_name == '証券・先物':
        return 29
    elif industry_name == '保険':
        return 30
    elif industry_name == 'その他金融':
        return 31
    elif industry_name == '不動産':
        return 32
    elif industry_name == 'サービス':
        return 33
    else:
        return 'null'


if __name__ == '__main__':
    # 対象の日付を設定(引数でYYYYMMDD形式で日付を入れるとその日付のファイルを対象とする)
    if len(args) < 2:
        TODAY = datetime.date.today()
    else:
        TARGET_DAY = args[1]
        TODAY = datetime.datetime(int(TARGET_DAY[:4]), int(TARGET_DAY[4:6]), int(TARGET_DAY[-2:]))

    # 対象のファイル名を取得する
    file_name_date_part = str(TODAY.year) + '{:0=2}'.format(TODAY.month) + '{:0=2}'.format(TODAY.day)
    file_name = 'japan-all-stock-prices_' + file_name_date_part + '.csv'

    with open (CSV_FILE_DIR + file_name) as csvfile:
        reader = csv.reader(csvfile)
        # headerと日経225、TOPIXをスキップする
        # KABU+では日本株全銘柄ファイルの最初の2行に日経平均株価とTOPIXが含まれているため、これとヘッダーをスキップする
        for i in range(3):
            next(reader, None)

        # MariaDB connect
        try:
            conn = mysql.connector.connect(user=DB_USER, password=DB_PASSWORD, host=DB_HOST, database=DB_DATABASE)
            cursor = conn.cursor()

            for row in reader:
                security_code = row[0] if row[0] != '-' else 'null'
                dt = file_name_date_part
                company_name = row[1] if row[1] != '-' else 'null'
                stock_exchange_code = get_stock_exchange_code(row[2])
                industry_type = get_industry_type(row[3])
                opening_price = row[9] if row[9] != '-' else 'null'
                closing_price = row[5] if row[5] != '-' else 'null'
                high_price = row[10] if row[10] != '-' else 'null'
                low_price = row[11] if row[11] != '-' else 'null'
                day_before_ratio = row[6] if row[6] != '-' else 'null'
                day_before_ratio_percentage = row[7] if row[7] != '-' else 'null'
                last_day_closing_price = row[8] if row[8] != '-' else 'null'
                volume = row[12] if row[12] != '-' else 'null'
                trading_value = row[13] if row[13] != '-' else 'null'
                market_capitalization = row[14] if row[14] != '-' else 'null'
                price_range_lower_limit = row[15] if row[15] != '-' else'null'
                price_range_upper_limit = row[16] if row[16] != '-' else 'null'
                cursor.execute('''INSERT INTO %s.%s (security_code, dt, company_name,
                               stock_exchange_code, industry_type, opening_price, closing_price, high_price, low_price,
                               day_before_ratio, day_before_ratio_percentage, last_day_closing_price, volume,
                               trading_value, market_capitalization, price_range_lower_limit, price_range_upper_limit)
                               VALUES(%s, '%s', "%s", %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s);
                               ''' % (DB_DATABASE, TABLE, security_code, dt, company_name, stock_exchange_code, industry_type, opening_price,
                                      closing_price, high_price, low_price, day_before_ratio, day_before_ratio_percentage,
                                      last_day_closing_price, volume, trading_value, market_capitalization,
                                      price_range_lower_limit, price_range_upper_limit))
        except mysql.connector.Error as e:
            print(e)
            conn.close()

        conn.commit()
        conn.close()

ひと月分を一括でインポートするスクリプト

次にひと月分のCSVファイルを一括でインポートするスクリプトです。これは

#coding:utf-8

import sys
import datetime
import os
import csv
import mysql.connector

# Constants
DB_USER = '※ここに作成したDB操作用のユーザ名を指定'
DB_PASSWORD = '※DB操作ユーザのパスワードを指定'
DB_HOST = '※ConoHaのプライベートネットワークのホストを指定'
DB_DATABASE = '※DBサーバに作成したデータベーススキーマ名を指定'
TABLE = 'japan_all_stock_prices'
BASE_DIR = '/usr/local/script/'
CSV_FILE_DIR = BASE_DIR + "/../data/japan_all_stock_prices/"


def get_stock_exchange_code(stock_exchange_name):
    if stock_exchange_name == '東証一部':
        return 1
    elif stock_exchange_name == '東証二部':
        return 2
    elif stock_exchange_name == 'JQS':
        return 3
    elif stock_exchange_name == 'JQG':
        return 4
    elif stock_exchange_name == '東証マザ':
        return 5
    elif stock_exchange_name == '名証一部':
        return 6
    elif stock_exchange_name == '名証二部':
        return 7
    elif stock_exchange_name == '名証セント':
        return 8
    elif stock_exchange_name == '札証':
        return 9
    elif stock_exchange_name == '札証アンビ':
        return 10
    elif stock_exchange_name == '福証':
        return 11
    elif stock_exchange_name == '福証QB':
        return 12
    else:
        return 'null'


def get_industry_type(industry_name):
    if industry_name == '水産・農林':
        return 1
    elif industry_name == '鉱業':
        return 2
    elif industry_name == '建設':
        return 3
    elif industry_name == '食料品':
        return 4
    elif industry_name == '繊維製品':
        return 5
    elif industry_name == 'パルプ・紙':
        return 6
    elif industry_name == '化学':
        return 7
    elif industry_name == '医薬品':
        return 8
    elif industry_name == '石油・石炭':
        return 9
    elif industry_name == 'ゴム製品':
        return 10
    elif industry_name == 'ガラス土石':
        return 11
    elif industry_name == '鉄鋼':
        return 12
    elif industry_name == '非鉄金属':
        return 13
    elif industry_name == '金属製品':
        return 14
    elif industry_name == '機械':
        return 15
    elif industry_name == '電気機器':
        return 16
    elif industry_name == '輸送用機器':
        return 17
    elif industry_name == '精密機器':
        return 18
    elif industry_name == 'その他製品':
        return 19
    elif industry_name == '電気・ガス':
        return 20
    elif industry_name == '陸運':
        return 21
    elif industry_name =='海運':
        return 22
    elif industry_name == '空運':
        return 23
    elif industry_name == '倉庫・運輸':
        return 24
    elif industry_name == '情報通信':
        return 25
    elif industry_name == '卸売':
        return 26
    elif industry_name == '小売':
        return 27
    elif industry_name == '銀行':
        return 28
    elif industry_name == '証券・先物':
        return 29
    elif industry_name == '保険':
        return 30
    elif industry_name == 'その他金融':
        return 31
    elif industry_name == '不動産':
        return 32
    elif industry_name == 'サービス':
        return 33
    else:
        return 'null'


def check_leap_year(year):
    # うるう年判定
    if int(year) % 400 == 0:
        return True
    elif int(year) % 4 == 0 and int(year) % 100 ==0:
        return False
    elif int(year) % 4 == 0:
        return True
    else:
        return False


if __name__ == '__main__':
    args = sys.argv
    if len(args[1]) != 6:
        print("コマンドライン引数が不正")
        sys.exit(1)

    target_month = args[1]
    year = target_month[:4]
    month = target_month[-2:]
    day_list = list()
    if month == '01':
        for i in range(1, 32):
            day_list.append(str(i).zfill(2))
    elif month == '02' and check_leap_year(year):
        for i in range(1, 30):
            day_list.append(str(i).zfill(2))
    elif month == '02' and not check_leap_year(year):
        for i in range(1, 29):
            day_list.append(str(i).zfill(2))
    elif month == '03':
        for i in range(1, 32):
            day_list.append(str(i).zfill(2))
    elif month == '04':
        for i in range(1, 31):
            day_list.append(str(i).zfill(2))
    elif month == '05':
        for i in range(1, 32):
            day_list.append(str(i).zfill(2))
    elif month == '06':
        for i in range(1, 31):
            day_list.append(str(i).zfill(2))
    elif month == '07':
        for i in range(1, 32):
            day_list.append(str(i).zfill(2))
    elif month == '08':
        for i in range(1, 32):
            day_list.append(str(i).zfill(2))
    elif month == '09':
        for i in range(1, 31):
            day_list.append(str(i).zfill(2))
    elif month == '10':
        for i in range(1, 32):
            day_list.append(str(i).zfill(2))
    elif month == '11':
        for i in range(1, 31):
            day_list.append(str(i).zfill(2))
    elif month == '12':
        for i in range(1, 32):
            day_list.append(str(i).zfill(2))

    for day in day_list:
        file_name_date_part = (year + month + day)

        # インポート対象のファイル名を指定
        file_name = 'japan-all-stock-prices_' + file_name_date_part + '.csv'

        # ファイルの存在チェック
        if not os.path.exists(CSV_FILE_DIR + file_name):
            continue

        with open (CSV_FILE_DIR + file_name) as csvfile:
            reader = csv.reader(csvfile)
            # headerと日経225、TOPIXをスキップする
            for i in range(3):
                next(reader, None)

            # MariaDB connect
            try:
                conn = mysql.connector.connect(user=DB_USER, password=DB_PASSWORD, host=DB_HOST, database=DB_DATABASE)
                cursor = conn.cursor()

                for row in reader:
                    security_code = row[0] if row[0] != '-' else 'null'
                    dt = file_name_date_part
                    company_name = row[1] if row[1] != '-' else 'null'
                    stock_exchange_code = get_stock_exchange_code(row[2])
                    industry_type = get_industry_type(row[3])
                    opening_price = row[9] if row[9] != '-' else 'null'
                    closing_price = row[5] if row[5] != '-' else 'null'
                    high_price = row[10] if row[10] != '-' else 'null'
                    low_price = row[11] if row[11] != '-' else 'null'
                    day_before_ratio = row[6] if row[6] != '-' else 'null'
                    day_before_ratio_percentage = row[7] if row[7] != '-' else 'null'
                    last_day_closing_price = row[8] if row[8] != '-' else 'null'
                    volume = row[12] if row[12] != '-' else 'null'
                    trading_value = row[13] if row[13] != '-' else 'null'
                    market_capitalization = row[14] if row[14] != '-' else 'null'
                    price_range_lower_limit = row[15] if row[15] != '-' else'null'
                    price_range_upper_limit = row[16] if row[16] != '-' else 'null'
                    query = '''INSERT INTO %s.%s (security_code, dt, company_name,
                                                   stock_exchange_code, industry_type, opening_price, closing_price, high_price, low_price,
                                                   day_before_ratio, day_before_ratio_percentage, last_day_closing_price, volume,
                                                   trading_value, market_capitalization, price_range_lower_limit, price_range_upper_limit)
                                                   VALUES(%s, '%s', '%s', %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s);
                                                   ''' % (
                    DB_DATABASE, TABLE, security_code, dt, company_name, stock_exchange_code, industry_type, opening_price,
                    closing_price, high_price, low_price, day_before_ratio, day_before_ratio_percentage,
                    last_day_closing_price, volume, trading_value, market_capitalization,
                    price_range_lower_limit, price_range_upper_limit)
                    cursor.execute(query)
            except mysql.connector.Error as e:
                print(e)
                conn.close()

        conn.commit()
        conn.close()

起動時に第1引数に「201812」のような形式で年月を指定します。引数を指定しない場合は処理が中止されるようにしています。

スクリプト内で対象年月の最終日を取得している処理がありますが、対象日が無い場合(土日祝日等)はスキップされるので、31日でしておいても良いです。そうすればうるう年を計算する処理も不要なのでコードがよりスッキリしますね。