Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commita721c90

Browse files
committed
[feat]: add model.py modifications for lstm model class implementation
1 parentf5e559f commita721c90

File tree

1 file changed

+184
-1
lines changed

1 file changed

+184
-1
lines changed

‎src/model.py‎

Lines changed: 184 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,25 @@
1010
fromtorch.optim.lr_schedulerimportLambdaLR
1111
fromtypingimportOptional
1212
fromtqdm.autoimporttqdm
13-
importnumpyasnp
1413
importmatplotlib.pyplotasplt
1514
fromtypeimportModel
1615
importos
1716
fromdotenvimportload_dotenv
1817

18+
importnumpyasnp
19+
importpandasaspd
20+
fromtensorflow.keras.modelsimportSequential
21+
fromtensorflow.keras.layersimportLSTM,Dense,Dropout
22+
fromtensorflow.keras.optimizersimportAdam
23+
fromtensorflow.keras.utilsimportto_categorical
24+
fromtensorflow.keras.callbacksimportEarlyStopping
25+
fromtensorflow.keras.regularizersimportl2
26+
fromsklearn.metricsimportprecision_score,recall_score,f1_score
27+
fromsklearn.model_selectionimporttrain_test_split
28+
fromsklearn.preprocessingimportStandardScaler
29+
fromsklearn.utilsimportresample
30+
fromta.momentumimportRSIIndicator
31+
1932

2033
# Load environment variables from .env file
2134
load_dotenv()
@@ -314,3 +327,173 @@ def plot_roc_curve(path, labels, probs):
314327
# Save the figure to the output directory with a unique name
315328
plt.savefig(path)
316329
plt.close()
330+
331+
332+
classLSTMModel(Model):
333+
def__init__(self,name,average_duration,test_size=0.2,batch_size=64,epochs=200,learning_rate=0.001):
334+
super().__init__(name)
335+
self.average_duration=average_duration
336+
self.test_size=test_size
337+
self.batch_size=batch_size
338+
self.epochs=epochs
339+
self.learning_rate=learning_rate
340+
self.model=None
341+
self.scaler=StandardScaler()
342+
343+
defpreprocess_data(self,prices):
344+
"""
345+
Preprocess the data: calculate ROC, RSI, Momentum, normalize features, and split into train/test sets.
346+
347+
Args:
348+
prices (DataFrame): The dataframe containing price and label information.
349+
350+
Returns:
351+
tuple: Training and test datasets along with their labels.
352+
"""
353+
# Calculate ROC, RSI, and Momentum
354+
prices['ROC']=prices['close'].pct_change()
355+
rsi_indicator=RSIIndicator(close=prices['close'],window=self.average_duration)
356+
prices['RSI']=rsi_indicator.rsi()
357+
prices['Momentum']=prices['close'].diff(periods=self.average_duration)
358+
359+
# Shift labels to create y_true and drop NaN rows
360+
prices['y_true']=prices['labels'].shift(-1).dropna()
361+
prices=prices.dropna()
362+
363+
# Define features and labels
364+
features= ['ROC','RSI','Momentum','rise_over_trend','previous_window_trend']
365+
labels=prices['y_true'].astype(int)
366+
367+
# Normalize features
368+
prices[features]=self.scaler.fit_transform(prices[features])
369+
370+
# Split the data
371+
X_train,X_test,y_train,y_test=train_test_split(prices[features].values,labels.values,test_size=self.test_size,shuffle=False)
372+
373+
# Reshape for LSTM
374+
X_train=X_train.reshape((X_train.shape[0],1,X_train.shape[1]))
375+
X_test=X_test.reshape((X_test.shape[0],1,X_test.shape[1]))
376+
377+
returnX_train,X_test,y_train,y_test
378+
379+
defbalance_data(self,X_train,y_train):
380+
"""
381+
Balance the training data using oversampling.
382+
383+
Args:
384+
X_train (array): The training feature data.
385+
y_train (array): The training labels.
386+
387+
Returns:
388+
tuple: Balanced training data and labels.
389+
"""
390+
# Combine features and labels into a single DataFrame for balancing
391+
train_data=np.hstack((X_train.reshape(X_train.shape[0],X_train.shape[2]),y_train.reshape(-1,1)))
392+
train_df=pd.DataFrame(train_data,columns=['ROC','RSI','Momentum','rise_over_trend','previous_window_trend','y_true'])
393+
394+
# Separate classes and oversample
395+
class_0=train_df[train_df['y_true']==0]
396+
class_1=train_df[train_df['y_true']==1]
397+
class_2=train_df[train_df['y_true']==2]
398+
399+
class_1_over=resample(class_1,replace=True,n_samples=len(class_0),random_state=42)
400+
class_2_over=resample(class_2,replace=True,n_samples=len(class_0),random_state=42)
401+
402+
# Combine the oversampled classes
403+
balanced_train_df=pd.concat([class_0,class_1_over,class_2_over])
404+
405+
# Sort by index to maintain order and separate features and labels
406+
balanced_train_df.sort_index(inplace=True)
407+
X_train_balanced=balanced_train_df.iloc[:, :-1].values.reshape(-1,1,len(balanced_train_df.columns)-1)
408+
y_train_balanced=to_categorical(balanced_train_df['y_true'].values,num_classes=3)
409+
410+
returnX_train_balanced,y_train_balanced
411+
412+
defbuild_model(self,input_shape):
413+
"""
414+
Build and compile the LSTM model.
415+
416+
Args:
417+
input_shape (tuple): The shape of the input data.
418+
"""
419+
self.model=Sequential([
420+
LSTM(64,input_shape=input_shape,return_sequences=True),
421+
Dropout(0.3),
422+
LSTM(32,return_sequences=False),
423+
Dropout(0.3),
424+
Dense(32,activation='relu',kernel_regularizer=l2(0.001)),
425+
Dense(3,activation='softmax')
426+
])
427+
self.model.compile(loss='categorical_crossentropy',
428+
optimizer=Adam(learning_rate=self.learning_rate),
429+
metrics=['accuracy'])
430+
431+
deftrain(self,X_train,y_train,X_val,y_val):
432+
"""
433+
Train the LSTM model with early stopping.
434+
435+
Args:
436+
X_train (array): The training feature data.
437+
y_train (array): The training labels.
438+
X_val (array): The validation feature data.
439+
y_val (array): The validation labels.
440+
"""
441+
early_stopping=EarlyStopping(monitor='val_loss',patience=10,restore_best_weights=True)
442+
self.history=self.model.fit(X_train,y_train,epochs=self.epochs,batch_size=self.batch_size,
443+
validation_data=(X_val,y_val),callbacks=[early_stopping],verbose=1)
444+
445+
defpredict(self,X_test):
446+
"""
447+
Predict using the trained model.
448+
449+
Args:
450+
X_test (array): The test feature data.
451+
452+
Returns:
453+
array: The predicted labels.
454+
"""
455+
y_pred=self.model.predict(X_test)
456+
returnnp.argmax(y_pred,axis=1)
457+
458+
defevaluate(self,X_test,y_test):
459+
"""
460+
Evaluate the model performance on the test data.
461+
462+
Args:
463+
X_test (array): The test feature data.
464+
y_test (array): The true labels for the test data.
465+
466+
Returns:
467+
float: The accuracy of the model.
468+
"""
469+
loss,accuracy=self.model.evaluate(X_test,y_test,verbose=0)
470+
print(f'Test Accuracy:{accuracy:.4f}')
471+
returnaccuracy
472+
473+
defcompute_metrics(self,y_pred,y_true):
474+
"""
475+
Compute accuracy, precision, recall, and F1 score.
476+
477+
Args:
478+
y_pred (array): The predicted labels.
479+
y_true (array): The true labels.
480+
481+
Returns:
482+
dict: The computed metrics.
483+
"""
484+
accuracy=accuracy_score(y_true,y_pred)
485+
f1=f1_score(y_true,y_pred,average='weighted')
486+
recall=recall_score(y_true,y_pred,average='weighted')
487+
precision=precision_score(y_true,y_pred,average='weighted')
488+
489+
print(f"Accuracy:{accuracy}")
490+
print(f"F1 Score:{f1}")
491+
print(f"Recall:{recall}")
492+
print(f"Precision:{precision}")
493+
494+
return {
495+
'accuracy':accuracy,
496+
'f1_score':f1,
497+
'recall':recall,
498+
'precision':precision
499+
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp