classAdvancedSQLiteSession(SQLiteSession):"""Enhanced SQLite session with conversation branching and usage analytics."""def__init__(self,*,session_id:str,db_path:str|Path=":memory:",create_tables:bool=False,logger:logging.Logger|None=None,**kwargs,):"""Initialize the AdvancedSQLiteSession. Args: session_id: The ID of the session db_path: The path to the SQLite database file. Defaults to `:memory:` for in-memory storage create_tables: Whether to create the structure tables logger: The logger to use. Defaults to the module logger **kwargs: Additional keyword arguments to pass to the superclass """# noqa: E501super().__init__(session_id,db_path,**kwargs)ifcreate_tables:self._init_structure_tables()self._current_branch_id="main"self._logger=loggerorlogging.getLogger(__name__)def_init_structure_tables(self):"""Add structure and usage tracking tables. Creates the message_structure and turn_usage tables with appropriate indexes for conversation branching and usage analytics. """conn=self._get_connection()# Message structure with branch supportconn.execute(""" CREATE TABLE IF NOT EXISTS message_structure ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT NOT NULL, message_id INTEGER NOT NULL, branch_id TEXT NOT NULL DEFAULT 'main', message_type TEXT NOT NULL, sequence_number INTEGER NOT NULL, user_turn_number INTEGER, branch_turn_number INTEGER, tool_name TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE, FOREIGN KEY (message_id) REFERENCES agent_messages(id) ON DELETE CASCADE ) """)# Turn-level usage tracking with branch support and full JSON detailsconn.execute(""" CREATE TABLE IF NOT EXISTS turn_usage ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT NOT NULL, branch_id TEXT NOT NULL DEFAULT 'main', user_turn_number INTEGER NOT NULL, requests INTEGER DEFAULT 0, input_tokens INTEGER DEFAULT 0, output_tokens INTEGER DEFAULT 0, total_tokens INTEGER DEFAULT 0, input_tokens_details JSON, output_tokens_details JSON, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE, UNIQUE(session_id, branch_id, user_turn_number) ) """)# Indexesconn.execute(""" CREATE INDEX IF NOT EXISTS idx_structure_session_seq ON message_structure(session_id, sequence_number) """)conn.execute(""" CREATE INDEX IF NOT EXISTS idx_structure_branch ON message_structure(session_id, branch_id) """)conn.execute(""" CREATE INDEX IF NOT EXISTS idx_structure_turn ON message_structure(session_id, branch_id, user_turn_number) """)conn.execute(""" CREATE INDEX IF NOT EXISTS idx_structure_branch_seq ON message_structure(session_id, branch_id, sequence_number) """)conn.execute(""" CREATE INDEX IF NOT EXISTS idx_turn_usage_session_turn ON turn_usage(session_id, branch_id, user_turn_number) """)conn.commit()asyncdefadd_items(self,items:list[TResponseInputItem])->None:"""Add items to the session. Args: items: The items to add to the session """# Add to base table firstawaitsuper().add_items(items)# Extract structure metadata with precise sequencingifitems:awaitself._add_structure_metadata(items)asyncdefget_items(self,limit:int|None=None,branch_id:str|None=None,)->list[TResponseInputItem]:"""Get items from current or specified branch. Args: limit: Maximum number of items to return. If None, returns all items. branch_id: Branch to get items from. If None, uses current branch. Returns: List of conversation items from the specified branch. """ifbranch_idisNone:branch_id=self._current_branch_id# Get all items for this branchdef_get_all_items_sync():"""Synchronous helper to get all items for a branch."""conn=self._get_connection()# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501withself._lockifself._is_memory_dbelsethreading.Lock():withclosing(conn.cursor())ascursor:iflimitisNone:cursor.execute(""" SELECT m.message_data FROM agent_messages m JOIN message_structure s ON m.id = s.message_id WHERE m.session_id = ? AND s.branch_id = ? ORDER BY s.sequence_number ASC """,(self.session_id,branch_id),)else:cursor.execute(""" SELECT m.message_data FROM agent_messages m JOIN message_structure s ON m.id = s.message_id WHERE m.session_id = ? AND s.branch_id = ? ORDER BY s.sequence_number DESC LIMIT ? """,(self.session_id,branch_id,limit),)rows=cursor.fetchall()iflimitisnotNone:rows=list(reversed(rows))items=[]for(message_data,)inrows:try:item=json.loads(message_data)items.append(item)exceptjson.JSONDecodeError:continuereturnitemsreturnawaitasyncio.to_thread(_get_all_items_sync)def_get_items_sync():"""Synchronous helper to get items for a specific branch."""conn=self._get_connection()# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501withself._lockifself._is_memory_dbelsethreading.Lock():withclosing(conn.cursor())ascursor:# Get message IDs in correct order for this branchiflimitisNone:cursor.execute(""" SELECT m.message_data FROM agent_messages m JOIN message_structure s ON m.id = s.message_id WHERE m.session_id = ? AND s.branch_id = ? ORDER BY s.sequence_number ASC """,(self.session_id,branch_id),)else:cursor.execute(""" SELECT m.message_data FROM agent_messages m JOIN message_structure s ON m.id = s.message_id WHERE m.session_id = ? AND s.branch_id = ? ORDER BY s.sequence_number DESC LIMIT ? """,(self.session_id,branch_id,limit),)rows=cursor.fetchall()iflimitisnotNone:rows=list(reversed(rows))items=[]for(message_data,)inrows:try:item=json.loads(message_data)items.append(item)exceptjson.JSONDecodeError:continuereturnitemsreturnawaitasyncio.to_thread(_get_items_sync)asyncdefstore_run_usage(self,result:RunResult)->None:"""Store usage data for the current conversation turn. This is designed to be called after `Runner.run()` completes. Session-level usage can be aggregated from turn data when needed. Args: result: The result from the run """try:ifresult.context_wrapper.usageisnotNone:# Get the current turn number for this branchcurrent_turn=self._get_current_turn_number()# Only update turn-level usage - session usage is aggregated on demandawaitself._update_turn_usage_internal(current_turn,result.context_wrapper.usage)exceptExceptionase:self._logger.error(f"Failed to store usage for session{self.session_id}:{e}")def_get_next_turn_number(self,branch_id:str)->int:"""Get the next turn number for a specific branch. Args: branch_id: The branch ID to get the next turn number for. Returns: The next available turn number for the specified branch. """conn=self._get_connection()withclosing(conn.cursor())ascursor:cursor.execute(""" SELECT COALESCE(MAX(user_turn_number), 0) FROM message_structure WHERE session_id = ? AND branch_id = ? """,(self.session_id,branch_id),)result=cursor.fetchone()max_turn=result[0]ifresultelse0returnmax_turn+1def_get_next_branch_turn_number(self,branch_id:str)->int:"""Get the next branch turn number for a specific branch. Args: branch_id: The branch ID to get the next branch turn number for. Returns: The next available branch turn number for the specified branch. """conn=self._get_connection()withclosing(conn.cursor())ascursor:cursor.execute(""" SELECT COALESCE(MAX(branch_turn_number), 0) FROM message_structure WHERE session_id = ? AND branch_id = ? """,(self.session_id,branch_id),)result=cursor.fetchone()max_turn=result[0]ifresultelse0returnmax_turn+1def_get_current_turn_number(self)->int:"""Get the current turn number for the current branch. Returns: The current turn number for the active branch. """conn=self._get_connection()withclosing(conn.cursor())ascursor:cursor.execute(""" SELECT COALESCE(MAX(user_turn_number), 0) FROM message_structure WHERE session_id = ? AND branch_id = ? """,(self.session_id,self._current_branch_id),)result=cursor.fetchone()returnresult[0]ifresultelse0asyncdef_add_structure_metadata(self,items:list[TResponseInputItem])->None:"""Extract structure metadata with branch-aware turn tracking. This method: - Assigns turn numbers per branch (not globally) - Assigns explicit sequence numbers for precise ordering - Links messages to their database IDs for structure tracking - Handles multiple user messages in a single batch correctly Args: items: The items to add to the session """def_add_structure_sync():"""Synchronous helper to add structure metadata to database."""conn=self._get_connection()# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501withself._lockifself._is_memory_dbelsethreading.Lock():# Get the IDs of messages we just inserted, in orderwithclosing(conn.cursor())ascursor:cursor.execute(f"SELECT id FROM{self.messages_table} "f"WHERE session_id = ? ORDER BY id DESC LIMIT ?",(self.session_id,len(items)),)message_ids=[row[0]forrowincursor.fetchall()]message_ids.reverse()# Match order of items# Get current max sequence number (global)withclosing(conn.cursor())ascursor:cursor.execute(""" SELECT COALESCE(MAX(sequence_number), 0) FROM message_structure WHERE session_id = ? """,(self.session_id,),)seq_start=cursor.fetchone()[0]# Get current turn numbers atomically with a single querywithclosing(conn.cursor())ascursor:cursor.execute(""" SELECT COALESCE(MAX(user_turn_number), 0) as max_global_turn, COALESCE(MAX(branch_turn_number), 0) as max_branch_turn FROM message_structure WHERE session_id = ? AND branch_id = ? """,(self.session_id,self._current_branch_id),)result=cursor.fetchone()current_turn=result[0]ifresultelse0current_branch_turn=result[1]ifresultelse0# Process items and assign turn numbers correctlystructure_data=[]user_message_count=0fori,(item,msg_id)inenumerate(zip(items,message_ids)):msg_type=self._classify_message_type(item)tool_name=self._extract_tool_name(item)# If this is a user message, increment turn countersifself._is_user_message(item):user_message_count+=1item_turn=current_turn+user_message_countitem_branch_turn=current_branch_turn+user_message_countelse:# Non-user messages inherit the turn number of the most recent user messageitem_turn=current_turn+user_message_countitem_branch_turn=current_branch_turn+user_message_countstructure_data.append((self.session_id,msg_id,self._current_branch_id,msg_type,seq_start+i+1,# Global sequenceitem_turn,# Global turn numberitem_branch_turn,# Branch-specific turn numbertool_name,))withclosing(conn.cursor())ascursor:cursor.executemany(""" INSERT INTO message_structure (session_id, message_id, branch_id, message_type, sequence_number, user_turn_number, branch_turn_number, tool_name) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """,structure_data,)conn.commit()try:awaitasyncio.to_thread(_add_structure_sync)exceptExceptionase:self._logger.error(f"Failed to add structure metadata for session{self.session_id}:{e}")# Try to clean up any orphaned messages to maintain consistencytry:awaitself._cleanup_orphaned_messages()exceptExceptionascleanup_error:self._logger.error(f"Failed to cleanup orphaned messages:{cleanup_error}")# Don't re-raise - structure metadata is supplementaryasyncdef_cleanup_orphaned_messages(self)->None:"""Remove messages that exist in agent_messages but not in message_structure. This can happen if _add_structure_metadata fails after super().add_items() succeeds. Used for maintaining data consistency. """def_cleanup_sync():"""Synchronous helper to cleanup orphaned messages."""conn=self._get_connection()# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501withself._lockifself._is_memory_dbelsethreading.Lock():withclosing(conn.cursor())ascursor:# Find messages without structure metadatacursor.execute(""" SELECT am.id FROM agent_messages am LEFT JOIN message_structure ms ON am.id = ms.message_id WHERE am.session_id = ? AND ms.message_id IS NULL """,(self.session_id,),)orphaned_ids=[row[0]forrowincursor.fetchall()]iforphaned_ids:# Delete orphaned messagesplaceholders=",".join("?"*len(orphaned_ids))cursor.execute(f"DELETE FROM agent_messages WHERE id IN ({placeholders})",orphaned_ids)deleted_count=cursor.rowcountconn.commit()self._logger.info(f"Cleaned up{deleted_count} orphaned messages")returndeleted_countreturn0returnawaitasyncio.to_thread(_cleanup_sync)def_classify_message_type(self,item:TResponseInputItem)->str:"""Classify the type of a message item. Args: item: The message item to classify. Returns: String representing the message type (user, assistant, etc.). """ifisinstance(item,dict):ifitem.get("role")=="user":return"user"elifitem.get("role")=="assistant":return"assistant"elifitem.get("type"):returnstr(item.get("type"))return"other"def_extract_tool_name(self,item:TResponseInputItem)->str|None:"""Extract tool name if this is a tool call/output. Args: item: The message item to extract tool name from. Returns: Tool name if item is a tool call, None otherwise. """ifisinstance(item,dict):item_type=item.get("type")# For MCP tools, try to extract from server_label if availableifitem_typein{"mcp_call","mcp_approval_request"}and"server_label"initem:server_label=item.get("server_label")tool_name=item.get("name")iftool_nameandserver_label:returnf"{server_label}.{tool_name}"elifserver_label:returnstr(server_label)eliftool_name:returnstr(tool_name)# For tool types without a 'name' field, derive from the typeelifitem_typein{"computer_call","file_search_call","web_search_call","code_interpreter_call",}:returnitem_type# Most other tool calls have a 'name' fieldelif"name"initem:name=item.get("name")returnstr(name)ifnameisnotNoneelseNonereturnNonedef_is_user_message(self,item:TResponseInputItem)->bool:"""Check if this is a user message. Args: item: The message item to check. Returns: True if the item is a user message, False otherwise. """returnisinstance(item,dict)anditem.get("role")=="user"asyncdefcreate_branch_from_turn(self,turn_number:int,branch_name:str|None=None)->str:"""Create a new branch starting from a specific user message turn. Args: turn_number: The branch turn number of the user message to branch from branch_name: Optional name for the branch (auto-generated if None) Returns: The branch_id of the newly created branch Raises: ValueError: If turn doesn't exist or doesn't contain a user message """importtime# Validate the turn exists and contains a user messagedef_validate_turn():"""Synchronous helper to validate turn exists and contains user message."""conn=self._get_connection()withclosing(conn.cursor())ascursor:cursor.execute(""" SELECT am.message_data FROM message_structure ms JOIN agent_messages am ON ms.message_id = am.id WHERE ms.session_id = ? AND ms.branch_id = ? AND ms.branch_turn_number = ? AND ms.message_type = 'user' """,(self.session_id,self._current_branch_id,turn_number),)result=cursor.fetchone()ifnotresult:raiseValueError(f"Turn{turn_number} does not contain a user message "f"in branch '{self._current_branch_id}'")message_data=result[0]try:content=json.loads(message_data).get("content","")returncontent[:50]+"..."iflen(content)>50elsecontentexceptException:return"Unable to parse content"turn_content=awaitasyncio.to_thread(_validate_turn)# Generate branch name if not providedifbranch_nameisNone:timestamp=int(time.time())branch_name=f"branch_from_turn_{turn_number}_{timestamp}"# Copy messages before the branch point to the new branchawaitself._copy_messages_to_new_branch(branch_name,turn_number)# Switch to new branchold_branch=self._current_branch_idself._current_branch_id=branch_nameself._logger.debug(f"Created branch '{branch_name}' from turn{turn_number} ('{turn_content}') in '{old_branch}'"# noqa: E501)returnbranch_nameasyncdefcreate_branch_from_content(self,search_term:str,branch_name:str|None=None)->str:"""Create branch from the first user turn matching the search term. Args: search_term: Text to search for in user messages. branch_name: Optional name for the branch (auto-generated if None). Returns: The branch_id of the newly created branch. Raises: ValueError: If no matching turns are found. """matching_turns=awaitself.find_turns_by_content(search_term)ifnotmatching_turns:raiseValueError(f"No user turns found containing '{search_term}'")# Use the first (earliest) matchturn_number=matching_turns[0]["turn"]returnawaitself.create_branch_from_turn(turn_number,branch_name)asyncdefswitch_to_branch(self,branch_id:str)->None:"""Switch to a different branch. Args: branch_id: The branch to switch to. Raises: ValueError: If the branch doesn't exist. """# Validate branch existsdef_validate_branch():"""Synchronous helper to validate branch exists."""conn=self._get_connection()withclosing(conn.cursor())ascursor:cursor.execute(""" SELECT COUNT(*) FROM message_structure WHERE session_id = ? AND branch_id = ? """,(self.session_id,branch_id),)count=cursor.fetchone()[0]ifcount==0:raiseValueError(f"Branch '{branch_id}' does not exist")awaitasyncio.to_thread(_validate_branch)old_branch=self._current_branch_idself._current_branch_id=branch_idself._logger.info(f"Switched from branch '{old_branch}' to '{branch_id}'")asyncdefdelete_branch(self,branch_id:str,force:bool=False)->None:"""Delete a branch and all its associated data. Args: branch_id: The branch to delete. force: If True, allows deleting the current branch (will switch to 'main'). Raises: ValueError: If branch doesn't exist, is 'main', or is current branch without force. """ifnotbranch_idornotbranch_id.strip():raiseValueError("Branch ID cannot be empty")branch_id=branch_id.strip()# Protect main branchifbranch_id=="main":raiseValueError("Cannot delete the 'main' branch")# Check if trying to delete current branchifbranch_id==self._current_branch_id:ifnotforce:raiseValueError(f"Cannot delete current branch '{branch_id}'. Use force=True or switch branches first"# noqa: E501)else:# Switch to main before deletingawaitself.switch_to_branch("main")def_delete_sync():"""Synchronous helper to delete branch and associated data."""conn=self._get_connection()# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501withself._lockifself._is_memory_dbelsethreading.Lock():withclosing(conn.cursor())ascursor:# First verify the branch existscursor.execute(""" SELECT COUNT(*) FROM message_structure WHERE session_id = ? AND branch_id = ? """,(self.session_id,branch_id),)count=cursor.fetchone()[0]ifcount==0:raiseValueError(f"Branch '{branch_id}' does not exist")# Delete from turn_usage first (foreign key constraint)cursor.execute(""" DELETE FROM turn_usage WHERE session_id = ? AND branch_id = ? """,(self.session_id,branch_id),)usage_deleted=cursor.rowcount# Delete from message_structurecursor.execute(""" DELETE FROM message_structure WHERE session_id = ? AND branch_id = ? """,(self.session_id,branch_id),)structure_deleted=cursor.rowcountconn.commit()returnusage_deleted,structure_deletedusage_deleted,structure_deleted=awaitasyncio.to_thread(_delete_sync)self._logger.info(f"Deleted branch '{branch_id}':{structure_deleted} message entries,{usage_deleted} usage entries"# noqa: E501)asyncdeflist_branches(self)->list[dict[str,Any]]:"""List all branches in this session. Returns: List of dicts with branch info containing: - 'branch_id': Branch identifier - 'message_count': Number of messages in branch - 'user_turns': Number of user turns in branch - 'is_current': Whether this is the current branch - 'created_at': When the branch was first created """def_list_branches_sync():"""Synchronous helper to list all branches."""conn=self._get_connection()withclosing(conn.cursor())ascursor:cursor.execute(""" SELECT ms.branch_id, COUNT(*) as message_count, COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns, MIN(ms.created_at) as created_at FROM message_structure ms WHERE ms.session_id = ? GROUP BY ms.branch_id ORDER BY created_at """,(self.session_id,),)branches=[]forrowincursor.fetchall():branch_id,msg_count,user_turns,created_at=rowbranches.append({"branch_id":branch_id,"message_count":msg_count,"user_turns":user_turns,"is_current":branch_id==self._current_branch_id,"created_at":created_at,})returnbranchesreturnawaitasyncio.to_thread(_list_branches_sync)asyncdef_copy_messages_to_new_branch(self,new_branch_id:str,from_turn_number:int)->None:"""Copy messages before the branch point to the new branch. Args: new_branch_id: The ID of the new branch to copy messages to. from_turn_number: The turn number to copy messages up to (exclusive). """def_copy_sync():"""Synchronous helper to copy messages to new branch."""conn=self._get_connection()# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501withself._lockifself._is_memory_dbelsethreading.Lock():withclosing(conn.cursor())ascursor:# Get all messages before the branch pointcursor.execute(""" SELECT ms.message_id, ms.message_type, ms.sequence_number, ms.user_turn_number, ms.branch_turn_number, ms.tool_name FROM message_structure ms WHERE ms.session_id = ? AND ms.branch_id = ? AND ms.branch_turn_number < ? ORDER BY ms.sequence_number """,(self.session_id,self._current_branch_id,from_turn_number),)messages_to_copy=cursor.fetchall()ifmessages_to_copy:# Get the max sequence number for the new insertscursor.execute(""" SELECT COALESCE(MAX(sequence_number), 0) FROM message_structure WHERE session_id = ? """,(self.session_id,),)seq_start=cursor.fetchone()[0]# Insert copied messages with new branch_idnew_structure_data=[]fori,(msg_id,msg_type,_,user_turn,branch_turn,tool_name,)inenumerate(messages_to_copy):new_structure_data.append((self.session_id,msg_id,# Same message_id (sharing the actual message data)new_branch_id,msg_type,seq_start+i+1,# New sequence numberuser_turn,# Keep same global turn numberbranch_turn,# Keep same branch turn numbertool_name,))cursor.executemany(""" INSERT INTO message_structure (session_id, message_id, branch_id, message_type, sequence_number, user_turn_number, branch_turn_number, tool_name) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """,new_structure_data,)conn.commit()awaitasyncio.to_thread(_copy_sync)asyncdefget_conversation_turns(self,branch_id:str|None=None)->list[dict[str,Any]]:"""Get user turns with content for easy browsing and branching decisions. Args: branch_id: Branch to get turns from (current branch if None). Returns: List of dicts with turn info containing: - 'turn': Branch turn number - 'content': User message content (truncated) - 'full_content': Full user message content - 'timestamp': When the turn was created - 'can_branch': Always True (all user messages can branch) """ifbranch_idisNone:branch_id=self._current_branch_iddef_get_turns_sync():"""Synchronous helper to get conversation turns."""conn=self._get_connection()withclosing(conn.cursor())ascursor:cursor.execute(""" SELECT ms.branch_turn_number, am.message_data, ms.created_at FROM message_structure ms JOIN agent_messages am ON ms.message_id = am.id WHERE ms.session_id = ? AND ms.branch_id = ? AND ms.message_type = 'user' ORDER BY ms.branch_turn_number """,(self.session_id,branch_id),)turns=[]forrowincursor.fetchall():turn_num,message_data,created_at=rowtry:content=json.loads(message_data).get("content","")turns.append({"turn":turn_num,"content":content[:100]+"..."iflen(content)>100elsecontent,"full_content":content,"timestamp":created_at,"can_branch":True,})except(json.JSONDecodeError,AttributeError):continuereturnturnsreturnawaitasyncio.to_thread(_get_turns_sync)asyncdeffind_turns_by_content(self,search_term:str,branch_id:str|None=None)->list[dict[str,Any]]:"""Find user turns containing specific content. Args: search_term: Text to search for in user messages. branch_id: Branch to search in (current branch if None). Returns: List of matching turns with same format as get_conversation_turns(). """ifbranch_idisNone:branch_id=self._current_branch_iddef_search_sync():"""Synchronous helper to search turns by content."""conn=self._get_connection()withclosing(conn.cursor())ascursor:cursor.execute(""" SELECT ms.branch_turn_number, am.message_data, ms.created_at FROM message_structure ms JOIN agent_messages am ON ms.message_id = am.id WHERE ms.session_id = ? AND ms.branch_id = ? AND ms.message_type = 'user' AND am.message_data LIKE ? ORDER BY ms.branch_turn_number """,(self.session_id,branch_id,f"%{search_term}%"),)matches=[]forrowincursor.fetchall():turn_num,message_data,created_at=rowtry:content=json.loads(message_data).get("content","")matches.append({"turn":turn_num,"content":content,"full_content":content,"timestamp":created_at,"can_branch":True,})except(json.JSONDecodeError,AttributeError):continuereturnmatchesreturnawaitasyncio.to_thread(_search_sync)asyncdefget_conversation_by_turns(self,branch_id:str|None=None)->dict[int,list[dict[str,str|None]]]:"""Get conversation grouped by user turns for specified branch. Args: branch_id: Branch to get conversation from (current branch if None). Returns: Dictionary mapping turn numbers to lists of message metadata. """ifbranch_idisNone:branch_id=self._current_branch_iddef_get_conversation_sync():"""Synchronous helper to get conversation by turns."""conn=self._get_connection()withclosing(conn.cursor())ascursor:cursor.execute(""" SELECT user_turn_number, message_type, tool_name FROM message_structure WHERE session_id = ? AND branch_id = ? ORDER BY sequence_number """,(self.session_id,branch_id),)turns:dict[int,list[dict[str,str|None]]]={}forrowincursor.fetchall():turn_num,msg_type,tool_name=rowifturn_numnotinturns:turns[turn_num]=[]turns[turn_num].append({"type":msg_type,"tool_name":tool_name})returnturnsreturnawaitasyncio.to_thread(_get_conversation_sync)asyncdefget_tool_usage(self,branch_id:str|None=None)->list[tuple[str,int,int]]:"""Get all tool usage by turn for specified branch. Args: branch_id: Branch to get tool usage from (current branch if None). Returns: List of tuples containing (tool_name, usage_count, turn_number). """ifbranch_idisNone:branch_id=self._current_branch_iddef_get_tool_usage_sync():"""Synchronous helper to get tool usage statistics."""conn=self._get_connection()withclosing(conn.cursor())ascursor:cursor.execute(""" SELECT tool_name, COUNT(*), user_turn_number FROM message_structure WHERE session_id = ? AND branch_id = ? AND message_type IN ( 'tool_call', 'function_call', 'computer_call', 'file_search_call', 'web_search_call', 'code_interpreter_call', 'custom_tool_call', 'mcp_call', 'mcp_approval_request' ) GROUP BY tool_name, user_turn_number ORDER BY user_turn_number """,(self.session_id,branch_id),)returncursor.fetchall()returnawaitasyncio.to_thread(_get_tool_usage_sync)asyncdefget_session_usage(self,branch_id:str|None=None)->dict[str,int]|None:"""Get cumulative usage for session or specific branch. Args: branch_id: If provided, only get usage for that branch. If None, get all branches. Returns: Dictionary with usage statistics or None if no usage data found. """def_get_usage_sync():"""Synchronous helper to get session usage data."""conn=self._get_connection()# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501withself._lockifself._is_memory_dbelsethreading.Lock():ifbranch_id:# Branch-specific usagequery=""" SELECT SUM(requests) as total_requests, SUM(input_tokens) as total_input_tokens, SUM(output_tokens) as total_output_tokens, SUM(total_tokens) as total_total_tokens, COUNT(*) as total_turns FROM turn_usage WHERE session_id = ? AND branch_id = ? """params:tuple[str,...]=(self.session_id,branch_id)else:# All branchesquery=""" SELECT SUM(requests) as total_requests, SUM(input_tokens) as total_input_tokens, SUM(output_tokens) as total_output_tokens, SUM(total_tokens) as total_total_tokens, COUNT(*) as total_turns FROM turn_usage WHERE session_id = ? """params=(self.session_id,)withclosing(conn.cursor())ascursor:cursor.execute(query,params)row=cursor.fetchone()ifrowandrow[0]isnotNone:return{"requests":row[0]or0,"input_tokens":row[1]or0,"output_tokens":row[2]or0,"total_tokens":row[3]or0,"total_turns":row[4]or0,}returnNoneresult=awaitasyncio.to_thread(_get_usage_sync)returncast(Union[dict[str,int],None],result)asyncdefget_turn_usage(self,user_turn_number:int|None=None,branch_id:str|None=None,)->list[dict[str,Any]]|dict[str,Any]:"""Get usage statistics by turn with full JSON token details. Args: user_turn_number: Specific turn to get usage for. If None, returns all turns. branch_id: Branch to get usage from (current branch if None). Returns: Dictionary with usage data for specific turn, or list of dictionaries for all turns. """ifbranch_idisNone:branch_id=self._current_branch_iddef_get_turn_usage_sync():"""Synchronous helper to get turn usage statistics."""conn=self._get_connection()ifuser_turn_numberisnotNone:query=""" SELECT requests, input_tokens, output_tokens, total_tokens, input_tokens_details, output_tokens_details FROM turn_usage WHERE session_id = ? AND branch_id = ? AND user_turn_number = ? """withclosing(conn.cursor())ascursor:cursor.execute(query,(self.session_id,branch_id,user_turn_number))row=cursor.fetchone()ifrow:# Parse JSON details if presentinput_details=Noneoutput_details=Noneifrow[4]:# input_tokens_detailstry:input_details=json.loads(row[4])exceptjson.JSONDecodeError:passifrow[5]:# output_tokens_detailstry:output_details=json.loads(row[5])exceptjson.JSONDecodeError:passreturn{"requests":row[0],"input_tokens":row[1],"output_tokens":row[2],"total_tokens":row[3],"input_tokens_details":input_details,"output_tokens_details":output_details,}return{}else:query=""" SELECT user_turn_number, requests, input_tokens, output_tokens, total_tokens, input_tokens_details, output_tokens_details FROM turn_usage WHERE session_id = ? AND branch_id = ? ORDER BY user_turn_number """withclosing(conn.cursor())ascursor:cursor.execute(query,(self.session_id,branch_id))results=[]forrowincursor.fetchall():# Parse JSON details if presentinput_details=Noneoutput_details=Noneifrow[5]:# input_tokens_detailstry:input_details=json.loads(row[5])exceptjson.JSONDecodeError:passifrow[6]:# output_tokens_detailstry:output_details=json.loads(row[6])exceptjson.JSONDecodeError:passresults.append({"user_turn_number":row[0],"requests":row[1],"input_tokens":row[2],"output_tokens":row[3],"total_tokens":row[4],"input_tokens_details":input_details,"output_tokens_details":output_details,})returnresultsresult=awaitasyncio.to_thread(_get_turn_usage_sync)returncast(Union[list[dict[str,Any]],dict[str,Any]],result)asyncdef_update_turn_usage_internal(self,user_turn_number:int,usage_data:Usage)->None:"""Internal method to update usage for a specific turn with full JSON details. Args: user_turn_number: The turn number to update usage for. usage_data: The usage data to store. """def_update_sync():"""Synchronous helper to update turn usage data."""conn=self._get_connection()# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501withself._lockifself._is_memory_dbelsethreading.Lock():# Serialize token details as JSONinput_details_json=Noneoutput_details_json=Noneifhasattr(usage_data,"input_tokens_details")andusage_data.input_tokens_details:try:input_details_json=json.dumps(usage_data.input_tokens_details.__dict__)except(TypeError,ValueError)ase:self._logger.warning(f"Failed to serialize input tokens details:{e}")input_details_json=Noneif(hasattr(usage_data,"output_tokens_details")andusage_data.output_tokens_details):try:output_details_json=json.dumps(usage_data.output_tokens_details.__dict__)except(TypeError,ValueError)ase:self._logger.warning(f"Failed to serialize output tokens details:{e}")output_details_json=Nonewithclosing(conn.cursor())ascursor:cursor.execute(""" INSERT OR REPLACE INTO turn_usage (session_id, branch_id, user_turn_number, requests, input_tokens, output_tokens, total_tokens, input_tokens_details, output_tokens_details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """,# noqa: E501(self.session_id,self._current_branch_id,user_turn_number,usage_data.requestsor0,usage_data.input_tokensor0,usage_data.output_tokensor0,usage_data.total_tokensor0,input_details_json,output_details_json,),)conn.commit()awaitasyncio.to_thread(_update_sync)