Movatterモバイル変換


[0]ホーム

URL:


Navigation

MachineLearningMastery.com

Making developers awesome at machine learning

Making developers awesome at machine learning

Creating a Secure Machine Learning API with FastAPI and Docker

Creating a Secure ML API with FastAPI and Docker

Image by Author | Canva

Machine learning models deliver real value only when they reach users, and APIs are the bridge that makes it happen. But exposing your model isn’t enough; you need a secure, scalable, and efficient API to ensure reliability. In this guide, we’ll build a production-ready ML API with FastAPI, adding authentication, input validation, and rate limiting. This way, your model doesn’t just work but works safely at scale.

In this guide, I’ll walk you through building a secure machine learning API. We’ll cover:

  • Building a fast, efficient API using FastAPI
  • Protecting your endpoints using JWT (JSON Web Token) authentication
  • Make sure the inputs to your model are valid and safe
  • Adding rate limiting to your API endpoints to guard against misuse or overload
  • Packaging everything neatly with Docker for consistent deployment

The project structure will look somewhat like this:

1
2
3
4
5
6
7
8
9
10
11
secure-ml-API/
├──app/
  ├──main.py          # FastAPI entry point
  ├──model.py        # Model training and serialization
  ├──predict.py      # Prediction logic
  ├──jwt.py          # JWT authentication logic
  ├──rate_limit.py    # Rate limiting logic
  ├──validation.py    # Input validation logic
├──Dockerfile          # Docker setup
├──requirements.txt    # Python dependencies
└──README.md            # Documentation for the project

Let’s do everything step by step.

Step 1: Train & Serialize the Model (app/model.py)

To keep things simple, we’ll use a RandomForestClassifier on the Iris dataset. The RandomForestClassifier is a machine-learning model that classifies things (e.g., flowers, emails, customers). In the Iris flower dataset:

  • Input: 4 numbers (sepal & petal length/width)
  • Output: Species (0=Setosa, 1=Versicolor, or 2=Virginica)

RandomForest checks patterns in the input numbers using many decision trees and returns the flower species that is likely based on those patterns.

1
2
3
4
5
6
7
8
9
10
11
12
# Function to train the model and save it as a pickle file
deftrain_model():
    iris=load_iris()
    X,y=iris.data,iris.target
    clf=RandomForestClassifier()
    clf.fit(X,y)
    # Save the trained model
    withopen("app/model.pkl","wb")asf:
        pickle.dump(clf,f)
 
if__name__=="__main__":
    train_model()

Run this script to generate the model.pkl file.

Step 2: Define Prediction Logic (app/predict.py)

Now let’s create a helper that loads the model and makes predictions from input data.

1
2
3
4
5
6
7
8
9
importpickle
importnumpyasnp
# Load the model
withopen("app/model.pkl","rb")asf:
  model=pickle.load(f)
# Make Predictions
defmake_prediction(data):
  arr=np.array(data).reshape(1,-1)  # Reshape input to 2D
  returnint(model.predict(arr)[0])  #Return the predicted flower

The function expects a list of 4 features (like [5.1, 3.5, 1.4, 0.2]).

Step 3: Validate the Input (app/validation.py)

FastAPI provides automatic input validation using the Pydantic model. This model will verify that incoming features are properly formatted. It also verifies that they are numeric values within the appropriate ranges before processing.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
frompydanticimportBaseModel,field_validator
fromtypingimportList
 
# Define a Pydantic model
classPredictionInput(BaseModel):
  
    data:List[float]
 
    # Validator to check if the input list contains 4 values
    @field_validator("data")
    @classmethod
    defcheck_length(cls,v):
        iflen(v)!=4:
            raiseValueError("data must contain exactly 4 float values")
        returnv
 
    # Provide an example schema for documentation
    classConfig:
        json_schema_extra={
            "example":{
                "data":[5.1,3.5,1.4,0.2],
            }
        }

Note: STEP 4-5 ARE OPTIONAL & ONLY FOR SECURITY PURPOSES

Step 4: Add JWT Authentication (app/jwt.py)

JWT (JSON Web Tokens) offers a safer authentication than simple token-based authentication. JWT allows for a more robust system where claims (user data, expiration, etc.) are embedded in the token. A shared secret or public/private key pair is used for verification.

We will use the pyjwt library to handle JWTs.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
importjwt
importos
fromdatetimeimportdatetime,timedelta
fromfastapiimportHTTPException,status
fromfastapi.securityimportOAuth2PasswordBearer
fromtypingimportOptional
fromfastapiimportDepends
 
SECRET_KEY=os.getenv("SECRET_KEY","mysecretkey")
ALGORITHM="HS256"
ACCESS_TOKEN_EXPIRE_MINUTES=30
oauth2_scheme=OAuth2PasswordBearer(tokenUrl="token")
 
defcreate_access_token(data:dict,expires_delta:Optional[timedelta]=None):
  ifexpires_delta:
      expire=datetime.utcnow()+expires_delta
  else:
      expire=datetime.utcnow()+timedelta(minutes=15)
  to_encode=data.copy()
  to_encode.update({"exp":expire})
  encoded_jwt=jwt.encode(to_encode,SECRET_KEY,algorithm=ALGORITHM)
  returnencoded_jwt
 
defverify_token(token:str=Depends(oauth2_scheme)):
  try:
      payload=jwt.decode(token,SECRET_KEY,algorithms=[ALGORITHM])
      returnpayload
  exceptjwt.PyJWTError:
      raiseHTTPException(
          status_code=status.HTTP_401_UNAUTHORIZED,
          detail="Invalid token",
      )

You’ll need to create a route to get the JWT.

Step 5: Protect Your API with Rate Limiting (app/rate_limit.py)

Rate limiting protects your API from being overused. It limits how many times each IP can send requests in a minute. I added this using middleware.

The RateLimitMiddleware checks the IP of each request, counts how many came in the last 60 seconds, and blocks the rest if the limit (default 60/min) is hit. It is also called the throttle rate. If someone crosses the limit, they get a “429 Too Many Requests” error.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
importtime
fromfastapiimportRequest,HTTPException
fromstarlette.middleware.baseimportBaseHTTPMiddleware
importtime
fromfastapiimportRequest,HTTPException
fromstarlette.middleware.baseimportBaseHTTPMiddleware
 
classRateLimitMiddleware(BaseHTTPMiddleware):
    def__init__(self,app,throttle_rate:int=60):
        super().__init__(app)
        self.throttle_rate=throttle_rate
        self.request_log={}  # Track timestamps per IP
 
    asyncdefdispatch(self,request:Request,call_next):
        client_ip=request.client.host
        now=time.time()
 
        # Clean up old request logs older than 60 seconds
        self.request_log={
            ip:[tsfortsintimesifts>now-60]
            forip,timesinself.request_log.items()
        }
 
        ip_history=self.request_log.get(client_ip,[])
 
        iflen(ip_history)>=self.throttle_rate:
            raiseHTTPException(status_code=429,detail="Too many requests")
 
        ip_history.append(now)
        self.request_log[client_ip]=ip_history
 
        returnawaitcall_next(request)

This is a simple, memory-based approach that works well for small projects.

Step 6: Build the FastAPI Application

Combine all the components into the main FastAPI app. This will include the routes for health checks, token generation, and prediction.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
fromfastapiimportFastAPI,Depends
fromapp.predictimportmake_prediction
fromapp.jwtimportverify_token,create_access_token,ACCESS_TOKEN_EXPIRE_MINUTES
fromapp.rate_limitimportRateLimitMiddleware
fromapp.validationimportPredictionInput
fromdatetimeimporttimedelta
 
# Initialize FastAPI app
app=FastAPI()
 
 
#Skip this route if you are not implementing step 5
# Add rate limiting middleware to limit requests to 5 per minute
app.add_middleware(RateLimitMiddleware,throttle_rate=5)
 
# Root endpoint to confirm the API is running
@app.get("/")
defroot():
    return{"message":"Welcome to the Secure Machine Learning API"}
#Skip this route if you are not implementing step 4
# This endpoint issues a token when valid credentials are provided
@app.post("/token")
deflogin():
    # Define the expiration time for the token (e.g., 30 minutes)
    access_token_expires=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)    
    # Generate the JWT token
    access_token=create_access_token(data={"sub":"user"},expires_delta=access_token_expires)
    return{"access_token":access_token,"token_type":"bearer"}
# Prediction endpoint, requires a valid JWT token for authentication
# Additionally, the input data is validated using the PredictionInput model
@app.post("/predict")
defpredict(input_data:PredictionInput,token:str=Depends(verify_token)):
    prediction=make_prediction(input_data.data)
    return{"prediction":prediction}

Step 7: Dockerize the Application

Create a Dockerfile to package the app and all dependencies.

1
2
3
4
5
6
7
8
9
10
11
# Use the official Python image
FROMpython:3.10-slim
# Set working directory
WORKDIR/app
# Install dependencies
COPYrequirements.txt.
RUNpipinstall--upgradepip&&pipinstall--no-cache-dir-rrequirements.txt
# Copy the app code
COPY./app./app
# Run FastAPI app with Uvicorn
CMD["python","-m","uvicorn","app.main:app","--host","0.0.0.0","--port","8000"]

And a simple requirements.txt as:

1
2
3
4
5
6
7
8
9
10
11
scikit-learn
numpy
python-dotenv
pyjwt
aioredis
fastapi-limiter
redis
pydantic
fastapi
uvicorn
starlette

Step 8: Build and Run the Docker Container

Use the following commands to run your API:

1
2
3
# Build the Docker image and run it
dockerbuild-tsecure-ml-api.
dockerrun-p8000:8000secure-ml-api

Now your machine leanring API will be available athttp://localhost:8000.

Step 9: Test your API with Curl

For that, first, get the JWT by running the following command:

1
curl-XPOSThttp://localhost:8000/token

Copy the access token and run the following command:

1
2
3
4
curl-XPOSThttp://localhost:8000/predict \
  -H"Content-Type: application/json"\
  -H"Authorization: Bearer PASTE-TOKEN-HERE"\
  -d'{"data": [1.5, 2.3, 3.1, 0.7]}'

You should receive a prediction like:

1
{"prediction":0}

You can try different inputs to test the API.

Conclusion

Deploying ML models as secure APIs requires careful attention to authentication, validation, and scalability. By leveraging FastAPI’s speed and simplicity alongside Docker’s portability, you can create robust endpoints that safely expose your model’s predictions while protecting against misuse. This approach ensures your ML solutions are not just accurate but also reliable and secure in real-world applications.

Get a Handle on Python for Machine Learning!

Python For Machine Learning

Be More Confident to Code in Python

...from learning the practical Python tricks

Discover how in my new Ebook:
Python for Machine Learning

It providesself-study tutorials withhundreds of working code to equip you with skills including:
debugging,profiling,duck typing,decorators,deployment, and much more...

Showing You the Python Toolbox at a High Level for
Your Projects


See What's Inside

Kanwal Mehreen
No comments yet.

Leave a ReplyClick here to cancel reply.

Never miss a tutorial:


LinkedIn   Twitter   Facebook   Email Newsletter   RSS Feed

Loving the Tutorials?

ThePython for Machine Learning is where
you'll find theReally Good stuff.

>> See What's Inside


[8]ページ先頭

©2009-2025 Movatter.jp