Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import asyncio | |
| import tokonomics | |
| from utils import create_model_hierarchy | |
| from utils_on import analyze_hf_model # New import for On Premise Estimator functionality | |
| st.set_page_config(page_title="LLM Pricing Calculator", layout="wide") | |
| # -------------------------- | |
| # Async Data Loading Function | |
| # -------------------------- | |
| async def load_data(): | |
| """Simulate loading data asynchronously.""" | |
| AVAILABLE_MODELS = await tokonomics.get_available_models() | |
| hierarchy = create_model_hierarchy(AVAILABLE_MODELS) | |
| FILTERED_MODELS = [] | |
| MODEL_PRICING = {} | |
| PROVIDERS = list(hierarchy.keys()) | |
| for provider in PROVIDERS: | |
| for model_family in hierarchy[provider]: | |
| for model_version in hierarchy[provider][model_family].keys(): | |
| for region in hierarchy[provider][model_family][model_version]: | |
| model_id = hierarchy[provider][model_family][model_version][region] | |
| MODEL_PRICING[model_id] = await tokonomics.get_model_costs(model_id) | |
| FILTERED_MODELS.append(model_id) | |
| return FILTERED_MODELS, MODEL_PRICING, PROVIDERS | |
| # -------------------------- | |
| # Provider Change Function | |
| # -------------------------- | |
| def provider_change(provider, selected_type, all_types=["text", "vision", "video", "image"]): | |
| """Filter models based on the selected provider and type.""" | |
| all_models = st.session_state.get("models", []) | |
| new_models = [] | |
| others = [a_type for a_type in all_types if selected_type != a_type] | |
| for model_name in all_models: | |
| if provider in model_name: | |
| if selected_type in model_name: | |
| new_models.append(model_name) | |
| elif any(other in model_name for other in others): | |
| continue | |
| else: | |
| new_models.append(model_name) | |
| return new_models if new_models else all_models | |
| # -------------------------- | |
| # Estimate Cost Function | |
| # -------------------------- | |
| def estimate_cost(num_alerts, input_size, output_size, model_id): | |
| pricing = st.session_state.get("pricing", {}) | |
| cost_token = pricing.get(model_id) | |
| if not cost_token: | |
| return "NA" | |
| input_tokens = round(input_size * 1.3) | |
| output_tokens = round(output_size * 1.3) | |
| price_day = cost_token.get("input_cost_per_token", 0) * input_tokens + \ | |
| cost_token.get("output_cost_per_token", 0) * output_tokens | |
| price_total = price_day * num_alerts | |
| return f"""## Estimated Cost: | |
| Day Price: {price_total:0.2f} USD | |
| Month Price: {price_total * 31:0.2f} USD | |
| Year Price: {price_total * 365:0.2f} USD | |
| """ | |
| # -------------------------- | |
| # Load Data into Session State (only once) | |
| # -------------------------- | |
| if "data_loaded" not in st.session_state: | |
| with st.spinner("Loading pricing data..."): | |
| models, pricing, providers = asyncio.run(load_data()) | |
| st.session_state["models"] = models | |
| st.session_state["pricing"] = pricing | |
| st.session_state["providers"] = providers | |
| st.session_state["data_loaded"] = True | |
| # -------------------------- | |
| # Sidebar | |
| # -------------------------- | |
| with st.sidebar: | |
| st.image("https://cdn.prod.website-files.com/630f558f2a15ca1e88a2f774/631f1436ad7a0605fecc5e15_Logo.svg", | |
| use_container_width=True) | |
| st.markdown("Visit: [https://www.priam.ai](https://www.priam.ai)") | |
| st.divider() | |
| st.sidebar.title("LLM Pricing Calculator") | |
| # -------------------------- | |
| # Pills Navigation (Using st.pills) | |
| # -------------------------- | |
| # st.pills creates a pill-style selection widget. | |
| page = st.pills("Head", | |
| options=["Model Selection", "On Premise Estimator", "About"],selection_mode="single",default="Model Selection",label_visibility="hidden", | |
| #index=0 # Change index if you want a different default | |
| ) | |
| # -------------------------- | |
| # Helper: Format Analysis Report | |
| # -------------------------- | |
| def format_analysis_report(analysis_result: dict) -> str: | |
| """Convert the raw analysis_result dict into a human-readable report.""" | |
| if "error" in analysis_result: | |
| return f"**Error:** {analysis_result['error']}" | |
| lines = [] | |
| lines.append(f"### Model Analysis Report for `{analysis_result.get('model_id', 'Unknown Model')}`\n") | |
| lines.append(f"**Parameter Size:** {analysis_result.get('parameter_size', 'N/A')} Billion parameters\n") | |
| lines.append(f"**Precision:** {analysis_result.get('precision', 'N/A')}\n") | |
| vram = analysis_result.get("vram_requirements", {}) | |
| lines.append("#### VRAM Requirements:") | |
| lines.append(f"- Model Size: {vram.get('model_size_gb', 0):.2f} GB") | |
| lines.append(f"- KV Cache: {vram.get('kv_cache_gb', 0):.2f} GB") | |
| lines.append(f"- Activations: {vram.get('activations_gb', 0):.2f} GB") | |
| lines.append(f"- Overhead: {vram.get('overhead_gb', 0):.2f} GB") | |
| lines.append(f"- **Total VRAM:** {vram.get('total_vram_gb', 0):.2f} GB\n") | |
| compatible_gpus = analysis_result.get("compatible_gpus", []) | |
| lines.append("#### Compatible GPUs:") | |
| if compatible_gpus: | |
| for gpu in compatible_gpus: | |
| lines.append(f"- {gpu}") | |
| else: | |
| lines.append("- None found") | |
| lines.append(f"\n**Largest Compatible GPU:** {analysis_result.get('largest_compatible_gpu', 'N/A')}\n") | |
| #gpu_perf = analysis_result.get("gpu_performance", {}) | |
| #if gpu_perf: | |
| # lines.append("#### GPU Performance:") | |
| # for gpu, perf in gpu_perf.items(): | |
| # lines.append(f"**{gpu}:**") | |
| # lines.append(f" - Tokens per Second: {perf.get('tokens_per_second', 0):.2f}") | |
| # lines.append(f" - FLOPs per Token: {perf.get('flops_per_token', 0):.2f}") | |
| # lines.append(f" - Effective TFLOPS: {perf.get('effective_tflops', 0):.2f}\n") | |
| #else: | |
| # lines.append("#### GPU Performance: N/A\n") | |
| return "\n".join(lines) | |
| # -------------------------- | |
| # Render Content Based on Selected Pill | |
| # -------------------------- | |
| if page == "Model Selection": | |
| st.divider() | |
| st.header("LLM Pricing App") | |
| # --- Row 1: Provider/Type and Model Selection --- | |
| col_left, col_right = st.columns(2) | |
| with col_left: | |
| selected_provider = st.selectbox( | |
| "Select a provider", | |
| st.session_state["providers"], | |
| index=st.session_state["providers"].index("azure") if "azure" in st.session_state["providers"] else 0 | |
| ) | |
| selected_type = st.radio("Select type", options=["text", "image"], index=0) | |
| with col_right: | |
| filtered_models = provider_change(selected_provider, selected_type) | |
| if filtered_models: | |
| default_model = "o1" if "o1" in filtered_models else filtered_models[0] | |
| selected_model = st.selectbox("Select a model", options=filtered_models, index=filtered_models.index(default_model)) | |
| else: | |
| selected_model = None | |
| st.write("No models available") | |
| # --- Row 2: Alert Stats --- | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| num_alerts = st.number_input("Security Alerts Per Day", value=100, min_value=1, step=1, | |
| help="Number of security alerts to analyze daily") | |
| with col2: | |
| input_size = st.number_input("Alert Content Size (characters)", value=1000, min_value=1, step=1, | |
| help="Include logs, metadata, and context per alert") | |
| with col3: | |
| output_size = st.number_input("Analysis Output Size (characters)", value=500, min_value=1, step=1, | |
| help="Expected length of security analysis and recommendations") | |
| # --- Row 3: Buttons --- | |
| btn_col1, btn_col2 = st.columns(2) | |
| with btn_col1: | |
| if st.button("Estimate"): | |
| if selected_model: | |
| st.session_state["result"] = estimate_cost(num_alerts, input_size, output_size, selected_model) | |
| else: | |
| st.session_state["result"] = "No model selected." | |
| with btn_col2: | |
| if st.button("Refresh Pricing Data"): | |
| with st.spinner("Refreshing pricing data..."): | |
| models, pricing, providers = asyncio.run(load_data()) | |
| st.session_state["models"] = models | |
| st.session_state["pricing"] = pricing | |
| st.session_state["providers"] = providers | |
| st.success("Pricing data refreshed!") | |
| st.divider() | |
| st.markdown("### Results") | |
| if "result" in st.session_state: | |
| st.write(st.session_state["result"]) | |
| else: | |
| st.write("Use the buttons above to estimate costs.") | |
| if st.button("Clear"): | |
| st.session_state.pop("result", None) | |
| elif page == "On Premise Estimator": | |
| st.divider() | |
| st.header("On Premise Estimator") | |
| st.markdown("Enter a Hugging Face model ID to perform an on premise analysis using the provided estimator.") | |
| hf_model_id = st.text_input("Hugging Face Model ID", value="meta-llama/Llama-4-Scout-17B-16E") | |
| if st.button("Analyze Model"): | |
| with st.spinner("Analyzing model..."): | |
| analysis_result = analyze_hf_model(hf_model_id) | |
| st.session_state["analysis_result"] = analysis_result | |
| if "analysis_result" in st.session_state: | |
| report = format_analysis_report(st.session_state["analysis_result"]) | |
| st.markdown(report) | |
| elif page == "About": | |
| st.divider() | |
| st.markdown( | |
| """ | |
| ## About This App | |
| This is based on the tokonomics package. | |
| - The app downloads the latest pricing from the LiteLLM repository. | |
| - Using simple maths to estimate the total tokens. | |
| - Helps you estimate hardware requirements for running open-source large language models (LLMs) on-premise using only the model ID from Hugging Face. | |
| - Latest Version 0.1 | |
| --- | |
| ### 📌 Version History | |
| | Version | Release Date | Key Feature Updates | | |
| |--------|--------------|---------------------| | |
| | `v1.1` | 2025-04-06 | Added On Premise Estimator Feature | | |
| | `v1.0` | 2025-03-26 | Initial release with basic total tokens estimation | | |
| --- | |
| Website: [https://www.priam.ai](https://www.priam.ai) | |
| """ | |
| ) | |
| st.markdown( | |
| """ | |
| ### Found a Bug? | |
| If you encounter any issues or have feedback, please email to **[email protected]** | |
| Your input helps us improve the app! | |
| """ | |
| ) | |