diff options
author | Navan Chauhan <navanchauhan@gmail.com> | 2025-04-27 22:41:14 -0600 |
---|---|---|
committer | Navan Chauhan <navanchauhan@gmail.com> | 2025-04-27 22:41:14 -0600 |
commit | f32142947b853076889801913d47b8c2c0f4f456 (patch) | |
tree | 20981c0c2b79c2ba9cc58eece69591cfbe5b21ff | |
parent | ba700c31fceb4554ccbb3181f0e5747fcf5c3259 (diff) |
format using black
-rw-r--r-- | server/api/bids.py | 43 | ||||
-rw-r--r-- | server/api/market.py | 33 | ||||
-rw-r--r-- | server/api/pnl.py | 8 | ||||
-rw-r--r-- | server/create_db.py | 27 | ||||
-rw-r--r-- | server/models/auth.py | 1 | ||||
-rw-r--r-- | server/models/bid.py | 5 | ||||
-rw-r--r-- | server/models/market.py | 7 | ||||
-rw-r--r-- | server/services/get_data.py | 67 | ||||
-rw-r--r-- | server/services/process_bids.py | 13 |
9 files changed, 139 insertions, 65 deletions
diff --git a/server/api/bids.py b/server/api/bids.py index 6139987..c08c74b 100644 --- a/server/api/bids.py +++ b/server/api/bids.py @@ -8,8 +8,6 @@ from datetime import datetime, timezone from zoneinfo import ZoneInfo from typing import List, Optional -# TODO: Can you submit a bid for 2AM the next day after 11AM? - router = APIRouter() MARKET_TIMEZONES = { @@ -18,6 +16,7 @@ MARKET_TIMEZONES = { "MISO": ZoneInfo("America/Chicago"), } + def get_db(): db = SessionLocal() try: @@ -25,16 +24,21 @@ def get_db(): finally: db.close() + class BidBase(BaseModel): timestamp: datetime quantity: float price: float - user_id: int # In production the user_id should be obtained from the authenticated user + user_id: ( + int # In production the user_id should be obtained from the authenticated user + ) market: str + class BidCreate(BidBase): pass + class BidResponse(BidBase): id: int status: str @@ -43,14 +47,19 @@ class BidResponse(BidBase): class Config: from_attributes = True + @router.get("/", response_model=List[BidResponse]) def get_bids(db: Session = Depends(get_db)): return db.query(BidModel).all() + @router.post("/", response_model=BidResponse) def submit_bid(bid: BidCreate, db: Session = Depends(get_db)): if bid.market not in MARKET_TIMEZONES: - raise HTTPException(status_code=400, detail=f"Invalid market. Supported markets: {list(MARKET_TIMEZONES.keys())}") + raise HTTPException( + status_code=400, + detail=f"Invalid market. Supported markets: {list(MARKET_TIMEZONES.keys())}", + ) market_tz = MARKET_TIMEZONES[bid.market] @@ -68,19 +77,29 @@ def submit_bid(bid: BidCreate, db: Session = Depends(get_db)): if bid_day == today: if now > cutoff_time: - raise HTTPException(status_code=400, detail="Cannot submit bids for today after 11AM local time.") + raise HTTPException( + status_code=400, + detail="Cannot submit bids for today after 11AM local time.", + ) start_of_hour = bid.timestamp.replace(minute=0, second=0, microsecond=0) end_of_hour = start_of_hour.replace(minute=59, second=59, microsecond=999999) - bid_count = db.query(BidModel).filter( - BidModel.timestamp >= start_of_hour, - BidModel.timestamp <= end_of_hour, - BidModel.market == bid.market - ).count() + bid_count = ( + db.query(BidModel) + .filter( + BidModel.timestamp >= start_of_hour, + BidModel.timestamp <= end_of_hour, + BidModel.market == bid.market, + ) + .count() + ) if bid_count >= 10: - raise HTTPException(status_code=400, detail="Cannot submit more than 10 bids for this hour in this market.") + raise HTTPException( + status_code=400, + detail="Cannot submit more than 10 bids for this hour in this market.", + ) db_bid = BidModel( timestamp=bid.timestamp, @@ -89,7 +108,7 @@ def submit_bid(bid: BidCreate, db: Session = Depends(get_db)): user_id=bid.user_id, market=bid.market, status="Submitted", - pnl=None + pnl=None, ) db.add(db_bid) db.commit() diff --git a/server/api/market.py b/server/api/market.py index 1d17857..7aaaa29 100644 --- a/server/api/market.py +++ b/server/api/market.py @@ -7,28 +7,34 @@ from db import SessionLocal router = APIRouter() -# Only allow these markets SUPPORTED_MARKETS: Dict[str, str] = { "ISONE": "ISONE", "MISO": "MISO", "NYISO": "NYISO", } + def check_market_supported(market: str): market = market.upper() if market not in SUPPORTED_MARKETS: - raise HTTPException(status_code=400, detail=f"Unsupported market '{market}'. Supported: {list(SUPPORTED_MARKETS.keys())}") + raise HTTPException( + status_code=400, + detail=f"Unsupported market '{market}'. Supported: {list(SUPPORTED_MARKETS.keys())}", + ) return market + @router.get("/day-ahead", response_model=List[MarketData]) def get_day_ahead_data(market: str = Query("ISONE")): db: Session = SessionLocal() market = check_market_supported(market) - records = db.query(MarketDataDB)\ - .filter(MarketDataDB.market == market, MarketDataDB.type == "DAYAHEAD")\ - .order_by(MarketDataDB.timestamp)\ + records = ( + db.query(MarketDataDB) + .filter(MarketDataDB.market == market, MarketDataDB.type == "DAYAHEAD") + .order_by(MarketDataDB.timestamp) .all() + ) db.close() return [ @@ -38,23 +44,27 @@ def get_day_ahead_data(market: str = Query("ISONE")): energy=r.energy, congestion=r.congestion, loss=r.loss, - ) for r in records + ) + for r in records ] + @router.get("/real-time", response_model=List[MarketData]) def get_real_time_data(market: str = Query("ISONE")): db: Session = SessionLocal() market = check_market_supported(market) start_time = datetime.utcnow() - timedelta(days=1) - records = db.query(MarketDataDB)\ + records = ( + db.query(MarketDataDB) .filter( MarketDataDB.market == market, MarketDataDB.type == "REALTIME", - MarketDataDB.timestamp >= start_time - )\ - .order_by(MarketDataDB.timestamp)\ + MarketDataDB.timestamp >= start_time, + ) + .order_by(MarketDataDB.timestamp) .all() + ) db.close() return [ @@ -64,5 +74,6 @@ def get_real_time_data(market: str = Query("ISONE")): energy=r.energy, congestion=r.congestion, loss=r.loss, - ) for r in records + ) + for r in records ] diff --git a/server/api/pnl.py b/server/api/pnl.py deleted file mode 100644 index 7565f96..0000000 --- a/server/api/pnl.py +++ /dev/null @@ -1,8 +0,0 @@ -from fastapi import APIRouter - -router = APIRouter() - -@router.get("/pnl") -def get_pnl(): - # TODO: Real logic - return {"profit": 42.0} diff --git a/server/create_db.py b/server/create_db.py index ca2f818..4142acf 100644 --- a/server/create_db.py +++ b/server/create_db.py @@ -8,6 +8,7 @@ from zoneinfo import ZoneInfo NEW_ENGLAND_TZ = ZoneInfo("America/New_York") + def init_db(): Base.metadata.create_all(bind=engine) db = SessionLocal() @@ -23,10 +24,15 @@ def init_db(): print("Default user already exists.") # Insert dummy bids for 2025-04-25 - existing_bids = db.query(Bid).filter(Bid.timestamp.between( - datetime(2025, 4, 25, 0, 0), - datetime(2025, 4, 25, 23, 59) - )).all() + existing_bids = ( + db.query(Bid) + .filter( + Bid.timestamp.between( + datetime(2025, 4, 25, 0, 0), datetime(2025, 4, 25, 23, 59) + ) + ) + .all() + ) if not existing_bids: print("Inserting dummy bids for 2025-04-25...") @@ -47,7 +53,7 @@ def init_db(): user_id=user.id, market="ISONE", status="Submitted", - pnl=None + pnl=None, ) dummy_bids.append(bid) @@ -59,13 +65,17 @@ def init_db(): # Insert one dummy bid for today at 11:00PM local time today_local = datetime.now(NEW_ENGLAND_TZ).date() - bid_time_local = datetime.combine(today_local, datetime.min.time(), tzinfo=NEW_ENGLAND_TZ).replace(hour=23) + bid_time_local = datetime.combine( + today_local, datetime.min.time(), tzinfo=NEW_ENGLAND_TZ + ).replace(hour=23) bid_time_utc = bid_time_local.astimezone(timezone.utc) existing_bid_today = db.query(Bid).filter(Bid.timestamp == bid_time_utc).first() if not existing_bid_today: - print(f"Inserting dummy bid for today at {bid_time_local.strftime('%Y-%m-%d %I:%M %p')} local time...") + print( + f"Inserting dummy bid for today at {bid_time_local.strftime('%Y-%m-%d %I:%M %p')} local time..." + ) today_bid = Bid( timestamp=bid_time_utc, quantity=20.0, @@ -73,7 +83,7 @@ def init_db(): user_id=user.id, market="ISONE", status="Submitted", - pnl=None + pnl=None, ) db.add(today_bid) db.commit() @@ -83,5 +93,6 @@ def init_db(): db.close() + if __name__ == "__main__": init_db() diff --git a/server/models/auth.py b/server/models/auth.py index 0bf1d18..b20a84f 100644 --- a/server/models/auth.py +++ b/server/models/auth.py @@ -2,6 +2,7 @@ from sqlalchemy import Column, Integer, String from sqlalchemy.orm import relationship from db import Base + class User(Base): __tablename__ = "users" diff --git a/server/models/bid.py b/server/models/bid.py index e717011..121020d 100644 --- a/server/models/bid.py +++ b/server/models/bid.py @@ -2,14 +2,15 @@ from sqlalchemy import Column, Integer, Float, String, DateTime, ForeignKey from sqlalchemy.orm import relationship from db import Base + class Bid(Base): __tablename__ = "bids" id = Column(Integer, primary_key=True, index=True) timestamp = Column(DateTime, index=True, nullable=False) # Bid target time quantity = Column(Float, nullable=False) # MWh - price = Column(Float, nullable=False) # $/MWh - market = Column(String, nullable=False) # Market name: ISONE / MISO / NYISO for now + price = Column(Float, nullable=False) # $/MWh + market = Column(String, nullable=False) # Market name: ISONE / MISO / NYISO for now status = Column(String, default="Submitted") # Submitted / Success / Fail pnl = Column(Float, nullable=True) # Profit/loss value, nullable initially diff --git a/server/models/market.py b/server/models/market.py index 8606075..0127141 100644 --- a/server/models/market.py +++ b/server/models/market.py @@ -3,6 +3,7 @@ from pydantic import BaseModel from sqlalchemy import Column, Integer, String, Float, DateTime from db import Base + class MarketData(BaseModel): timestamp: datetime lmp: float @@ -10,7 +11,7 @@ class MarketData(BaseModel): congestion: float loss: float -# New DB model + class MarketDataDB(Base): __tablename__ = "market_data" @@ -20,5 +21,5 @@ class MarketDataDB(Base): energy = Column(Float) congestion = Column(Float) loss = Column(Float) - market = Column(String, index=True) # eg. "ISONE" - type = Column(String, index=True) # "REALTIME" or "DAYAHEAD" + market = Column(String, index=True) # eg. "ISONE" + type = Column(String, index=True) # "REALTIME" or "DAYAHEAD" diff --git a/server/services/get_data.py b/server/services/get_data.py index 1b446c6..3014ea1 100644 --- a/server/services/get_data.py +++ b/server/services/get_data.py @@ -11,12 +11,16 @@ MARKET_CLASSES: Dict[str, Type] = { "NYISO": NYISO, } + def get_iso_instance(market: str): market = market.upper() if market not in MARKET_CLASSES: - raise ValueError(f"Unsupported market '{market}'. Supported: {list(MARKET_CLASSES.keys())}") + raise ValueError( + f"Unsupported market '{market}'. Supported: {list(MARKET_CLASSES.keys())}" + ) return MARKET_CLASSES[market]() + def update_market_data(): db: Session = SessionLocal() @@ -24,20 +28,30 @@ def update_market_data(): print(f"Processing {market_name}") iso = get_iso_instance(market_name) - # --- Real-Time Data (5 min) --- - last_realtime = db.query(MarketDataDB)\ - .filter(MarketDataDB.market == market_name, MarketDataDB.type == "REALTIME")\ - .order_by(MarketDataDB.timestamp.desc())\ + last_realtime = ( + db.query(MarketDataDB) + .filter(MarketDataDB.market == market_name, MarketDataDB.type == "REALTIME") + .order_by(MarketDataDB.timestamp.desc()) .first() + ) - if not last_realtime or (datetime.now(UTC) - last_realtime.timestamp.replace(tzinfo=UTC) > timedelta(minutes=5)): + if not last_realtime or ( + datetime.now(UTC) - last_realtime.timestamp.replace(tzinfo=UTC) + > timedelta(minutes=5) + ): print(f"Getting realtime data for {market_name}") df = iso.get_lmp(date="latest", market="REAL_TIME_5_MIN", locations="ALL") - df["Interval Start"] = df["Interval Start"].dt.tz_convert('UTC') - grouped = df.groupby("Interval Start")[["LMP", "Energy", "Congestion", "Loss"]].mean().reset_index() + df["Interval Start"] = df["Interval Start"].dt.tz_convert("UTC") + grouped = ( + df.groupby("Interval Start")[["LMP", "Energy", "Congestion", "Loss"]] + .mean() + .reset_index() + ) for _, row in grouped.iterrows(): - if last_realtime and row["Interval Start"] <= last_realtime.timestamp.replace(tzinfo=UTC): + if last_realtime and row[ + "Interval Start" + ] <= last_realtime.timestamp.replace(tzinfo=UTC): continue # Skip old data entry = MarketDataDB( timestamp=row["Interval Start"], @@ -50,24 +64,38 @@ def update_market_data(): ) db.add(entry) - # --- Day-Ahead Hourly Data (1 hour) --- - last_dayahead = db.query(MarketDataDB)\ - .filter(MarketDataDB.market == market_name, MarketDataDB.type == "DAYAHEAD")\ - .order_by(MarketDataDB.timestamp.desc())\ + last_dayahead = ( + db.query(MarketDataDB) + .filter(MarketDataDB.market == market_name, MarketDataDB.type == "DAYAHEAD") + .order_by(MarketDataDB.timestamp.desc()) .first() + ) - if not last_dayahead or (datetime.now(UTC) - last_dayahead.timestamp.replace(tzinfo=UTC) > timedelta(hours=1)): + if not last_dayahead or ( + datetime.now(UTC) - last_dayahead.timestamp.replace(tzinfo=UTC) + > timedelta(hours=1) + ): print(f"Getting day-ahead data for {market_name}") now_utc = datetime.now(UTC) day_ahead_date = now_utc.date() - if now_utc.hour >= 18: # After 6PM UTC, markets usually publish next day's data + if ( + now_utc.hour >= 18 + ): # After 6PM UTC, markets usually publish next day's data day_ahead_date += timedelta(days=1) - df = iso.get_lmp(date=day_ahead_date, market="DAY_AHEAD_HOURLY", locations="ALL") - df["Interval Start"] = df["Interval Start"].dt.tz_convert('UTC') - grouped = df.groupby("Interval Start")[["LMP", "Energy", "Congestion", "Loss"]].mean().reset_index() + df = iso.get_lmp( + date=day_ahead_date, market="DAY_AHEAD_HOURLY", locations="ALL" + ) + df["Interval Start"] = df["Interval Start"].dt.tz_convert("UTC") + grouped = ( + df.groupby("Interval Start")[["LMP", "Energy", "Congestion", "Loss"]] + .mean() + .reset_index() + ) for _, row in grouped.iterrows(): - if last_dayahead and row["Interval Start"] <= last_dayahead.timestamp.replace(tzinfo=UTC): + if last_dayahead and row[ + "Interval Start" + ] <= last_dayahead.timestamp.replace(tzinfo=UTC): continue entry = MarketDataDB( timestamp=row["Interval Start"], @@ -83,5 +111,6 @@ def update_market_data(): db.commit() db.close() + if __name__ == "__main__": update_market_data() diff --git a/server/services/process_bids.py b/server/services/process_bids.py index 4ecbef9..a600128 100644 --- a/server/services/process_bids.py +++ b/server/services/process_bids.py @@ -18,18 +18,25 @@ MARKET_TIMEZONES = { "MISO": ZoneInfo("America/Chicago"), } + def get_day_ahead_price(market: str, target_time: datetime) -> float: """Fetch the Day Ahead clearing price for the hour of target_time.""" iso = MARKET_ISOS[market] - df = iso.get_lmp(date=target_time.date(), market="DAY_AHEAD_HOURLY", locations="ALL") + df = iso.get_lmp( + date=target_time.date(), market="DAY_AHEAD_HOURLY", locations="ALL" + ) df = df.groupby("Interval Start")["LMP"].mean().reset_index() for _, row in df.iterrows(): - if abs(row["Interval Start"] - target_time.replace(minute=0, second=0, microsecond=0)) < timedelta(minutes=30): + if abs( + row["Interval Start"] + - target_time.replace(minute=0, second=0, microsecond=0) + ) < timedelta(minutes=30): return row["LMP"] raise ValueError(f"No day ahead price found for {target_time} in {market}") + def get_real_time_prices(market: str, target_time: datetime) -> list[float]: """Fetch the Real Time 5-min prices during the hour of target_time.""" iso = MARKET_ISOS[market] @@ -47,6 +54,7 @@ def get_real_time_prices(market: str, target_time: datetime) -> list[float]: return prices + def process_bids(): db: Session = SessionLocal() @@ -121,5 +129,6 @@ def process_bids(): db.close() + if __name__ == "__main__": process_bids() |