This page is a code-heavy cookbook for building production Flask services, with a bias toward ML-serving use cases that AI/ML data engineers hit day to day: loading a scikit-learn or PyTorch model at startup, validating request payloads, offloading heavy inference to Celery, and exposing Kubernetes-ready liveness and readiness probes. Generic CRUD, auth, uploads, and pagination are included because a real ML service needs them too.
Prerequisites:
pip install flask flask-sqlalchemy flask-marshmallow marshmallow-sqlalchemy
flask-jwt-extended flask-socketio celery[redis] redis scikit-learn joblib pydantic
torchdocker run -p 6379:6379 redis:7).Every snippet below is runnable as shown. Imports are complete. No ...
placeholders.
The smallest Flask app — a single file, a single route, dev server via
flask run.
# app.py
from flask import Flask, jsonify
app = Flask(__name__)
@app.route("/")
def index():
return jsonify(message="hello from flask", status="ok")
if __name__ == "__main__":
# dev only; use gunicorn in prod
app.run(host="0.0.0.0", port=5000, debug=True)
export FLASK_APP=app.py
flask run --host 0.0.0.0 --port 5000
# or
python app.py
For production use a WSGI server — never app.run():
gunicorn -w 4 -b 0.0.0.0:5000 app:app
A full User resource with SQLite for dev, SQLAlchemy ORM, and
Marshmallow schemas for validation and serialization. Notice the explicit 404/400
handling and the use of db.session.get (SQLAlchemy 2.x style).
# crud_app.py
from datetime import datetime
from flask import Flask, jsonify, request, abort
from flask_sqlalchemy import SQLAlchemy
from flask_marshmallow import Marshmallow
from marshmallow import fields, validate, ValidationError
app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///users.db"
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db = SQLAlchemy(app)
ma = Marshmallow(app)
class User(db.Model):
__tablename__ = "users"
id = db.Column(db.Integer, primary_key=True)
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
name = db.Column(db.String(120), nullable=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)
class UserSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = User
load_instance = True
sqla_session = db.session
email = fields.Email(required=True)
name = fields.Str(required=True, validate=validate.Length(min=1, max=120))
user_schema = UserSchema()
users_schema = UserSchema(many=True)
@app.errorhandler(ValidationError)
def handle_validation(err):
return jsonify(errors=err.messages), 400
@app.route("/users", methods=["POST"])
def create_user():
data = request.get_json(silent=True) or {}
user = user_schema.load(data) # validates + hydrates User instance
db.session.add(user)
db.session.commit()
return user_schema.dump(user), 201
@app.route("/users", methods=["GET"])
def list_users():
q = User.query.order_by(User.id.asc()).all()
return jsonify(users_schema.dump(q))
@app.route("/users/<int:user_id>", methods=["GET"])
def get_user(user_id):
user = db.session.get(User, user_id) or abort(404)
return user_schema.dump(user)
@app.route("/users/<int:user_id>", methods=["PUT"])
def update_user(user_id):
user = db.session.get(User, user_id) or abort(404)
data = request.get_json(silent=True) or {}
updated = user_schema.load(data, instance=user, partial=True)
db.session.commit()
return user_schema.dump(updated)
@app.route("/users/<int:user_id>", methods=["DELETE"])
def delete_user(user_id):
user = db.session.get(User, user_id) or abort(404)
db.session.delete(user)
db.session.commit()
return "", 204
if __name__ == "__main__":
with app.app_context():
db.create_all()
app.run(debug=True)
The underlying SQLite schema SQLAlchemy emits is equivalent to:
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email VARCHAR(255) NOT NULL UNIQUE,
name VARCHAR(120) NOT NULL,
created_at DATETIME NOT NULL
);
CREATE INDEX ix_users_email ON users(email);
flask-jwt-extended with access + refresh tokens and refresh rotation.
Access tokens are short-lived (15 min), refresh tokens longer (30 days); every
/refresh mints a new refresh token and the old one should be denylisted in
a real deployment (Redis set with TTL).
# auth_app.py
from datetime import timedelta
from flask import Flask, jsonify, request
from flask_sqlalchemy import SQLAlchemy
from flask_jwt_extended import (
JWTManager, create_access_token, create_refresh_token,
jwt_required, get_jwt_identity, get_jwt,
)
from werkzeug.security import generate_password_hash, check_password_hash
app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///auth.db"
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
app.config["JWT_SECRET_KEY"] = "change-me-in-prod"
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15)
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=30)
db = SQLAlchemy(app)
jwt = JWTManager(app)
# simple in-memory refresh-token denylist; use Redis in prod
REFRESH_DENYLIST: set[str] = set()
class Account(db.Model):
id = db.Column(db.Integer, primary_key=True)
email = db.Column(db.String(255), unique=True, nullable=False)
pw_hash = db.Column(db.String(255), nullable=False)
@jwt.token_in_blocklist_loader
def is_revoked(jwt_header, jwt_payload):
return jwt_payload["jti"] in REFRESH_DENYLIST
@app.post("/register")
def register():
data = request.get_json() or {}
if Account.query.filter_by(email=data["email"]).first():
return jsonify(error="email exists"), 409
acct = Account(email=data["email"], pw_hash=generate_password_hash(data["password"]))
db.session.add(acct)
db.session.commit()
return jsonify(id=acct.id), 201
@app.post("/login")
def login():
data = request.get_json() or {}
acct = Account.query.filter_by(email=data.get("email")).first()
if not acct or not check_password_hash(acct.pw_hash, data.get("password", "")):
return jsonify(error="bad credentials"), 401
return jsonify(
access_token=create_access_token(identity=str(acct.id)),
refresh_token=create_refresh_token(identity=str(acct.id)),
)
@app.post("/refresh")
@jwt_required(refresh=True)
def refresh():
# rotate: deny the old refresh jti and mint a fresh pair
old = get_jwt()
REFRESH_DENYLIST.add(old["jti"])
identity = get_jwt_identity()
return jsonify(
access_token=create_access_token(identity=identity),
refresh_token=create_refresh_token(identity=identity),
)
@app.get("/me")
@jwt_required()
def me():
return jsonify(user_id=get_jwt_identity())
if __name__ == "__main__":
with app.app_context():
db.create_all()
app.run(debug=True)
Safe upload with secure_filename, a 16 MB ceiling via
MAX_CONTENT_LENGTH, extension + magic-byte MIME validation. Never trust the
client-supplied Content-Type — sniff the bytes.
# uploads.py
import os
import magic # pip install python-magic (needs libmagic)
from pathlib import Path
from flask import Flask, request, jsonify, abort
from werkzeug.utils import secure_filename
UPLOAD_DIR = Path("/tmp/uploads")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
ALLOWED_EXT = {".png", ".jpg", ".jpeg", ".pdf", ".csv"}
ALLOWED_MIME = {"image/png", "image/jpeg", "application/pdf", "text/csv", "text/plain"}
app = Flask(__name__)
app.config["MAX_CONTENT_LENGTH"] = 16 * 1024 * 1024 # 16 MB
@app.errorhandler(413)
def too_big(_):
return jsonify(error="file too large, max 16MB"), 413
@app.post("/upload")
def upload():
if "file" not in request.files:
abort(400, "no file part")
f = request.files["file"]
if not f.filename:
abort(400, "empty filename")
safe = secure_filename(f.filename)
ext = os.path.splitext(safe)[1].lower()
if ext not in ALLOWED_EXT:
abort(400, f"extension {ext} not allowed")
head = f.stream.read(2048)
f.stream.seek(0)
mime = magic.from_buffer(head, mime=True)
if mime not in ALLOWED_MIME:
abort(400, f"mime {mime} not allowed")
dest = UPLOAD_DIR / safe
f.save(dest)
return jsonify(filename=safe, mime=mime, size=dest.stat().st_size), 201
if __name__ == "__main__":
app.run(debug=True)
Two files: a training script that persists the model to disk, and a serving app that
loads it once at startup (via the app factory — before_first_request
is removed in Flask 2.3+). Input is validated by Pydantic; the response includes both
the predicted class and a calibrated confidence from predict_proba.
# train.py (run offline)
import joblib
from sklearn.datasets import load_iris
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
X, y = load_iris(return_X_y=True)
pipe = Pipeline([
("scaler", StandardScaler()),
("clf", LogisticRegression(max_iter=500, multi_class="multinomial")),
])
pipe.fit(X, y)
joblib.dump({"model": pipe, "labels": ["setosa", "versicolor", "virginica"]},
"iris_model.joblib")
print("saved iris_model.joblib")
# serve_sklearn.py
import joblib
import numpy as np
from flask import Flask, jsonify, request
from pydantic import BaseModel, Field, ValidationError, conlist
class PredictIn(BaseModel):
# 4 iris features: sepal length/width, petal length/width (cm)
features: conlist(float, min_length=4, max_length=4) = Field(...)
def create_app(model_path: str = "iris_model.joblib") -> Flask:
app = Flask(__name__)
bundle = joblib.load(model_path) # loaded ONCE at process start
app.config["MODEL"] = bundle["model"]
app.config["LABELS"] = bundle["labels"]
@app.post("/predict")
def predict():
try:
payload = PredictIn(**(request.get_json() or {}))
except ValidationError as e:
return jsonify(errors=e.errors()), 400
x = np.asarray(payload.features, dtype=np.float64).reshape(1, -1)
proba = app.config["MODEL"].predict_proba(x)[0]
idx = int(np.argmax(proba))
return jsonify(
label=app.config["LABELS"][idx],
class_index=idx,
confidence=float(proba[idx]),
proba={app.config["LABELS"][i]: float(p) for i, p in enumerate(proba)},
)
return app
app = create_app()
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000)
curl -s -XPOST localhost:5000/predict \
-H 'content-type: application/json' \
-d '{"features": [5.1, 3.5, 1.4, 0.2]}'
# {"label":"setosa","class_index":0,"confidence":0.97,...}
A minimal torch.nn.Module, eval mode, torch.no_grad(),
device selection, and a batched endpoint that accepts an array of input vectors and
returns an array of predictions. Batching is critical for GPU throughput — one
forward pass per HTTP call leaves the GPU idle 99% of the time.
# serve_torch.py
import torch
import torch.nn as nn
from flask import Flask, jsonify, request
from pydantic import BaseModel, ValidationError, conlist
class MLP(nn.Module):
def __init__(self, in_dim=10, hidden=64, out_dim=3):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden), nn.ReLU(),
nn.Linear(hidden, hidden), nn.ReLU(),
nn.Linear(hidden, out_dim),
)
def forward(self, x):
return self.net(x)
class BatchIn(BaseModel):
# list of 10-dim vectors, max 512 per request
inputs: conlist(conlist(float, min_length=10, max_length=10),
min_length=1, max_length=512)
def create_app(weights_path: str | None = None) -> Flask:
app = Flask(__name__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MLP().to(device)
if weights_path:
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval()
app.config["MODEL"] = model
app.config["DEVICE"] = device
@app.post("/predict")
def predict():
try:
payload = BatchIn(**(request.get_json() or {}))
except ValidationError as e:
return jsonify(errors=e.errors()), 400
x = torch.tensor(payload.inputs, dtype=torch.float32,
device=app.config["DEVICE"])
with torch.no_grad():
logits = app.config["MODEL"](x)
probs = torch.softmax(logits, dim=-1)
conf, cls = probs.max(dim=-1)
return jsonify(
device=str(app.config["DEVICE"]),
batch_size=x.shape[0],
predictions=[
{"class": int(c), "confidence": float(p)}
for c, p in zip(cls.tolist(), conf.tolist())
],
)
return app
app = create_app()
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)
Flask should never block a gunicorn worker on a 30-second inference call. Push the job to Celery, return a task id, and let the client poll. Redis is both broker and result backend here.
# celery_app.py -- shared Celery instance
from celery import Celery
celery = Celery(
"ml_jobs",
broker="redis://localhost:6379/0",
backend="redis://localhost:6379/1",
)
celery.conf.update(
task_serializer="json",
result_serializer="json",
accept_content=["json"],
task_time_limit=300,
task_soft_time_limit=270,
)
# tasks.py -- worker-side
import time
import joblib
import numpy as np
from celery_app import celery
_BUNDLE = joblib.load("iris_model.joblib") # loaded per worker process
@celery.task(name="tasks.batch_predict", bind=True, max_retries=2)
def batch_predict(self, rows: list[list[float]]) -> list[dict]:
try:
X = np.asarray(rows, dtype=np.float64)
proba = _BUNDLE["model"].predict_proba(X)
idx = proba.argmax(axis=1)
return [
{"label": _BUNDLE["labels"][i], "confidence": float(proba[r, i])}
for r, i in enumerate(idx)
]
except Exception as exc:
raise self.retry(exc=exc, countdown=5)
# flask_dispatch.py -- web-side
from flask import Flask, jsonify, request
from celery.result import AsyncResult
from celery_app import celery
import tasks # noqa: F401 ensures task is registered
app = Flask(__name__)
@app.post("/jobs/predict")
def enqueue():
rows = (request.get_json() or {}).get("rows", [])
async_res = celery.send_task("tasks.batch_predict", args=[rows])
return jsonify(task_id=async_res.id), 202
@app.get("/jobs/<task_id>")
def status(task_id):
res = AsyncResult(task_id, app=celery)
body = {"task_id": task_id, "state": res.state}
if res.successful():
body["result"] = res.result
elif res.failed():
body["error"] = str(res.result)
return jsonify(body)
if __name__ == "__main__":
app.run(debug=True)
# run worker
celery -A celery_app.celery worker --loglevel=info --concurrency=4
# run flask
python flask_dispatch.py
# enqueue
curl -s -XPOST localhost:5000/jobs/predict -H 'content-type: application/json' \
-d '{"rows": [[5.1,3.5,1.4,0.2], [6.2,3.4,5.4,2.3]]}'
Bi-directional messaging for chat, live prediction streams, or training progress. Use the Redis message queue so multiple gunicorn/eventlet workers share state.
# socket_chat.py
from flask import Flask, render_template_string
from flask_socketio import SocketIO, emit, join_room, leave_room
app = Flask(__name__)
app.config["SECRET_KEY"] = "dev"
socketio = SocketIO(app, message_queue="redis://localhost:6379/2",
cors_allowed_origins="*")
PAGE = """
<script src="https://cdn.socket.io/4.7.5/socket.io.min.js"></script>
<script>
const s = io();
s.on('connect', () => s.emit('join', {room: 'lobby'}));
s.on('message', m => console.log(m));
function send(text){ s.emit('chat', {room:'lobby', text}); }
</script>
"""
@app.get("/")
def page():
return render_template_string(PAGE)
@socketio.on("join")
def on_join(data):
join_room(data["room"])
emit("message", {"sys": f"joined {data['room']}"}, room=data["room"])
@socketio.on("leave")
def on_leave(data):
leave_room(data["room"])
@socketio.on("chat")
def on_chat(data):
emit("message", {"text": data["text"]}, room=data["room"])
if __name__ == "__main__":
socketio.run(app, host="0.0.0.0", port=5000)
Offset pagination breaks down at scale (deep pages scan huge amounts of data and shift under concurrent writes). Cursor pagination uses the last row's sort key, base64-encoded as an opaque token.
# paginate.py
import base64
import json
from flask import Flask, jsonify, request
from flask_sqlalchemy import SQLAlchemy
app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///feed.db"
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db = SQLAlchemy(app)
class Post(db.Model):
id = db.Column(db.Integer, primary_key=True)
created_at = db.Column(db.DateTime, nullable=False, index=True)
title = db.Column(db.String(200), nullable=False)
def encode_cursor(created_at, pk) -> str:
raw = json.dumps({"ts": created_at.isoformat(), "id": pk}).encode()
return base64.urlsafe_b64encode(raw).decode().rstrip("=")
def decode_cursor(token: str) -> tuple[str, int]:
pad = "=" * (-len(token) % 4)
raw = base64.urlsafe_b64decode(token + pad)
d = json.loads(raw)
return d["ts"], d["id"]
@app.get("/posts")
def feed():
limit = min(int(request.args.get("limit", 20)), 100)
cursor = request.args.get("cursor")
q = Post.query.order_by(Post.created_at.desc(), Post.id.desc())
if cursor:
ts, pk = decode_cursor(cursor)
q = q.filter(
(Post.created_at < ts)
| ((Post.created_at == ts) & (Post.id < pk))
)
rows = q.limit(limit + 1).all()
has_more = len(rows) > limit
rows = rows[:limit]
next_cursor = (
encode_cursor(rows[-1].created_at, rows[-1].id) if has_more and rows else None
)
return jsonify(
items=[{"id": r.id, "title": r.title,
"created_at": r.created_at.isoformat()} for r in rows],
next_cursor=next_cursor,
)
Liveness and readiness are not the same probe. Liveness says
"the process is alive — don't restart me." Readiness says "I've loaded my model
and my dependencies are reachable — send me traffic." In Kubernetes, returning
200 on /ready before the model is loaded causes the service mesh to route
traffic that will 500.
# health.py
import time
from flask import Flask, jsonify
from sqlalchemy import text
from flask_sqlalchemy import SQLAlchemy
app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///app.db"
db = SQLAlchemy(app)
STATE = {"model_loaded": False, "started_at": time.time()}
def _load_model_at_startup():
# simulate a slow warmup (downloading weights, compiling, etc.)
time.sleep(2)
STATE["model_loaded"] = True
with app.app_context():
_load_model_at_startup()
@app.get("/health")
def liveness():
# cheap, no external deps — just confirms the process responds
return jsonify(status="alive", uptime_s=int(time.time() - STATE["started_at"]))
@app.get("/ready")
def readiness():
checks = {"model": STATE["model_loaded"], "db": False}
try:
db.session.execute(text("SELECT 1"))
checks["db"] = True
except Exception:
pass
ok = all(checks.values())
return jsonify(ready=ok, checks=checks), (200 if ok else 503)
Matching Kubernetes probe config:
livenessProbe:
httpGet:
path: /health
port: 5000
initialDelaySeconds: 10
periodSeconds: 15
failureThreshold: 3
readinessProbe:
httpGet:
path: /ready
port: 5000
initialDelaySeconds: 5
periodSeconds: 5
failureThreshold: 2
The real-world layout: a create_app() factory, config classes per
environment, blueprints for modularity, extensions instantiated at module scope and
bound inside the factory. This is the pattern every non-trivial Flask service should
start from.
# config.py
import os
class BaseConfig:
SECRET_KEY = os.environ.get("SECRET_KEY", "dev-secret")
SQLALCHEMY_TRACK_MODIFICATIONS = False
JSON_SORT_KEYS = False
class DevConfig(BaseConfig):
DEBUG = True
SQLALCHEMY_DATABASE_URI = "sqlite:///dev.db"
class TestConfig(BaseConfig):
TESTING = True
SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:"
class ProdConfig(BaseConfig):
DEBUG = False
SQLALCHEMY_DATABASE_URI = os.environ["DATABASE_URL"]
SECRET_KEY = os.environ["SECRET_KEY"]
CONFIG_MAP = {"dev": DevConfig, "test": TestConfig, "prod": ProdConfig}
# extensions.py
from flask_sqlalchemy import SQLAlchemy
from flask_jwt_extended import JWTManager
db = SQLAlchemy()
jwt = JWTManager()
# blueprints/health_bp.py
from flask import Blueprint, jsonify
health_bp = Blueprint("health", __name__)
@health_bp.get("/health")
def health():
return jsonify(status="ok")
# blueprints/predict_bp.py
from flask import Blueprint, jsonify, request, current_app
import numpy as np
predict_bp = Blueprint("predict", __name__, url_prefix="/api/v1")
@predict_bp.post("/predict")
def predict():
payload = request.get_json() or {}
x = np.asarray(payload["features"], dtype=np.float64).reshape(1, -1)
model = current_app.config["MODEL"]
labels = current_app.config["LABELS"]
proba = model.predict_proba(x)[0]
idx = int(np.argmax(proba))
return jsonify(label=labels[idx], confidence=float(proba[idx]))
# app/__init__.py
import os
import joblib
from flask import Flask
from config import CONFIG_MAP
from extensions import db, jwt
from blueprints.health_bp import health_bp
from blueprints.predict_bp import predict_bp
def create_app(env: str | None = None) -> Flask:
env = env or os.environ.get("APP_ENV", "dev")
app = Flask(__name__)
app.config.from_object(CONFIG_MAP[env])
# extensions
db.init_app(app)
jwt.init_app(app)
# ML model: loaded once at startup, stored on app.config
if env != "test":
bundle = joblib.load(os.environ.get("MODEL_PATH", "iris_model.joblib"))
app.config["MODEL"] = bundle["model"]
app.config["LABELS"] = bundle["labels"]
# blueprints
app.register_blueprint(health_bp)
app.register_blueprint(predict_bp)
with app.app_context():
db.create_all()
return app
# wsgi.py -- entry point for gunicorn
from app import create_app
app = create_app()
# dev
APP_ENV=dev python -m flask --app wsgi:app run
# prod
APP_ENV=prod DATABASE_URL=postgresql://user:pw@db/app \
SECRET_KEY=$(openssl rand -hex 32) \
gunicorn -w 4 -k gthread --threads 8 -b 0.0.0.0:5000 wsgi:app
From here you bolt on the pieces above as blueprints: auth routes, upload routes, Celery dispatch, Socket.IO, pagination. The factory stays small; the blueprints grow.