Making developers awesome at machine learning
Making developers awesome at machine learning

Image by Author | Canva
Machine learning models deliver real value only when they reach users, and APIs are the bridge that makes it happen. But exposing your model isn’t enough; you need a secure, scalable, and efficient API to ensure reliability. In this guide, we’ll build a production-ready ML API with FastAPI, adding authentication, input validation, and rate limiting. This way, your model doesn’t just work but works safely at scale.
In this guide, I’ll walk you through building a secure machine learning API. We’ll cover:
The project structure will look somewhat like this:
1 2 3 4 5 6 7 8 9 10 11 | secure-ml-API/ ├──app/ │ ├──main.py # FastAPI entry point │ ├──model.py # Model training and serialization │ ├──predict.py # Prediction logic │ ├──jwt.py # JWT authentication logic │ ├──rate_limit.py # Rate limiting logic │ ├──validation.py # Input validation logic ├──Dockerfile # Docker setup ├──requirements.txt # Python dependencies └──README.md # Documentation for the project |
Let’s do everything step by step.
To keep things simple, we’ll use a RandomForestClassifier on the Iris dataset. The RandomForestClassifier is a machine-learning model that classifies things (e.g., flowers, emails, customers). In the Iris flower dataset:
RandomForest checks patterns in the input numbers using many decision trees and returns the flower species that is likely based on those patterns.
1 2 3 4 5 6 7 8 9 10 11 12 | # Function to train the model and save it as a pickle file deftrain_model(): iris=load_iris() X,y=iris.data,iris.target clf=RandomForestClassifier() clf.fit(X,y) # Save the trained model withopen("app/model.pkl","wb")asf: pickle.dump(clf,f) if__name__=="__main__": train_model() |
Run this script to generate the model.pkl file.
Now let’s create a helper that loads the model and makes predictions from input data.
1 2 3 4 5 6 7 8 9 | importpickle importnumpyasnp # Load the model withopen("app/model.pkl","rb")asf: model=pickle.load(f) # Make Predictions defmake_prediction(data): arr=np.array(data).reshape(1,-1) # Reshape input to 2D returnint(model.predict(arr)[0]) #Return the predicted flower |
The function expects a list of 4 features (like [5.1, 3.5, 1.4, 0.2]).
FastAPI provides automatic input validation using the Pydantic model. This model will verify that incoming features are properly formatted. It also verifies that they are numeric values within the appropriate ranges before processing.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | frompydanticimportBaseModel,field_validator fromtypingimportList # Define a Pydantic model classPredictionInput(BaseModel): data:List[float] # Validator to check if the input list contains 4 values @field_validator("data") @classmethod defcheck_length(cls,v): iflen(v)!=4: raiseValueError("data must contain exactly 4 float values") returnv # Provide an example schema for documentation classConfig: json_schema_extra={ "example":{ "data":[5.1,3.5,1.4,0.2], } } |
Note: STEP 4-5 ARE OPTIONAL & ONLY FOR SECURITY PURPOSES
JWT (JSON Web Tokens) offers a safer authentication than simple token-based authentication. JWT allows for a more robust system where claims (user data, expiration, etc.) are embedded in the token. A shared secret or public/private key pair is used for verification.
We will use the pyjwt library to handle JWTs.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | importjwt importos fromdatetimeimportdatetime,timedelta fromfastapiimportHTTPException,status fromfastapi.securityimportOAuth2PasswordBearer fromtypingimportOptional fromfastapiimportDepends SECRET_KEY=os.getenv("SECRET_KEY","mysecretkey") ALGORITHM="HS256" ACCESS_TOKEN_EXPIRE_MINUTES=30 oauth2_scheme=OAuth2PasswordBearer(tokenUrl="token") defcreate_access_token(data:dict,expires_delta:Optional[timedelta]=None): ifexpires_delta: expire=datetime.utcnow()+expires_delta else: expire=datetime.utcnow()+timedelta(minutes=15) to_encode=data.copy() to_encode.update({"exp":expire}) encoded_jwt=jwt.encode(to_encode,SECRET_KEY,algorithm=ALGORITHM) returnencoded_jwt defverify_token(token:str=Depends(oauth2_scheme)): try: payload=jwt.decode(token,SECRET_KEY,algorithms=[ALGORITHM]) returnpayload exceptjwt.PyJWTError: raiseHTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token", ) |
You’ll need to create a route to get the JWT.
Rate limiting protects your API from being overused. It limits how many times each IP can send requests in a minute. I added this using middleware.
The RateLimitMiddleware checks the IP of each request, counts how many came in the last 60 seconds, and blocks the rest if the limit (default 60/min) is hit. It is also called the throttle rate. If someone crosses the limit, they get a “429 Too Many Requests” error.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | importtime fromfastapiimportRequest,HTTPException fromstarlette.middleware.baseimportBaseHTTPMiddleware importtime fromfastapiimportRequest,HTTPException fromstarlette.middleware.baseimportBaseHTTPMiddleware classRateLimitMiddleware(BaseHTTPMiddleware): def__init__(self,app,throttle_rate:int=60): super().__init__(app) self.throttle_rate=throttle_rate self.request_log={} # Track timestamps per IP asyncdefdispatch(self,request:Request,call_next): client_ip=request.client.host now=time.time() # Clean up old request logs older than 60 seconds self.request_log={ ip:[tsfortsintimesifts>now-60] forip,timesinself.request_log.items() } ip_history=self.request_log.get(client_ip,[]) iflen(ip_history)>=self.throttle_rate: raiseHTTPException(status_code=429,detail="Too many requests") ip_history.append(now) self.request_log[client_ip]=ip_history returnawaitcall_next(request) |
This is a simple, memory-based approach that works well for small projects.
Combine all the components into the main FastAPI app. This will include the routes for health checks, token generation, and prediction.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | fromfastapiimportFastAPI,Depends fromapp.predictimportmake_prediction fromapp.jwtimportverify_token,create_access_token,ACCESS_TOKEN_EXPIRE_MINUTES fromapp.rate_limitimportRateLimitMiddleware fromapp.validationimportPredictionInput fromdatetimeimporttimedelta # Initialize FastAPI app app=FastAPI() #Skip this route if you are not implementing step 5 # Add rate limiting middleware to limit requests to 5 per minute app.add_middleware(RateLimitMiddleware,throttle_rate=5) # Root endpoint to confirm the API is running @app.get("/") defroot(): return{"message":"Welcome to the Secure Machine Learning API"} #Skip this route if you are not implementing step 4 # This endpoint issues a token when valid credentials are provided @app.post("/token") deflogin(): # Define the expiration time for the token (e.g., 30 minutes) access_token_expires=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) # Generate the JWT token access_token=create_access_token(data={"sub":"user"},expires_delta=access_token_expires) return{"access_token":access_token,"token_type":"bearer"} # Prediction endpoint, requires a valid JWT token for authentication # Additionally, the input data is validated using the PredictionInput model @app.post("/predict") defpredict(input_data:PredictionInput,token:str=Depends(verify_token)): prediction=make_prediction(input_data.data) return{"prediction":prediction} |
Create a Dockerfile to package the app and all dependencies.
1 2 3 4 5 6 7 8 9 10 11 | # Use the official Python image FROMpython:3.10-slim # Set working directory WORKDIR/app # Install dependencies COPYrequirements.txt. RUNpipinstall--upgradepip&&pipinstall--no-cache-dir-rrequirements.txt # Copy the app code COPY./app./app # Run FastAPI app with Uvicorn CMD["python","-m","uvicorn","app.main:app","--host","0.0.0.0","--port","8000"] |
And a simple requirements.txt as:
1 2 3 4 5 6 7 8 9 10 11 | scikit-learn numpy python-dotenv pyjwt aioredis fastapi-limiter redis pydantic fastapi uvicorn starlette |
Use the following commands to run your API:
1 2 3 | # Build the Docker image and run it dockerbuild-tsecure-ml-api. dockerrun-p8000:8000secure-ml-api |
Now your machine leanring API will be available athttp://localhost:8000.
For that, first, get the JWT by running the following command:
1 | curl-XPOSThttp://localhost:8000/token |
Copy the access token and run the following command:
1 2 3 4 | curl-XPOSThttp://localhost:8000/predict \ -H"Content-Type: application/json"\ -H"Authorization: Bearer PASTE-TOKEN-HERE"\ -d'{"data": [1.5, 2.3, 3.1, 0.7]}' |
You should receive a prediction like:
1 | {"prediction":0} |
You can try different inputs to test the API.
Deploying ML models as secure APIs requires careful attention to authentication, validation, and scalability. By leveraging FastAPI’s speed and simplicity alongside Docker’s portability, you can create robust endpoints that safely expose your model’s predictions while protecting against misuse. This approach ensures your ML solutions are not just accurate but also reliable and secure in real-world applications.

...from learning the practical Python tricks
Discover how in my new Ebook:
Python for Machine Learning
It providesself-study tutorials withhundreds of working code to equip you with skills including:
debugging,profiling,duck typing,decorators,deployment, and much more...

Welcome!
I'mJason Brownlee PhD
and Ihelp developers get results withmachine learning.
Read more
ThePython for Machine Learning is where
you'll find theReally Good stuff.