22
33import argparse
44import json
5+ import sys
56from pathlib import Path
67
78from langchain_core .messages import AIMessage ,HumanMessage
1112from dive_mcp_host .host .conf import HostConfig
1213from dive_mcp_host .host .host import DiveMcpHost
1314
15+ # Default paths for CLI
16+ CLI_DATA_DIR = Path .home ()/ ".dive_mcp_host"
17+ CHECKPOINTER_PATH = CLI_DATA_DIR / "checkpoints.db"
18+
1419
1520def parse_query (args :type [CLIArgs ])-> HumanMessage :
1621"""Parse the query from the command line arguments."""
@@ -75,7 +80,16 @@ def setup_argument_parser() -> type[CLIArgs]:
7580def load_config (config_path :str )-> HostConfig :
7681"""Load the configuration."""
7782with Path (config_path ).open ("r" )as f :
78- return HostConfig .model_validate_json (f .read ())
83+ config_data = json .load (f )
84+
85+ # Add default checkpointer if not present
86+ if "checkpointer" not in config_data :
87+ CHECKPOINTER_PATH .parent .mkdir (parents = True ,exist_ok = True )
88+ config_data ["checkpointer" ]= {
89+ "uri" :f"sqlite:///{ CHECKPOINTER_PATH } "
90+ }
91+
92+ return HostConfig .model_validate (config_data )
7993
8094
8195def load_merged_config (mcp_config_path :str ,model_config_path :str )-> HostConfig :
@@ -105,10 +119,16 @@ def load_merged_config(mcp_config_path: str, model_config_path: str) -> HostConf
105119server_config_with_name = {** server_config ,"name" :server_name }
106120mcp_servers [server_name ]= server_config_with_name
107121
122+ # Setup default checkpointer for CLI (use sqlite in home directory)
123+ CHECKPOINTER_PATH .parent .mkdir (parents = True ,exist_ok = True )
124+
108125# Merge configs
109126merged_config = {
110127"llm" :active_config ,
111- "mcp_servers" :mcp_servers
128+ "mcp_servers" :mcp_servers ,
129+ "checkpointer" : {
130+ "uri" :f"sqlite:///{ CHECKPOINTER_PATH } "
131+ }
112132 }
113133
114134return HostConfig .model_validate (merged_config )
@@ -117,7 +137,11 @@ def load_merged_config(mcp_config_path: str, model_config_path: str) -> HostConf
117137async def run ()-> None :
118138"""dive_mcp_host CLI entrypoint."""
119139args = setup_argument_parser ()
120- query = parse_query (args )
140+
141+ # Get initial query if provided
142+ initial_query = None
143+ if args .query :
144+ initial_query = parse_query (args )
121145
122146# Load config based on provided arguments
123147if args .config_path :
@@ -151,23 +175,66 @@ async def run() -> None:
151175system_prompt = f .read ()
152176
153177output_parser = StrOutputParser ()
154- async with DiveMcpHost (config )as mcp_host :
155- print ("Waiting for tools to initialize..." )
156- await mcp_host .tools_initialized_event .wait ()
157- print ("Tools initialized" )
158- chat = mcp_host .chat (chat_id = current_chat_id ,system_prompt = system_prompt )
159- current_chat_id = chat .chat_id
160- async with chat :
161- async for response in chat .query (query ,stream_mode = "messages" ):
162- assert isinstance (response ,tuple )
163- msg = response [0 ]
164- if isinstance (msg ,AIMessage ):
165- content = output_parser .invoke (msg )
166- print (content ,end = "" )
167- continue
168- print (f"\n \n ==== Start Of{ type (msg )} ===" )
169- print (msg )
170- print (f"==== End Of{ type (msg )} ===\n " )
171-
172- print ()
173- print (f"Chat ID:{ current_chat_id } " )
178+
179+ try :
180+ async with DiveMcpHost (config )as mcp_host :
181+ print ("Waiting for tools to initialize..." )
182+ await mcp_host .tools_initialized_event .wait ()
183+ print ("Tools initialized" )
184+ print ("=" * 60 )
185+
186+ chat = mcp_host .chat (chat_id = current_chat_id ,system_prompt = system_prompt )
187+ current_chat_id = chat .chat_id
188+
189+ async with chat :
190+ # Process initial query if provided
191+ if initial_query :
192+ await process_query (chat ,initial_query ,output_parser )
193+
194+ # Start interactive chat loop
195+ print ("\n Chat started. Type 'exit' or press Ctrl-C to quit." )
196+ print (f"Chat ID:{ current_chat_id } " )
197+ print ("=" * 60 )
198+
199+ while True :
200+ try :
201+ # Read user input
202+ user_input = input ("\n You: " ).strip ()
203+
204+ if not user_input :
205+ continue
206+
207+ # Check for exit commands
208+ if user_input .lower ()in ["exit" ,"quit" ]:
209+ print ("\n Goodbye!" )
210+ break
211+
212+ # Process the query
213+ query = HumanMessage (content = user_input )
214+ print ("\n Assistant: " ,end = "" )
215+ await process_query (chat ,query ,output_parser )
216+
217+ except EOFError :
218+ # Handle Ctrl-D
219+ print ("\n \n Goodbye!" )
220+ break
221+
222+ except KeyboardInterrupt :
223+ # Handle Ctrl-C
224+ print ("\n \n Goodbye!" )
225+ sys .exit (0 )
226+
227+
228+ async def process_query (chat ,query :HumanMessage ,output_parser :StrOutputParser )-> None :
229+ """Process a single query and print the response."""
230+ async for response in chat .query (query ,stream_mode = "messages" ):
231+ assert isinstance (response ,tuple )
232+ msg = response [0 ]
233+ if isinstance (msg ,AIMessage ):
234+ content = output_parser .invoke (msg )
235+ print (content ,end = "" )
236+ continue
237+ print (f"\n \n ==== Start Of{ type (msg )} ===" )
238+ print (msg )
239+ print (f"==== End Of{ type (msg )} ===\n " )
240+ print ()# Add newline after response