250 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			250 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import io
 | |
| import json
 | |
| import os
 | |
| import pickle
 | |
| import threading
 | |
| import time
 | |
| import uuid
 | |
| from enum import Enum
 | |
| 
 | |
| import pandas as pd
 | |
| import requests
 | |
| from flask import Flask, jsonify, render_template, request, redirect, make_response
 | |
| 
 | |
| 
 | |
| app = Flask(__name__)
 | |
| 
 | |
| TREASURY_URL = "https://home.treasury.gov/resource-center/data-chart-center/interest-rates/daily-treasury-rates.csv/2024/all?type=daily_treasury_yield_curve&field_tdr_date_value=2024&page&_format=csv"
 | |
| 
 | |
| CURRENT_DATE = 20240524  # Set to friday since market is closed
 | |
| 
 | |
| RATES = None
 | |
| 
 | |
| 
 | |
| class Side(str, Enum):
 | |
|     BUY = "BUY"
 | |
|     SELL = "SELL"
 | |
| 
 | |
| 
 | |
| # Since we know all the types of treasuries, just enumerate them
 | |
| class Instrument(str, Enum):
 | |
|     I_1MO = "1 Mo"
 | |
|     I_2MO = "2 Mo"
 | |
|     I_3MO = "3 Mo"
 | |
|     I_4MO = "4 Mo"
 | |
|     I_6MO = "6 Mo"
 | |
|     I_1YR = "1 Yr"
 | |
|     I_2YR = "2 Yr"
 | |
|     I_3YR = "3 Yr"
 | |
|     I_5YR = "5 Yr"
 | |
|     I_7YR = "7 Yr"
 | |
|     I_10YR = "10 Yr"
 | |
|     I_20YR = "12 Yr"
 | |
|     I_30YR = "30 Yr"
 | |
| 
 | |
|     @classmethod
 | |
|     def from_text(cls, txt):
 | |
|         if txt == "1 Mo":
 | |
|             return Instrument.I_1MO
 | |
|         elif txt == "2 Mo":
 | |
|             return Instrument.I_2MO
 | |
|         elif txt == "3 Mo":
 | |
|             return Instrument.I_3MO
 | |
|         elif txt == "4 Mo":
 | |
|             return Instrument.I_4MO
 | |
|         elif txt == "6 Mo":
 | |
|             return Instrument.I_6MO
 | |
|         elif txt == "1 Yr":
 | |
|             return Instrument.I_1YR
 | |
|         elif txt == "2 Yr":
 | |
|             return Instrument.I_2YR
 | |
|         elif txt == "3 Yr":
 | |
|             return Instrument.I_3YR
 | |
|         elif txt == "5 Yr":
 | |
|             return Instrument.I_5YR
 | |
|         elif txt == "7 Yr":
 | |
|             return Instrument.I_7YR
 | |
|         elif txt == "10 Yr":
 | |
|             return Instrument.I_20YR
 | |
|         elif txt == "12 Yr":
 | |
|             return Instrument.I_20YR
 | |
|         return Instrument.I_30YR
 | |
| 
 | |
| 
 | |
| class Transaction:
 | |
|     def __init__(
 | |
|         self, user: str, instrument: Instrument, quantity: int, date: int, side: Side, rate: float
 | |
|     ):
 | |
|         self.user = user
 | |
|         self.instrument: str = str(instrument.value)
 | |
|         self.side: str = str(side.value)
 | |
|         self.quantity = quantity
 | |
|         self.date = date
 | |
|         self.rate = rate
 | |
| 
 | |
| 
 | |
| class User:
 | |
|     def __init__(self, balance: int = 100000) -> None:
 | |
|         self.ID = str(uuid.uuid4())
 | |
|         self.balance = balance
 | |
| 
 | |
| 
 | |
| USERS: dict[str, User] = {}
 | |
| TRANSACTIONS: dict[str, list[Transaction]] = {}
 | |
| 
 | |
| 
 | |
| def background_saver():
 | |
|     while True:
 | |
|         with open("users.p", "wb") as fp:
 | |
|             pickle.dump(USERS, fp)
 | |
|         with open("transactions.p", "wb") as fp:
 | |
|             pickle.dump(TRANSACTIONS, fp)
 | |
| 
 | |
|         time.sleep(5)
 | |
| 
 | |
| 
 | |
| def get_treasury_rates_for_year(year: int) -> pd.DataFrame | None:
 | |
|     global RATES
 | |
| 
 | |
|     use_cache = False
 | |
|     if RATES is None:
 | |
|         if os.path.isfile("./cache.csv"):
 | |
|             if time.time() - os.path.getmtime("./cache.csv") < 60000:
 | |
|                 use_cache = True
 | |
|     else:
 | |
|         use_cache = True
 | |
| 
 | |
|     if use_cache:
 | |
|         csv: pd.DataFrame = pd.read_csv("./cache.csv")
 | |
|         RATES = csv
 | |
|     else:
 | |
|         response = requests.get(TREASURY_URL)
 | |
|         response.raise_for_status()
 | |
|         raw_csv = response.content.decode("UTF-8")
 | |
|         csv_io = io.StringIO(raw_csv)
 | |
|         csv: pd.DataFrame = pd.read_csv(csv_io)
 | |
|         csv.to_csv("./cache.csv", index=False)
 | |
|         RATES = csv
 | |
| 
 | |
|     ret = {}
 | |
|     cols = list(RATES.columns.values)
 | |
|     for col in cols:
 | |
|         ret[col] = RATES[col].tolist()
 | |
| 
 | |
|     dates = ret["Date"]
 | |
|     del ret["Date"]
 | |
|     return ret, dates
 | |
| 
 | |
| 
 | |
| @app.route("/api/transactions/<id>")
 | |
| def get_transactions_for_user(id):
 | |
|     t = TRANSACTIONS.get(id)
 | |
|     if t is None:
 | |
|         return jsonify([])
 | |
| 
 | |
|     return jsonify([{"term": tx.instrument, "quantity": tx.quantity, "side": tx.side, "date": tx.date, "rate": tx.rate} for tx in t])
 | |
| 
 | |
| 
 | |
| @app.route("/api/buy", methods=["POST"])
 | |
| def buy():
 | |
|     content = request.json
 | |
|     print(content)
 | |
|     q = int(content["quantity"])
 | |
|     i = content["instrument"]
 | |
|     user_id = content["user_id"]
 | |
|     cost = q * 100
 | |
|     if USERS.get(user_id) is not None:
 | |
|         if cost > USERS[user_id].balance:
 | |
|             return "You do not have the funds for this transaction", 400
 | |
|     else:
 | |
|         return "User does not exist", 404
 | |
| 
 | |
|     current_rate = RATES[i].tolist()[-1]
 | |
| 
 | |
|     USERS[user_id].balance = USERS[user_id].balance - cost
 | |
|     t = Transaction(user_id, Instrument.from_text(i), q, time.time(), Side.BUY, current_rate)
 | |
| 
 | |
|     if TRANSACTIONS.get(user_id) is None:
 | |
|         TRANSACTIONS[user_id] = [t]
 | |
|     else:
 | |
|         TRANSACTIONS[user_id].append(t)
 | |
| 
 | |
|     tx = t
 | |
|     return jsonify({"term": tx.instrument, "quantity": tx.quantity, "side": tx.side, "date": tx.date, "rate": tx.rate})
 | |
| 
 | |
| 
 | |
| @app.route("/api/login", methods=["POST"])
 | |
| def login():
 | |
|     content = request.json
 | |
|     id = content.get("id")
 | |
|     if id is None:
 | |
|         return "ID cannot be empty", 400
 | |
|     u = USERS.get(id)
 | |
|     if u is None:
 | |
|         return "User does not exist", 404
 | |
|     resp = make_response(redirect("/"))
 | |
|     resp.set_cookie("userid", u.ID)
 | |
|     return resp
 | |
| 
 | |
| 
 | |
| @app.route("/api/createuser", methods=["POST"])
 | |
| def createuser():
 | |
|     u = User()
 | |
|     USERS[u.ID] = u
 | |
|     resp = make_response(redirect("/"))
 | |
|     resp.set_cookie("userid", u.ID)
 | |
|     return resp
 | |
| 
 | |
| 
 | |
| @app.route("/register")
 | |
| def register():
 | |
|     if "userid" in request.cookies:
 | |
|         return redirect("/")
 | |
|     return render_template("register.html")
 | |
| 
 | |
| 
 | |
| @app.route("/logout")
 | |
| def logout():
 | |
|     if "userid" not in request.cookies:
 | |
|         return redirect("/register")
 | |
|     resp = make_response(redirect("/register"))
 | |
|     resp.set_cookie("userid", "", expires=0)
 | |
|     return resp
 | |
| 
 | |
| 
 | |
| @app.route("/")
 | |
| def index():
 | |
|     if "userid" not in request.cookies:
 | |
|         return redirect("/register")
 | |
|     uid = request.cookies.get("userid")
 | |
|     if uid not in USERS:
 | |
|         resp = make_response(redirect("/register"))
 | |
|         resp.set_cookie("userid", "", expires=0)
 | |
|         return resp
 | |
|     rates, dates = get_treasury_rates_for_year(2024)
 | |
|     return render_template("index.html", rates=[rates], dates=dates, userid=uid, balance=USERS[uid].balance)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     # Load up files if they exist
 | |
|     if os.path.isfile("./users.p"):
 | |
|         print("Loading users")
 | |
|         with open("./users.p", "rb") as fp:
 | |
|             USERS = pickle.load(fp)
 | |
| 
 | |
|     if os.path.isfile("./transactions.p"):
 | |
|         print("Loading transactions")
 | |
|         with open("./transactions.p", "rb") as fp:
 | |
|             TRANSACTIONS = pickle.load(fp)
 | |
| 
 | |
|     # Start persistence
 | |
|     thread = threading.Thread(target=background_saver)
 | |
|     thread.start()
 | |
| 
 | |
|     # Start Server
 | |
|     app.run(host='0.0.0.0', port=8008)
 | |
| 
 | |
| 
 | |
| thread = threading.Thread(target=background_saver)
 | |
| thread.start()
 |