|
10 | 10 | fromtorch.optim.lr_schedulerimportLambdaLR |
11 | 11 | fromtypingimportOptional |
12 | 12 | fromtqdm.autoimporttqdm |
13 | | -importnumpyasnp |
14 | 13 | importmatplotlib.pyplotasplt |
15 | 14 | fromtypeimportModel |
16 | 15 | importos |
17 | 16 | fromdotenvimportload_dotenv |
18 | 17 |
|
| 18 | +importnumpyasnp |
| 19 | +importpandasaspd |
| 20 | +fromtensorflow.keras.modelsimportSequential |
| 21 | +fromtensorflow.keras.layersimportLSTM,Dense,Dropout |
| 22 | +fromtensorflow.keras.optimizersimportAdam |
| 23 | +fromtensorflow.keras.utilsimportto_categorical |
| 24 | +fromtensorflow.keras.callbacksimportEarlyStopping |
| 25 | +fromtensorflow.keras.regularizersimportl2 |
| 26 | +fromsklearn.metricsimportprecision_score,recall_score,f1_score |
| 27 | +fromsklearn.model_selectionimporttrain_test_split |
| 28 | +fromsklearn.preprocessingimportStandardScaler |
| 29 | +fromsklearn.utilsimportresample |
| 30 | +fromta.momentumimportRSIIndicator |
| 31 | + |
19 | 32 |
|
20 | 33 | # Load environment variables from .env file |
21 | 34 | load_dotenv() |
@@ -314,3 +327,173 @@ def plot_roc_curve(path, labels, probs): |
314 | 327 | # Save the figure to the output directory with a unique name |
315 | 328 | plt.savefig(path) |
316 | 329 | plt.close() |
| 330 | + |
| 331 | + |
| 332 | +classLSTMModel(Model): |
| 333 | +def__init__(self,name,average_duration,test_size=0.2,batch_size=64,epochs=200,learning_rate=0.001): |
| 334 | +super().__init__(name) |
| 335 | +self.average_duration=average_duration |
| 336 | +self.test_size=test_size |
| 337 | +self.batch_size=batch_size |
| 338 | +self.epochs=epochs |
| 339 | +self.learning_rate=learning_rate |
| 340 | +self.model=None |
| 341 | +self.scaler=StandardScaler() |
| 342 | + |
| 343 | +defpreprocess_data(self,prices): |
| 344 | +""" |
| 345 | + Preprocess the data: calculate ROC, RSI, Momentum, normalize features, and split into train/test sets. |
| 346 | +
|
| 347 | + Args: |
| 348 | + prices (DataFrame): The dataframe containing price and label information. |
| 349 | +
|
| 350 | + Returns: |
| 351 | + tuple: Training and test datasets along with their labels. |
| 352 | + """ |
| 353 | +# Calculate ROC, RSI, and Momentum |
| 354 | +prices['ROC']=prices['close'].pct_change() |
| 355 | +rsi_indicator=RSIIndicator(close=prices['close'],window=self.average_duration) |
| 356 | +prices['RSI']=rsi_indicator.rsi() |
| 357 | +prices['Momentum']=prices['close'].diff(periods=self.average_duration) |
| 358 | + |
| 359 | +# Shift labels to create y_true and drop NaN rows |
| 360 | +prices['y_true']=prices['labels'].shift(-1).dropna() |
| 361 | +prices=prices.dropna() |
| 362 | + |
| 363 | +# Define features and labels |
| 364 | +features= ['ROC','RSI','Momentum','rise_over_trend','previous_window_trend'] |
| 365 | +labels=prices['y_true'].astype(int) |
| 366 | + |
| 367 | +# Normalize features |
| 368 | +prices[features]=self.scaler.fit_transform(prices[features]) |
| 369 | + |
| 370 | +# Split the data |
| 371 | +X_train,X_test,y_train,y_test=train_test_split(prices[features].values,labels.values,test_size=self.test_size,shuffle=False) |
| 372 | + |
| 373 | +# Reshape for LSTM |
| 374 | +X_train=X_train.reshape((X_train.shape[0],1,X_train.shape[1])) |
| 375 | +X_test=X_test.reshape((X_test.shape[0],1,X_test.shape[1])) |
| 376 | + |
| 377 | +returnX_train,X_test,y_train,y_test |
| 378 | + |
| 379 | +defbalance_data(self,X_train,y_train): |
| 380 | +""" |
| 381 | + Balance the training data using oversampling. |
| 382 | +
|
| 383 | + Args: |
| 384 | + X_train (array): The training feature data. |
| 385 | + y_train (array): The training labels. |
| 386 | +
|
| 387 | + Returns: |
| 388 | + tuple: Balanced training data and labels. |
| 389 | + """ |
| 390 | +# Combine features and labels into a single DataFrame for balancing |
| 391 | +train_data=np.hstack((X_train.reshape(X_train.shape[0],X_train.shape[2]),y_train.reshape(-1,1))) |
| 392 | +train_df=pd.DataFrame(train_data,columns=['ROC','RSI','Momentum','rise_over_trend','previous_window_trend','y_true']) |
| 393 | + |
| 394 | +# Separate classes and oversample |
| 395 | +class_0=train_df[train_df['y_true']==0] |
| 396 | +class_1=train_df[train_df['y_true']==1] |
| 397 | +class_2=train_df[train_df['y_true']==2] |
| 398 | + |
| 399 | +class_1_over=resample(class_1,replace=True,n_samples=len(class_0),random_state=42) |
| 400 | +class_2_over=resample(class_2,replace=True,n_samples=len(class_0),random_state=42) |
| 401 | + |
| 402 | +# Combine the oversampled classes |
| 403 | +balanced_train_df=pd.concat([class_0,class_1_over,class_2_over]) |
| 404 | + |
| 405 | +# Sort by index to maintain order and separate features and labels |
| 406 | +balanced_train_df.sort_index(inplace=True) |
| 407 | +X_train_balanced=balanced_train_df.iloc[:, :-1].values.reshape(-1,1,len(balanced_train_df.columns)-1) |
| 408 | +y_train_balanced=to_categorical(balanced_train_df['y_true'].values,num_classes=3) |
| 409 | + |
| 410 | +returnX_train_balanced,y_train_balanced |
| 411 | + |
| 412 | +defbuild_model(self,input_shape): |
| 413 | +""" |
| 414 | + Build and compile the LSTM model. |
| 415 | +
|
| 416 | + Args: |
| 417 | + input_shape (tuple): The shape of the input data. |
| 418 | + """ |
| 419 | +self.model=Sequential([ |
| 420 | +LSTM(64,input_shape=input_shape,return_sequences=True), |
| 421 | +Dropout(0.3), |
| 422 | +LSTM(32,return_sequences=False), |
| 423 | +Dropout(0.3), |
| 424 | +Dense(32,activation='relu',kernel_regularizer=l2(0.001)), |
| 425 | +Dense(3,activation='softmax') |
| 426 | + ]) |
| 427 | +self.model.compile(loss='categorical_crossentropy', |
| 428 | +optimizer=Adam(learning_rate=self.learning_rate), |
| 429 | +metrics=['accuracy']) |
| 430 | + |
| 431 | +deftrain(self,X_train,y_train,X_val,y_val): |
| 432 | +""" |
| 433 | + Train the LSTM model with early stopping. |
| 434 | +
|
| 435 | + Args: |
| 436 | + X_train (array): The training feature data. |
| 437 | + y_train (array): The training labels. |
| 438 | + X_val (array): The validation feature data. |
| 439 | + y_val (array): The validation labels. |
| 440 | + """ |
| 441 | +early_stopping=EarlyStopping(monitor='val_loss',patience=10,restore_best_weights=True) |
| 442 | +self.history=self.model.fit(X_train,y_train,epochs=self.epochs,batch_size=self.batch_size, |
| 443 | +validation_data=(X_val,y_val),callbacks=[early_stopping],verbose=1) |
| 444 | + |
| 445 | +defpredict(self,X_test): |
| 446 | +""" |
| 447 | + Predict using the trained model. |
| 448 | +
|
| 449 | + Args: |
| 450 | + X_test (array): The test feature data. |
| 451 | +
|
| 452 | + Returns: |
| 453 | + array: The predicted labels. |
| 454 | + """ |
| 455 | +y_pred=self.model.predict(X_test) |
| 456 | +returnnp.argmax(y_pred,axis=1) |
| 457 | + |
| 458 | +defevaluate(self,X_test,y_test): |
| 459 | +""" |
| 460 | + Evaluate the model performance on the test data. |
| 461 | +
|
| 462 | + Args: |
| 463 | + X_test (array): The test feature data. |
| 464 | + y_test (array): The true labels for the test data. |
| 465 | +
|
| 466 | + Returns: |
| 467 | + float: The accuracy of the model. |
| 468 | + """ |
| 469 | +loss,accuracy=self.model.evaluate(X_test,y_test,verbose=0) |
| 470 | +print(f'Test Accuracy:{accuracy:.4f}') |
| 471 | +returnaccuracy |
| 472 | + |
| 473 | +defcompute_metrics(self,y_pred,y_true): |
| 474 | +""" |
| 475 | + Compute accuracy, precision, recall, and F1 score. |
| 476 | +
|
| 477 | + Args: |
| 478 | + y_pred (array): The predicted labels. |
| 479 | + y_true (array): The true labels. |
| 480 | +
|
| 481 | + Returns: |
| 482 | + dict: The computed metrics. |
| 483 | + """ |
| 484 | +accuracy=accuracy_score(y_true,y_pred) |
| 485 | +f1=f1_score(y_true,y_pred,average='weighted') |
| 486 | +recall=recall_score(y_true,y_pred,average='weighted') |
| 487 | +precision=precision_score(y_true,y_pred,average='weighted') |
| 488 | + |
| 489 | +print(f"Accuracy:{accuracy}") |
| 490 | +print(f"F1 Score:{f1}") |
| 491 | +print(f"Recall:{recall}") |
| 492 | +print(f"Precision:{precision}") |
| 493 | + |
| 494 | +return { |
| 495 | +'accuracy':accuracy, |
| 496 | +'f1_score':f1, |
| 497 | +'recall':recall, |
| 498 | +'precision':precision |
| 499 | + } |