Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from pathlib import Path | |
| import asyncio | |
| import google.generativeai as genai | |
| import os | |
| import logging | |
| from dotenv import load_dotenv | |
| from typing import Optional, Tuple | |
| from flashcard import FlashcardSet | |
| from chat_agent import ( | |
| chat_agent, | |
| ChatDeps, | |
| ChatResponse | |
| ) | |
| # Load environment variables | |
| load_dotenv() | |
| genai.configure(api_key=os.environ["GEMINI_API_KEY"]) | |
| async def process_message(message: dict, history: list, current_flashcards: Optional[FlashcardSet]) -> Tuple[str, list, Optional[FlashcardSet]]: | |
| """Process uploaded files and chat messages""" | |
| # Get any text provided with the upload as system prompt | |
| user_text = message.get("text", "").strip() | |
| # Create chat dependencies | |
| deps = ChatDeps( | |
| message=user_text, | |
| current_flashcards=current_flashcards | |
| ) | |
| # Handle file uploads | |
| if message.get("files"): | |
| for file_path in message["files"]: | |
| if file_path.endswith('.pdf'): | |
| try: | |
| with open(file_path, "rb") as pdf_file: | |
| deps.pdf_data = pdf_file.read() | |
| deps.system_prompt = user_text if user_text else None | |
| # Let chat agent handle the PDF upload | |
| result = await chat_agent.run("Process this PDF upload", deps=deps) | |
| if result.data.should_generate_flashcards: | |
| # Update current flashcards | |
| current_flashcards = result.data.flashcards | |
| history.append([ | |
| f"Uploaded: {Path(file_path).name}" + | |
| (f"\nWith instructions: {user_text}" if user_text else ""), | |
| result.data.response | |
| ]) | |
| return "", history, current_flashcards | |
| except Exception as e: | |
| error_msg = f"Error processing PDF: {str(e)}" | |
| logging.error(error_msg) | |
| history.append([f"Uploaded: {Path(file_path).name}", error_msg]) | |
| return "", history, current_flashcards | |
| else: | |
| history.append([f"Uploaded: {Path(file_path).name}", "Please upload a PDF file."]) | |
| return "", history, current_flashcards | |
| # Handle text messages | |
| if user_text: | |
| try: | |
| result = await chat_agent.run(user_text, deps=deps) | |
| # Update flashcards if modified | |
| if result.data.should_modify_flashcards: | |
| current_flashcards = result.data.flashcards | |
| history.append([user_text, result.data.response]) | |
| return "", history, current_flashcards | |
| except Exception as e: | |
| error_msg = f"Error processing request: {str(e)}" | |
| logging.error(error_msg) | |
| history.append([user_text, error_msg]) | |
| return "", history, current_flashcards | |
| history.append(["", "Please upload a PDF file or send a message."]) | |
| return "", history, current_flashcards | |
| async def clear_chat(): | |
| """Reset the conversation and clear current flashcards""" | |
| return None, None, None | |
| # Create Gradio interface | |
| with gr.Blocks(title="PDF Flashcard Generator") as demo: | |
| gr.Markdown(""" | |
| # π PDF Flashcard Generator | |
| Upload a PDF document and get AI-generated flashcards to help you study! | |
| You can provide custom instructions along with your PDF upload to guide the flashcard generation. | |
| Powered by Google's Gemini AI | |
| """) | |
| chatbot = gr.Chatbot( | |
| label="Flashcard Generation Chat", | |
| bubble_full_width=False, | |
| show_copy_button=True, | |
| height=600 | |
| ) | |
| # Session state for flashcards | |
| current_flashcards = gr.State(value=None) | |
| with gr.Row(): | |
| chat_input = gr.MultimodalTextbox( | |
| label="Upload PDF or type a message", | |
| placeholder="Drop a PDF file here. You can also add instructions for how the flashcards should be generated...", | |
| file_types=[".pdf", "application/pdf", "pdf"], | |
| show_label=False, | |
| sources=["upload"], | |
| scale=20, | |
| min_width=100 | |
| ) | |
| clear_btn = gr.Button("ποΈ", variant="secondary", scale=1, min_width=50) | |
| chat_input.submit( | |
| fn=process_message, | |
| inputs=[chat_input, chatbot, current_flashcards], | |
| outputs=[chat_input, chatbot, current_flashcards] | |
| ) | |
| clear_btn.click( | |
| fn=clear_chat, | |
| inputs=[], | |
| outputs=[chat_input, chatbot, current_flashcards] | |
| ) | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| demo.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) | |