more lm eval harness
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- lm-evaluation-harness/docs/API_guide.md +211 -0
- lm-evaluation-harness/docs/CONTRIBUTING.md +83 -0
- lm-evaluation-harness/docs/README.md +11 -0
- lm-evaluation-harness/docs/chat-template-readme.md +31 -0
- lm-evaluation-harness/docs/decontamination.md +76 -0
- lm-evaluation-harness/docs/footguns.md +58 -0
- lm-evaluation-harness/docs/img/fewshot_example_gpt3.png +3 -0
- lm-evaluation-harness/docs/interface.md +170 -0
- lm-evaluation-harness/docs/model_guide.md +192 -0
- lm-evaluation-harness/docs/new_task_guide.md +521 -0
- lm-evaluation-harness/docs/task_guide.md +335 -0
- lm-evaluation-harness/examples/lm-eval-overview.ipynb +1240 -0
- lm-evaluation-harness/examples/transformer-lens.py +59 -0
- lm-evaluation-harness/examples/visualize-wandb.ipynb +172 -0
- lm-evaluation-harness/examples/visualize-zeno.ipynb +115 -0
- lm-evaluation-harness/lm_eval/__init__.py +21 -0
- lm-evaluation-harness/lm_eval/__main__.py +536 -0
- lm-evaluation-harness/lm_eval/api/__init__.py +0 -0
- lm-evaluation-harness/lm_eval/api/filter.py +56 -0
- lm-evaluation-harness/lm_eval/api/group.py +115 -0
- lm-evaluation-harness/lm_eval/api/instance.py +38 -0
- lm-evaluation-harness/lm_eval/api/metrics.py +629 -0
- lm-evaluation-harness/lm_eval/api/model.py +502 -0
- lm-evaluation-harness/lm_eval/api/registry.py +196 -0
- lm-evaluation-harness/lm_eval/api/samplers.py +232 -0
- lm-evaluation-harness/lm_eval/api/task.py +1885 -0
- lm-evaluation-harness/lm_eval/caching/__init__.py +0 -0
- lm-evaluation-harness/lm_eval/caching/cache.py +59 -0
- lm-evaluation-harness/lm_eval/decontamination/__init__.py +0 -0
- lm-evaluation-harness/lm_eval/decontamination/archiver.py +174 -0
- lm-evaluation-harness/lm_eval/decontamination/decontaminate.py +166 -0
- lm-evaluation-harness/lm_eval/decontamination/janitor.py +329 -0
- lm-evaluation-harness/lm_eval/evaluator.py +787 -0
- lm-evaluation-harness/lm_eval/evaluator_utils.py +554 -0
- lm-evaluation-harness/lm_eval/filters/__init__.py +25 -0
- lm-evaluation-harness/lm_eval/filters/custom.py +17 -0
- lm-evaluation-harness/lm_eval/filters/decontamination.py +25 -0
- lm-evaluation-harness/lm_eval/filters/extraction.py +233 -0
- lm-evaluation-harness/lm_eval/filters/selection.py +61 -0
- lm-evaluation-harness/lm_eval/filters/transformation.py +122 -0
- lm-evaluation-harness/lm_eval/loggers/__init__.py +2 -0
- lm-evaluation-harness/lm_eval/loggers/evaluation_tracker.py +537 -0
- lm-evaluation-harness/lm_eval/loggers/utils.py +149 -0
- lm-evaluation-harness/lm_eval/loggers/wandb_logger.py +358 -0
- lm-evaluation-harness/lm_eval/models/__init__.py +35 -0
- lm-evaluation-harness/lm_eval/models/anthropic_llms.py +382 -0
- lm-evaluation-harness/lm_eval/models/api_models.py +810 -0
- lm-evaluation-harness/lm_eval/models/dummy.py +41 -0
- lm-evaluation-harness/lm_eval/models/gguf.py +132 -0
.gitattributes
CHANGED
|
@@ -45,3 +45,5 @@ records/012625_BatchSize/ablations.png filter=lfs diff=lfs merge=lfs -text
|
|
| 45 |
records/102924_Optimizers/nanogpt_speedrun81w.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
records/102924_Optimizers/nanogpt_speedrun82w.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
records/110624_ShortcutsTweaks/nanogpt_speedrun111.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 45 |
records/102924_Optimizers/nanogpt_speedrun81w.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
records/102924_Optimizers/nanogpt_speedrun82w.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
records/110624_ShortcutsTweaks/nanogpt_speedrun111.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
lm-evaluation-harness/docs/img/fewshot_example_gpt3.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
lm-evaluation-harness/lm_eval/tasks/noreval/noreval.jpg filter=lfs diff=lfs merge=lfs -text
|
lm-evaluation-harness/docs/API_guide.md
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TemplateAPI Usage Guide
|
| 2 |
+
|
| 3 |
+
The `TemplateAPI` class is a versatile superclass designed to facilitate the integration of various API-based language models into the lm-evaluation-harness framework. This guide will explain how to use and extend the `TemplateAPI` class to implement your own API models. If your API implements the OpenAI API you can use the `local-completions` or the `local-chat-completions` (defined [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/openai_completions.py)) model types, which can also serve as examples of how to effectively subclass this template.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
The `TemplateAPI` class provides a template for creating API-based model implementations. It handles common functionalities such as:
|
| 8 |
+
|
| 9 |
+
- Tokenization (optional)
|
| 10 |
+
- Batch processing
|
| 11 |
+
- Caching
|
| 12 |
+
- Retrying failed requests
|
| 13 |
+
- Parsing API responses
|
| 14 |
+
|
| 15 |
+
To use this class, you typically need to subclass it and implement specific methods for your API.
|
| 16 |
+
|
| 17 |
+
## Key Methods to Implement
|
| 18 |
+
|
| 19 |
+
When subclassing `TemplateAPI`, you need to implement the following methods:
|
| 20 |
+
|
| 21 |
+
1. `_create_payload`: Creates the JSON payload for API requests.
|
| 22 |
+
2. `parse_logprobs`: Parses log probabilities from API responses.
|
| 23 |
+
3. `parse_generations`: Parses generated text from API responses.
|
| 24 |
+
|
| 25 |
+
Optional Properties:
|
| 26 |
+
|
| 27 |
+
4. `header`: Returns the headers for the API request.
|
| 28 |
+
5. `api_key`: Returns the API key for authentication (if required).
|
| 29 |
+
|
| 30 |
+
You may also need to override other methods or properties depending on your API's specific requirements.
|
| 31 |
+
|
| 32 |
+
> [!NOTE]
|
| 33 |
+
> Currently loglikelihood and MCQ based tasks (such as MMLU) are only supported for completion endpoints. Not for chat-completion — those that expect a list of dicts — endpoints! Completion APIs which support instruct tuned models can be evaluated with the `--apply_chat_template` option in order to simultaneously evaluate models using a chat template format while still being able to access the model logits needed for loglikelihood-based tasks.
|
| 34 |
+
|
| 35 |
+
## TemplateAPI Arguments
|
| 36 |
+
|
| 37 |
+
When initializing a `TemplateAPI` instance or a subclass, you can provide several arguments to customize its behavior. Here's a detailed explanation of some important arguments:
|
| 38 |
+
|
| 39 |
+
- `model` or `pretrained` (str):
|
| 40 |
+
- The name or identifier of the model to use.
|
| 41 |
+
- `model` takes precedence over `pretrained` when both are provided.
|
| 42 |
+
|
| 43 |
+
- `base_url` (str):
|
| 44 |
+
- The base URL for the API endpoint.
|
| 45 |
+
|
| 46 |
+
- `tokenizer` (str, optional):
|
| 47 |
+
- The name or path of the tokenizer to use.
|
| 48 |
+
- If not provided, it defaults to using the same tokenizer name as the model.
|
| 49 |
+
|
| 50 |
+
- `num_concurrent` (int):
|
| 51 |
+
- Number of concurrent requests to make to the API.
|
| 52 |
+
- Useful for APIs that support parallel processing.
|
| 53 |
+
- Default is 1 (sequential processing).
|
| 54 |
+
|
| 55 |
+
- `timeout` (int, optional):
|
| 56 |
+
- Timeout for API requests in seconds.
|
| 57 |
+
- Default is 30.
|
| 58 |
+
|
| 59 |
+
- `tokenized_requests` (bool):
|
| 60 |
+
- Determines whether the input is pre-tokenized. Defaults to `True`.
|
| 61 |
+
- Requests can be sent in either tokenized form (`list[list[int]]`) or as text (`list[str]`, or `str` for batch_size=1).
|
| 62 |
+
- For loglikelihood-based tasks, prompts require tokenization to calculate the context length. If `False` prompts are decoded back to text before being sent to the API.
|
| 63 |
+
- Not as important for `generate_until` tasks.
|
| 64 |
+
- Ignored for chat formatted inputs (list[dict...]) or if tokenizer_backend is None.
|
| 65 |
+
|
| 66 |
+
- `tokenizer_backend` (str, optional):
|
| 67 |
+
- Required for loglikelihood-based or MCQ tasks.
|
| 68 |
+
- Specifies the tokenizer library to use. Options are "tiktoken", "huggingface", or None.
|
| 69 |
+
- Default is "huggingface".
|
| 70 |
+
|
| 71 |
+
- `max_length` (int, optional):
|
| 72 |
+
- Maximum length of input + output.
|
| 73 |
+
- Default is 2048.
|
| 74 |
+
|
| 75 |
+
- `max_retries` (int, optional):
|
| 76 |
+
- Maximum number of retries for failed API requests.
|
| 77 |
+
- Default is 3.
|
| 78 |
+
|
| 79 |
+
- `max_gen_toks` (int, optional):
|
| 80 |
+
- Maximum number of tokens to generate in completion tasks.
|
| 81 |
+
- Default is 256 or set in task yaml.
|
| 82 |
+
|
| 83 |
+
- `batch_size` (int or str, optional):
|
| 84 |
+
- Number of requests to batch together (if the API supports batching).
|
| 85 |
+
- Can be an integer or "auto" (which defaults to 1 for API models).
|
| 86 |
+
- Default is 1.
|
| 87 |
+
|
| 88 |
+
- `seed` (int, optional):
|
| 89 |
+
- Random seed for reproducibility.
|
| 90 |
+
- Default is 1234.
|
| 91 |
+
|
| 92 |
+
- `add_bos_token` (bool, optional):
|
| 93 |
+
- Whether to add the beginning-of-sequence token to inputs (when tokenizing).
|
| 94 |
+
- Default is False.
|
| 95 |
+
|
| 96 |
+
- `custom_prefix_token_id` (int, optional):
|
| 97 |
+
- Custom token ID to use as a prefix for inputs.
|
| 98 |
+
- If not provided, uses the model's default BOS or EOS token (if `add_bos_token` is True).
|
| 99 |
+
|
| 100 |
+
- `verify_certificate` (bool, optional):
|
| 101 |
+
- Whether to validate the certificate of the API endpoint (if HTTPS).
|
| 102 |
+
- Default is True.
|
| 103 |
+
|
| 104 |
+
- `header` (dict, optional):
|
| 105 |
+
- Custom headers for API requests.
|
| 106 |
+
- If not provided, uses `{"Authorization": f"Bearer {self.api_key}"}` by default.
|
| 107 |
+
|
| 108 |
+
Example usage:
|
| 109 |
+
|
| 110 |
+
```python
|
| 111 |
+
class MyAPIModel(TemplateAPI):
|
| 112 |
+
def __init__(self, **kwargs):
|
| 113 |
+
super().__init__(
|
| 114 |
+
model="my-model",
|
| 115 |
+
base_url="https://api.mymodel.com/v1/completions",
|
| 116 |
+
tokenizer_backend="huggingface",
|
| 117 |
+
num_concurrent=5,
|
| 118 |
+
max_retries=5,
|
| 119 |
+
batch_size=10,
|
| 120 |
+
**kwargs
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Implement other required methods...
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
When subclassing `TemplateAPI`, you can override these arguments in your `__init__` method to set default values specific to your API. You can also add additional (potentially user-specified) arguments as needed for your specific implementation.
|
| 127 |
+
|
| 128 |
+
## Example Implementation: OpenAI API
|
| 129 |
+
|
| 130 |
+
The `OpenAICompletionsAPI` and `OpenAIChatCompletion` ([here](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/openai_completions.py) classes demonstrate how to implement API models using the `TemplateAPI` class. Here's a breakdown of the key components:
|
| 131 |
+
|
| 132 |
+
### 1. Subclassing and Initialization
|
| 133 |
+
|
| 134 |
+
```python
|
| 135 |
+
@register_model("openai-completions")
|
| 136 |
+
class OpenAICompletionsAPI(LocalCompletionsAPI):
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
base_url="https://api.openai.com/v1/completions",
|
| 140 |
+
tokenizer_backend="tiktoken",
|
| 141 |
+
**kwargs,
|
| 142 |
+
):
|
| 143 |
+
super().__init__(
|
| 144 |
+
base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
|
| 145 |
+
)
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
### 2. Implementing API Key Retrieval
|
| 149 |
+
|
| 150 |
+
```python
|
| 151 |
+
@cached_property
|
| 152 |
+
def api_key(self):
|
| 153 |
+
key = os.environ.get("OPENAI_API_KEY", None)
|
| 154 |
+
if key is None:
|
| 155 |
+
raise ValueError(
|
| 156 |
+
"API key not found. Please set the OPENAI_API_KEY environment variable."
|
| 157 |
+
)
|
| 158 |
+
return key
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
### 3. Creating the Payload
|
| 162 |
+
|
| 163 |
+
```python
|
| 164 |
+
def _create_payload(
|
| 165 |
+
self,
|
| 166 |
+
messages: Union[List[List[int]], List[dict], List[str], str],
|
| 167 |
+
generate=False,
|
| 168 |
+
gen_kwargs: Optional[dict] = None,
|
| 169 |
+
**kwargs,
|
| 170 |
+
) -> dict:
|
| 171 |
+
if generate:
|
| 172 |
+
# ... (implementation for generation)
|
| 173 |
+
else:
|
| 174 |
+
# ... (implementation for log likelihood)
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
### 4. Parsing API Responses
|
| 178 |
+
|
| 179 |
+
```python
|
| 180 |
+
@staticmethod
|
| 181 |
+
def parse_logprobs(
|
| 182 |
+
outputs: Union[Dict, List[Dict]],
|
| 183 |
+
tokens: List[List[int]] = None,
|
| 184 |
+
ctxlens: List[int] = None,
|
| 185 |
+
**kwargs,
|
| 186 |
+
) -> List[Tuple[float, bool]]:
|
| 187 |
+
# ... (implementation)
|
| 188 |
+
|
| 189 |
+
@staticmethod
|
| 190 |
+
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
|
| 191 |
+
# ... (implementation)
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
The requests are initiated in the `model_call` or the `amodel_call` methods.
|
| 195 |
+
|
| 196 |
+
## Implementing Your Own API Model
|
| 197 |
+
|
| 198 |
+
To implement your own API model:
|
| 199 |
+
|
| 200 |
+
1. Subclass `TemplateAPI` or one of its subclasses (e.g., `LocalCompletionsAPI`).
|
| 201 |
+
2. Override the `__init__` method if you need to set specific parameters.
|
| 202 |
+
3. Implement the `_create_payload` and `header` methods to create the appropriate payload for your API.
|
| 203 |
+
4. Implement the `parse_logprobs` and `parse_generations` methods to parse your API's responses.
|
| 204 |
+
5. Override the `api_key` property if your API requires authentication.
|
| 205 |
+
6. Override any other methods as necessary to match your API's behavior.
|
| 206 |
+
|
| 207 |
+
## Best Practices
|
| 208 |
+
|
| 209 |
+
1. Use the `@register_model` decorator to register your model with the framework (and import it in `lm_eval/models/__init__.py`!).
|
| 210 |
+
2. Use environment variables for sensitive information like API keys.
|
| 211 |
+
3. Properly handle batching and concurrent requests if supported by your API.
|
lm-evaluation-harness/docs/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to LM Evaluation Harness
|
| 2 |
+
|
| 3 |
+
Welcome and thank you for your interest in the LM Evaluation Harness! We welcome contributions and feedback and appreciate your time spent with our library, and hope you find it useful!
|
| 4 |
+
|
| 5 |
+
## Important Resources
|
| 6 |
+
|
| 7 |
+
There are several places information about LM Evaluation Harness is located:
|
| 8 |
+
|
| 9 |
+
- Our [documentation pages](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/docs)
|
| 10 |
+
- We occasionally use [GitHub Milestones](https://github.com/EleutherAI/lm-evaluation-harness/milestones) to track progress toward specific near-term version releases.
|
| 11 |
+
- We maintain a [Project Board](https://github.com/orgs/EleutherAI/projects/25) for tracking current work items and PRs, and for future roadmap items or feature requests.
|
| 12 |
+
- Further discussion and support conversations are located in the #lm-thunderdome channel of the [EleutherAI discord](https://discord.gg/eleutherai).
|
| 13 |
+
|
| 14 |
+
## Code Style
|
| 15 |
+
|
| 16 |
+
LM Evaluation Harness uses [ruff](https://github.com/astral-sh/ruff) for linting via [pre-commit](https://pre-commit.com/).
|
| 17 |
+
|
| 18 |
+
You can install linters and dev tools via
|
| 19 |
+
|
| 20 |
+
```pip install lm_eval[dev]``` or ```pip install -e ".[dev]"```
|
| 21 |
+
|
| 22 |
+
Then, run
|
| 23 |
+
|
| 24 |
+
```pre-commit install```
|
| 25 |
+
|
| 26 |
+
in order to ensure linters and other checks will be run upon committing.
|
| 27 |
+
|
| 28 |
+
## Testing
|
| 29 |
+
|
| 30 |
+
We use [pytest](https://docs.pytest.org/en/latest/) for running unit tests. All library unit tests can be run via:
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
python -m pytest --showlocals -s -vv -n=auto --ignore=tests/models/test_openvino.py
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## Contributor License Agreement
|
| 37 |
+
|
| 38 |
+
We ask that new contributors agree to a Contributor License Agreement affirming that EleutherAI has the rights to use your contribution to our library.
|
| 39 |
+
First-time pull requests will have a reply added by @CLAassistant containing instructions for how to confirm this, and we require it before merging your PR.
|
| 40 |
+
|
| 41 |
+
## Contribution Best Practices
|
| 42 |
+
|
| 43 |
+
We recommend a few best practices to make your contributions or reported errors easier to assist with.
|
| 44 |
+
|
| 45 |
+
**For Pull Requests:**
|
| 46 |
+
|
| 47 |
+
- PRs should be titled descriptively, and be opened with a brief description of the scope and intent of the new contribution.
|
| 48 |
+
- New features should have appropriate documentation added alongside them.
|
| 49 |
+
- Aim for code maintainability, and minimize code copying.
|
| 50 |
+
- If opening a task, try to share test results on the task using a publicly-available model, and if any public results are available on the task, compare to them.
|
| 51 |
+
|
| 52 |
+
**For Feature Requests:**
|
| 53 |
+
|
| 54 |
+
- Provide a short paragraph's worth of description. What is the feature you are requesting? What is its motivation, and an example use case of it? How does this differ from what is currently supported?
|
| 55 |
+
|
| 56 |
+
**For Bug Reports**:
|
| 57 |
+
|
| 58 |
+
- Provide a short description of the bug.
|
| 59 |
+
- Provide a *reproducible example*--what is the command you run with our library that results in this error? Have you tried any other steps to resolve it?
|
| 60 |
+
- Provide a *full error traceback* of the error that occurs, if applicable. A one-line error message or small screenshot snippet is unhelpful without the surrounding context.
|
| 61 |
+
- Note what version of the codebase you are using, and any specifics of your environment and setup that may be relevant.
|
| 62 |
+
|
| 63 |
+
**For Requesting New Tasks**:
|
| 64 |
+
|
| 65 |
+
- Provide a 1-2 sentence description of what the task is and what it evaluates.
|
| 66 |
+
- Provide a link to the paper introducing the task.
|
| 67 |
+
- Provide a link to where the dataset can be found.
|
| 68 |
+
- Provide a link to a paper containing results on an open-source model on the task, for use in comparisons and implementation validation.
|
| 69 |
+
- If applicable, link to any codebase that has implemented the task (especially the original publication's codebase, if existent).
|
| 70 |
+
|
| 71 |
+
## How Can I Get Involved?
|
| 72 |
+
|
| 73 |
+
To quickly get started, we maintain a list of good first issues, which can be found [on our project board](https://github.com/orgs/EleutherAI/projects/25/views/8) or by [filtering GH Issues](https://github.com/EleutherAI/lm-evaluation-harness/issues?q=is%3Aopen+label%3A%22good+first+issue%22+label%3A%22help+wanted%22). These are typically smaller code changes or self-contained features which can be added without extensive familiarity with library internals, and we recommend new contributors consider taking a stab at one of these first if they are feeling uncertain where to begin.
|
| 74 |
+
|
| 75 |
+
There are a number of distinct ways to contribute to LM Evaluation Harness, and all are extremely helpful! A sampling of ways to contribute include:
|
| 76 |
+
|
| 77 |
+
- **Implementing and verifying new evaluation tasks**: Is there a task you'd like to see LM Evaluation Harness support? Consider opening an issue requesting it, or helping add it! Verifying and cross-checking task implementations with their original versions is also a very valuable form of assistance in ensuring standardized evaluation.
|
| 78 |
+
- **Improving documentation** - Improvements to the documentation, or noting pain points / gaps in documentation, are helpful in order for us to improve the user experience of the library and clarity + coverage of documentation.
|
| 79 |
+
- **Testing and devops** - We are very grateful for any assistance in adding tests for the library that can be run for new PRs, and other devops workflows.
|
| 80 |
+
- **Adding new modeling / inference library integrations** - We hope to support a broad range of commonly-used inference libraries popular among the community, and welcome PRs for new integrations, so long as they are documented properly and maintainable.
|
| 81 |
+
- **Proposing or Contributing New Features** - We want LM Evaluation Harness to support a broad range of evaluation usecases. If you have a feature that is not currently supported but desired, feel free to open an issue describing the feature and, if applicable, how you intend to implement it. We would be happy to give feedback on the cleanest way to implement new functionalities and are happy to coordinate with interested contributors via GH discussions or via discord.
|
| 82 |
+
|
| 83 |
+
We hope that this has been helpful, and appreciate your interest in contributing! Further questions can be directed to [our Discord](discord.gg/eleutherai).
|
lm-evaluation-harness/docs/README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Eval Harness Documentation
|
| 2 |
+
|
| 3 |
+
Welcome to the docs for the LM Evaluation Harness!
|
| 4 |
+
|
| 5 |
+
## Table of Contents
|
| 6 |
+
|
| 7 |
+
* To learn about the public interface of the library, as well as how to evaluate via the command line or as integrated into an external library, see the [Interface](./interface.md).
|
| 8 |
+
* To learn how to add a new library, API, or model type to the library, as well as a quick explainer on the types of ways to evaluate an LM, see the [Model Guide](./model_guide.md).
|
| 9 |
+
* For an extended description of how to extend the library to new model classes served over an API, see the [API Guide](./API_guide.md).
|
| 10 |
+
* For a crash course on adding new tasks to the library, see our [New Task Guide](./new_task_guide.md).
|
| 11 |
+
* To learn more about pushing the limits of task configuration that the Eval Harness supports, see the [Task Configuration Guide](./task_guide.md).
|
lm-evaluation-harness/docs/chat-template-readme.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Chat Template Delimiter Handling Update
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
This change modifies how delimiters are handled when applying chat templates in the request construction process for likelihood and multiple-choice based tasks. When `apply_chat_template` is set to `True`, the target delimiter is now set to an empty string instead of using the configured delimiter.
|
| 6 |
+
|
| 7 |
+
## Background
|
| 8 |
+
|
| 9 |
+
By default, the system uses a target delimiter (typically a whitespace " ") between the context and target text when constructing prompts. The full string is constructed as:
|
| 10 |
+
|
| 11 |
+
```text
|
| 12 |
+
doc_to_text(doc) + target_delimiter + doc_to_target(doc)
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
While this worked well for base models where we wanted the model to predict a single whitespace followed by the answer, chat models have their own formatting conventions that handle spacing differently.
|
| 16 |
+
|
| 17 |
+
## The Change
|
| 18 |
+
|
| 19 |
+
- When `apply_chat_template=True`, the target delimiter is now empty ("") instead of the default whitespace
|
| 20 |
+
- This prevents interference between chat template formatting and the default delimiter system
|
| 21 |
+
- Particularly important for multiple choice tasks where the template itself handles spacing
|
| 22 |
+
|
| 23 |
+
## Example
|
| 24 |
+
|
| 25 |
+
```text
|
| 26 |
+
# Before (with default delimiter " ")
|
| 27 |
+
<user>Question: What color is the sky?\nAnswer:<assistant> blue
|
| 28 |
+
|
| 29 |
+
# After
|
| 30 |
+
<user>Question: What color is the sky?\nAnswer:<assistant>blue
|
| 31 |
+
```
|
lm-evaluation-harness/docs/decontamination.md
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Decontamination
|
| 2 |
+
|
| 3 |
+
## Usage
|
| 4 |
+
|
| 5 |
+
The provided directory should contain
|
| 6 |
+
the ngram files and info.json produced in "Pile Ngram Generation" further down.
|
| 7 |
+
|
| 8 |
+
```bash
|
| 9 |
+
python -m lm_eval \
|
| 10 |
+
--model gpt2 \
|
| 11 |
+
--device 0 \
|
| 12 |
+
--tasks sciq
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
## Background
|
| 16 |
+
|
| 17 |
+
Downstream evaluations test model generalization, and are less useful when test set data also exists in the training set, referred to as leakage or contamination.
|
| 18 |
+
|
| 19 |
+
Filtering your training set against the test set is a good first step, however this isn't always possible, as in the case of a new benchmark or one that wasn't considered prior to model training. When training set filtering isn't possible, it is useful to measure the impact of test set leakage by detecting the contaminated test examples and producing a clean version of the benchmark.
|
| 20 |
+
|
| 21 |
+
The basis for our decontamination procedure can be found in Appendix C of "Language Models are Few-Shot Learners". OpenAI defined a test document as contaminated if any N-gram overlap existed with any training document. They used a range of N values between 8 and 13 depending on dataset, while we just used 13 for simplicity.
|
| 22 |
+
|
| 23 |
+
## Implementation
|
| 24 |
+
|
| 25 |
+
Contamination detection can be found in `lm_eval/decontaminate.py` with supporting code in `lm_eval/decontamination/`.
|
| 26 |
+
|
| 27 |
+
decontaminate.py does the following:
|
| 28 |
+
|
| 29 |
+
1. Build dictionaries of all ngrams and their corresponding evaluation/document ids.
|
| 30 |
+
2. Scan through sorted files containing training set n-grams.
|
| 31 |
+
3. If a match is found, the corresponding evaluation/document combinations are marked as contaminated.
|
| 32 |
+
|
| 33 |
+
`lm_eval/evaluator.py` can then produce a clean version of the benchmark by excluding the results of contaminated documents. For each metric, a clean version will be shown in the results with a "decontaminate" suffix.
|
| 34 |
+
|
| 35 |
+
This is disabled by default for new tasks, to support decontamination on a task override the "should_decontaminate" and "doc_to_decontamination_query" methods. For more details see the [task guide](task_guide.md).
|
| 36 |
+
|
| 37 |
+
## Pile Ngram Generation
|
| 38 |
+
|
| 39 |
+
The relevant scripts can be found in `scripts/clean_training_data`, which also import from
|
| 40 |
+
`lm_eval/decontamination/`
|
| 41 |
+
|
| 42 |
+
1. git clone https://github.com/EleutherAI/lm-evaluation-harness.git
|
| 43 |
+
2. pip install -r requirements.txt
|
| 44 |
+
3. Download The Pile from [The Eye](https://the-eye.eu/public/AI/pile/train/)
|
| 45 |
+
4. Place pile files in "pile" directory under "lm-evaluation-harness" (or create a symlink)
|
| 46 |
+
5. Run generate_13_grams.
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
export PYTHONHASHSEED=0
|
| 50 |
+
python -m scripts/clean_training_data/generate_13_grams \
|
| 51 |
+
-dir path/to/working/directory \
|
| 52 |
+
-n 13 \
|
| 53 |
+
-buckets 500
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Took approximately 4 days for us. We had the time to wait, but this could be scaled out by doing partial pile scans on multiple instances of this script and merging the relevant buckets. We fixed PYTHONHASHSEED to ensure reproducibility of bucket hashing in case you need to stop and start.
|
| 57 |
+
|
| 58 |
+
6. Sort the generated 13-grams.
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
python -m scripts/clean_training_data/sort_13_gram_buckets \
|
| 62 |
+
-dir path/to/working/directory/output
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Took approximately 5 days for us. You could speed this up by spreading the files around to different machines and running the sort script before gathering them together.
|
| 66 |
+
|
| 67 |
+
7. Compress the sorted 13 grams files and place them together with info.json.
|
| 68 |
+
|
| 69 |
+
This step only takes a few hours.
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
python -m scripts/clean_training_data/compress_and_package \
|
| 73 |
+
-dir path/to/working/directory \
|
| 74 |
+
-output path/to/final/directory \
|
| 75 |
+
-procs 8
|
| 76 |
+
```
|
lm-evaluation-harness/docs/footguns.md
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Common Pitfalls and Troubleshooting Guide
|
| 2 |
+
|
| 3 |
+
This document highlights common pitfalls and troubleshooting tips when using this library. We'll continue to add more tips as we discover them.
|
| 4 |
+
|
| 5 |
+
## YAML Configuration Issues
|
| 6 |
+
|
| 7 |
+
### Newline Characters in YAML (`\n`)
|
| 8 |
+
|
| 9 |
+
**Problem:** When specifying newline characters in YAML, they may be interpreted incorrectly depending on how you format them.
|
| 10 |
+
|
| 11 |
+
```yaml
|
| 12 |
+
# ❌ WRONG: Single quotes don't process escape sequences
|
| 13 |
+
generation_kwargs:
|
| 14 |
+
until: ['\n'] # Gets parsed as the literal characters '\' and 'n' i.e "\\n"
|
| 15 |
+
|
| 16 |
+
```
|
| 17 |
+
```yaml
|
| 18 |
+
# ✅ RIGHT: Use double quotes for escape sequences
|
| 19 |
+
generation_kwargs:
|
| 20 |
+
until: ["\n"] # Gets parsed as an actual newline character
|
| 21 |
+
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
**Solutions:**
|
| 25 |
+
- Use double quotes for strings containing escape sequences
|
| 26 |
+
- For multiline content, use YAML's block scalars (`|` or `>`)
|
| 27 |
+
- When generating YAML programmatically, be careful with how template engines handle escape sequences
|
| 28 |
+
|
| 29 |
+
### Quoting in YAML
|
| 30 |
+
|
| 31 |
+
**When to use different types of quotes:**
|
| 32 |
+
|
| 33 |
+
- **No quotes**: Simple values (numbers, booleans, alphanumeric strings without special characters)
|
| 34 |
+
```yaml
|
| 35 |
+
simple_value: plain text
|
| 36 |
+
number: 42
|
| 37 |
+
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
- **Single quotes (')**:
|
| 41 |
+
- Preserves literal values
|
| 42 |
+
- Use when you need special characters to be treated literally
|
| 43 |
+
- Escape single quotes by doubling them: `'It''s working'`
|
| 44 |
+
```yaml
|
| 45 |
+
literal_string: 'The newline character \n is not processed here'
|
| 46 |
+
path: 'C:\Users\name' # Backslashes preserved
|
| 47 |
+
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
- **Double quotes (")**:
|
| 51 |
+
- Processes escape sequences like `\n`, `\t`, etc.
|
| 52 |
+
- Use for strings that need special characters interpreted
|
| 53 |
+
- Escape double quotes with backslash: `"He said \"Hello\""`
|
| 54 |
+
```yaml
|
| 55 |
+
processed_string: "First line\nSecond line" # Creates actual newline
|
| 56 |
+
unicode: "Copyright symbol: \u00A9" # Unicode character
|
| 57 |
+
|
| 58 |
+
```
|
lm-evaluation-harness/docs/img/fewshot_example_gpt3.png
ADDED
|
Git LFS Details
|
lm-evaluation-harness/docs/interface.md
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# User Guide
|
| 2 |
+
|
| 3 |
+
This document details the interface exposed by `lm-eval` and provides details on what flags are available to users.
|
| 4 |
+
|
| 5 |
+
## Command-line Interface
|
| 6 |
+
|
| 7 |
+
A majority of users run the library by cloning it from Github, installing the package as editable, and running the `python -m lm_eval` script.
|
| 8 |
+
|
| 9 |
+
Equivalently, running the library can be done via the `lm-eval` entrypoint at the command line.
|
| 10 |
+
|
| 11 |
+
This mode supports a number of command-line arguments, the details of which can also be seen via running with `-h` or `--help`:
|
| 12 |
+
|
| 13 |
+
- `--model` : Selects which model type or provider is evaluated. Must be a string corresponding to the name of the model type/provider being used. See [the main README](https://github.com/EleutherAI/lm-evaluation-harness/tree/main#model-apis-and-inference-servers) for a full list of enabled model names and supported libraries or APIs.
|
| 14 |
+
|
| 15 |
+
- `--model_args` : Controls parameters passed to the model constructor. Accepts a string containing comma-separated keyword arguments to the model class of the format `"arg1=val1,arg2=val2,..."`, such as, for example `--model_args pretrained=EleutherAI/pythia-160m,dtype=float32`. For a full list of what keyword arguments, see the initialization of the `lm_eval.api.model.LM` subclass, e.g. [`HFLM`](https://github.com/EleutherAI/lm-evaluation-harness/blob/365fcda9b85bbb6e0572d91976b8daf409164500/lm_eval/models/huggingface.py#L66)
|
| 16 |
+
|
| 17 |
+
- `--tasks` : Determines which tasks or task groups are evaluated. Accepts a comma-separated list of task names or task group names. Must be solely comprised of valid tasks/groups. A list of supported tasks can be viewed with `--tasks list`.
|
| 18 |
+
|
| 19 |
+
- `--num_fewshot` : Sets the number of few-shot examples to place in context. Must be an integer.
|
| 20 |
+
|
| 21 |
+
- `--gen_kwargs` : takes an arg string in same format as `--model_args` and creates a dictionary of keyword arguments. These will be passed to the models for all called `generate_until` (free-form or greedy generation task) tasks, to set options such as the sampling temperature or `top_p` / `top_k`. For a list of what args are supported for each model type, reference the respective library's documentation (for example, the documentation for `transformers.AutoModelForCausalLM.generate()`.) These kwargs will be applied to all `generate_until` tasks called--we do not currently support unique gen_kwargs or batch_size values per task in a single run of the library. To control these on a per-task level, set them in that task's YAML file.
|
| 22 |
+
|
| 23 |
+
- `--batch_size` : Sets the batch size used for evaluation. Can be a positive integer or `"auto"` to automatically select the largest batch size that will fit in memory, speeding up evaluation. One can pass `--batch_size auto:N` to re-select the maximum batch size `N` times during evaluation. This can help accelerate evaluation further, since `lm-eval` sorts documents in descending order of context length.
|
| 24 |
+
|
| 25 |
+
- `--max_batch_size` : Sets the maximum batch size to try to fit in memory, if `--batch_size auto` is passed.
|
| 26 |
+
|
| 27 |
+
- `--device` : Sets which device to place the model onto. Must be a string, for example, `"cuda", "cuda:0", "cpu", "mps"`. Defaults to "cuda", and can be ignored if running multi-GPU or running a non-local model type.
|
| 28 |
+
|
| 29 |
+
- `--output_path` : A string of the form `dir/file.jsonl` or `dir/`. Provides a path where high-level results will be saved, either into the file named or into the directory named. If `--log_samples` is passed as well, then per-document outputs and metrics will be saved into the directory as well.
|
| 30 |
+
|
| 31 |
+
- `--log_samples` : If this flag is passed, then the model's outputs, and the text fed into the model, will be saved at per-document granularity. Must be used with `--output_path`.
|
| 32 |
+
|
| 33 |
+
- `--limit` : Accepts an integer, or a float between 0.0 and 1.0 . If passed, will limit the number of documents to evaluate to the first X documents (if an integer) per task or first X% of documents per task. Useful for debugging, especially on costly API models.
|
| 34 |
+
|
| 35 |
+
- `--use_cache` : Should be a path where a sqlite db file can be written to. Takes a string of format `/path/to/sqlite_cache_` in order to create a cache db at `/path/to/sqlite_cache_rank{i}.db` for each process (0-NUM_GPUS). This allows results of prior runs to be cached, so that there is no need to re-run results in order to re-score or re-run a given (model, task) pair again.
|
| 36 |
+
|
| 37 |
+
- `--cache_requests` : Can be "true", "refresh", or "delete". "true" means that the cache should be used. "refresh" means that you wish to regenerate the cache, which you should run if you change your dataset configuration for a given task. "delete" will delete the cache. Cached files are stored under lm_eval/cache/.cache unless you specify a different path via the environment variable: `LM_HARNESS_CACHE_PATH`. e.g. `LM_HARNESS_CACHE_PATH=~/Documents/cache_for_lm_harness`.
|
| 38 |
+
|
| 39 |
+
- `--check_integrity` : If this flag is used, the library tests for each task selected are run to confirm task integrity.
|
| 40 |
+
|
| 41 |
+
- `--write_out` : Used for diagnostic purposes to observe the format of task documents passed to a model. If this flag is used, then prints the prompt and gold target string for the first document of each task.
|
| 42 |
+
|
| 43 |
+
- `--show_config` : If used, prints the full `lm_eval.api.task.TaskConfig` contents (non-default settings the task YAML file) for each task which was run, at the completion of an evaluation. Useful for when one is modifying a task's configuration YAML locally to transmit the exact configurations used for debugging or for reproducibility purposes.
|
| 44 |
+
|
| 45 |
+
- `--include_path` : Accepts a path to a folder. If passed, then all YAML files containing `lm-eval` compatible task configurations will be added to the task registry as available tasks. Used for when one is writing config files for their own task in a folder other than `lm_eval/tasks/`.
|
| 46 |
+
|
| 47 |
+
- `--system_instruction`: Specifies a system instruction string to prepend to the prompt.
|
| 48 |
+
|
| 49 |
+
- `--apply_chat_template` : This flag specifies whether to apply a chat template to the prompt. It can be used in the following ways:
|
| 50 |
+
- `--apply_chat_template` : When used without an argument, applies the only available chat template to the prompt. For Hugging Face models, if no dedicated chat template exists, the default chat template will be applied.
|
| 51 |
+
- `--apply_chat_template template_name` : If the model has multiple chat templates, apply the specified template to the prompt.
|
| 52 |
+
|
| 53 |
+
For Hugging Face models, the default chat template can be found in the [`default_chat_template`](https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1912) property of the Transformers Tokenizer.
|
| 54 |
+
|
| 55 |
+
- `--fewshot_as_multiturn` : If this flag is on, the Fewshot examples are treated as a multi-turn conversation. Questions are provided as user content and answers are provided as assistant responses. Requires `--num_fewshot` to be set to be greater than 0, and `--apply_chat_template` to be on.
|
| 56 |
+
|
| 57 |
+
- `--predict_only`: Generates the model outputs without computing metrics. Use with `--log_samples` to retrieve decoded results.
|
| 58 |
+
|
| 59 |
+
- `--seed`: Set seed for python's random, numpy and torch. Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, or a single integer to set the same seed for all three. The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility). E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`. E.g, `--seed 42` sets all three seeds to 42.
|
| 60 |
+
|
| 61 |
+
- `--wandb_args`: Tracks logging to Weights and Biases for evaluation runs and includes args passed to `wandb.init`, such as `project` and `job_type`. Full list [here](https://docs.wandb.ai/ref/python/init). e.g., ```--wandb_args project=test-project,name=test-run```. Also allows for the passing of the step to log things at (passed to `wandb.run.log`), e.g., `--wandb_args step=123`.
|
| 62 |
+
|
| 63 |
+
- `--hf_hub_log_args` : Logs evaluation results to Hugging Face Hub. Accepts a string with the arguments separated by commas. Available arguments:
|
| 64 |
+
- `hub_results_org` - organization name on Hugging Face Hub, e.g., `EleutherAI`. If not provided, the results will be pushed to the owner of the Hugging Face token,
|
| 65 |
+
- `hub_repo_name` - repository name on Hugging Face Hub (deprecated, `details_repo_name` and `results_repo_name` should be used instead), e.g., `lm-eval-results`,
|
| 66 |
+
- `details_repo_name` - repository name on Hugging Face Hub to store details, e.g., `lm-eval-results`,
|
| 67 |
+
- `results_repo_name` - repository name on Hugging Face Hub to store results, e.g., `lm-eval-results`,
|
| 68 |
+
- `push_results_to_hub` - whether to push results to Hugging Face Hub, can be `True` or `False`,
|
| 69 |
+
- `push_samples_to_hub` - whether to push samples results to Hugging Face Hub, can be `True` or `False`. Requires `--log_samples` to be set,
|
| 70 |
+
- `public_repo` - whether the repository is public, can be `True` or `False`,
|
| 71 |
+
- `leaderboard_url` - URL to the leaderboard, e.g., `https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard`.
|
| 72 |
+
- `point_of_contact` - Point of contact for the results dataset, e.g., `[email protected]`.
|
| 73 |
+
- `gated` - whether to gate the details dataset, can be `True` or `False`.
|
| 74 |
+
|
| 75 |
+
- `--metadata`: JSON string to pass to TaskConfig. Used for some tasks which require additional metadata to be passed for processing. E.g., `--metadata '{"key": "value"}'`.
|
| 76 |
+
|
| 77 |
+
## External Library Usage
|
| 78 |
+
|
| 79 |
+
We also support using the library's external API for use within model training loops or other scripts.
|
| 80 |
+
|
| 81 |
+
`lm_eval` supplies two functions for external import and use: `lm_eval.evaluate()` and `lm_eval.simple_evaluate()`.
|
| 82 |
+
|
| 83 |
+
`simple_evaluate()` can be used by simply creating an `lm_eval.api.model.LM` subclass that implements the methods described in the [Model Guide](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/docs/model_guide.md), and wrapping your custom model in that class as follows:
|
| 84 |
+
|
| 85 |
+
```python
|
| 86 |
+
import lm_eval
|
| 87 |
+
from lm_eval.utils import setup_logging
|
| 88 |
+
...
|
| 89 |
+
# initialize logging
|
| 90 |
+
setup_logging("DEBUG") # optional, but recommended; or you can set up logging yourself
|
| 91 |
+
my_model = initialize_my_model() # create your model (could be running finetuning with some custom modeling code)
|
| 92 |
+
...
|
| 93 |
+
# instantiate an LM subclass that takes your initialized model and can run
|
| 94 |
+
# - `Your_LM.loglikelihood()`
|
| 95 |
+
# - `Your_LM.loglikelihood_rolling()`
|
| 96 |
+
# - `Your_LM.generate_until()`
|
| 97 |
+
lm_obj = Your_LM(model=my_model, batch_size=16)
|
| 98 |
+
|
| 99 |
+
# indexes all tasks from the `lm_eval/tasks` subdirectory.
|
| 100 |
+
# Alternatively, you can set `TaskManager(include_path="path/to/my/custom/task/configs")`
|
| 101 |
+
# to include a set of tasks in a separate directory.
|
| 102 |
+
task_manager = lm_eval.tasks.TaskManager()
|
| 103 |
+
|
| 104 |
+
# Setting `task_manager` to the one above is optional and should generally be done
|
| 105 |
+
# if you want to include tasks from paths other than ones in `lm_eval/tasks`.
|
| 106 |
+
# `simple_evaluate` will instantiate its own task_manager if it is set to None here.
|
| 107 |
+
results = lm_eval.simple_evaluate( # call simple_evaluate
|
| 108 |
+
model=lm_obj,
|
| 109 |
+
tasks=["taskname1", "taskname2"],
|
| 110 |
+
num_fewshot=0,
|
| 111 |
+
task_manager=task_manager,
|
| 112 |
+
...
|
| 113 |
+
)
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
See the `simple_evaluate()` and `evaluate()` functions in [lm_eval/evaluator.py](../lm_eval/evaluator.py#:~:text=simple_evaluate) for a full description of all arguments available. All keyword arguments to simple_evaluate share the same role as the command-line flags described previously.
|
| 117 |
+
|
| 118 |
+
Additionally, the `evaluate()` function offers the core evaluation functionality provided by the library, but without some of the special handling and simplification + abstraction provided by `simple_evaluate()`.
|
| 119 |
+
|
| 120 |
+
As a brief example usage of `evaluate()`:
|
| 121 |
+
|
| 122 |
+
```python
|
| 123 |
+
import lm_eval
|
| 124 |
+
|
| 125 |
+
# suppose you've defined a custom lm_eval.api.Task subclass in your own external codebase
|
| 126 |
+
from my_tasks import MyTask1
|
| 127 |
+
...
|
| 128 |
+
|
| 129 |
+
# create your model (could be running finetuning with some custom modeling code)
|
| 130 |
+
my_model = initialize_my_model()
|
| 131 |
+
...
|
| 132 |
+
|
| 133 |
+
# instantiate an LM subclass that takes your initialized model and can run
|
| 134 |
+
# - `Your_LM.loglikelihood()`
|
| 135 |
+
# - `Your_LM.loglikelihood_rolling()`
|
| 136 |
+
# - `Your_LM.generate_until()`
|
| 137 |
+
lm_obj = Your_LM(model=my_model, batch_size=16)
|
| 138 |
+
|
| 139 |
+
# optional: the task_manager indexes tasks including ones
|
| 140 |
+
# specified by the user through `include_path`.
|
| 141 |
+
task_manager = lm_eval.tasks.TaskManager(
|
| 142 |
+
include_path="/path/to/custom/yaml"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# To get a task dict for `evaluate`
|
| 146 |
+
task_dict = lm_eval.tasks.get_task_dict(
|
| 147 |
+
[
|
| 148 |
+
"mmlu", # A stock task
|
| 149 |
+
"my_custom_task", # A custom task
|
| 150 |
+
{
|
| 151 |
+
"task": ..., # A dict that configures a task
|
| 152 |
+
"doc_to_text": ...,
|
| 153 |
+
},
|
| 154 |
+
MyTask1 # A task object from `lm_eval.task.Task`
|
| 155 |
+
],
|
| 156 |
+
task_manager # A task manager that allows lm_eval to
|
| 157 |
+
# load the task during evaluation.
|
| 158 |
+
# If none is provided, `get_task_dict`
|
| 159 |
+
# will instantiate one itself, but this
|
| 160 |
+
# only includes the stock tasks so users
|
| 161 |
+
# will need to set this if including
|
| 162 |
+
# custom paths is required.
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
results = evaluate(
|
| 166 |
+
lm=lm_obj,
|
| 167 |
+
task_dict=task_dict,
|
| 168 |
+
...
|
| 169 |
+
)
|
| 170 |
+
```
|
lm-evaluation-harness/docs/model_guide.md
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# New Model Guide
|
| 2 |
+
|
| 3 |
+
This guide may be of special interest to users who are using the library outside of the repository, via installing the library via pypi and calling `lm_eval.evaluator.evaluate()` to evaluate an existing model.
|
| 4 |
+
|
| 5 |
+
In order to properly evaluate a given LM, we require implementation of a wrapper class subclassing the `lm_eval.api.model.LM` class, that defines how the Evaluation Harness should interface with your model. This guide walks through how to write this `LM` subclass via adding it to the library!
|
| 6 |
+
|
| 7 |
+
## Setup
|
| 8 |
+
|
| 9 |
+
To get started contributing, go ahead and fork the main repo, clone it, create a branch with the name of your model, and install the project requirements in your environment:
|
| 10 |
+
|
| 11 |
+
```sh
|
| 12 |
+
# After forking...
|
| 13 |
+
git clone https://github.com/<YOUR-USERNAME>/lm-evaluation-harness.git
|
| 14 |
+
cd lm-evaluation-harness
|
| 15 |
+
git checkout -b <model-type>
|
| 16 |
+
pip install -e ".[dev]"
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
Now, we'll create a new file where we'll be adding our model:
|
| 20 |
+
|
| 21 |
+
```sh
|
| 22 |
+
touch lm_eval/models/<my_model_filename>.py
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
**Tip: this filename should not shadow package names! For example, naming your file `anthropic.py` is disallowed since the API's name on pypi is `anthropic`, but naming it `anthropic_llms.py` works with no problems.**
|
| 26 |
+
|
| 27 |
+
## Interface
|
| 28 |
+
|
| 29 |
+
All models must subclass the `lm_eval.api.model.LM` class.
|
| 30 |
+
|
| 31 |
+
The LM class enforces a common interface via which we can extract responses from a model:
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
class MyCustomLM(LM):
|
| 35 |
+
#...
|
| 36 |
+
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
|
| 37 |
+
#...
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def loglikelihood_rolling(self, requests: list[Instance]) -> list[tuple[float, bool]]:
|
| 41 |
+
#...
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def generate_until(self, requests: list[Instance]) -> list[str]:
|
| 45 |
+
#...
|
| 46 |
+
#...
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
Where `Instance` is a dataclass defined in [`lm_eval.api.instance`](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/api/instance.py) with property `args` of request-dependent type signature described below.
|
| 50 |
+
|
| 51 |
+
We support three types of requests, consisting of different interactions / measurements with an autoregressive LM.
|
| 52 |
+
|
| 53 |
+
All three request types take as input `requests` of type `list[Instance]` that have a matching `Instance.request_type` to the method name.
|
| 54 |
+
|
| 55 |
+
- `generate_until`
|
| 56 |
+
- Each request contains `Instance.args : Tuple[str, dict]` containing 1. an input string to the LM and 2. a dictionary of keyword arguments used to control generation parameters.
|
| 57 |
+
- Using this input and these generation parameters, text will be sampled from the language model (typically until a maximum output length or specific stopping string sequences--for example, `{"until": ["\n\n", "."], "max_gen_toks": 128}`).
|
| 58 |
+
- The generated output text from the model will then be returned.
|
| 59 |
+
|
| 60 |
+
- `loglikelihood`
|
| 61 |
+
- Each request contains `Instance.args : Tuple[str, str]` containing 1. an input string to the LM and 2. a target string on which the loglikelihood of the LM producing this target, conditioned on the input, will be returned.
|
| 62 |
+
- Each request will have, as result, `(ll, is_greedy): Tuple[float, int]` returned, where `ll` is a floating point number representing the log probability of generating the target string conditioned on the input, and `is_greedy` being either the value `0` or `1`, with it being `1` if and only if the target string *would be generated by greedy sampling from the LM* (that is, if the target string is the *most likely* N-token string to be output by the LM given the input. )
|
| 63 |
+
|
| 64 |
+
- `loglikelihood_rolling`
|
| 65 |
+
- Each request contains `Instance.args : Tuple[str]`, which is an input string to the model whose *entire* loglikelihood, conditioned on purely the EOT token, will be calculated.
|
| 66 |
+
- This is used to evaluate *perplexity* on a data distribution.
|
| 67 |
+
- It should return `(ll,) : Tuple[float]` , a.k.a. solely the *loglikelihood* of producing each piece of text given no starting input.
|
| 68 |
+
|
| 69 |
+
To allow a model to be evaluated on all types of tasks, you will need to implement these three types of measurements (note that `loglikelihood_rolling` is a special case of `loglikelihood`). For a reference implementation, check out `lm_eval/models/huggingface.py` ! Additionally, check out `lm_eval.api.model.TemplateLM` for a class that abstracts away some commonly used functions across LM subclasses, or see if your model would lend itself well to subclassing the `lm_eval.models.huggingface.HFLM` class and overriding just the initialization or a couple methods!
|
| 70 |
+
|
| 71 |
+
**Tip: be careful of indexing in loglikelihood!**
|
| 72 |
+
|
| 73 |
+
LMs take in tokens in position `[0 1 2 ... N]` and output a probability distribution for token position `N+1`. We provide a simplified graphic here, excerpted from `huggingface.py`:
|
| 74 |
+
|
| 75 |
+
```text
|
| 76 |
+
# how this all works (illustrated on a causal decoder-only setup):
|
| 77 |
+
# CTX CONT
|
| 78 |
+
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
|
| 79 |
+
# model \ \
|
| 80 |
+
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
|
| 81 |
+
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
The final token of the target is not passed into the LM, because we want the LM's predictions *up to but not past* that final target token. For more information, check out https://github.com/EleutherAI/lm-evaluation-harness/issues/942 .
|
| 85 |
+
|
| 86 |
+
## Registration
|
| 87 |
+
|
| 88 |
+
Congrats on implementing your model! Now it's time to test it out.
|
| 89 |
+
|
| 90 |
+
To make your model usable via the command line interface to `lm-eval` using `python -m lm_eval`, you'll need to tell `lm-eval` what your model's name is.
|
| 91 |
+
|
| 92 |
+
This is done via a *decorator*, `lm_eval.api.registry.register_model`. Using `register_model()`, one can both tell the package what the model's name(s) to be used are when invoking it with `python -m lm_eval --model <name>` and alert `lm-eval` to the model's existence.
|
| 93 |
+
|
| 94 |
+
```python
|
| 95 |
+
from lm_eval.api.registry import register_model
|
| 96 |
+
|
| 97 |
+
@register_model("<name1>", "<name2>")
|
| 98 |
+
class MyCustomLM(LM):
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
Using this decorator results in the class being added to an accounting of the usable LM types maintained internally to the library at `lm_eval.api.registry.MODEL_REGISTRY`. See `lm_eval.api.registry` for more detail on what sorts of registries and decorators exist in the library!
|
| 102 |
+
|
| 103 |
+
**Tip: be sure to import your model in `lm_eval/models/__init__.py!`**
|
| 104 |
+
|
| 105 |
+
## Testing
|
| 106 |
+
|
| 107 |
+
We also recommend that new model contributions be accompanied by short tests of their 3 core functionalities, at minimum. To see an example of such tests, look at https://github.com/EleutherAI/lm-evaluation-harness/blob/35bdecd379c0cefad6897e67db892f4a6026a128/tests/test_ggml.py .
|
| 108 |
+
|
| 109 |
+
## Chat Templating
|
| 110 |
+
|
| 111 |
+
Many models are fine-tuned with a [Chat Template](https://huggingface.co/docs/transformers/main/en/chat_templating) in order to enable back-and-forth interaction between a "User"'s queries and the model (often called "Assistant")'s responses. It can be desirable to evaluate fine-tuned models on evaluation tasks while wrapped in the conversational format they expect.
|
| 112 |
+
|
| 113 |
+
In order to make your model optionally compatible with a chat format, three additional methods must be implemented:
|
| 114 |
+
|
| 115 |
+
```python
|
| 116 |
+
class MyCustomLM(LM):
|
| 117 |
+
#...
|
| 118 |
+
@property
|
| 119 |
+
def tokenizer_name(self) -> str:
|
| 120 |
+
"""
|
| 121 |
+
Return the name of the model's tokenizer and/or the accompanying chat template.
|
| 122 |
+
The returned string is used to cache requests.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
str: The name of the model's tokenizer and/or chat template.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def chat_template(self, chat_template: Union[bool, str] = False) -> str:
|
| 129 |
+
"""
|
| 130 |
+
Get the appropriate chat template for the model based on the `chat_template` argument.
|
| 131 |
+
|
| 132 |
+
This method returns the chat template string to build the prompt from a chat history.
|
| 133 |
+
The chat template is saved in the evaluation results for reproducibility.
|
| 134 |
+
Boolean arguments should be used with models that have only one chat template,
|
| 135 |
+
while string arguments are used with models that have multiple chat templates.
|
| 136 |
+
For the reference implementation, see HFLM class in `lm_eval.models.huggingface`.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
chat_template (Union[bool, str]): Specifies whether to apply a chat template:
|
| 140 |
+
- If False: Do not apply any chat template.
|
| 141 |
+
- If True: Apply the default chat template.
|
| 142 |
+
- If str: Apply the specified chat template by name.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
str: The selected chat template in Jinja format.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
|
| 149 |
+
"""
|
| 150 |
+
Process a chat history to create a string that can be tokenized and input into the model.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
chat_history (List[Dict[str, str]]): A list of dictionaries representing the chat history,
|
| 154 |
+
where each dictionary has "role" and "content" keys.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
str: A string representing the chat history that can be tokenized and fed into the model.
|
| 158 |
+
"""
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
- `apply_chat_template`
|
| 162 |
+
- This method performs the bulk of the work required for chat-formatting.
|
| 163 |
+
- As input, a `chat_history: List[Dict[str, str]]` is passed in. This is a transcript of a conversation of a form similar to
|
| 164 |
+
|
| 165 |
+
```text
|
| 166 |
+
[
|
| 167 |
+
{"system": <user-provided system message such as "You are a helpful math-focused chatbot">},
|
| 168 |
+
{"user": <task example - a few-shot example 'input'>}
|
| 169 |
+
{"assistant": <correct response to the above example>},
|
| 170 |
+
# ... more few-shot examples, potentially
|
| 171 |
+
{"user": <test set query--response on which we will evaluate>},
|
| 172 |
+
]
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
which can then be converted into a string input.
|
| 176 |
+
- The output is a string representing this conversation that can be fed into the model.
|
| 177 |
+
- For example, this consists of simply calling `tokenizer.apply_chat_template` for HFLM--see the implementation there for reference.
|
| 178 |
+
- `tokenizer_name`
|
| 179 |
+
- LM Eval Harness supports [caching requests](https://github.com/EleutherAI/lm-evaluation-harness/blob/4902aaaf1f374682f95ac25fe2e13b23faddc91a/lm_eval/__main__.py#L140) that are sent to a model, for faster setup when repeating an already-performed evaluation.
|
| 180 |
+
- However, we don't want to use the cache of chat transcripts rendered using one chat template or system prompt to send to a model with a different template! So, we use this `lm.tokenizer_name` string to distinguish caches for a given model (and chat template) from one another.
|
| 181 |
+
- `chat_template`
|
| 182 |
+
- Chat templates are typically provided as a Jinja template string or a string formatted with str.format to include user and assistant messages in a single prompt. This template string is saved in the evaluation results to ensure reproducibility.
|
| 183 |
+
|
| 184 |
+
If not implemented for a given model type, the flags `--apply_chat_template` , `--fewshot_as_multiturn`, and `--system_instruction` cannot be used.
|
| 185 |
+
|
| 186 |
+
## Other
|
| 187 |
+
|
| 188 |
+
**Pro tip**: In order to make the Evaluation Harness overestimate total runtimes rather than underestimate it, HuggingFace models come in-built with the ability to provide responses on data points in *descending order by total input length* via `lm_eval.utils.Reorderer`. Take a look at `lm_eval.models.hf_causal.HFLM` to see how this is done, and see if you can implement it in your own model!
|
| 189 |
+
|
| 190 |
+
## Conclusion
|
| 191 |
+
|
| 192 |
+
After reading this guide, you should be able to add new model APIs or implementations to the Eval Harness library!
|
lm-evaluation-harness/docs/new_task_guide.md
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# New Task Guide
|
| 2 |
+
|
| 3 |
+
`lm-evaluation-harness` is a framework that strives to support a wide range of zero- and few-shot evaluation tasks on autoregressive language models (LMs).
|
| 4 |
+
|
| 5 |
+
This documentation page provides a walkthrough to get started creating your own task, in `lm-eval` versions v0.4.0 and later.
|
| 6 |
+
|
| 7 |
+
A more interactive tutorial is available as a Jupyter notebook [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/examples/lm-eval-overview.ipynb).
|
| 8 |
+
|
| 9 |
+
## Setup
|
| 10 |
+
|
| 11 |
+
If you haven't already, go ahead and fork the main repo, clone it, create a branch with the name of your task, and install the project requirements in your environment:
|
| 12 |
+
|
| 13 |
+
```sh
|
| 14 |
+
# After forking...
|
| 15 |
+
git clone https://github.com/<YOUR-USERNAME>/lm-evaluation-harness.git
|
| 16 |
+
cd lm-evaluation-harness
|
| 17 |
+
git checkout -b <task-name>
|
| 18 |
+
pip install -e ".[dev]"
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
In this document, we'll walk through the basics of implementing a static benchmark evaluation in two formats: a *generative* task which requires sampling text from a model, such as [`gsm8k`](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/gsm8k/gsm8k.yaml), and a *discriminative*, or *multiple choice*, task where the model picks the most likely of several fixed answer choices, such as [`sciq`](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/sciq/sciq.yaml).
|
| 22 |
+
|
| 23 |
+
## Creating a YAML file
|
| 24 |
+
|
| 25 |
+
To implement a new standard task, we'll need to write a YAML file which configures our task logic. We start by making a new empty YAML file. This file can have any name, but we recommend placing it in a subfolder of `lm_eval/tasks` titled by the dataset or task's shorthand name: for example,
|
| 26 |
+
|
| 27 |
+
```sh
|
| 28 |
+
touch lm_eval/tasks/<dataset_name>/<my_new_task_name>.yaml
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
Or, copy the template subfolder we provide from `templates/new_yaml_task`:
|
| 32 |
+
|
| 33 |
+
```sh
|
| 34 |
+
cp -r templates/new_yaml_task lm_eval/tasks/
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
and rename the folders and YAML file(s) as desired.
|
| 38 |
+
|
| 39 |
+
### Selecting and configuring a dataset
|
| 40 |
+
|
| 41 |
+
All data downloading and management is handled through the HuggingFace (**HF**) [`datasets`](https://github.com/huggingface/datasets) API. So, the first thing you should do is check to see if your task's dataset is already provided in their catalog [here](https://huggingface.co/datasets). If it's not in there, please consider adding it to their Hub to make it accessible to a wider user base by following their [new dataset guide](https://github.com/huggingface/datasets/blob/main/ADD_NEW_DATASET.md)
|
| 42 |
+
.
|
| 43 |
+
> [!TIP]
|
| 44 |
+
> To test your task, we recommend using verbose logging using `export LOGLEVEL = DEBUG` in your shell before running the evaluation script. This will help you debug any issues that may arise.
|
| 45 |
+
Once you have a HuggingFace dataset prepared for your task, we want to assign our new YAML to use this dataset:
|
| 46 |
+
|
| 47 |
+
```yaml
|
| 48 |
+
dataset_path: ... # the name of the dataset on the HF Hub.
|
| 49 |
+
dataset_name: ... # the dataset configuration to use. Leave `null` if your dataset does not require a config to be passed. See https://huggingface.co/docs/datasets/load_hub#configurations for more info.
|
| 50 |
+
dataset_kwargs: null # any extra keyword arguments that should be passed to the dataset constructor, e.g. `data_dir`.
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
Next, we'd like to tell our task what the dataset's train, validation, and test splits are named, if they exist:
|
| 54 |
+
|
| 55 |
+
```yaml
|
| 56 |
+
training_split: <split name of training set, or `null`>
|
| 57 |
+
validation_split: <split name of val. set, or `null`>
|
| 58 |
+
test_split: <split name of test set, or `null`>
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
Tests will run on the `test_split` if it is available, and otherwise evaluate on the `validation_split`.
|
| 62 |
+
|
| 63 |
+
We can also specify from which split the task should retrieve few-shot examples via:
|
| 64 |
+
|
| 65 |
+
```yaml
|
| 66 |
+
fewshot_split: <split name to draw fewshot examples from, or `null`>
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
or by hardcoding them, either using the following in the yaml file:
|
| 70 |
+
|
| 71 |
+
```yaml
|
| 72 |
+
fewshot_config:
|
| 73 |
+
sampler: first_n
|
| 74 |
+
samples: [
|
| 75 |
+
{<sample 1>},
|
| 76 |
+
{<sample 2>},
|
| 77 |
+
]
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
or by adding the function `list_fewshot_samples` in the associated utils.py file:
|
| 81 |
+
|
| 82 |
+
```python
|
| 83 |
+
def list_fewshot_samples() -> list[dict]:
|
| 84 |
+
return [{<sample 1>}, {<sample 2>}]
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
See `lm_eval/tasks/minerva_math/minerva_math_algebra.yaml` for an example of the latter, and `lm_eval/tasks/gsm8k/gsm8k-cot.yaml` for an example of the former.
|
| 88 |
+
|
| 89 |
+
In this case, each sample must contain the same fields as the samples in the above sets--for example, if `doc_to_text` expects an `input` field when rendering input prompts, these provided samples must include an `input` key.
|
| 90 |
+
|
| 91 |
+
If neither above options are not set, we will default to train/validation/test sets, in that order.
|
| 92 |
+
|
| 93 |
+
Finally, our dataset may not be already in the exact format we want. Maybe we have to strip whitespace and special characters via a regex from our dataset's "question" field! Or maybe we just want to rename its columns to match a convention we'll be using for our prompts.
|
| 94 |
+
|
| 95 |
+
Let's create a python file in the directory where we're writing our YAML file:
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
touch lm_eval/tasks/<dataset_name>/utils.py
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
Now, in `utils.py` we'll write a function to process each split of our dataset (the following example is drawn from [the `hellaswag` task](../lm_eval/tasks/hellaswag/utils.py)):
|
| 102 |
+
|
| 103 |
+
```python
|
| 104 |
+
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
|
| 105 |
+
def _process_doc(doc):
|
| 106 |
+
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
|
| 107 |
+
out_doc = {
|
| 108 |
+
"query": preprocess(doc["activity_label"] + ": " + ctx),
|
| 109 |
+
"choices": [preprocess(ending) for ending in doc["endings"]],
|
| 110 |
+
"gold": int(doc["label"]),
|
| 111 |
+
}
|
| 112 |
+
return out_doc
|
| 113 |
+
|
| 114 |
+
return dataset.map(_process_doc)
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
Now, in our YAML config file we'll use the `!function` constructor, and tell the config where our imported Python function will come from. At runtime, before doing anything else we will preprocess our dataset according to this function!
|
| 118 |
+
|
| 119 |
+
```yaml
|
| 120 |
+
process_docs: !function utils.process_docs
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
### Using Local Datasets
|
| 124 |
+
|
| 125 |
+
To load a local dataset for evaluation, you can specify data files in the `dataset_kwargs` field, such as the following for JSON files:
|
| 126 |
+
|
| 127 |
+
```yaml
|
| 128 |
+
dataset_path: json
|
| 129 |
+
dataset_name: null
|
| 130 |
+
dataset_kwargs:
|
| 131 |
+
data_files: /path/to/my/json
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
Or with files already split into separate directories:
|
| 135 |
+
|
| 136 |
+
```yaml
|
| 137 |
+
dataset_path: arrow
|
| 138 |
+
dataset_kwargs:
|
| 139 |
+
data_files:
|
| 140 |
+
train: /path/to/arrow/train/data-00000-of-00001.arrow
|
| 141 |
+
validation: /path/to/arrow/validation/data-00000-of-00001.arrow
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
Alternatively, if you have previously downloaded a dataset from huggingface hub (using `save_to_disk()`) and wish to use the local files, you will need to use `data_dir` under `dataset_kwargs` to point to where the directory is.
|
| 145 |
+
|
| 146 |
+
```yaml
|
| 147 |
+
dataset_path: hellaswag
|
| 148 |
+
dataset_kwargs:
|
| 149 |
+
data_dir: hellaswag_local/
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
You can also set `dataset_path` as a directory path in your local system. This will assume that there is a loading script with the same name as the directory. [See datasets docs](https://huggingface.co/docs/datasets/loading#local-loading-script).
|
| 153 |
+
|
| 154 |
+
## Writing a Prompt Template
|
| 155 |
+
|
| 156 |
+
The next thing we need to do is decide what format to use when presenting the data to the LM. This is our **prompt**, where we'll define both an input and output format.
|
| 157 |
+
|
| 158 |
+
To write a prompt, users will use `doc_to_text`, `doc_to_target`, and `doc_to_choice` (Optional when certain conditions are met).
|
| 159 |
+
|
| 160 |
+
`doc_to_text` defines the input string a model will be given while `doc_to_target` and `doc_to_choice` will be used to generate the target text. `doc_to_target` can be either a text string that refers to the target string or an integer that refers to the index of the correct label. When it is set as an index, `doc_to_choice` must also be set with the appropriate list of possible choice strings.
|
| 161 |
+
|
| 162 |
+
### Basic prompts
|
| 163 |
+
|
| 164 |
+
If a dataset is straightforward enough, users can enter the feature name directly. This assumes that no preprocessing is required. For example in [Swag](https://github.com/EleutherAI/lm-evaluation-harness/blob/1710b42d52d0f327cb0eb3cb1bfbbeca992836ca/lm_eval/tasks/swag/swag.yaml#L10-L11), `doc_to_text` and `doc_to_target` given the name of one of the feature each.
|
| 165 |
+
|
| 166 |
+
```yaml
|
| 167 |
+
doc_to_text: startphrase
|
| 168 |
+
doc_to_target: label
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
Hard-coding is also possible as is the case in [SciQ](https://github.com/EleutherAI/lm-evaluation-harness/blob/1710b42d52d0f327cb0eb3cb1bfbbeca992836ca/lm_eval/tasks/sciq/sciq.yaml#L11).
|
| 172 |
+
|
| 173 |
+
```yaml
|
| 174 |
+
doc_to_target: 3
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
`doc_to_choice` can be directly given a list of text as option (See [Toxigen](https://github.com/EleutherAI/lm-evaluation-harness/blob/1710b42d52d0f327cb0eb3cb1bfbbeca992836ca/lm_eval/tasks/toxigen/toxigen.yaml#L11))
|
| 178 |
+
|
| 179 |
+
```yaml
|
| 180 |
+
doc_to_choice: ['No', 'Yes']
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
if a dataset feature is already a list, you can set the name of the feature as `doc_to_choice` (See [Hellaswag](https://github.com/EleutherAI/lm-evaluation-harness/blob/e0eda4d3ffa10e5f65e0976161cd134bec61983a/lm_eval/tasks/hellaswag/hellaswag.yaml#L13))
|
| 184 |
+
|
| 185 |
+
```yaml
|
| 186 |
+
doc_to_choice: choices
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
### Writing a prompt with Jinja 2
|
| 190 |
+
|
| 191 |
+
We support the [Jinja 2](https://jinja.palletsprojects.com/en/3.1.x/) templating language for writing prompts. In practice, this means you can take your dataset's columns and do many basic string manipulations to place each document into prompted format.
|
| 192 |
+
|
| 193 |
+
Take for example the dataset `super_glue/boolq`. As input, we'd like to use the features `passage` and `question` and string them together so that for a sample line `doc`, the model sees something in the format of:
|
| 194 |
+
|
| 195 |
+
```text
|
| 196 |
+
doc["passage"]
|
| 197 |
+
Question: doc["question"]?
|
| 198 |
+
Answer:
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
We do this by [writing](https://github.com/EleutherAI/lm-evaluation-harness/blob/1710b42d52d0f327cb0eb3cb1bfbbeca992836ca/lm_eval/tasks/super_glue/boolq/default.yaml#L9C1-L9C61)
|
| 202 |
+
|
| 203 |
+
```yaml
|
| 204 |
+
doc_to_text: "{{passage}}\nQuestion: {{question}}?\nAnswer:"
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
Such that `{{passage}}` will be replaced by `doc["passage"]` and `{{question}}` with `doc["question"]` when rendering the prompt template.
|
| 208 |
+
|
| 209 |
+
Our intended output is for the model to predict a single whitespace, and then the answer to the question. We do this via:
|
| 210 |
+
|
| 211 |
+
```yaml
|
| 212 |
+
doc_to_target: "{{answer}}"
|
| 213 |
+
```
|
| 214 |
+
|
| 215 |
+
#### Multiple choice format
|
| 216 |
+
|
| 217 |
+
For tasks which are multiple choice (a fixed, finite set of label words per each document) and evaluated via comparing loglikelihoods of all label words (the `multiple_choice` task output type) we enforce a particular convention on prompt format.
|
| 218 |
+
|
| 219 |
+
> [!WARNING]
|
| 220 |
+
> We add `target_delimiter` between input and target which defaults to " ", such that the full input-output string is `doc_to_text(doc) + target_delimiter + doc_to_target(doc)`. `doc_to_text` and `doc_to_target` should not contain trailing right or left whitespace, respectively. For multiple choice the target will be each choice index concatenated with the delimiter.
|
| 221 |
+
|
| 222 |
+
An annotated example in the case of SciQ is as follows:
|
| 223 |
+
|
| 224 |
+
```yaml
|
| 225 |
+
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:" # This is the input portion of the prompt for this doc. It will have " {{choice}}" appended to it as target for each choice in answer_choices.
|
| 226 |
+
doc_to_target: 3 # this contains the index into the answer choice list of the correct answer.
|
| 227 |
+
doc_to_choice: "{{[distractor1, distractor2, distractor3, correct_answer]}}"
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
Task implementers are thus able to decide what the answer choices should be for a document, and what prompt format to use.
|
| 231 |
+
|
| 232 |
+
The label index can also be sourced from a feature directly. For example in `superglue/boolq`, the label index if defined in the feature `label`. We can set `doc_to_target` as simply `label`. The options or verbalizers can be written in the form of a list `["no", "yes"]` that will correspond to the label index.
|
| 233 |
+
|
| 234 |
+
```yaml
|
| 235 |
+
doc_to_text: "{{passage}}\nQuestion: {{question}}?\nAnswer:"
|
| 236 |
+
doc_to_target: label
|
| 237 |
+
doc_to_choice: ["no", "yes"]
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
### Using Python Functions for Prompts
|
| 241 |
+
|
| 242 |
+
There may be cases where the prompt we want to implement is easier expressed in Python instead of Jinja 2. For this, we can use Python helper functions that are defined in the YAML config. It should be noted that the function script must be in the same directory as the yaml.
|
| 243 |
+
|
| 244 |
+
A good example is WikiText that requires a lot of regex rules to clean the samples.
|
| 245 |
+
|
| 246 |
+
```python
|
| 247 |
+
def wikitext_detokenizer(doc):
|
| 248 |
+
string = doc["page"]
|
| 249 |
+
# contractions
|
| 250 |
+
string = string.replace("s '", "s'")
|
| 251 |
+
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
|
| 252 |
+
...
|
| 253 |
+
string = string.replace(" 's", "'s")
|
| 254 |
+
|
| 255 |
+
return string
|
| 256 |
+
```
|
| 257 |
+
|
| 258 |
+
We can load this function in `doc_to_target` by using a `!function` operator after `doc_to_target` and followed by `<file name>.<function name>`. In the file [wikitext.yaml](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/wikitext/wikitext.yaml) we write:
|
| 259 |
+
|
| 260 |
+
```yaml
|
| 261 |
+
doc_to_target: !function preprocess_wikitext.wikitext_detokenizer
|
| 262 |
+
```
|
| 263 |
+
|
| 264 |
+
### Importing a Prompt from Promptsource
|
| 265 |
+
|
| 266 |
+
[Promptsource](https://github.com/bigscience-workshop/promptsource/tree/main/promptsource) is a great repository for crowdsourced prompts for many datasets. We can load these prompts easily by using the `use_prompt` argument and filling it with the format `"promptsource:<name of prompt template>"`. To use this, `doc_to_text` and `doc_to_target` should be left undefined. This will fetch the template of the dataset defined in the YAML file.
|
| 267 |
+
|
| 268 |
+
For example, For Super Glue BoolQ, if we want to use the prompt template `GPT-3 Style` we can add this to the YAML file.
|
| 269 |
+
|
| 270 |
+
```yaml
|
| 271 |
+
use_prompt: "promptsource:GPT-3 Style"
|
| 272 |
+
```
|
| 273 |
+
|
| 274 |
+
If you would like to run evaluation on all prompt templates, you can simply call it this way.
|
| 275 |
+
|
| 276 |
+
```yaml
|
| 277 |
+
use_prompt: "promptsource:*"
|
| 278 |
+
```
|
| 279 |
+
|
| 280 |
+
### Setting metrics
|
| 281 |
+
|
| 282 |
+
You're almost done! Now we need to choose how to score our task.
|
| 283 |
+
|
| 284 |
+
- *If this is a multiple choice task:* do you just want to check your model's accuracy in choosing the correct answer choice?
|
| 285 |
+
- *If this is a generation task:* do you just want to check how often your model outputs *exactly the ground-truth output string provided*?
|
| 286 |
+
|
| 287 |
+
If the answer to the above is no: you'll need to record what scoring metrics to use! Metrics can be listed in the following format:
|
| 288 |
+
|
| 289 |
+
```yaml
|
| 290 |
+
metric_list:
|
| 291 |
+
- metric: <name of the metric here>
|
| 292 |
+
aggregation: <name of the aggregation fn here>
|
| 293 |
+
higher_is_better: <true or false>
|
| 294 |
+
- metric: !function script.function
|
| 295 |
+
aggregation: ...
|
| 296 |
+
higher_is_better: ...
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
`aggregation` and `higher_is_better` can optionally be left out to default to the manually-set defaults if using a natively supported metric, otherwise it must be defined explicitly (for example, when using a custom metric implemented as a function).
|
| 300 |
+
|
| 301 |
+
For a full list of natively supported metrics and aggregation functions see [`docs/task_guide.md`](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/task_guide.md). All metrics supported in [HuggingFace Evaluate](https://github.com/huggingface/evaluate/tree/main/metrics) can also be used, and will be loaded if a given metric name is not one natively supported in `lm-eval` or `hf_evaluate` is set to `true`.
|
| 302 |
+
|
| 303 |
+
### Optional, More Advanced Setup
|
| 304 |
+
|
| 305 |
+
Some tasks may require more advanced processing logic than is described in this guide.
|
| 306 |
+
|
| 307 |
+
As a heuristic check:
|
| 308 |
+
|
| 309 |
+
- Does your task require generating multiple free-form outputs per input document?
|
| 310 |
+
- Does your task require complex, multi-step post-processing of generated model outputs?
|
| 311 |
+
- Does your task require subsetting documents on the fly based on their content?
|
| 312 |
+
- Do you expect to compute metrics after applying multiple such processing steps on your model outputs?
|
| 313 |
+
- Does your task rely on metrics that need a custom implementation?
|
| 314 |
+
|
| 315 |
+
For more detail on the task system and advanced features, see [`docs/task_guide.md`](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/task_guide.md). If none of the above sounds like they apply to your task, it's time to continue onto checking your task performance!
|
| 316 |
+
|
| 317 |
+
### Task name + tags (registering a task)
|
| 318 |
+
|
| 319 |
+
To test a task conveniently, it helps to *register* the task--that is, to give it a name and make the `lm-eval` library aware it exists!
|
| 320 |
+
|
| 321 |
+
If you're writing your YAML file inside the `lm_eval/tasks` folder, you just need to give your task a name! You can do this inside your YAML file:
|
| 322 |
+
|
| 323 |
+
```yaml
|
| 324 |
+
task: <name of the task>
|
| 325 |
+
```
|
| 326 |
+
|
| 327 |
+
Including a task name is mandatory.
|
| 328 |
+
|
| 329 |
+
It is often also convenient to label your task with several `tag` values, though this field is optional:
|
| 330 |
+
|
| 331 |
+
```yaml
|
| 332 |
+
tag:
|
| 333 |
+
- tag1
|
| 334 |
+
- tag2
|
| 335 |
+
```
|
| 336 |
+
|
| 337 |
+
This will add your task to the `tag1` and `tag2` tags, enabling people to know how to categorize your task, and if desired run all tasks in one of these groups at once, your task along with them.
|
| 338 |
+
|
| 339 |
+
If your task is not in the `lm_eval/tasks` folder, you'll need to tell the Eval Harness where to look for YAML files.
|
| 340 |
+
|
| 341 |
+
You can do this via the `--include_path` argument in `__main__.py`. This command will be used to initialize the `TaskManager` object which you can also use for your custom scripts.
|
| 342 |
+
|
| 343 |
+
```python
|
| 344 |
+
task_manager = TaskManager(args.verbosity, include_path=args.include_path)
|
| 345 |
+
```
|
| 346 |
+
|
| 347 |
+
Passing `--tasks /path/to/yaml/file` is also accepted.
|
| 348 |
+
|
| 349 |
+
### Advanced Group Configs
|
| 350 |
+
|
| 351 |
+
While `tag` values are helpful when you want to be able to quickly and conveniently run a set of related tasks via `--tasks my_tag_name`, often, we wish to implement more complex logic. For example, the MMLU benchmark contains 57 *subtasks* that must all be *averaged* together in order to report a final 'MMLU score'.
|
| 352 |
+
|
| 353 |
+
Groupings of tasks might also use particular variants of a task--for example, we might want to default to evaluating a task as 5-shot when called as part of a given grouping, but not have a preference for number of shots when evaluating it as a standalone.
|
| 354 |
+
|
| 355 |
+
We implement this via **groups**, which are distinct from tags. Groups can be implemented via *group config* YAML files, which are laid out similarly but slightly differently to tasks' YAML configs.
|
| 356 |
+
|
| 357 |
+
The most basic form of group can be defined via a YAML config similar to the following:
|
| 358 |
+
|
| 359 |
+
```yaml
|
| 360 |
+
group: nli_tasks
|
| 361 |
+
task:
|
| 362 |
+
- cb
|
| 363 |
+
- anli_r1
|
| 364 |
+
- rte
|
| 365 |
+
metadata:
|
| 366 |
+
version: 1.0
|
| 367 |
+
```
|
| 368 |
+
|
| 369 |
+
This will behave almost identically to a `tag` that includes these 3 tasks, but with one key distinction: we'll print the `nli_tasks` group as a row (with no associated metrics) in our table of outputs, and visually show that these 3 tasks appear under its subheader.
|
| 370 |
+
|
| 371 |
+
Now, let's assume we actually want to report an aggregate score for `nli_tasks`. We would instead use a YAML config like the following:
|
| 372 |
+
|
| 373 |
+
```yaml
|
| 374 |
+
group: nli_tasks
|
| 375 |
+
task:
|
| 376 |
+
- cb
|
| 377 |
+
- anli_r1
|
| 378 |
+
- rte
|
| 379 |
+
aggregate_metric_list:
|
| 380 |
+
- metric: acc
|
| 381 |
+
aggregation: mean
|
| 382 |
+
weight_by_size: true # defaults to `true`. Set this to `false` to do a "macro" average (taking each subtask's average accuracy, and summing those accuracies and dividing by 3)--by default we do a "micro" average (retain all subtasks' per-document accuracies, and take the mean over all documents' accuracies to get our aggregate mean).
|
| 383 |
+
metadata:
|
| 384 |
+
version: 1.0
|
| 385 |
+
```
|
| 386 |
+
|
| 387 |
+
Similar to our `metric_list` for listing out the metrics we want to calculate for a given task, we use an `aggregate_metric_list` field to specify which metric name to aggregate across subtasks, what aggregation function to use, and whether we should micro- or macro- average these metrics. See [./task_guide.md](./task_guide.md) for a full list of related sub-keys.
|
| 388 |
+
|
| 389 |
+
**[!Tip]: currently, we predominantly only support the aggregation of group metrics that use `mean` (either micro- or macro- averaged) over their subtasks. If you require even more complex aggregation rules, you may want to perform aggregation offline.**
|
| 390 |
+
|
| 391 |
+
Group configs can be fairly complex! We can do various operations, such as defining new subtask(s) inline in our group YAML, overriding an existing task's specific config value, or nesting existing groups within our
|
| 392 |
+
|
| 393 |
+
For example, let's build a config for evaluating MMLU and a few natural language inference tasks. For MMLU, we can write the name for the benchmark as a subtask written under `task`. You can configure the parameters such as `num_fewshot`. If the task being configured is a group such as `mmlu` or `super_glue`, the parameter set will be applied to all of the subtasks.
|
| 394 |
+
|
| 395 |
+
```yaml
|
| 396 |
+
group: nli_and_mmlu
|
| 397 |
+
task:
|
| 398 |
+
- group: nli_tasks
|
| 399 |
+
task:
|
| 400 |
+
- cb
|
| 401 |
+
- anli_r1
|
| 402 |
+
- rte
|
| 403 |
+
aggregate_metric_list:
|
| 404 |
+
- metric: acc
|
| 405 |
+
aggregation: mean
|
| 406 |
+
higher_is_better: true
|
| 407 |
+
- task: mmlu
|
| 408 |
+
num_fewshot: 2
|
| 409 |
+
```
|
| 410 |
+
|
| 411 |
+
### Configuring python classes
|
| 412 |
+
|
| 413 |
+
There can be occasions when yaml-based tasks cannot accommodate how a task is handled. LM-Eval supports the manually implementing tasks as was previously done before `0.4.x`. To register the task, you can simply make a yaml with the name of the task in `task` and the class object in `class` using the `!function` prefix.
|
| 414 |
+
|
| 415 |
+
```yaml
|
| 416 |
+
task: squadv2
|
| 417 |
+
class: !function task.SQuAD2
|
| 418 |
+
```
|
| 419 |
+
|
| 420 |
+
This also applies to building group configurations with subtasks that are python classes.
|
| 421 |
+
|
| 422 |
+
```yaml
|
| 423 |
+
group: scrolls
|
| 424 |
+
task:
|
| 425 |
+
- task: scrolls_qasper
|
| 426 |
+
class: !function task.Qasper
|
| 427 |
+
- task: scrolls_quality
|
| 428 |
+
class: !function task.QuALITY
|
| 429 |
+
- task: scrolls_narrativeqa
|
| 430 |
+
class: !function task.NarrativeQA
|
| 431 |
+
...
|
| 432 |
+
```
|
| 433 |
+
|
| 434 |
+
You can also pass a custom argument to your class by accepting `config` in the custom class constructor.
|
| 435 |
+
Here's how to do it:
|
| 436 |
+
|
| 437 |
+
```yaml
|
| 438 |
+
task: 20_newsgroups
|
| 439 |
+
class: !function task.Unitxt
|
| 440 |
+
recipe: card=cards.20_newsgroups,template=templates.classification.multi_class.title
|
| 441 |
+
```
|
| 442 |
+
|
| 443 |
+
In this example, `recipe` is the custom argument for the `Unitxt` class.
|
| 444 |
+
|
| 445 |
+
## Beautifying Table Display
|
| 446 |
+
|
| 447 |
+
To avoid conflict, each task needs to be registered with a unique name. Because of this, slight variations of task are still counted as unique tasks and need to be named uniquely. This could be done by appending an additional naming that may refer to the variation such as in MMLU where the template used to evaluated for flan are differentiated from the default by the prefix `mmlu_flan_*`. Printing the full task names can easily clutter the results table at the end of the evaluation especially when you have a long list of tasks or are using a benchmark that comprises of many tasks. To make it more legible, you can use `task_alias` and `group_alias` to provide an alternative task name and group name that will be printed. For example in `mmlu_abstract_algebra.yaml` we set `task_alias` to `abstract_algebra`. In group configs, a `group_alias` for a group can also be set.
|
| 448 |
+
|
| 449 |
+
```yaml
|
| 450 |
+
"dataset_name": "abstract_algebra"
|
| 451 |
+
"description": "The following are multiple choice questions (with answers) about abstract\
|
| 452 |
+
\ algebra.\n\n"
|
| 453 |
+
"include": "_default_template_yaml"
|
| 454 |
+
"task": "mmlu_abstract_algebra"
|
| 455 |
+
"task_alias": "abstract_algebra"
|
| 456 |
+
```
|
| 457 |
+
|
| 458 |
+
## Checking validity
|
| 459 |
+
|
| 460 |
+
After registering your task, you can now check on your data downloading and verify that the few-shot samples look as intended. Run the following command with your desired args:
|
| 461 |
+
|
| 462 |
+
```bash
|
| 463 |
+
python -m scripts.write_out \
|
| 464 |
+
--output_base_path <path> \
|
| 465 |
+
--tasks <your-task-name> \
|
| 466 |
+
--sets <train | val | test> \
|
| 467 |
+
--num_fewshot K \
|
| 468 |
+
--num_examples N \
|
| 469 |
+
```
|
| 470 |
+
|
| 471 |
+
Open the file specified at the `--output_base_path <path>` and ensure it passes
|
| 472 |
+
a simple eye test.
|
| 473 |
+
|
| 474 |
+
## Versioning
|
| 475 |
+
|
| 476 |
+
One key feature in LM Evaluation Harness is the ability to version tasks and groups--that is, mark them with a specific version number that can be bumped whenever a breaking change is made.
|
| 477 |
+
|
| 478 |
+
This version info can be provided by adding the following to your new task or group config file:
|
| 479 |
+
|
| 480 |
+
```yaml
|
| 481 |
+
metadata:
|
| 482 |
+
version: 0
|
| 483 |
+
```
|
| 484 |
+
|
| 485 |
+
Now, whenever a change needs to be made to your task in the future, please increase the version number by 1 so that users can differentiate the different task iterations and versions.
|
| 486 |
+
|
| 487 |
+
If you are incrementing a task's version, please also consider adding a changelog to the task's README.md noting the date, PR number, what version you have updated to, and a one-liner describing the change.
|
| 488 |
+
|
| 489 |
+
for example,
|
| 490 |
+
|
| 491 |
+
- \[Dec 25, 2023\] (PR #999) Version 0.0 -> 1.0: Fixed a bug with answer extraction that led to underestimated performance.
|
| 492 |
+
|
| 493 |
+
## Checking performance + equivalence
|
| 494 |
+
|
| 495 |
+
It's now time to check models' performance on your task! In the evaluation harness, we intend to support a wide range of evaluation tasks and setups, but prioritize the inclusion of already-proven benchmarks following the precise evaluation setups in the literature where possible.
|
| 496 |
+
|
| 497 |
+
To enable this, we provide a checklist that should be completed when contributing a new task, to enable accurate book-keeping and to ensure that tasks added to the library are well-tested and, where applicable, precedented.
|
| 498 |
+
|
| 499 |
+
### Task Validity Checklist
|
| 500 |
+
|
| 501 |
+
The checklist is the following:
|
| 502 |
+
|
| 503 |
+
For adding novel benchmarks/datasets to the library:
|
| 504 |
+
|
| 505 |
+
- [ ] Is the task an existing benchmark in the literature?
|
| 506 |
+
- [ ] Have you referenced the original paper that introduced the task?
|
| 507 |
+
- [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
|
| 508 |
+
|
| 509 |
+
If other tasks on this dataset are already supported:
|
| 510 |
+
|
| 511 |
+
- [ ] Is the "Main" variant of this task clearly denoted?
|
| 512 |
+
- [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
|
| 513 |
+
- [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
|
| 514 |
+
|
| 515 |
+
It is recommended to include a filled-out copy of this checklist in the README.md for the subfolder you are creating, if you have created a new subfolder in `lm_eval/tasks`.
|
| 516 |
+
|
| 517 |
+
**Finally, please add a short description of your task(s), along with a link to its subfolder in lm_eval/tasks, to [`lm_eval/tasks/README.md`](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/README.md) so that users can discover your task in the library, and follow the link to your README for more information about the variants supported, their task names, and the original source of the dataset and/or evaluation setup.**
|
| 518 |
+
|
| 519 |
+
## Submitting your task
|
| 520 |
+
|
| 521 |
+
You're all set! Now push your work and make a pull request to the `main` branch! Thanks for the contribution :). If there are any questions, please leave a message in the `#lm-thunderdome` channel on the EAI discord!
|
lm-evaluation-harness/docs/task_guide.md
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Task Configuration
|
| 2 |
+
|
| 3 |
+
The `lm-evaluation-harness` is meant to be an extensible and flexible framework within which many different evaluation tasks can be defined. All tasks in the new version of the harness are built around a YAML configuration file format.
|
| 4 |
+
|
| 5 |
+
These YAML configuration files, along with the current codebase commit hash, are intended to be shareable such that providing the YAML config enables another researcher to precisely replicate the evaluation setup used by another, in the case that the prompt or setup differs from standard `lm-eval` task implementations.
|
| 6 |
+
|
| 7 |
+
While adding a standard evaluation task on a new dataset can be occasionally as simple as swapping out a Hugging Face dataset path in an existing file, more specialized evaluation setups also exist. Here we'll provide a crash course on the more advanced logic implementable in YAML form available to users.
|
| 8 |
+
|
| 9 |
+
If your intended task relies on features beyond what is described in this guide, we'd love to hear about it! Feel free to open an issue describing the scenario on Github, create a PR to the project with a proposed implementation, or ask in the `#lm-thunderdome` channel on the EleutherAI discord.
|
| 10 |
+
|
| 11 |
+
## Configurations
|
| 12 |
+
|
| 13 |
+
Tasks are configured via the `TaskConfig` object. Below, we describe all fields usable within the object, and their role in defining a task.
|
| 14 |
+
|
| 15 |
+
### Parameters
|
| 16 |
+
|
| 17 |
+
Task naming + registration:
|
| 18 |
+
|
| 19 |
+
- **task** (`str`, defaults to None) — name of the task.
|
| 20 |
+
- **task_alias** (`str`, defaults to None) - Alias of the task name that will be printed in the final table results.
|
| 21 |
+
- **tag** (`str`, *optional*) — name of the task tags(s) a task belongs to. Enables one to run all tasks with a specified tag name at once.
|
| 22 |
+
|
| 23 |
+
Dataset configuration options:
|
| 24 |
+
|
| 25 |
+
- **dataset_path** (`str`) — The name of the dataset as listed by HF in the datasets Hub.
|
| 26 |
+
- **dataset_name** (`str`, *optional*, defaults to None) — The name of what HF calls a “data instance” or sub-task of the benchmark. If your task does not contain any data instances, just leave this to default to None. (If you're familiar with the HF `datasets.load_dataset` function, these are just the first 2 arguments to it.)
|
| 27 |
+
- **dataset_kwargs** (`dict`, *optional*) — Auxiliary arguments that `datasets.load_dataset` accepts. This can be used to specify arguments such as `data_files` or `data_dir` if you want to use local datafiles such as json or csv.
|
| 28 |
+
- **custom_dataset** (`Callable`, *optional) - A function that returns a `dict[str, datasets.Dataset]` (<split_name>, dataset) object. This can be used to load a dataset from a custom source or to preprocess the dataset in a way that is not supported by the `datasets` library. Will have access to `metadata` field if defined (from config and passed to TaskManager), and `model_args` from runtime (if using `evaluate`).
|
| 29 |
+
- **training_split** (`str`, *optional*) — Split in the dataset to use as the training split.
|
| 30 |
+
- **validation_split** (`str`, *optional*) — Split in the dataset to use as the validation split.
|
| 31 |
+
- **test_split** (`str`, *optional*) — Split in the dataset to use as the test split.
|
| 32 |
+
- **fewshot_split** (`str`, *optional*) — Split in the dataset to draw few-shot exemplars from. assert that this not None if num_fewshot > 0.
|
| 33 |
+
- **process_docs** (`Callable`, *optional*) — Optionally define a function to apply to each HF dataset split, to preprocess all documents before being fed into prompt template rendering or other evaluation steps. Can be used to rename dataset columns, or to process documents into a format closer to the expected format expected by a prompt template.
|
| 34 |
+
|
| 35 |
+
Prompting / in-context formatting options:
|
| 36 |
+
|
| 37 |
+
- **use_prompt** (`str`, *optional*) — Name of prompt in promptsource to use. if defined, will overwrite doc_to_text, doc_to_target, and doc_to_choice.
|
| 38 |
+
- **description** (`str`, *optional*) — An optional prepended Jinja2 template or string which will be prepended to the few-shot examples passed into the model, often describing the task or providing instructions to a model, such as `"The following are questions (with answers) about {{subject}}.\n\n"`. No delimiters or spacing are inserted between the description and the first few-shot example.
|
| 39 |
+
- **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into the appropriate input for the model.
|
| 40 |
+
- **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into the answer choice list of the correct answer.
|
| 41 |
+
- **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks.
|
| 42 |
+
- **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples.
|
| 43 |
+
- **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested.
|
| 44 |
+
- **gen_prefix** (`str`, *optional*) — String to append after the <|assistant|> token. For example, if the task is to generate a question, the gen_prefix could be "The answer is: " to prompt the model to generate an answer to the question. If not using a chat template then this string will be appended to the end of the prompt.
|
| 45 |
+
|
| 46 |
+
Runtime configuration options:
|
| 47 |
+
|
| 48 |
+
- **num_fewshot** (`int`, *optional*, defaults to 0) — Number of few-shot examples before the input.
|
| 49 |
+
- **batch_size** (`int`, *optional*, defaults to 1) — Batch size.
|
| 50 |
+
|
| 51 |
+
Scoring details:
|
| 52 |
+
|
| 53 |
+
- **metric_list** (`str`, *optional*, defaults to None) — A list of metrics to use for evaluation. See docs for expected format.
|
| 54 |
+
- **output_type** (`str`, *optional*, defaults to "generate_until") — Selects the type of model output for the given task. Options are `generate_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`.
|
| 55 |
+
- **generation_kwargs** (`dict`, *optional*) — Auxiliary arguments for the `generate` function from HF transformers library. Advanced keyword arguments may not be supported for non-HF LM classes.
|
| 56 |
+
- **repeats** (`int`, *optional*, defaults to 1) — Number of repeated runs through model for each sample. Can be used for cases such as self-consistency.
|
| 57 |
+
- **filter_list** (`Union[str, list]`, *optional*) — List of filters to postprocess model outputs. See below for further detail on the filter API.
|
| 58 |
+
- **should_decontaminate** (`bool`, *optional*, defaults to False) - Whether to decontaminate or not.
|
| 59 |
+
- **doc_to_decontamination_query** (`str`, *optional*) — Query for decontamination if `should_decontaminate` is True. If `should_decontaminate` is True but `doc_to_decontamination_query` is `None`, `doc_to_decontamination_query` will follow `doc_to_text`.
|
| 60 |
+
|
| 61 |
+
Other:
|
| 62 |
+
|
| 63 |
+
- **metadata** (`dict`, *optional*) — An optional field where arbitrary metadata can be passed. Most tasks should include a `version` key in this field that is used to denote the version of the yaml config. Other special metadata keys are: `num_fewshot`, to override the printed `n-shot` table column for a task. Will also be passed to the `custom_dataset` function if defined.
|
| 64 |
+
|
| 65 |
+
## Filters
|
| 66 |
+
|
| 67 |
+
A key component of the `lm-evaluation-harness` library is the `Filter` object. In a typical evaluation run of the harness, we take the formatted inputs and run them through our LM, with the appropriate output type (greedy or free-form generation, or loglikelihood-based comparative scoring).
|
| 68 |
+
|
| 69 |
+
After getting scores or output text from our LM on each `Instance` or document in the dataset, we then need to feed these responses into a metric or scoring function to return scores to a user.
|
| 70 |
+
|
| 71 |
+
However, certain tasks may require more complex behavior than directly turning over model outputs to a metric function. For example, we may want to post-process our output text by truncating it or extracting a model's answer, we may want to ensemble over multiple "takes" on a different document, et cetera.
|
| 72 |
+
|
| 73 |
+
**Detailed Aside**:
|
| 74 |
+
We do such post-processing by operating on *responses*, which are stored after running an LM on an `Instance` from the task in `Instance.resps`.
|
| 75 |
+
|
| 76 |
+
`resps` is a `List[str]` for each instance, and we pass a `List[List[<expected return type from model>]]` to our filters that is a list of `[instance.resps for instance in instances]`.
|
| 77 |
+
|
| 78 |
+
Our filters, after completing a pipeline, must return a `List[<expected return type from model>]` which we then unpack and store each element of in `Instance.filtered_resps` for the corresponding instance. Thus, we take as input a list of returns from our model for each doc, and must return a return from our model *without it being wrapped in a list* for each doc.
|
| 79 |
+
**End Aside**
|
| 80 |
+
|
| 81 |
+
A full list of supported filter operations can be found in `lm_eval/filters/__init__.py`. Contributions of new filter types are welcome!
|
| 82 |
+
|
| 83 |
+
### Multiple Filter Pipelines
|
| 84 |
+
|
| 85 |
+
Tasks need not be limited to a single filter pipeline. We enable users to run multiple, distinct, filter pipelines on *the same model outputs* generated in one run on a task.
|
| 86 |
+
|
| 87 |
+
As a case study, let's look at an implementation of solving the Gsm8k math word problem benchmark in `lm_eval/tasks/gsm8k/gsm8k-cot-self-consistency.yaml`. Here, we are emulating the setup used by [Self-Consistency Improves Chain of Thought Prompting](https://arxiv.org/abs/2203.11171), in which evaluation is performed by generating N chain-of-thought outputs from a model via temperature-based sampling, then selecting the answers output by the model at the end of the chains of thought, then majority voting across all those numeric answers.
|
| 88 |
+
|
| 89 |
+
Within our YAML file:
|
| 90 |
+
|
| 91 |
+
```yaml
|
| 92 |
+
...
|
| 93 |
+
repeats: 64
|
| 94 |
+
filter_list:
|
| 95 |
+
- name: "score-first"
|
| 96 |
+
filter:
|
| 97 |
+
- function: "regex"
|
| 98 |
+
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
|
| 99 |
+
- function: "take_first"
|
| 100 |
+
- name: "maj@64"
|
| 101 |
+
filter:
|
| 102 |
+
- function: "regex"
|
| 103 |
+
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
|
| 104 |
+
- function: "majority_vote"
|
| 105 |
+
- function: "take_first"
|
| 106 |
+
- name: "maj@8"
|
| 107 |
+
filter:
|
| 108 |
+
- function: "take_first_k"
|
| 109 |
+
k: 8
|
| 110 |
+
- function: "regex"
|
| 111 |
+
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
|
| 112 |
+
- function: "majority_vote"
|
| 113 |
+
- function: "take_first"
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
We are able to provide multiple different filter pipelines, each with their own name and list of filters to apply in sequence.
|
| 117 |
+
|
| 118 |
+
Our first filter pipeline implements
|
| 119 |
+
|
| 120 |
+
- applying a regex to the model generations (extracting the number within the phrase "The answer is (number)")
|
| 121 |
+
- selecting only the first out of the 64 model answers
|
| 122 |
+
|
| 123 |
+
Then scoring this single answer.
|
| 124 |
+
|
| 125 |
+
```yaml
|
| 126 |
+
- name: "score-first"
|
| 127 |
+
filter:
|
| 128 |
+
- function: "regex"
|
| 129 |
+
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
|
| 130 |
+
- function: "take_first"
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
Our second filter pipeline, "maj@64", does majority voting across all 64 answers via:
|
| 134 |
+
|
| 135 |
+
- applying the same regex to all responses, to get the numerical answer from the model for each of the 64 responses per problem
|
| 136 |
+
- applying majority voting to all responses, which then returns a length-1 `[<majority answer>]` list for each
|
| 137 |
+
- taking the first element of this length-1 list, to then score the sole response `<majority answer>` for each document.
|
| 138 |
+
|
| 139 |
+
```yaml
|
| 140 |
+
- name: "maj@64"
|
| 141 |
+
filter:
|
| 142 |
+
- function: "regex"
|
| 143 |
+
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
|
| 144 |
+
- function: "majority_vote"
|
| 145 |
+
- function: "take_first"
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
Our final filter pipeline, "maj@8", does majority voting across the first 8 of the model's responses per document via:
|
| 149 |
+
|
| 150 |
+
- subsetting the len-64 list of responses `[answer1, answer2, ..., answer64]` to `[answer1, answer2, ..., answer8]` for each document
|
| 151 |
+
- performing the same sequence of filters on these new sets of 8 responses, for each document.
|
| 152 |
+
|
| 153 |
+
```yaml
|
| 154 |
+
- name: "maj@8"
|
| 155 |
+
filter:
|
| 156 |
+
- function: "take_first_k"
|
| 157 |
+
k: 8
|
| 158 |
+
- function: "regex"
|
| 159 |
+
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
|
| 160 |
+
- function: "majority_vote"
|
| 161 |
+
- function: "take_first"
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
Thus, given the 64 responses from our LM on each document, we can report metrics on these responses in these 3 different ways, as defined by our filter pipelines.
|
| 165 |
+
|
| 166 |
+
### Adding a custom filter
|
| 167 |
+
|
| 168 |
+
Just like adding a custom model with `register_model` decorator one is able to do the same with filters, for example
|
| 169 |
+
|
| 170 |
+
```python
|
| 171 |
+
from lm_eval.api.filter import Filter
|
| 172 |
+
from lm_eval.api.registry import register_filter
|
| 173 |
+
|
| 174 |
+
@register_filter("new_filter")
|
| 175 |
+
class NewFilter(Filter)
|
| 176 |
+
...
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
## Embedded Python Code
|
| 180 |
+
|
| 181 |
+
Use can use python functions for certain arguments by using the `!function` operator after the argument name followed by `<filename>.<pythonfunctionname>`. This feature can be used for the following arguments:
|
| 182 |
+
|
| 183 |
+
1. `doc_to_text`
|
| 184 |
+
2. `doc_to_target`
|
| 185 |
+
3. `doc_to_choice`
|
| 186 |
+
4. `aggregation` for a `metric` in `metric_list`
|
| 187 |
+
|
| 188 |
+
## (No Longer Recommended) Direct `Task` Subclassing
|
| 189 |
+
|
| 190 |
+
The prior implementation method of new tasks was to subclass `Task`. While we intend to migrate all tasks to the new YAML implementation option going forward, it remains possible to subclass the Task class and implement custom logic. For more information, see `docs/task_guide.md` in v0.3.0 of the `lm-evaluation-harness`.
|
| 191 |
+
|
| 192 |
+
## Including a Base YAML
|
| 193 |
+
|
| 194 |
+
You can base a YAML on another YAML file as a template. This can be handy when you need to just change the prompt for `doc_to_text` but keep the rest the same or change `filters` to compare which is better. Simply use `include` in the YAML file and write the name of the template you want to base from. This assumes that the base template is in the same directory. Otherwise, You will need to define the full path.
|
| 195 |
+
|
| 196 |
+
```yaml
|
| 197 |
+
include: <YAML filename or with full path>
|
| 198 |
+
...
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
You can find an example of how to use this feature at [gsm8k-cot-self-consistency.yaml](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/gsm8k/gsm8k-cot-self-consistency.yaml) where it is based off [gsm8k-cot.yaml](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/gsm8k/gsm8k-cot.yaml)
|
| 202 |
+
|
| 203 |
+
## Passing Arguments to Metrics
|
| 204 |
+
|
| 205 |
+
Metrics can be defined in the `metric_list` argument when building the YAML config. Multiple metrics can be listed along with any auxiliary arguments. For example, setting the [`exact_match` metric](https://github.com/huggingface/evaluate/tree/main/metrics/exact_match), auxiliary arguments such as `ignore_case`, `ignore_punctuation`, `regexes_to_ignore` can be listed as well. They will be added to the metric function as `kwargs`. Some metrics have predefined values for `aggregation` and `higher_is_better` so listing the metric name only can be sufficient.
|
| 206 |
+
|
| 207 |
+
```yaml
|
| 208 |
+
metric_list:
|
| 209 |
+
- metric: acc
|
| 210 |
+
- metric: exact_match
|
| 211 |
+
aggregation: mean
|
| 212 |
+
higher_is_better: true
|
| 213 |
+
ignore_case: true
|
| 214 |
+
ignore_punctuation: false
|
| 215 |
+
regexes_to_ignore:
|
| 216 |
+
- ","
|
| 217 |
+
- "\\$"
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
### Natively Supported Metrics
|
| 221 |
+
|
| 222 |
+
Here we list all metrics currently supported natively in `lm-eval`:
|
| 223 |
+
|
| 224 |
+
Metrics:
|
| 225 |
+
|
| 226 |
+
- `acc` (accuracy)
|
| 227 |
+
- `acc_norm` (length-normalized accuracy)
|
| 228 |
+
- `acc_mutual_info` (baseline loglikelihood - normalized accuracy)
|
| 229 |
+
- `perplexity`
|
| 230 |
+
- `word_perplexity` (perplexity per word)
|
| 231 |
+
- `byte_perplexity` (perplexity per byte)
|
| 232 |
+
- `bits_per_byte`
|
| 233 |
+
- `matthews_corrcoef` (Matthews correlation coefficient)
|
| 234 |
+
- `f1` (F1 score)
|
| 235 |
+
- `bleu`
|
| 236 |
+
- `chrf`
|
| 237 |
+
- `ter`
|
| 238 |
+
|
| 239 |
+
Aggregation functions:
|
| 240 |
+
|
| 241 |
+
- `mean`
|
| 242 |
+
- `median`
|
| 243 |
+
- `perplexity`
|
| 244 |
+
- `weighted_perplexity`
|
| 245 |
+
- `bits_per_byte`
|
| 246 |
+
|
| 247 |
+
### Adding a Multiple Choice Metric
|
| 248 |
+
|
| 249 |
+
Adding a multiple choice metric has a few steps. To get it working you need to:
|
| 250 |
+
|
| 251 |
+
1. register a metric function
|
| 252 |
+
2. register an aggregation function
|
| 253 |
+
3. update the `Task` definition to make sure the correct arguments are passed
|
| 254 |
+
|
| 255 |
+
The default metric and aggregation functions are in `lm_eval/api/metrics.py`, and you can add a function there if it's for general use. The metrics are towards the bottom of the file and look like this:
|
| 256 |
+
|
| 257 |
+
```python
|
| 258 |
+
@register_metric(
|
| 259 |
+
metric="mcc",
|
| 260 |
+
higher_is_better=True,
|
| 261 |
+
output_type="multiple_choice",
|
| 262 |
+
aggregation="matthews_corrcoef",
|
| 263 |
+
)
|
| 264 |
+
def mcc_fn(items): # This is a passthrough function
|
| 265 |
+
return items
|
| 266 |
+
```
|
| 267 |
+
|
| 268 |
+
Note that many of these are passthrough functions, and for multiple choice (at least) this function is never actually called.
|
| 269 |
+
|
| 270 |
+
Aggregation functions are defined towards the top of the file, here's an example:
|
| 271 |
+
|
| 272 |
+
```python
|
| 273 |
+
@register_aggregation("matthews_corrcoef")
|
| 274 |
+
def matthews_corrcoef(items):
|
| 275 |
+
unzipped_list = list(zip(*items))
|
| 276 |
+
golds = unzipped_list[0]
|
| 277 |
+
preds = unzipped_list[1]
|
| 278 |
+
return sklearn.metrics.matthews_corrcoef(golds, preds)
|
| 279 |
+
```
|
| 280 |
+
|
| 281 |
+
This function returns a single numeric value. The input is defined in `Task.process_results` in `lm_eval/api/task.py`. There's a section that looks like this:
|
| 282 |
+
|
| 283 |
+
```python
|
| 284 |
+
result_dict = {
|
| 285 |
+
**({"acc": acc} if "acc" in use_metric else {}),
|
| 286 |
+
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
|
| 287 |
+
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
|
| 288 |
+
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
|
| 289 |
+
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
|
| 290 |
+
}
|
| 291 |
+
```
|
| 292 |
+
|
| 293 |
+
The value here determines the input to the aggregation function, though the name used matches the metric function. These metrics all have simple needs and just need the accuracy or gold and predicted values, but immediately below this there are examples of metrics with more complicated needs you can use as reference.
|
| 294 |
+
|
| 295 |
+
## Good Reference Tasks
|
| 296 |
+
|
| 297 |
+
Contributing a new task can be daunting! Luckily, much of the work has often been done for you in a different, similarly evaluated task. Good examples of task implementations to study include:
|
| 298 |
+
|
| 299 |
+
Multiple choice tasks:
|
| 300 |
+
|
| 301 |
+
- SciQ (`lm_eval/tasks/sciq/sciq.yaml`)
|
| 302 |
+
|
| 303 |
+
Corpus perplexity evaluations:
|
| 304 |
+
|
| 305 |
+
- Wikitext (`lm_eval/tasks/wikitext/wikitext.yaml`)
|
| 306 |
+
|
| 307 |
+
Generative tasks:
|
| 308 |
+
|
| 309 |
+
- GSM8k (`lm_eval/tasks/gsm8k/gsm8k.yaml`)
|
| 310 |
+
|
| 311 |
+
Tasks using complex filtering:
|
| 312 |
+
|
| 313 |
+
- GSM8k with CoT (+ with Self-Consistency): (`lm_eval/tasks/gsm8k/gsm8k-cot.yaml` ; `lm_eval/tasks/gsm8k/gsm8k-cot-self-consistency.yaml`)
|
| 314 |
+
|
| 315 |
+
# Group Configuration
|
| 316 |
+
|
| 317 |
+
When evaluating a language model, it is not unusual to test across a number of tasks that may not be related to one another in order to assess a variety of capabilities. To this end, it may be cumbersome to have to list the set of tasks or add a new group name to each yaml of each individual task.
|
| 318 |
+
|
| 319 |
+
To solve this, we can create a **group** yaml config. This is a config that contains the names of the tasks that should be included in a particular group. The config consists of two main keys: a `group` key which denotes the name of the group (as it would be called from the command line, e.g. `mmlu`) and a `task` key which is where we can list the tasks. The tasks listed in `task` are the task names that have been registered. A good example of a group yaml config can be found at [../lm_eval/tasks/mmlu/default/_mmlu.yaml]. See also the [New Task Guide](./new_task_guide.md) for a more in-depth and tutorial-esque explanation of how to write complex GroupConfigs.
|
| 320 |
+
|
| 321 |
+
## Configurations
|
| 322 |
+
|
| 323 |
+
Groups are configured via the `GroupConfig` object. Below, we describe all fields usable within the object, and their role in defining a task.
|
| 324 |
+
|
| 325 |
+
### Parameters
|
| 326 |
+
|
| 327 |
+
- **group** (`str`, defaults to `None`) — name of the group. Used to invoke it from the command line.
|
| 328 |
+
- **group_alias** (`str`, defaults to `None`) - Alternative name for the group that will be printed in the table output.
|
| 329 |
+
- **task** (`Union[str, list]`, defaults to `None`) - List of tasks that constitute the group.
|
| 330 |
+
- **aggregate_metric_list** (`list`, defaults to `None`) - similar to `metric_list` in TaskConfigs, provide a list of configurations for metrics that should be aggregated across subtasks. Leaving empty will result in no aggregation being performed for this group. Keys for each list entry are:
|
| 331 |
+
- `metric: str` - the name of the metric to aggregate over (all subtasks must report a metric holding this name.)
|
| 332 |
+
- `aggregation: str` - what aggregation function to apply to aggregate these per-subtask metrics. **currently, only `mean` is supported.**
|
| 333 |
+
- `weight_by_size: bool = True` whether to perform micro- averaging (`True`) or macro- (`False`) averaging of subtasks' accuracy scores when reporting the group's metric. MMLU, for example, averages over per-document accuracies (the *micro average*), resulting in the same accuracy as if one simply concatenated all 57 subjects into a single dataset and evaluated accuracy on that dataset.
|
| 334 |
+
- `filter_list: Union[str, List[str]] = "none"` - what filter keys one should match on to aggregate results. For example, if trying to aggregate over the `exact_match` metric using `strict-match` filter for `bbh_cot_zeroshot`, then set this to be `filter_list: "strict-match"`.
|
| 335 |
+
- **metadata** (`dict`, *optional*) - As with TaskConfigs, a field where extra config metadata can be passed. set the `num_fewshot` key within this to override the printed n_shot value in a results table for your group, for example.
|
lm-evaluation-harness/examples/lm-eval-overview.ipynb
ADDED
|
@@ -0,0 +1,1240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {
|
| 6 |
+
"id": "Qw83KAePAhaS"
|
| 7 |
+
},
|
| 8 |
+
"source": [
|
| 9 |
+
"# Releasing LM-Evaluation-Harness v0.4.0"
|
| 10 |
+
]
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"cell_type": "markdown",
|
| 14 |
+
"metadata": {
|
| 15 |
+
"id": "Z7k2vq1iAdqr"
|
| 16 |
+
},
|
| 17 |
+
"source": [
|
| 18 |
+
"With the vast amount of work done in the field today, it helps to have a tool that people can use easily to share their results and use to check others to ensure reported numbers are valid. The LM Evaluation Harness is one such tool the community has used extensively. We want to continue to support the community and with that in mind, we’re excited to announce a major update on the LM Evaluation Harness to further our goal for open and accessible AI research."
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "markdown",
|
| 23 |
+
"metadata": {
|
| 24 |
+
"id": "0gDoM0AJAvEc"
|
| 25 |
+
},
|
| 26 |
+
"source": [
|
| 27 |
+
"Our refactor stems from our desires to make the following believed best practices easier to carry out. \n",
|
| 28 |
+
"\n",
|
| 29 |
+
"1. Never copy results from other papers\n",
|
| 30 |
+
"2. Always share your exact prompts\n",
|
| 31 |
+
"3. Always provide model outputs\n",
|
| 32 |
+
"4. Qualitatively review a small batch of outputs before running evaluation jobs at scale\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"We also wanted to make the library a better experience to use and to contribute or design evaluations within. New features in the new release that serve this purpose include:\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"1. Faster Evaluation Runtimes (accelerated data-parallel inference with HF Transformers + Accelerate, and commonly used or faster inference libraries such as vLLM and Llama-CPP)\n",
|
| 37 |
+
"2. Easier addition and sharing of new tasks (YAML-based task config formats, allowing single-file sharing of custom tasks)\n",
|
| 38 |
+
"3. More configurability, for more advanced workflows and easier operation with modifying prompts\n",
|
| 39 |
+
"4. Better logging of data at runtime and post-hoc"
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "markdown",
|
| 44 |
+
"metadata": {
|
| 45 |
+
"id": "nnwsOpjda_YW"
|
| 46 |
+
},
|
| 47 |
+
"source": [
|
| 48 |
+
"In this notebook we will be going through a short tutorial on how things work."
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "markdown",
|
| 53 |
+
"metadata": {
|
| 54 |
+
"id": "zAov81vTbL2K"
|
| 55 |
+
},
|
| 56 |
+
"source": [
|
| 57 |
+
"## Install LM-Eval"
|
| 58 |
+
]
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"cell_type": "code",
|
| 62 |
+
"execution_count": 1,
|
| 63 |
+
"metadata": {
|
| 64 |
+
"colab": {
|
| 65 |
+
"base_uri": "https://localhost:8080/"
|
| 66 |
+
},
|
| 67 |
+
"id": "8hiosGzq_qZg",
|
| 68 |
+
"outputId": "6ab73e5e-1f54-417e-a388-07e0d870b132"
|
| 69 |
+
},
|
| 70 |
+
"outputs": [
|
| 71 |
+
{
|
| 72 |
+
"name": "stdout",
|
| 73 |
+
"output_type": "stream",
|
| 74 |
+
"text": [
|
| 75 |
+
"Collecting git+https://github.com/EleutherAI/lm-evaluation-harness.git@big-refactor\n",
|
| 76 |
+
" Cloning https://github.com/EleutherAI/lm-evaluation-harness.git (to revision big-refactor) to /tmp/pip-req-build-tnssql5s\n",
|
| 77 |
+
" Running command git clone --filter=blob:none --quiet https://github.com/EleutherAI/lm-evaluation-harness.git /tmp/pip-req-build-tnssql5s\n",
|
| 78 |
+
" Running command git checkout -b big-refactor --track origin/big-refactor\n",
|
| 79 |
+
" Switched to a new branch 'big-refactor'\n",
|
| 80 |
+
" Branch 'big-refactor' set up to track remote branch 'big-refactor' from 'origin'.\n",
|
| 81 |
+
" Resolved https://github.com/EleutherAI/lm-evaluation-harness.git to commit 42f486ee49b65926a444cb0620870a39a5b4b0a8\n",
|
| 82 |
+
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
|
| 83 |
+
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
|
| 84 |
+
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
|
| 85 |
+
"Collecting accelerate>=0.21.0 (from lm-eval==1.0.0)\n",
|
| 86 |
+
" Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)\n",
|
| 87 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m261.4/261.4 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 88 |
+
"\u001b[?25hCollecting evaluate (from lm-eval==1.0.0)\n",
|
| 89 |
+
" Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)\n",
|
| 90 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.1/84.1 kB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 91 |
+
"\u001b[?25hCollecting datasets>=2.0.0 (from lm-eval==1.0.0)\n",
|
| 92 |
+
" Downloading datasets-2.15.0-py3-none-any.whl (521 kB)\n",
|
| 93 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m521.2/521.2 kB\u001b[0m \u001b[31m9.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 94 |
+
"\u001b[?25hCollecting jsonlines (from lm-eval==1.0.0)\n",
|
| 95 |
+
" Downloading jsonlines-4.0.0-py3-none-any.whl (8.7 kB)\n",
|
| 96 |
+
"Requirement already satisfied: numexpr in /usr/local/lib/python3.10/dist-packages (from lm-eval==1.0.0) (2.8.7)\n",
|
| 97 |
+
"Collecting peft>=0.2.0 (from lm-eval==1.0.0)\n",
|
| 98 |
+
" Downloading peft-0.6.2-py3-none-any.whl (174 kB)\n",
|
| 99 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m174.7/174.7 kB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 100 |
+
"\u001b[?25hCollecting pybind11>=2.6.2 (from lm-eval==1.0.0)\n",
|
| 101 |
+
" Downloading pybind11-2.11.1-py3-none-any.whl (227 kB)\n",
|
| 102 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m227.7/227.7 kB\u001b[0m \u001b[31m12.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 103 |
+
"\u001b[?25hCollecting pytablewriter (from lm-eval==1.0.0)\n",
|
| 104 |
+
" Downloading pytablewriter-1.2.0-py3-none-any.whl (111 kB)\n",
|
| 105 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m111.1/111.1 kB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 106 |
+
"\u001b[?25hCollecting rouge-score>=0.0.4 (from lm-eval==1.0.0)\n",
|
| 107 |
+
" Downloading rouge_score-0.1.2.tar.gz (17 kB)\n",
|
| 108 |
+
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
|
| 109 |
+
"Collecting sacrebleu>=1.5.0 (from lm-eval==1.0.0)\n",
|
| 110 |
+
" Downloading sacrebleu-2.3.2-py3-none-any.whl (119 kB)\n",
|
| 111 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m119.7/119.7 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 112 |
+
"\u001b[?25hRequirement already satisfied: scikit-learn>=0.24.1 in /usr/local/lib/python3.10/dist-packages (from lm-eval==1.0.0) (1.2.2)\n",
|
| 113 |
+
"Collecting sqlitedict (from lm-eval==1.0.0)\n",
|
| 114 |
+
" Downloading sqlitedict-2.1.0.tar.gz (21 kB)\n",
|
| 115 |
+
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
|
| 116 |
+
"Requirement already satisfied: torch>=1.8 in /usr/local/lib/python3.10/dist-packages (from lm-eval==1.0.0) (2.1.0+cu118)\n",
|
| 117 |
+
"Collecting tqdm-multiprocess (from lm-eval==1.0.0)\n",
|
| 118 |
+
" Downloading tqdm_multiprocess-0.0.11-py3-none-any.whl (9.8 kB)\n",
|
| 119 |
+
"Requirement already satisfied: transformers>=4.1 in /usr/local/lib/python3.10/dist-packages (from lm-eval==1.0.0) (4.35.2)\n",
|
| 120 |
+
"Collecting zstandard (from lm-eval==1.0.0)\n",
|
| 121 |
+
" Downloading zstandard-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)\n",
|
| 122 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.4/5.4 MB\u001b[0m \u001b[31m29.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 123 |
+
"\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.21.0->lm-eval==1.0.0) (1.23.5)\n",
|
| 124 |
+
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.21.0->lm-eval==1.0.0) (23.2)\n",
|
| 125 |
+
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.21.0->lm-eval==1.0.0) (5.9.5)\n",
|
| 126 |
+
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.21.0->lm-eval==1.0.0) (6.0.1)\n",
|
| 127 |
+
"Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.21.0->lm-eval==1.0.0) (0.19.4)\n",
|
| 128 |
+
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->lm-eval==1.0.0) (9.0.0)\n",
|
| 129 |
+
"Collecting pyarrow-hotfix (from datasets>=2.0.0->lm-eval==1.0.0)\n",
|
| 130 |
+
" Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)\n",
|
| 131 |
+
"Collecting dill<0.3.8,>=0.3.0 (from datasets>=2.0.0->lm-eval==1.0.0)\n",
|
| 132 |
+
" Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n",
|
| 133 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 134 |
+
"\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->lm-eval==1.0.0) (1.5.3)\n",
|
| 135 |
+
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->lm-eval==1.0.0) (2.31.0)\n",
|
| 136 |
+
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->lm-eval==1.0.0) (4.66.1)\n",
|
| 137 |
+
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->lm-eval==1.0.0) (3.4.1)\n",
|
| 138 |
+
"Collecting multiprocess (from datasets>=2.0.0->lm-eval==1.0.0)\n",
|
| 139 |
+
" Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)\n",
|
| 140 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m19.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 141 |
+
"\u001b[?25hRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->lm-eval==1.0.0) (2023.6.0)\n",
|
| 142 |
+
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->lm-eval==1.0.0) (3.8.6)\n",
|
| 143 |
+
"Collecting responses<0.19 (from evaluate->lm-eval==1.0.0)\n",
|
| 144 |
+
" Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n",
|
| 145 |
+
"Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft>=0.2.0->lm-eval==1.0.0) (0.4.0)\n",
|
| 146 |
+
"Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from rouge-score>=0.0.4->lm-eval==1.0.0) (1.4.0)\n",
|
| 147 |
+
"Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (from rouge-score>=0.0.4->lm-eval==1.0.0) (3.8.1)\n",
|
| 148 |
+
"Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from rouge-score>=0.0.4->lm-eval==1.0.0) (1.16.0)\n",
|
| 149 |
+
"Collecting portalocker (from sacrebleu>=1.5.0->lm-eval==1.0.0)\n",
|
| 150 |
+
" Downloading portalocker-2.8.2-py3-none-any.whl (17 kB)\n",
|
| 151 |
+
"Requirement already satisfied: regex in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.5.0->lm-eval==1.0.0) (2023.6.3)\n",
|
| 152 |
+
"Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.5.0->lm-eval==1.0.0) (0.9.0)\n",
|
| 153 |
+
"Collecting colorama (from sacrebleu>=1.5.0->lm-eval==1.0.0)\n",
|
| 154 |
+
" Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)\n",
|
| 155 |
+
"Requirement already satisfied: lxml in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.5.0->lm-eval==1.0.0) (4.9.3)\n",
|
| 156 |
+
"Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.24.1->lm-eval==1.0.0) (1.11.3)\n",
|
| 157 |
+
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.24.1->lm-eval==1.0.0) (1.3.2)\n",
|
| 158 |
+
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.24.1->lm-eval==1.0.0) (3.2.0)\n",
|
| 159 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm-eval==1.0.0) (3.13.1)\n",
|
| 160 |
+
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm-eval==1.0.0) (4.5.0)\n",
|
| 161 |
+
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm-eval==1.0.0) (1.12)\n",
|
| 162 |
+
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm-eval==1.0.0) (3.2.1)\n",
|
| 163 |
+
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm-eval==1.0.0) (3.1.2)\n",
|
| 164 |
+
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm-eval==1.0.0) (2.1.0)\n",
|
| 165 |
+
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.1->lm-eval==1.0.0) (0.15.0)\n",
|
| 166 |
+
"Requirement already satisfied: attrs>=19.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonlines->lm-eval==1.0.0) (23.1.0)\n",
|
| 167 |
+
"Requirement already satisfied: setuptools>=38.3.0 in /usr/local/lib/python3.10/dist-packages (from pytablewriter->lm-eval==1.0.0) (67.7.2)\n",
|
| 168 |
+
"Collecting DataProperty<2,>=1.0.1 (from pytablewriter->lm-eval==1.0.0)\n",
|
| 169 |
+
" Downloading DataProperty-1.0.1-py3-none-any.whl (27 kB)\n",
|
| 170 |
+
"Collecting mbstrdecoder<2,>=1.0.0 (from pytablewriter->lm-eval==1.0.0)\n",
|
| 171 |
+
" Downloading mbstrdecoder-1.1.3-py3-none-any.whl (7.8 kB)\n",
|
| 172 |
+
"Collecting pathvalidate<4,>=2.3.0 (from pytablewriter->lm-eval==1.0.0)\n",
|
| 173 |
+
" Downloading pathvalidate-3.2.0-py3-none-any.whl (23 kB)\n",
|
| 174 |
+
"Collecting tabledata<2,>=1.3.1 (from pytablewriter->lm-eval==1.0.0)\n",
|
| 175 |
+
" Downloading tabledata-1.3.3-py3-none-any.whl (11 kB)\n",
|
| 176 |
+
"Collecting tcolorpy<1,>=0.0.5 (from pytablewriter->lm-eval==1.0.0)\n",
|
| 177 |
+
" Downloading tcolorpy-0.1.4-py3-none-any.whl (7.9 kB)\n",
|
| 178 |
+
"Collecting typepy[datetime]<2,>=1.3.2 (from pytablewriter->lm-eval==1.0.0)\n",
|
| 179 |
+
" Downloading typepy-1.3.2-py3-none-any.whl (31 kB)\n",
|
| 180 |
+
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->lm-eval==1.0.0) (3.3.2)\n",
|
| 181 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->lm-eval==1.0.0) (6.0.4)\n",
|
| 182 |
+
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->lm-eval==1.0.0) (4.0.3)\n",
|
| 183 |
+
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->lm-eval==1.0.0) (1.9.2)\n",
|
| 184 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->lm-eval==1.0.0) (1.4.0)\n",
|
| 185 |
+
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->lm-eval==1.0.0) (1.3.1)\n",
|
| 186 |
+
"Requirement already satisfied: chardet<6,>=3.0.4 in /usr/local/lib/python3.10/dist-packages (from mbstrdecoder<2,>=1.0.0->pytablewriter->lm-eval==1.0.0) (5.2.0)\n",
|
| 187 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.0.0->lm-eval==1.0.0) (3.4)\n",
|
| 188 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.0.0->lm-eval==1.0.0) (2.0.7)\n",
|
| 189 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.0.0->lm-eval==1.0.0) (2023.7.22)\n",
|
| 190 |
+
"Requirement already satisfied: python-dateutil<3.0.0,>=2.8.0 in /usr/local/lib/python3.10/dist-packages (from typepy[datetime]<2,>=1.3.2->pytablewriter->lm-eval==1.0.0) (2.8.2)\n",
|
| 191 |
+
"Requirement already satisfied: pytz>=2018.9 in /usr/local/lib/python3.10/dist-packages (from typepy[datetime]<2,>=1.3.2->pytablewriter->lm-eval==1.0.0) (2023.3.post1)\n",
|
| 192 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.8->lm-eval==1.0.0) (2.1.3)\n",
|
| 193 |
+
"Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk->rouge-score>=0.0.4->lm-eval==1.0.0) (8.1.7)\n",
|
| 194 |
+
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.8->lm-eval==1.0.0) (1.3.0)\n",
|
| 195 |
+
"Building wheels for collected packages: lm-eval, rouge-score, sqlitedict\n",
|
| 196 |
+
" Building wheel for lm-eval (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
|
| 197 |
+
" Created wheel for lm-eval: filename=lm_eval-1.0.0-py3-none-any.whl size=994254 sha256=88356155b19f2891981ecef948326ad6ce8ca40a6009378410ec20d0e225995a\n",
|
| 198 |
+
" Stored in directory: /tmp/pip-ephem-wheel-cache-9v6ye7h3/wheels/17/01/26/599c0779e9858a70a73fa8a306699b5b9a868f820c225457b0\n",
|
| 199 |
+
" Building wheel for rouge-score (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
|
| 200 |
+
" Created wheel for rouge-score: filename=rouge_score-0.1.2-py3-none-any.whl size=24933 sha256=6bb0d44e4881972c43ce194e7cb65233d309758cb15f0dec54590d3d2efcfc36\n",
|
| 201 |
+
" Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4\n",
|
| 202 |
+
" Building wheel for sqlitedict (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
|
| 203 |
+
" Created wheel for sqlitedict: filename=sqlitedict-2.1.0-py3-none-any.whl size=16863 sha256=5747f7dd73ddf3d8fbcebf51b5e4f718fabe1e94bccdf16d2f22a2e65ee7fdf4\n",
|
| 204 |
+
" Stored in directory: /root/.cache/pip/wheels/79/d6/e7/304e0e6cb2221022c26d8161f7c23cd4f259a9e41e8bbcfabd\n",
|
| 205 |
+
"Successfully built lm-eval rouge-score sqlitedict\n",
|
| 206 |
+
"Installing collected packages: sqlitedict, zstandard, tcolorpy, pybind11, pyarrow-hotfix, portalocker, pathvalidate, mbstrdecoder, jsonlines, dill, colorama, typepy, tqdm-multiprocess, sacrebleu, rouge-score, responses, multiprocess, accelerate, datasets, DataProperty, tabledata, peft, evaluate, pytablewriter, lm-eval\n",
|
| 207 |
+
"Successfully installed DataProperty-1.0.1 accelerate-0.24.1 colorama-0.4.6 datasets-2.15.0 dill-0.3.7 evaluate-0.4.1 jsonlines-4.0.0 lm-eval-1.0.0 mbstrdecoder-1.1.3 multiprocess-0.70.15 pathvalidate-3.2.0 peft-0.6.2 portalocker-2.8.2 pyarrow-hotfix-0.6 pybind11-2.11.1 pytablewriter-1.2.0 responses-0.18.0 rouge-score-0.1.2 sacrebleu-2.3.2 sqlitedict-2.1.0 tabledata-1.3.3 tcolorpy-0.1.4 tqdm-multiprocess-0.0.11 typepy-1.3.2 zstandard-0.22.0\n"
|
| 208 |
+
]
|
| 209 |
+
}
|
| 210 |
+
],
|
| 211 |
+
"source": [
|
| 212 |
+
"# Install LM-Eval\n",
|
| 213 |
+
"!pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git"
|
| 214 |
+
]
|
| 215 |
+
},
|
| 216 |
+
{
|
| 217 |
+
"cell_type": "code",
|
| 218 |
+
"execution_count": 2,
|
| 219 |
+
"metadata": {
|
| 220 |
+
"colab": {
|
| 221 |
+
"base_uri": "https://localhost:8080/",
|
| 222 |
+
"height": 0,
|
| 223 |
+
"referenced_widgets": [
|
| 224 |
+
"a1d3a8aa016544a78e8821c8f6199e06",
|
| 225 |
+
"f61ed33fad754146bdd2ac9db1ba1c48",
|
| 226 |
+
"bfa0af6aeff344c6845e1080a878e92e",
|
| 227 |
+
"fd1ad9e0367d4004aae853b91c3a7617",
|
| 228 |
+
"6b2d90209ec14230b3d58a74ac9b83bf",
|
| 229 |
+
"a73f357065d34d7baf0453ae4a8d75e2",
|
| 230 |
+
"46f521b73fd943c081c648fd873ebc0a",
|
| 231 |
+
"7c5689bc13684db8a22681f41863dddd",
|
| 232 |
+
"48763b6233374554ae76035c0483066f",
|
| 233 |
+
"4986a21eb560448fa79f4b25cde48951",
|
| 234 |
+
"aed3acd2f2d74003b44079c333a0698e"
|
| 235 |
+
]
|
| 236 |
+
},
|
| 237 |
+
"id": "uyO5MaKkZyah",
|
| 238 |
+
"outputId": "d46e8096-5086-4e49-967e-ea33d4a2a335"
|
| 239 |
+
},
|
| 240 |
+
"outputs": [
|
| 241 |
+
{
|
| 242 |
+
"data": {
|
| 243 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 244 |
+
"model_id": "a1d3a8aa016544a78e8821c8f6199e06",
|
| 245 |
+
"version_major": 2,
|
| 246 |
+
"version_minor": 0
|
| 247 |
+
},
|
| 248 |
+
"text/plain": [
|
| 249 |
+
"Downloading builder script: 0%| | 0.00/5.67k [00:00<?, ?B/s]"
|
| 250 |
+
]
|
| 251 |
+
},
|
| 252 |
+
"metadata": {},
|
| 253 |
+
"output_type": "display_data"
|
| 254 |
+
}
|
| 255 |
+
],
|
| 256 |
+
"source": []
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"cell_type": "markdown",
|
| 260 |
+
"metadata": {
|
| 261 |
+
"id": "8rfUeX6n_wkK"
|
| 262 |
+
},
|
| 263 |
+
"source": [
|
| 264 |
+
"## Create new evaluation tasks with config-based tasks\n",
|
| 265 |
+
"\n",
|
| 266 |
+
"Even within the same task, many works have reported numbers based on different choices of evaluation. Some report on the test sets, validation sets, or even subset of the training sets. Others have specialized prompts and verbalizers. We introduce YAMLs to allow users to easily make different variations. By leveraging the YAML configs to configure evaluations, the refactored LM-Eval takes the methods of the `Task` object and makes them configurable by setting the appropriate attributes in the config file. There, users can set the tasks they want by setting the name of the HF dataset (local tasks are also possible), the dataset splits used, and much more. Key configurations relating to prompting, such as `doc_to_text`, previously implemented as a method of the same name, are now configurable with jinja2 to allow high-level scripting to transform a HF dataset to text string as input to the model.\n",
|
| 267 |
+
"\n"
|
| 268 |
+
]
|
| 269 |
+
},
|
| 270 |
+
{
|
| 271 |
+
"cell_type": "markdown",
|
| 272 |
+
"metadata": {
|
| 273 |
+
"id": "HYFUhhfOSJKe"
|
| 274 |
+
},
|
| 275 |
+
"source": [
|
| 276 |
+
"A core-feature to LM-Eval is to configure tasks with YAML configs. With configs, you can fill preset fields to easily set up a task.\n",
|
| 277 |
+
"\n",
|
| 278 |
+
"Here, we write a demo YAML config for a multiple-choice evaluation of BoolQ:"
|
| 279 |
+
]
|
| 280 |
+
},
|
| 281 |
+
{
|
| 282 |
+
"cell_type": "code",
|
| 283 |
+
"execution_count": 3,
|
| 284 |
+
"metadata": {
|
| 285 |
+
"id": "bg3dGROW-V39"
|
| 286 |
+
},
|
| 287 |
+
"outputs": [],
|
| 288 |
+
"source": [
|
| 289 |
+
"YAML_boolq_string = \"\"\"\n",
|
| 290 |
+
"task: demo_boolq\n",
|
| 291 |
+
"dataset_path: super_glue\n",
|
| 292 |
+
"dataset_name: boolq\n",
|
| 293 |
+
"output_type: multiple_choice\n",
|
| 294 |
+
"training_split: train\n",
|
| 295 |
+
"validation_split: validation\n",
|
| 296 |
+
"doc_to_text: \"{{passage}}\\nQuestion: {{question}}?\\nAnswer:\"\n",
|
| 297 |
+
"doc_to_target: label\n",
|
| 298 |
+
"doc_to_choice: [\"no\", \"yes\"]\n",
|
| 299 |
+
"should_decontaminate: true\n",
|
| 300 |
+
"doc_to_decontamination_query: passage\n",
|
| 301 |
+
"metric_list:\n",
|
| 302 |
+
" - metric: acc\n",
|
| 303 |
+
"\"\"\"\n",
|
| 304 |
+
"with open(\"boolq.yaml\", \"w\") as f:\n",
|
| 305 |
+
" f.write(YAML_boolq_string)"
|
| 306 |
+
]
|
| 307 |
+
},
|
| 308 |
+
{
|
| 309 |
+
"cell_type": "markdown",
|
| 310 |
+
"metadata": {},
|
| 311 |
+
"source": [
|
| 312 |
+
"And we can now run evaluation on this task, by pointing to the config file we've just created:"
|
| 313 |
+
]
|
| 314 |
+
},
|
| 315 |
+
{
|
| 316 |
+
"cell_type": "code",
|
| 317 |
+
"execution_count": 4,
|
| 318 |
+
"metadata": {
|
| 319 |
+
"id": "LOUHK7PtQfq4"
|
| 320 |
+
},
|
| 321 |
+
"outputs": [
|
| 322 |
+
{
|
| 323 |
+
"name": "stdout",
|
| 324 |
+
"output_type": "stream",
|
| 325 |
+
"text": [
|
| 326 |
+
"2023-11-29:11:54:55,156 INFO [utils.py:160] NumExpr defaulting to 2 threads.\n",
|
| 327 |
+
"2023-11-29 11:54:55.942051: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
|
| 328 |
+
"2023-11-29 11:54:55.942108: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
|
| 329 |
+
"2023-11-29 11:54:55.942142: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
|
| 330 |
+
"2023-11-29 11:54:57.066802: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
|
| 331 |
+
"2023-11-29:11:55:00,954 INFO [__main__.py:132] Verbosity set to INFO\n",
|
| 332 |
+
"2023-11-29:11:55:11,038 WARNING [__main__.py:138] --limit SHOULD ONLY BE USED FOR TESTING.REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.\n",
|
| 333 |
+
"2023-11-29:11:55:11,038 INFO [__main__.py:143] Including path: ./\n",
|
| 334 |
+
"2023-11-29:11:55:11,046 INFO [__main__.py:205] Selected Tasks: ['demo_boolq']\n",
|
| 335 |
+
"2023-11-29:11:55:11,047 WARNING [evaluator.py:93] generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks.\n",
|
| 336 |
+
"2023-11-29:11:55:11,110 INFO [huggingface.py:120] Using device 'cuda'\n",
|
| 337 |
+
"config.json: 100% 571/571 [00:00<00:00, 2.87MB/s]\n",
|
| 338 |
+
"model.safetensors: 100% 5.68G/5.68G [00:32<00:00, 173MB/s]\n",
|
| 339 |
+
"tokenizer_config.json: 100% 396/396 [00:00<00:00, 2.06MB/s]\n",
|
| 340 |
+
"tokenizer.json: 100% 2.11M/2.11M [00:00<00:00, 11.6MB/s]\n",
|
| 341 |
+
"special_tokens_map.json: 100% 99.0/99.0 [00:00<00:00, 555kB/s]\n",
|
| 342 |
+
"2023-11-29:11:56:18,658 WARNING [task.py:614] [Task: demo_boolq] metric acc is defined, but aggregation is not. using default aggregation=mean\n",
|
| 343 |
+
"2023-11-29:11:56:18,658 WARNING [task.py:626] [Task: demo_boolq] metric acc is defined, but higher_is_better is not. using default higher_is_better=True\n",
|
| 344 |
+
"Downloading builder script: 100% 30.7k/30.7k [00:00<00:00, 59.0MB/s]\n",
|
| 345 |
+
"Downloading metadata: 100% 38.7k/38.7k [00:00<00:00, 651kB/s]\n",
|
| 346 |
+
"Downloading readme: 100% 14.8k/14.8k [00:00<00:00, 37.3MB/s]\n",
|
| 347 |
+
"Downloading data: 100% 4.12M/4.12M [00:00<00:00, 55.1MB/s]\n",
|
| 348 |
+
"Generating train split: 100% 9427/9427 [00:00<00:00, 15630.89 examples/s]\n",
|
| 349 |
+
"Generating validation split: 100% 3270/3270 [00:00<00:00, 20002.56 examples/s]\n",
|
| 350 |
+
"Generating test split: 100% 3245/3245 [00:00<00:00, 20866.19 examples/s]\n",
|
| 351 |
+
"2023-11-29:11:56:22,315 INFO [task.py:355] Building contexts for task on rank 0...\n",
|
| 352 |
+
"2023-11-29:11:56:22,322 INFO [evaluator.py:319] Running loglikelihood requests\n",
|
| 353 |
+
"100% 20/20 [00:04<00:00, 4.37it/s]\n",
|
| 354 |
+
"fatal: not a git repository (or any of the parent directories): .git\n",
|
| 355 |
+
"hf (pretrained=EleutherAI/pythia-2.8b), gen_kwargs: (), limit: 10.0, num_fewshot: None, batch_size: 1\n",
|
| 356 |
+
"| Tasks |Version|Filter|n-shot|Metric|Value| |Stderr|\n",
|
| 357 |
+
"|----------|-------|------|-----:|------|----:|---|-----:|\n",
|
| 358 |
+
"|demo_boolq|Yaml |none | 0|acc | 1|± | 0|\n",
|
| 359 |
+
"\n"
|
| 360 |
+
]
|
| 361 |
+
}
|
| 362 |
+
],
|
| 363 |
+
"source": [
|
| 364 |
+
"%env LOGLEVEL=DEBUG\n",
|
| 365 |
+
"!lm_eval \\\n",
|
| 366 |
+
" --model hf \\\n",
|
| 367 |
+
" --model_args pretrained=EleutherAI/pythia-2.8b \\\n",
|
| 368 |
+
" --include_path ./ \\\n",
|
| 369 |
+
" --tasks demo_boolq \\\n",
|
| 370 |
+
" --limit 10"
|
| 371 |
+
]
|
| 372 |
+
},
|
| 373 |
+
{
|
| 374 |
+
"cell_type": "markdown",
|
| 375 |
+
"metadata": {
|
| 376 |
+
"id": "LOUHK7PtQfq4"
|
| 377 |
+
},
|
| 378 |
+
"source": [
|
| 379 |
+
"Often, tasks are part of a larger group used to measure different capabilities. The dynamism of the field today means new dimensions of evaluation can come about which would mix and match new and older tasks alike. In LM-Eval, We can also group tasks and call that the group name to evaluate on a set of tasks easily. In this instance, let's evaluate the tag `yes_or_no_tasks` which comprise of the tasks `demo_boolq` and `demo_cola`; tasks which are multiple choice tasks with options `yes` and `no` as the name suggests.\n",
|
| 380 |
+
"\n",
|
| 381 |
+
"<!-- making new groups is easier than ever, allowing user to work bottom-up by makiing individual tasks and linking them to a group or Top-Down, making a new group by listing existing tasks.\n",
|
| 382 |
+
"\n",
|
| 383 |
+
"We also show the aggregate across samples besides only showing the aggregation between subtasks. This may come in handy when certain groups want to be aggregated as a single task. -->\n",
|
| 384 |
+
"\n",
|
| 385 |
+
"\n"
|
| 386 |
+
]
|
| 387 |
+
},
|
| 388 |
+
{
|
| 389 |
+
"cell_type": "code",
|
| 390 |
+
"execution_count": 5,
|
| 391 |
+
"metadata": {
|
| 392 |
+
"id": "fthNg3ywO-kA"
|
| 393 |
+
},
|
| 394 |
+
"outputs": [],
|
| 395 |
+
"source": [
|
| 396 |
+
"YAML_cola_string = \"\"\"\n",
|
| 397 |
+
"tag: yes_or_no_tasks\n",
|
| 398 |
+
"task: demo_cola\n",
|
| 399 |
+
"dataset_path: glue\n",
|
| 400 |
+
"dataset_name: cola\n",
|
| 401 |
+
"output_type: multiple_choice\n",
|
| 402 |
+
"training_split: train\n",
|
| 403 |
+
"validation_split: validation\n",
|
| 404 |
+
"doc_to_text: \"{{sentence}}\\nQuestion: Does this sentence make sense?\\nAnswer:\"\n",
|
| 405 |
+
"doc_to_target: label\n",
|
| 406 |
+
"doc_to_choice: [\"no\", \"yes\"]\n",
|
| 407 |
+
"should_decontaminate: true\n",
|
| 408 |
+
"doc_to_decontamination_query: sentence\n",
|
| 409 |
+
"metric_list:\n",
|
| 410 |
+
" - metric: acc\n",
|
| 411 |
+
"\"\"\"\n",
|
| 412 |
+
"with open(\"cola.yaml\", \"w\") as f:\n",
|
| 413 |
+
" f.write(YAML_cola_string)"
|
| 414 |
+
]
|
| 415 |
+
},
|
| 416 |
+
{
|
| 417 |
+
"cell_type": "code",
|
| 418 |
+
"execution_count": 6,
|
| 419 |
+
"metadata": {
|
| 420 |
+
"id": "XceRKCuuDtbn"
|
| 421 |
+
},
|
| 422 |
+
"outputs": [
|
| 423 |
+
{
|
| 424 |
+
"name": "stdout",
|
| 425 |
+
"output_type": "stream",
|
| 426 |
+
"text": [
|
| 427 |
+
"2023-11-29:11:56:33,016 INFO [utils.py:160] NumExpr defaulting to 2 threads.\n",
|
| 428 |
+
"2023-11-29 11:56:33.852995: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
|
| 429 |
+
"2023-11-29 11:56:33.853050: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
|
| 430 |
+
"2023-11-29 11:56:33.853087: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
|
| 431 |
+
"2023-11-29 11:56:35.129047: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
|
| 432 |
+
"2023-11-29:11:56:38,546 INFO [__main__.py:132] Verbosity set to INFO\n",
|
| 433 |
+
"2023-11-29:11:56:47,509 WARNING [__main__.py:138] --limit SHOULD ONLY BE USED FOR TESTING.REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.\n",
|
| 434 |
+
"2023-11-29:11:56:47,509 INFO [__main__.py:143] Including path: ./\n",
|
| 435 |
+
"2023-11-29:11:56:47,517 INFO [__main__.py:205] Selected Tasks: ['yes_or_no_tasks']\n",
|
| 436 |
+
"2023-11-29:11:56:47,520 WARNING [evaluator.py:93] generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks.\n",
|
| 437 |
+
"2023-11-29:11:56:47,550 INFO [huggingface.py:120] Using device 'cuda'\n",
|
| 438 |
+
"2023-11-29:11:57:08,743 WARNING [task.py:614] [Task: demo_cola] metric acc is defined, but aggregation is not. using default aggregation=mean\n",
|
| 439 |
+
"2023-11-29:11:57:08,743 WARNING [task.py:626] [Task: demo_cola] metric acc is defined, but higher_is_better is not. using default higher_is_better=True\n",
|
| 440 |
+
"Downloading builder script: 100% 28.8k/28.8k [00:00<00:00, 52.7MB/s]\n",
|
| 441 |
+
"Downloading metadata: 100% 28.7k/28.7k [00:00<00:00, 51.9MB/s]\n",
|
| 442 |
+
"Downloading readme: 100% 27.9k/27.9k [00:00<00:00, 48.0MB/s]\n",
|
| 443 |
+
"Downloading data: 100% 377k/377k [00:00<00:00, 12.0MB/s]\n",
|
| 444 |
+
"Generating train split: 100% 8551/8551 [00:00<00:00, 19744.58 examples/s]\n",
|
| 445 |
+
"Generating validation split: 100% 1043/1043 [00:00<00:00, 27057.01 examples/s]\n",
|
| 446 |
+
"Generating test split: 100% 1063/1063 [00:00<00:00, 22705.17 examples/s]\n",
|
| 447 |
+
"2023-11-29:11:57:11,698 INFO [task.py:355] Building contexts for task on rank 0...\n",
|
| 448 |
+
"2023-11-29:11:57:11,704 INFO [evaluator.py:319] Running loglikelihood requests\n",
|
| 449 |
+
"100% 20/20 [00:03<00:00, 5.15it/s]\n",
|
| 450 |
+
"fatal: not a git repository (or any of the parent directories): .git\n",
|
| 451 |
+
"hf (pretrained=EleutherAI/pythia-2.8b), gen_kwargs: (), limit: 10.0, num_fewshot: None, batch_size: 1\n",
|
| 452 |
+
"| Tasks |Version|Filter|n-shot|Metric|Value| |Stderr|\n",
|
| 453 |
+
"|---------------|-------|------|-----:|------|----:|---|-----:|\n",
|
| 454 |
+
"|yes_or_no_tasks|N/A |none | 0|acc | 0.7|± |0.1528|\n",
|
| 455 |
+
"| - demo_cola |Yaml |none | 0|acc | 0.7|± |0.1528|\n",
|
| 456 |
+
"\n",
|
| 457 |
+
"| Groups |Version|Filter|n-shot|Metric|Value| |Stderr|\n",
|
| 458 |
+
"|---------------|-------|------|-----:|------|----:|---|-----:|\n",
|
| 459 |
+
"|yes_or_no_tasks|N/A |none | 0|acc | 0.7|± |0.1528|\n",
|
| 460 |
+
"\n"
|
| 461 |
+
]
|
| 462 |
+
}
|
| 463 |
+
],
|
| 464 |
+
"source": [
|
| 465 |
+
"# !accelerate launch --no_python\n",
|
| 466 |
+
"%env LOGLEVEL=DEBUG\n",
|
| 467 |
+
"!lm_eval \\\n",
|
| 468 |
+
" --model hf \\\n",
|
| 469 |
+
" --model_args pretrained=EleutherAI/pythia-2.8b \\\n",
|
| 470 |
+
" --include_path ./ \\\n",
|
| 471 |
+
" --tasks yes_or_no_tasks \\\n",
|
| 472 |
+
" --limit 10 \\\n",
|
| 473 |
+
" --output output/yes_or_no_tasks/ \\\n",
|
| 474 |
+
" --log_samples"
|
| 475 |
+
]
|
| 476 |
+
},
|
| 477 |
+
{
|
| 478 |
+
"cell_type": "markdown",
|
| 479 |
+
"metadata": {
|
| 480 |
+
"id": "XceRKCuuDtbn"
|
| 481 |
+
},
|
| 482 |
+
"source": [
|
| 483 |
+
"## Edit Prompt Templates Quickly\n",
|
| 484 |
+
"\n",
|
| 485 |
+
"The following is a yaml made to evaluate the specific subtask of `high_school_geography` from MMLU. It uses the standard prompt where the we choose the letters from the options with most likelihood as the model's prediction."
|
| 486 |
+
]
|
| 487 |
+
},
|
| 488 |
+
{
|
| 489 |
+
"cell_type": "code",
|
| 490 |
+
"execution_count": 7,
|
| 491 |
+
"metadata": {
|
| 492 |
+
"id": "GTFvdt9kSlBG"
|
| 493 |
+
},
|
| 494 |
+
"outputs": [],
|
| 495 |
+
"source": [
|
| 496 |
+
"YAML_mmlu_geo_string = \"\"\"\n",
|
| 497 |
+
"task: demo_mmlu_high_school_geography\n",
|
| 498 |
+
"dataset_path: cais/mmlu\n",
|
| 499 |
+
"dataset_name: high_school_geography\n",
|
| 500 |
+
"description: \"The following are multiple choice questions (with answers) about high school geography.\\n\\n\"\n",
|
| 501 |
+
"test_split: test\n",
|
| 502 |
+
"fewshot_split: dev\n",
|
| 503 |
+
"fewshot_config:\n",
|
| 504 |
+
" sampler: first_n\n",
|
| 505 |
+
"output_type: multiple_choice\n",
|
| 506 |
+
"doc_to_text: \"{{question.strip()}}\\nA. {{choices[0]}}\\nB. {{choices[1]}}\\nC. {{choices[2]}}\\nD. {{choices[3]}}\\nAnswer:\"\n",
|
| 507 |
+
"doc_to_choice: [\"A\", \"B\", \"C\", \"D\"]\n",
|
| 508 |
+
"doc_to_target: answer\n",
|
| 509 |
+
"metric_list:\n",
|
| 510 |
+
" - metric: acc\n",
|
| 511 |
+
" aggregation: mean\n",
|
| 512 |
+
" higher_is_better: true\n",
|
| 513 |
+
" - metric: acc_norm\n",
|
| 514 |
+
" aggregation: mean\n",
|
| 515 |
+
" higher_is_better: true\n",
|
| 516 |
+
"\"\"\"\n",
|
| 517 |
+
"with open(\"mmlu_high_school_geography.yaml\", \"w\") as f:\n",
|
| 518 |
+
" f.write(YAML_mmlu_geo_string)"
|
| 519 |
+
]
|
| 520 |
+
},
|
| 521 |
+
{
|
| 522 |
+
"cell_type": "code",
|
| 523 |
+
"execution_count": 8,
|
| 524 |
+
"metadata": {
|
| 525 |
+
"id": "jyKOfCsKb-xy"
|
| 526 |
+
},
|
| 527 |
+
"outputs": [
|
| 528 |
+
{
|
| 529 |
+
"name": "stdout",
|
| 530 |
+
"output_type": "stream",
|
| 531 |
+
"text": [
|
| 532 |
+
"2023-11-29:11:57:23,598 INFO [utils.py:160] NumExpr defaulting to 2 threads.\n",
|
| 533 |
+
"2023-11-29 11:57:24.719750: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
|
| 534 |
+
"2023-11-29 11:57:24.719806: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
|
| 535 |
+
"2023-11-29 11:57:24.719847: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
|
| 536 |
+
"2023-11-29 11:57:26.656125: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
|
| 537 |
+
"2023-11-29:11:57:31,563 INFO [__main__.py:132] Verbosity set to INFO\n",
|
| 538 |
+
"2023-11-29:11:57:40,541 WARNING [__main__.py:138] --limit SHOULD ONLY BE USED FOR TESTING.REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.\n",
|
| 539 |
+
"2023-11-29:11:57:40,541 INFO [__main__.py:143] Including path: ./\n",
|
| 540 |
+
"2023-11-29:11:57:40,558 INFO [__main__.py:205] Selected Tasks: ['demo_mmlu_high_school_geography']\n",
|
| 541 |
+
"2023-11-29:11:57:40,559 WARNING [evaluator.py:93] generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks.\n",
|
| 542 |
+
"2023-11-29:11:57:40,589 INFO [huggingface.py:120] Using device 'cuda'\n",
|
| 543 |
+
"Downloading builder script: 100% 5.84k/5.84k [00:00<00:00, 17.7MB/s]\n",
|
| 544 |
+
"Downloading metadata: 100% 106k/106k [00:00<00:00, 892kB/s] \n",
|
| 545 |
+
"Downloading readme: 100% 39.7k/39.7k [00:00<00:00, 631kB/s]\n",
|
| 546 |
+
"Downloading data: 100% 166M/166M [00:01<00:00, 89.0MB/s]\n",
|
| 547 |
+
"Generating auxiliary_train split: 100% 99842/99842 [00:07<00:00, 12536.83 examples/s]\n",
|
| 548 |
+
"Generating test split: 100% 198/198 [00:00<00:00, 1439.20 examples/s]\n",
|
| 549 |
+
"Generating validation split: 100% 22/22 [00:00<00:00, 4181.76 examples/s]\n",
|
| 550 |
+
"Generating dev split: 100% 5/5 [00:00<00:00, 36.25 examples/s]\n",
|
| 551 |
+
"2023-11-29:11:58:09,798 INFO [task.py:355] Building contexts for task on rank 0...\n",
|
| 552 |
+
"2023-11-29:11:58:09,822 INFO [evaluator.py:319] Running loglikelihood requests\n",
|
| 553 |
+
"100% 40/40 [00:05<00:00, 7.86it/s]\n",
|
| 554 |
+
"fatal: not a git repository (or any of the parent directories): .git\n",
|
| 555 |
+
"hf (pretrained=EleutherAI/pythia-2.8b), gen_kwargs: (), limit: 10.0, num_fewshot: None, batch_size: 1\n",
|
| 556 |
+
"| Tasks |Version|Filter|n-shot| Metric |Value| |Stderr|\n",
|
| 557 |
+
"|-------------------------------|-------|------|-----:|--------|----:|---|-----:|\n",
|
| 558 |
+
"|demo_mmlu_high_school_geography|Yaml |none | 0|acc | 0.3|± |0.1528|\n",
|
| 559 |
+
"| | |none | 0|acc_norm| 0.3|± |0.1528|\n",
|
| 560 |
+
"\n"
|
| 561 |
+
]
|
| 562 |
+
}
|
| 563 |
+
],
|
| 564 |
+
"source": [
|
| 565 |
+
"# !accelerate launch --no_python\n",
|
| 566 |
+
"%env LOGLEVEL=DEBUG\n",
|
| 567 |
+
"!lm_eval \\\n",
|
| 568 |
+
" --model hf \\\n",
|
| 569 |
+
" --model_args pretrained=EleutherAI/pythia-2.8b \\\n",
|
| 570 |
+
" --include_path ./ \\\n",
|
| 571 |
+
" --tasks demo_mmlu_high_school_geography \\\n",
|
| 572 |
+
" --limit 10 \\\n",
|
| 573 |
+
" --output output/mmlu_high_school_geography/ \\\n",
|
| 574 |
+
" --log_samples"
|
| 575 |
+
]
|
| 576 |
+
},
|
| 577 |
+
{
|
| 578 |
+
"cell_type": "markdown",
|
| 579 |
+
"metadata": {
|
| 580 |
+
"id": "jyKOfCsKb-xy"
|
| 581 |
+
},
|
| 582 |
+
"source": [
|
| 583 |
+
"We could also evaluate this task in a different way. For example, instead of observing the loglikelihood of the letters, we can instead evaluate on the choices themselves as the continuation. This is done by simply changing `doc_to_choice` from a list of letters to the corresponding `choices` field from the HF dataset. We write `\"{{choices}}\"` so that the string field is interpreted as jinja string that acquires the list from the HF dataset directly.\n",
|
| 584 |
+
"\n",
|
| 585 |
+
"Another convenient feature here is since we're only modifying the `doc_to_choice` and the rest of config is the same as the task above, we can use the above configuration as a template by using `include: mmlu_high_school_geography.yaml` to load the config from that file. We'll need to add a unique task name as to not colide with the existing yaml config we're including. For this case we'll simply name this one `mmlu_high_school_geography_continuation`. `doc_to_text` is added here just for sake of clarity."
|
| 586 |
+
]
|
| 587 |
+
},
|
| 588 |
+
{
|
| 589 |
+
"cell_type": "code",
|
| 590 |
+
"execution_count": 9,
|
| 591 |
+
"metadata": {
|
| 592 |
+
"id": "lqElwU54TaK-"
|
| 593 |
+
},
|
| 594 |
+
"outputs": [],
|
| 595 |
+
"source": [
|
| 596 |
+
"YAML_mmlu_geo_string = \"\"\"\n",
|
| 597 |
+
"include: mmlu_high_school_geography.yaml\n",
|
| 598 |
+
"task: demo_mmlu_high_school_geography_continuation\n",
|
| 599 |
+
"doc_to_text: \"{{question.strip()}}\\nA. {{choices[0]}}\\nB. {{choices[1]}}\\nC. {{choices[2]}}\\nD. {{choices[3]}}\\nAnswer:\"\n",
|
| 600 |
+
"doc_to_choice: \"{{choices}}\"\n",
|
| 601 |
+
"\"\"\"\n",
|
| 602 |
+
"with open(\"mmlu_high_school_geography_continuation.yaml\", \"w\") as f:\n",
|
| 603 |
+
" f.write(YAML_mmlu_geo_string)"
|
| 604 |
+
]
|
| 605 |
+
},
|
| 606 |
+
{
|
| 607 |
+
"cell_type": "code",
|
| 608 |
+
"execution_count": 10,
|
| 609 |
+
"metadata": {
|
| 610 |
+
"id": "-_CVnDirdy7j"
|
| 611 |
+
},
|
| 612 |
+
"outputs": [
|
| 613 |
+
{
|
| 614 |
+
"name": "stdout",
|
| 615 |
+
"output_type": "stream",
|
| 616 |
+
"text": [
|
| 617 |
+
"2023-11-29:11:58:21,284 INFO [utils.py:160] NumExpr defaulting to 2 threads.\n",
|
| 618 |
+
"2023-11-29 11:58:22.850159: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
|
| 619 |
+
"2023-11-29 11:58:22.850219: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
|
| 620 |
+
"2023-11-29 11:58:22.850254: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
|
| 621 |
+
"2023-11-29 11:58:24.948103: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
|
| 622 |
+
"2023-11-29:11:58:28,460 INFO [__main__.py:132] Verbosity set to INFO\n",
|
| 623 |
+
"2023-11-29:11:58:37,935 WARNING [__main__.py:138] --limit SHOULD ONLY BE USED FOR TESTING.REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.\n",
|
| 624 |
+
"2023-11-29:11:58:37,935 INFO [__main__.py:143] Including path: ./\n",
|
| 625 |
+
"2023-11-29:11:58:37,969 INFO [__main__.py:205] Selected Tasks: ['demo_mmlu_high_school_geography_continuation']\n",
|
| 626 |
+
"2023-11-29:11:58:37,972 WARNING [evaluator.py:93] generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks.\n",
|
| 627 |
+
"2023-11-29:11:58:38,008 INFO [huggingface.py:120] Using device 'cuda'\n",
|
| 628 |
+
"2023-11-29:11:58:59,758 INFO [task.py:355] Building contexts for task on rank 0...\n",
|
| 629 |
+
"2023-11-29:11:58:59,777 INFO [evaluator.py:319] Running loglikelihood requests\n",
|
| 630 |
+
"100% 40/40 [00:02<00:00, 16.23it/s]\n",
|
| 631 |
+
"fatal: not a git repository (or any of the parent directories): .git\n",
|
| 632 |
+
"hf (pretrained=EleutherAI/pythia-2.8b), gen_kwargs: (), limit: 10.0, num_fewshot: None, batch_size: 1\n",
|
| 633 |
+
"| Tasks |Version|Filter|n-shot| Metric |Value| |Stderr|\n",
|
| 634 |
+
"|--------------------------------------------|-------|------|-----:|--------|----:|---|-----:|\n",
|
| 635 |
+
"|demo_mmlu_high_school_geography_continuation|Yaml |none | 0|acc | 0.1|± |0.1000|\n",
|
| 636 |
+
"| | |none | 0|acc_norm| 0.2|± |0.1333|\n",
|
| 637 |
+
"\n"
|
| 638 |
+
]
|
| 639 |
+
}
|
| 640 |
+
],
|
| 641 |
+
"source": [
|
| 642 |
+
"# !accelerate launch --no_python\n",
|
| 643 |
+
"%env LOGLEVEL=DEBUG\n",
|
| 644 |
+
"!lm_eval \\\n",
|
| 645 |
+
" --model hf \\\n",
|
| 646 |
+
" --model_args pretrained=EleutherAI/pythia-2.8b \\\n",
|
| 647 |
+
" --include_path ./ \\\n",
|
| 648 |
+
" --tasks demo_mmlu_high_school_geography_continuation \\\n",
|
| 649 |
+
" --limit 10 \\\n",
|
| 650 |
+
" --output output/mmlu_high_school_geography_continuation/ \\\n",
|
| 651 |
+
" --log_samples"
|
| 652 |
+
]
|
| 653 |
+
},
|
| 654 |
+
{
|
| 655 |
+
"cell_type": "markdown",
|
| 656 |
+
"metadata": {
|
| 657 |
+
"id": "-_CVnDirdy7j"
|
| 658 |
+
},
|
| 659 |
+
"source": [
|
| 660 |
+
"If we take a look at the samples, we can see that it is in fact evaluating the continuation based on the choices rather than the letters."
|
| 661 |
+
]
|
| 662 |
+
},
|
| 663 |
+
{
|
| 664 |
+
"cell_type": "code",
|
| 665 |
+
"execution_count": 11,
|
| 666 |
+
"metadata": {
|
| 667 |
+
"id": "duBDqC6PAdjL"
|
| 668 |
+
},
|
| 669 |
+
"outputs": [
|
| 670 |
+
{
|
| 671 |
+
"data": {
|
| 672 |
+
"application/javascript": "\n ((filepath) => {{\n if (!google.colab.kernel.accessAllowed) {{\n return;\n }}\n google.colab.files.view(filepath);\n }})(\"/content/output/mmlu_high_school_geography_continuation/pretrained__EleutherAI__pythia-2.8b_demo_mmlu_high_school_geography_continuation.jsonl\")",
|
| 673 |
+
"text/plain": [
|
| 674 |
+
"<IPython.core.display.Javascript object>"
|
| 675 |
+
]
|
| 676 |
+
},
|
| 677 |
+
"metadata": {},
|
| 678 |
+
"output_type": "display_data"
|
| 679 |
+
}
|
| 680 |
+
],
|
| 681 |
+
"source": [
|
| 682 |
+
"from google.colab import files\n",
|
| 683 |
+
"\n",
|
| 684 |
+
"\n",
|
| 685 |
+
"files.view(\n",
|
| 686 |
+
" \"output/mmlu_high_school_geography_continuation/pretrained__EleutherAI__pythia-2.8b_demo_mmlu_high_school_geography_continuation.jsonl\"\n",
|
| 687 |
+
")"
|
| 688 |
+
]
|
| 689 |
+
},
|
| 690 |
+
{
|
| 691 |
+
"cell_type": "markdown",
|
| 692 |
+
"metadata": {
|
| 693 |
+
"id": "6p0-KPwAgK5j"
|
| 694 |
+
},
|
| 695 |
+
"source": [
|
| 696 |
+
"## Closer Look at YAML Fields\n",
|
| 697 |
+
"\n",
|
| 698 |
+
"To prepare a task we can simply fill in a YAML config with the relevant information.\n",
|
| 699 |
+
"\n",
|
| 700 |
+
"`output_type`\n",
|
| 701 |
+
"The current provided evaluation types comprise of the following:\n",
|
| 702 |
+
"1. `loglikelihood`: Evaluates the loglikelihood of a continuation, conditioned on some input string.\n",
|
| 703 |
+
"2. `loglikelihood_rolling`: evaluate the loglikelihood of producing a string, conditioned on the empty string. (Used for perplexity evaluations)\n",
|
| 704 |
+
"3. `multiple_choice`: Evaluates loglikelihood among the a number of choices predicted by the model.\n",
|
| 705 |
+
"4. `greedy_until`: Model outputs greedy generation (can be configured to to use beam search and other generation-related parameters)\n",
|
| 706 |
+
"\n",
|
| 707 |
+
"The core prompt revolves around 3 fields.\n",
|
| 708 |
+
"1. `doc_to_text`: Denotes the prompt template that will be used as input to the model.\n",
|
| 709 |
+
"2. `doc_to_choice`: Available choices that will be used as continuation for the model. This is used when the `output_type` is `multiple_choice`, and otherwise can be left as `None`.\n",
|
| 710 |
+
"3. `doc_to_target`: When `output_type` is `multiple_choice`, this can be an index that corresponds to the correct answer, or the answer string itself (must be a subset of `doc_to_choice`). For other tasks, this is expected to be a string. You can fill this field with a feature name from the HF dataset so long as the resulting feature follows the conditioned described.\n",
|
| 711 |
+
"\n",
|
| 712 |
+
"These three fields can be expressed as strings, column names from the source dataset, or as Jinja2 templates that can use fields from the source dataset as variables.\n"
|
| 713 |
+
]
|
| 714 |
+
},
|
| 715 |
+
{
|
| 716 |
+
"cell_type": "markdown",
|
| 717 |
+
"metadata": {
|
| 718 |
+
"id": "6p0-KPwAgK5j"
|
| 719 |
+
},
|
| 720 |
+
"source": [
|
| 721 |
+
"## What if Jinja is not Sufficient?\n",
|
| 722 |
+
"\n",
|
| 723 |
+
"There can be times where the Jinja2 templating language is not enough to make the prompt we had in mind. There are a few ways to circumvent this limitation:\n",
|
| 724 |
+
"\n",
|
| 725 |
+
"1. Use `!function` operator for the prompt-related fields to pass a python function that takes as input the dataset row, and will output the prompt template component.\n",
|
| 726 |
+
"2. Perform a transformation on the dataset beforehand."
|
| 727 |
+
]
|
| 728 |
+
},
|
| 729 |
+
{
|
| 730 |
+
"cell_type": "markdown",
|
| 731 |
+
"metadata": {},
|
| 732 |
+
"source": [
|
| 733 |
+
"Below, we show an example of using `!function` to create `doc_to_text` from a python function:"
|
| 734 |
+
]
|
| 735 |
+
},
|
| 736 |
+
{
|
| 737 |
+
"cell_type": "code",
|
| 738 |
+
"execution_count": 12,
|
| 739 |
+
"metadata": {
|
| 740 |
+
"colab": {
|
| 741 |
+
"base_uri": "https://localhost:8080/"
|
| 742 |
+
},
|
| 743 |
+
"id": "DYZ5c0JhR1lJ",
|
| 744 |
+
"outputId": "ca945235-fb9e-4f17-8bfa-78e7d6ec1490"
|
| 745 |
+
},
|
| 746 |
+
"outputs": [
|
| 747 |
+
{
|
| 748 |
+
"name": "stdout",
|
| 749 |
+
"output_type": "stream",
|
| 750 |
+
"text": [
|
| 751 |
+
"2023-11-29:11:59:08,312 INFO [utils.py:160] NumExpr defaulting to 2 threads.\n",
|
| 752 |
+
"2023-11-29 11:59:09.348327: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
|
| 753 |
+
"2023-11-29 11:59:09.348387: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
|
| 754 |
+
"2023-11-29 11:59:09.348421: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
|
| 755 |
+
"2023-11-29 11:59:10.573752: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
|
| 756 |
+
"2023-11-29:11:59:14,044 INFO [__main__.py:132] Verbosity set to INFO\n",
|
| 757 |
+
"2023-11-29:11:59:23,654 WARNING [__main__.py:138] --limit SHOULD ONLY BE USED FOR TESTING.REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.\n",
|
| 758 |
+
"2023-11-29:11:59:23,654 INFO [__main__.py:143] Including path: ./\n",
|
| 759 |
+
"2023-11-29:11:59:23,678 INFO [__main__.py:205] Selected Tasks: ['demo_mmlu_high_school_geography_function_prompt']\n",
|
| 760 |
+
"2023-11-29:11:59:23,679 WARNING [evaluator.py:93] generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks.\n",
|
| 761 |
+
"2023-11-29:11:59:23,708 INFO [huggingface.py:120] Using device 'cuda'\n",
|
| 762 |
+
"2023-11-29:11:59:44,516 INFO [task.py:355] Building contexts for task on rank 0...\n",
|
| 763 |
+
"2023-11-29:11:59:44,524 INFO [evaluator.py:319] Running loglikelihood requests\n",
|
| 764 |
+
"100% 40/40 [00:02<00:00, 15.41it/s]\n",
|
| 765 |
+
"fatal: not a git repository (or any of the parent directories): .git\n",
|
| 766 |
+
"hf (pretrained=EleutherAI/pythia-2.8b), gen_kwargs: (), limit: 10.0, num_fewshot: None, batch_size: 1\n",
|
| 767 |
+
"| Tasks |Version|Filter|n-shot| Metric |Value| |Stderr|\n",
|
| 768 |
+
"|-----------------------------------------------|-------|------|-----:|--------|----:|---|-----:|\n",
|
| 769 |
+
"|demo_mmlu_high_school_geography_function_prompt|Yaml |none | 0|acc | 0.1|± |0.1000|\n",
|
| 770 |
+
"| | |none | 0|acc_norm| 0.2|± |0.1333|\n",
|
| 771 |
+
"\n"
|
| 772 |
+
]
|
| 773 |
+
}
|
| 774 |
+
],
|
| 775 |
+
"source": [
|
| 776 |
+
"YAML_mmlu_geo_string = \"\"\"\n",
|
| 777 |
+
"include: mmlu_high_school_geography.yaml\n",
|
| 778 |
+
"task: demo_mmlu_high_school_geography_function_prompt\n",
|
| 779 |
+
"doc_to_text: !function utils.doc_to_text\n",
|
| 780 |
+
"doc_to_choice: \"{{choices}}\"\n",
|
| 781 |
+
"\"\"\"\n",
|
| 782 |
+
"with open(\"demo_mmlu_high_school_geography_function_prompt.yaml\", \"w\") as f:\n",
|
| 783 |
+
" f.write(YAML_mmlu_geo_string)\n",
|
| 784 |
+
"\n",
|
| 785 |
+
"DOC_TO_TEXT = \"\"\"\n",
|
| 786 |
+
"def doc_to_text(x):\n",
|
| 787 |
+
" question = x[\"question\"].strip()\n",
|
| 788 |
+
" choices = x[\"choices\"]\n",
|
| 789 |
+
" option_a = choices[0]\n",
|
| 790 |
+
" option_b = choices[1]\n",
|
| 791 |
+
" option_c = choices[2]\n",
|
| 792 |
+
" option_d = choices[3]\n",
|
| 793 |
+
" return f\"{question}\\\\nA. {option_a}\\\\nB. {option_b}\\\\nC. {option_c}\\\\nD. {option_d}\\\\nAnswer:\"\n",
|
| 794 |
+
"\"\"\"\n",
|
| 795 |
+
"with open(\"utils.py\", \"w\") as f:\n",
|
| 796 |
+
" f.write(DOC_TO_TEXT)\n",
|
| 797 |
+
"\n",
|
| 798 |
+
"!lm_eval \\\n",
|
| 799 |
+
" --model hf \\\n",
|
| 800 |
+
" --model_args pretrained=EleutherAI/pythia-2.8b \\\n",
|
| 801 |
+
" --include_path ./ \\\n",
|
| 802 |
+
" --tasks demo_mmlu_high_school_geography_function_prompt \\\n",
|
| 803 |
+
" --limit 10 \\\n",
|
| 804 |
+
" --output output/demo_mmlu_high_school_geography_function_prompt/ \\\n",
|
| 805 |
+
" --log_samples"
|
| 806 |
+
]
|
| 807 |
+
},
|
| 808 |
+
{
|
| 809 |
+
"cell_type": "markdown",
|
| 810 |
+
"metadata": {},
|
| 811 |
+
"source": [
|
| 812 |
+
"Next, we'll also show how to do this via preprocessing the dataset as necessary using the `process_docs` config field:\n",
|
| 813 |
+
"\n",
|
| 814 |
+
"We will write a function that will modify each document in our evaluation dataset's split to add a field that is suitable for us to use in `doc_to_text`."
|
| 815 |
+
]
|
| 816 |
+
},
|
| 817 |
+
{
|
| 818 |
+
"cell_type": "code",
|
| 819 |
+
"execution_count": null,
|
| 820 |
+
"metadata": {},
|
| 821 |
+
"outputs": [],
|
| 822 |
+
"source": [
|
| 823 |
+
"YAML_mmlu_geo_string = \"\"\"\n",
|
| 824 |
+
"include: mmlu_high_school_geography.yaml\n",
|
| 825 |
+
"task: demo_mmlu_high_school_geography_function_prompt_2\n",
|
| 826 |
+
"process_docs: !function utils_process_docs.process_docs\n",
|
| 827 |
+
"doc_to_text: \"{{input}}\"\n",
|
| 828 |
+
"doc_to_choice: \"{{choices}}\"\n",
|
| 829 |
+
"\"\"\"\n",
|
| 830 |
+
"with open(\"demo_mmlu_high_school_geography_process_docs.yaml\", \"w\") as f:\n",
|
| 831 |
+
" f.write(YAML_mmlu_geo_string)\n",
|
| 832 |
+
"\n",
|
| 833 |
+
"DOC_TO_TEXT = \"\"\"\n",
|
| 834 |
+
"def process_docs(dataset):\n",
|
| 835 |
+
" def _process_doc(x):\n",
|
| 836 |
+
" question = x[\"question\"].strip()\n",
|
| 837 |
+
" choices = x[\"choices\"]\n",
|
| 838 |
+
" option_a = choices[0]\n",
|
| 839 |
+
" option_b = choices[1]\n",
|
| 840 |
+
" option_c = choices[2]\n",
|
| 841 |
+
" option_d = choices[3]\n",
|
| 842 |
+
" doc[\"input\"] = f\"{question}\\\\nA. {option_a}\\\\nB. {option_b}\\\\nC. {option_c}\\\\nD. {option_d}\\\\nAnswer:\"\n",
|
| 843 |
+
" return out_doc\n",
|
| 844 |
+
"\n",
|
| 845 |
+
" return dataset.map(_process_doc)\n",
|
| 846 |
+
"\"\"\"\n",
|
| 847 |
+
"\n",
|
| 848 |
+
"with open(\"utils_process_docs.py\", \"w\") as f:\n",
|
| 849 |
+
" f.write(DOC_TO_TEXT)\n",
|
| 850 |
+
"\n",
|
| 851 |
+
"!lm_eval \\\n",
|
| 852 |
+
" --model hf \\\n",
|
| 853 |
+
" --model_args pretrained=EleutherAI/pythia-2.8b \\\n",
|
| 854 |
+
" --include_path ./ \\\n",
|
| 855 |
+
" --tasks demo_mmlu_high_school_geography_function_prompt_2 \\\n",
|
| 856 |
+
" --limit 10 \\\n",
|
| 857 |
+
" --output output/demo_mmlu_high_school_geography_function_prompt_2/ \\\n",
|
| 858 |
+
" --log_samples"
|
| 859 |
+
]
|
| 860 |
+
},
|
| 861 |
+
{
|
| 862 |
+
"cell_type": "markdown",
|
| 863 |
+
"metadata": {},
|
| 864 |
+
"source": [
|
| 865 |
+
"We hope that this explainer gives you a sense of what can be done with and how to work with LM-Evaluation-Harnes v0.4.0 ! \n",
|
| 866 |
+
"\n",
|
| 867 |
+
"For more information, check out our documentation pages in the `docs/` folder, and if you have questions, please raise them in GitHub issues, or in #lm-thunderdome or #release-discussion on the EleutherAI discord server."
|
| 868 |
+
]
|
| 869 |
+
}
|
| 870 |
+
],
|
| 871 |
+
"metadata": {
|
| 872 |
+
"accelerator": "GPU",
|
| 873 |
+
"colab": {
|
| 874 |
+
"collapsed_sections": [
|
| 875 |
+
"zAov81vTbL2K"
|
| 876 |
+
],
|
| 877 |
+
"gpuType": "T4",
|
| 878 |
+
"provenance": []
|
| 879 |
+
},
|
| 880 |
+
"kernelspec": {
|
| 881 |
+
"display_name": "Python 3",
|
| 882 |
+
"name": "python3"
|
| 883 |
+
},
|
| 884 |
+
"language_info": {
|
| 885 |
+
"name": "python"
|
| 886 |
+
},
|
| 887 |
+
"widgets": {
|
| 888 |
+
"application/vnd.jupyter.widget-state+json": {
|
| 889 |
+
"state": {
|
| 890 |
+
"46f521b73fd943c081c648fd873ebc0a": {
|
| 891 |
+
"model_module": "@jupyter-widgets/controls",
|
| 892 |
+
"model_module_version": "1.5.0",
|
| 893 |
+
"model_name": "DescriptionStyleModel",
|
| 894 |
+
"state": {
|
| 895 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 896 |
+
"_model_module_version": "1.5.0",
|
| 897 |
+
"_model_name": "DescriptionStyleModel",
|
| 898 |
+
"_view_count": null,
|
| 899 |
+
"_view_module": "@jupyter-widgets/base",
|
| 900 |
+
"_view_module_version": "1.2.0",
|
| 901 |
+
"_view_name": "StyleView",
|
| 902 |
+
"description_width": ""
|
| 903 |
+
}
|
| 904 |
+
},
|
| 905 |
+
"48763b6233374554ae76035c0483066f": {
|
| 906 |
+
"model_module": "@jupyter-widgets/controls",
|
| 907 |
+
"model_module_version": "1.5.0",
|
| 908 |
+
"model_name": "ProgressStyleModel",
|
| 909 |
+
"state": {
|
| 910 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 911 |
+
"_model_module_version": "1.5.0",
|
| 912 |
+
"_model_name": "ProgressStyleModel",
|
| 913 |
+
"_view_count": null,
|
| 914 |
+
"_view_module": "@jupyter-widgets/base",
|
| 915 |
+
"_view_module_version": "1.2.0",
|
| 916 |
+
"_view_name": "StyleView",
|
| 917 |
+
"bar_color": null,
|
| 918 |
+
"description_width": ""
|
| 919 |
+
}
|
| 920 |
+
},
|
| 921 |
+
"4986a21eb560448fa79f4b25cde48951": {
|
| 922 |
+
"model_module": "@jupyter-widgets/base",
|
| 923 |
+
"model_module_version": "1.2.0",
|
| 924 |
+
"model_name": "LayoutModel",
|
| 925 |
+
"state": {
|
| 926 |
+
"_model_module": "@jupyter-widgets/base",
|
| 927 |
+
"_model_module_version": "1.2.0",
|
| 928 |
+
"_model_name": "LayoutModel",
|
| 929 |
+
"_view_count": null,
|
| 930 |
+
"_view_module": "@jupyter-widgets/base",
|
| 931 |
+
"_view_module_version": "1.2.0",
|
| 932 |
+
"_view_name": "LayoutView",
|
| 933 |
+
"align_content": null,
|
| 934 |
+
"align_items": null,
|
| 935 |
+
"align_self": null,
|
| 936 |
+
"border": null,
|
| 937 |
+
"bottom": null,
|
| 938 |
+
"display": null,
|
| 939 |
+
"flex": null,
|
| 940 |
+
"flex_flow": null,
|
| 941 |
+
"grid_area": null,
|
| 942 |
+
"grid_auto_columns": null,
|
| 943 |
+
"grid_auto_flow": null,
|
| 944 |
+
"grid_auto_rows": null,
|
| 945 |
+
"grid_column": null,
|
| 946 |
+
"grid_gap": null,
|
| 947 |
+
"grid_row": null,
|
| 948 |
+
"grid_template_areas": null,
|
| 949 |
+
"grid_template_columns": null,
|
| 950 |
+
"grid_template_rows": null,
|
| 951 |
+
"height": null,
|
| 952 |
+
"justify_content": null,
|
| 953 |
+
"justify_items": null,
|
| 954 |
+
"left": null,
|
| 955 |
+
"margin": null,
|
| 956 |
+
"max_height": null,
|
| 957 |
+
"max_width": null,
|
| 958 |
+
"min_height": null,
|
| 959 |
+
"min_width": null,
|
| 960 |
+
"object_fit": null,
|
| 961 |
+
"object_position": null,
|
| 962 |
+
"order": null,
|
| 963 |
+
"overflow": null,
|
| 964 |
+
"overflow_x": null,
|
| 965 |
+
"overflow_y": null,
|
| 966 |
+
"padding": null,
|
| 967 |
+
"right": null,
|
| 968 |
+
"top": null,
|
| 969 |
+
"visibility": null,
|
| 970 |
+
"width": null
|
| 971 |
+
}
|
| 972 |
+
},
|
| 973 |
+
"6b2d90209ec14230b3d58a74ac9b83bf": {
|
| 974 |
+
"model_module": "@jupyter-widgets/base",
|
| 975 |
+
"model_module_version": "1.2.0",
|
| 976 |
+
"model_name": "LayoutModel",
|
| 977 |
+
"state": {
|
| 978 |
+
"_model_module": "@jupyter-widgets/base",
|
| 979 |
+
"_model_module_version": "1.2.0",
|
| 980 |
+
"_model_name": "LayoutModel",
|
| 981 |
+
"_view_count": null,
|
| 982 |
+
"_view_module": "@jupyter-widgets/base",
|
| 983 |
+
"_view_module_version": "1.2.0",
|
| 984 |
+
"_view_name": "LayoutView",
|
| 985 |
+
"align_content": null,
|
| 986 |
+
"align_items": null,
|
| 987 |
+
"align_self": null,
|
| 988 |
+
"border": null,
|
| 989 |
+
"bottom": null,
|
| 990 |
+
"display": null,
|
| 991 |
+
"flex": null,
|
| 992 |
+
"flex_flow": null,
|
| 993 |
+
"grid_area": null,
|
| 994 |
+
"grid_auto_columns": null,
|
| 995 |
+
"grid_auto_flow": null,
|
| 996 |
+
"grid_auto_rows": null,
|
| 997 |
+
"grid_column": null,
|
| 998 |
+
"grid_gap": null,
|
| 999 |
+
"grid_row": null,
|
| 1000 |
+
"grid_template_areas": null,
|
| 1001 |
+
"grid_template_columns": null,
|
| 1002 |
+
"grid_template_rows": null,
|
| 1003 |
+
"height": null,
|
| 1004 |
+
"justify_content": null,
|
| 1005 |
+
"justify_items": null,
|
| 1006 |
+
"left": null,
|
| 1007 |
+
"margin": null,
|
| 1008 |
+
"max_height": null,
|
| 1009 |
+
"max_width": null,
|
| 1010 |
+
"min_height": null,
|
| 1011 |
+
"min_width": null,
|
| 1012 |
+
"object_fit": null,
|
| 1013 |
+
"object_position": null,
|
| 1014 |
+
"order": null,
|
| 1015 |
+
"overflow": null,
|
| 1016 |
+
"overflow_x": null,
|
| 1017 |
+
"overflow_y": null,
|
| 1018 |
+
"padding": null,
|
| 1019 |
+
"right": null,
|
| 1020 |
+
"top": null,
|
| 1021 |
+
"visibility": null,
|
| 1022 |
+
"width": null
|
| 1023 |
+
}
|
| 1024 |
+
},
|
| 1025 |
+
"7c5689bc13684db8a22681f41863dddd": {
|
| 1026 |
+
"model_module": "@jupyter-widgets/base",
|
| 1027 |
+
"model_module_version": "1.2.0",
|
| 1028 |
+
"model_name": "LayoutModel",
|
| 1029 |
+
"state": {
|
| 1030 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1031 |
+
"_model_module_version": "1.2.0",
|
| 1032 |
+
"_model_name": "LayoutModel",
|
| 1033 |
+
"_view_count": null,
|
| 1034 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1035 |
+
"_view_module_version": "1.2.0",
|
| 1036 |
+
"_view_name": "LayoutView",
|
| 1037 |
+
"align_content": null,
|
| 1038 |
+
"align_items": null,
|
| 1039 |
+
"align_self": null,
|
| 1040 |
+
"border": null,
|
| 1041 |
+
"bottom": null,
|
| 1042 |
+
"display": null,
|
| 1043 |
+
"flex": null,
|
| 1044 |
+
"flex_flow": null,
|
| 1045 |
+
"grid_area": null,
|
| 1046 |
+
"grid_auto_columns": null,
|
| 1047 |
+
"grid_auto_flow": null,
|
| 1048 |
+
"grid_auto_rows": null,
|
| 1049 |
+
"grid_column": null,
|
| 1050 |
+
"grid_gap": null,
|
| 1051 |
+
"grid_row": null,
|
| 1052 |
+
"grid_template_areas": null,
|
| 1053 |
+
"grid_template_columns": null,
|
| 1054 |
+
"grid_template_rows": null,
|
| 1055 |
+
"height": null,
|
| 1056 |
+
"justify_content": null,
|
| 1057 |
+
"justify_items": null,
|
| 1058 |
+
"left": null,
|
| 1059 |
+
"margin": null,
|
| 1060 |
+
"max_height": null,
|
| 1061 |
+
"max_width": null,
|
| 1062 |
+
"min_height": null,
|
| 1063 |
+
"min_width": null,
|
| 1064 |
+
"object_fit": null,
|
| 1065 |
+
"object_position": null,
|
| 1066 |
+
"order": null,
|
| 1067 |
+
"overflow": null,
|
| 1068 |
+
"overflow_x": null,
|
| 1069 |
+
"overflow_y": null,
|
| 1070 |
+
"padding": null,
|
| 1071 |
+
"right": null,
|
| 1072 |
+
"top": null,
|
| 1073 |
+
"visibility": null,
|
| 1074 |
+
"width": null
|
| 1075 |
+
}
|
| 1076 |
+
},
|
| 1077 |
+
"a1d3a8aa016544a78e8821c8f6199e06": {
|
| 1078 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1079 |
+
"model_module_version": "1.5.0",
|
| 1080 |
+
"model_name": "HBoxModel",
|
| 1081 |
+
"state": {
|
| 1082 |
+
"_dom_classes": [],
|
| 1083 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1084 |
+
"_model_module_version": "1.5.0",
|
| 1085 |
+
"_model_name": "HBoxModel",
|
| 1086 |
+
"_view_count": null,
|
| 1087 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1088 |
+
"_view_module_version": "1.5.0",
|
| 1089 |
+
"_view_name": "HBoxView",
|
| 1090 |
+
"box_style": "",
|
| 1091 |
+
"children": [
|
| 1092 |
+
"IPY_MODEL_f61ed33fad754146bdd2ac9db1ba1c48",
|
| 1093 |
+
"IPY_MODEL_bfa0af6aeff344c6845e1080a878e92e",
|
| 1094 |
+
"IPY_MODEL_fd1ad9e0367d4004aae853b91c3a7617"
|
| 1095 |
+
],
|
| 1096 |
+
"layout": "IPY_MODEL_6b2d90209ec14230b3d58a74ac9b83bf"
|
| 1097 |
+
}
|
| 1098 |
+
},
|
| 1099 |
+
"a73f357065d34d7baf0453ae4a8d75e2": {
|
| 1100 |
+
"model_module": "@jupyter-widgets/base",
|
| 1101 |
+
"model_module_version": "1.2.0",
|
| 1102 |
+
"model_name": "LayoutModel",
|
| 1103 |
+
"state": {
|
| 1104 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1105 |
+
"_model_module_version": "1.2.0",
|
| 1106 |
+
"_model_name": "LayoutModel",
|
| 1107 |
+
"_view_count": null,
|
| 1108 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1109 |
+
"_view_module_version": "1.2.0",
|
| 1110 |
+
"_view_name": "LayoutView",
|
| 1111 |
+
"align_content": null,
|
| 1112 |
+
"align_items": null,
|
| 1113 |
+
"align_self": null,
|
| 1114 |
+
"border": null,
|
| 1115 |
+
"bottom": null,
|
| 1116 |
+
"display": null,
|
| 1117 |
+
"flex": null,
|
| 1118 |
+
"flex_flow": null,
|
| 1119 |
+
"grid_area": null,
|
| 1120 |
+
"grid_auto_columns": null,
|
| 1121 |
+
"grid_auto_flow": null,
|
| 1122 |
+
"grid_auto_rows": null,
|
| 1123 |
+
"grid_column": null,
|
| 1124 |
+
"grid_gap": null,
|
| 1125 |
+
"grid_row": null,
|
| 1126 |
+
"grid_template_areas": null,
|
| 1127 |
+
"grid_template_columns": null,
|
| 1128 |
+
"grid_template_rows": null,
|
| 1129 |
+
"height": null,
|
| 1130 |
+
"justify_content": null,
|
| 1131 |
+
"justify_items": null,
|
| 1132 |
+
"left": null,
|
| 1133 |
+
"margin": null,
|
| 1134 |
+
"max_height": null,
|
| 1135 |
+
"max_width": null,
|
| 1136 |
+
"min_height": null,
|
| 1137 |
+
"min_width": null,
|
| 1138 |
+
"object_fit": null,
|
| 1139 |
+
"object_position": null,
|
| 1140 |
+
"order": null,
|
| 1141 |
+
"overflow": null,
|
| 1142 |
+
"overflow_x": null,
|
| 1143 |
+
"overflow_y": null,
|
| 1144 |
+
"padding": null,
|
| 1145 |
+
"right": null,
|
| 1146 |
+
"top": null,
|
| 1147 |
+
"visibility": null,
|
| 1148 |
+
"width": null
|
| 1149 |
+
}
|
| 1150 |
+
},
|
| 1151 |
+
"aed3acd2f2d74003b44079c333a0698e": {
|
| 1152 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1153 |
+
"model_module_version": "1.5.0",
|
| 1154 |
+
"model_name": "DescriptionStyleModel",
|
| 1155 |
+
"state": {
|
| 1156 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1157 |
+
"_model_module_version": "1.5.0",
|
| 1158 |
+
"_model_name": "DescriptionStyleModel",
|
| 1159 |
+
"_view_count": null,
|
| 1160 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1161 |
+
"_view_module_version": "1.2.0",
|
| 1162 |
+
"_view_name": "StyleView",
|
| 1163 |
+
"description_width": ""
|
| 1164 |
+
}
|
| 1165 |
+
},
|
| 1166 |
+
"bfa0af6aeff344c6845e1080a878e92e": {
|
| 1167 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1168 |
+
"model_module_version": "1.5.0",
|
| 1169 |
+
"model_name": "FloatProgressModel",
|
| 1170 |
+
"state": {
|
| 1171 |
+
"_dom_classes": [],
|
| 1172 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1173 |
+
"_model_module_version": "1.5.0",
|
| 1174 |
+
"_model_name": "FloatProgressModel",
|
| 1175 |
+
"_view_count": null,
|
| 1176 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1177 |
+
"_view_module_version": "1.5.0",
|
| 1178 |
+
"_view_name": "ProgressView",
|
| 1179 |
+
"bar_style": "success",
|
| 1180 |
+
"description": "",
|
| 1181 |
+
"description_tooltip": null,
|
| 1182 |
+
"layout": "IPY_MODEL_7c5689bc13684db8a22681f41863dddd",
|
| 1183 |
+
"max": 5669,
|
| 1184 |
+
"min": 0,
|
| 1185 |
+
"orientation": "horizontal",
|
| 1186 |
+
"style": "IPY_MODEL_48763b6233374554ae76035c0483066f",
|
| 1187 |
+
"value": 5669
|
| 1188 |
+
}
|
| 1189 |
+
},
|
| 1190 |
+
"f61ed33fad754146bdd2ac9db1ba1c48": {
|
| 1191 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1192 |
+
"model_module_version": "1.5.0",
|
| 1193 |
+
"model_name": "HTMLModel",
|
| 1194 |
+
"state": {
|
| 1195 |
+
"_dom_classes": [],
|
| 1196 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1197 |
+
"_model_module_version": "1.5.0",
|
| 1198 |
+
"_model_name": "HTMLModel",
|
| 1199 |
+
"_view_count": null,
|
| 1200 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1201 |
+
"_view_module_version": "1.5.0",
|
| 1202 |
+
"_view_name": "HTMLView",
|
| 1203 |
+
"description": "",
|
| 1204 |
+
"description_tooltip": null,
|
| 1205 |
+
"layout": "IPY_MODEL_a73f357065d34d7baf0453ae4a8d75e2",
|
| 1206 |
+
"placeholder": "",
|
| 1207 |
+
"style": "IPY_MODEL_46f521b73fd943c081c648fd873ebc0a",
|
| 1208 |
+
"value": "Downloading builder script: 100%"
|
| 1209 |
+
}
|
| 1210 |
+
},
|
| 1211 |
+
"fd1ad9e0367d4004aae853b91c3a7617": {
|
| 1212 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1213 |
+
"model_module_version": "1.5.0",
|
| 1214 |
+
"model_name": "HTMLModel",
|
| 1215 |
+
"state": {
|
| 1216 |
+
"_dom_classes": [],
|
| 1217 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1218 |
+
"_model_module_version": "1.5.0",
|
| 1219 |
+
"_model_name": "HTMLModel",
|
| 1220 |
+
"_view_count": null,
|
| 1221 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1222 |
+
"_view_module_version": "1.5.0",
|
| 1223 |
+
"_view_name": "HTMLView",
|
| 1224 |
+
"description": "",
|
| 1225 |
+
"description_tooltip": null,
|
| 1226 |
+
"layout": "IPY_MODEL_4986a21eb560448fa79f4b25cde48951",
|
| 1227 |
+
"placeholder": "",
|
| 1228 |
+
"style": "IPY_MODEL_aed3acd2f2d74003b44079c333a0698e",
|
| 1229 |
+
"value": " 5.67k/5.67k [00:00<00:00, 205kB/s]"
|
| 1230 |
+
}
|
| 1231 |
+
}
|
| 1232 |
+
},
|
| 1233 |
+
"version_major": 2,
|
| 1234 |
+
"version_minor": 0
|
| 1235 |
+
}
|
| 1236 |
+
}
|
| 1237 |
+
},
|
| 1238 |
+
"nbformat": 4,
|
| 1239 |
+
"nbformat_minor": 0
|
| 1240 |
+
}
|
lm-evaluation-harness/examples/transformer-lens.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformer_lens import HookedTransformer
|
| 6 |
+
from transformers import AutoConfig
|
| 7 |
+
|
| 8 |
+
from lm_eval import evaluator
|
| 9 |
+
from lm_eval.models.huggingface import HFLM
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def evaluate_lm_eval(lens_model: HookedTransformer, tasks: list[str], **kwargs):
|
| 13 |
+
class HFLikeModelAdapter(nn.Module):
|
| 14 |
+
"""Adapts HookedTransformer to match the HuggingFace interface expected by lm-eval"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, model: HookedTransformer):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.model = model
|
| 19 |
+
self.tokenizer = model.tokenizer
|
| 20 |
+
self.config = AutoConfig.from_pretrained(model.cfg.tokenizer_name)
|
| 21 |
+
self.device = model.cfg.device
|
| 22 |
+
self.tie_weights = lambda: self
|
| 23 |
+
|
| 24 |
+
def forward(self, input_ids=None, attention_mask=None, **kwargs):
|
| 25 |
+
output = self.model(input_ids, attention_mask=attention_mask, **kwargs)
|
| 26 |
+
# Make sure output has the expected .logits attribute
|
| 27 |
+
if not hasattr(output, "logits"):
|
| 28 |
+
if isinstance(output, torch.Tensor):
|
| 29 |
+
output.logits = output
|
| 30 |
+
return output
|
| 31 |
+
|
| 32 |
+
# Only delegate specific attributes we know we need
|
| 33 |
+
def to(self, *args, **kwargs):
|
| 34 |
+
return self.model.to(*args, **kwargs)
|
| 35 |
+
|
| 36 |
+
def eval(self):
|
| 37 |
+
self.model.eval()
|
| 38 |
+
return self
|
| 39 |
+
|
| 40 |
+
def train(self, mode=True):
|
| 41 |
+
self.model.train(mode)
|
| 42 |
+
return self
|
| 43 |
+
|
| 44 |
+
model = HFLikeModelAdapter(lens_model)
|
| 45 |
+
warnings.filterwarnings("ignore", message="Failed to get model SHA for")
|
| 46 |
+
results = evaluator.simple_evaluate(
|
| 47 |
+
model=HFLM(pretrained=model, tokenizer=model.tokenizer),
|
| 48 |
+
tasks=tasks,
|
| 49 |
+
verbosity="WARNING",
|
| 50 |
+
**kwargs,
|
| 51 |
+
)
|
| 52 |
+
return results
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
# Load base model
|
| 57 |
+
model = HookedTransformer.from_pretrained("pythia-70m")
|
| 58 |
+
res = evaluate_lm_eval(model, tasks=["arc_easy"])
|
| 59 |
+
print(res["results"])
|
lm-evaluation-harness/examples/visualize-wandb.ipynb
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "fc477b96-adee-4829-a9d7-a5eb990df358",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# Visualizing Results in Weights and Biases\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"With the Weights and Biases integration, you can now spend more time extracting deeper insights into your evaluation results. The integration is designed to streamline the process of logging and visualizing experiment results using the Weights & Biases (W&B) platform.\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"The integration provide functionalities\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"- to automatically log the evaluation results,\n",
|
| 15 |
+
"- log the samples as W&B Tables for easy visualization,\n",
|
| 16 |
+
"- log the `results.json` file as an artifact for version control,\n",
|
| 17 |
+
"- log the `<task_name>_eval_samples.json` file if the samples are logged,\n",
|
| 18 |
+
"- generate a comprehensive report for analysis and visualization with all the important metric,\n",
|
| 19 |
+
"- log task and cli configs,\n",
|
| 20 |
+
"- and more out of the box like the command used to run the evaluation, GPU/CPU counts, timestamp, etc.\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"The integration is super easy to use with the eval harness. Let's see how!"
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"cell_type": "code",
|
| 27 |
+
"execution_count": null,
|
| 28 |
+
"id": "3851439a-bff4-41f2-bf21-1b3d8704913b",
|
| 29 |
+
"metadata": {
|
| 30 |
+
"scrolled": true
|
| 31 |
+
},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"# Install this project if you did not already have it.\n",
|
| 35 |
+
"# This is all that is needed to be installed to start using Weights and Biases\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"!pip -qq install -e ..[wandb]"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"cell_type": "markdown",
|
| 42 |
+
"id": "8507fd7e-3b99-4a92-89fa-9eaada74ba91",
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"source": [
|
| 45 |
+
"# Run the Eval Harness\n",
|
| 46 |
+
"\n",
|
| 47 |
+
"Run the eval harness as usual with a `wandb_args` flag. This flag is used to provide arguments for initializing a wandb run ([wandb.init](https://docs.wandb.ai/ref/python/init)) as comma separated string arguments.\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"If `wandb_args` flag is used, the metrics and all other goodness will be automatically logged to Weights and Biases. In the stdout, you will find the link to the W&B run page as well as link to the generated report."
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "markdown",
|
| 54 |
+
"id": "eec5866e-f01e-42f8-8803-9d77472ef991",
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"source": [
|
| 57 |
+
"## Set your API Key\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"Before you can use W&B, you need to authenticate your machine with an authentication key. Visit https://wandb.ai/authorize to get one."
|
| 60 |
+
]
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"cell_type": "code",
|
| 64 |
+
"execution_count": null,
|
| 65 |
+
"id": "d824d163-71a9-4313-935d-f1d56397841c",
|
| 66 |
+
"metadata": {},
|
| 67 |
+
"outputs": [],
|
| 68 |
+
"source": [
|
| 69 |
+
"import wandb\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"wandb.login()"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "markdown",
|
| 77 |
+
"id": "124e4a34-1547-4bed-bc09-db012bacbda6",
|
| 78 |
+
"metadata": {},
|
| 79 |
+
"source": [
|
| 80 |
+
"> Note that if you are using command line you can simply authenticate your machine by doing `wandb login` in your terminal. For more info check out the [documentation](https://docs.wandb.ai/quickstart#2-log-in-to-wb)."
|
| 81 |
+
]
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"cell_type": "markdown",
|
| 85 |
+
"id": "abc6f6b6-179a-4aff-ada9-f380fb74df6e",
|
| 86 |
+
"metadata": {},
|
| 87 |
+
"source": [
|
| 88 |
+
"## Run and log to W&B"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "code",
|
| 93 |
+
"execution_count": null,
|
| 94 |
+
"id": "bd0a8130-a97b-451a-acd2-3f9885b88643",
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"outputs": [],
|
| 97 |
+
"source": [
|
| 98 |
+
"!lm_eval \\\n",
|
| 99 |
+
" --model hf \\\n",
|
| 100 |
+
" --model_args pretrained=microsoft/phi-2,trust_remote_code=True \\\n",
|
| 101 |
+
" --tasks hellaswag,mmlu_abstract_algebra \\\n",
|
| 102 |
+
" --device cuda:0 \\\n",
|
| 103 |
+
" --batch_size 8 \\\n",
|
| 104 |
+
" --output_path output/phi-2 \\\n",
|
| 105 |
+
" --limit 10 \\\n",
|
| 106 |
+
" --wandb_args project=lm-eval-harness-integration \\\n",
|
| 107 |
+
" --log_samples"
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"cell_type": "markdown",
|
| 112 |
+
"id": "e974cabdbe70b667",
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"source": []
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"cell_type": "markdown",
|
| 118 |
+
"id": "5178ca9445b844e4",
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"source": [
|
| 121 |
+
"W&B can also be initialized programmatically for use outside the CLI to parse and log the results."
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "code",
|
| 126 |
+
"execution_count": null,
|
| 127 |
+
"id": "c6a421b2cf3ddac5",
|
| 128 |
+
"metadata": {},
|
| 129 |
+
"outputs": [],
|
| 130 |
+
"source": [
|
| 131 |
+
"import lm_eval\n",
|
| 132 |
+
"from lm_eval.loggers import WandbLogger\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"results = lm_eval.simple_evaluate(\n",
|
| 136 |
+
" model=\"hf\",\n",
|
| 137 |
+
" model_args=\"pretrained=microsoft/phi-2,trust_remote_code=True\",\n",
|
| 138 |
+
" tasks=\"hellaswag,mmlu_abstract_algebra\",\n",
|
| 139 |
+
" log_samples=True,\n",
|
| 140 |
+
")\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"wandb_logger = WandbLogger(\n",
|
| 143 |
+
" project=\"lm-eval-harness-integration\", job_type=\"eval\"\n",
|
| 144 |
+
") # or empty if wandb.init(...) already called before\n",
|
| 145 |
+
"wandb_logger.post_init(results)\n",
|
| 146 |
+
"wandb_logger.log_eval_result()\n",
|
| 147 |
+
"wandb_logger.log_eval_samples(results[\"samples\"]) # if log_samples"
|
| 148 |
+
]
|
| 149 |
+
}
|
| 150 |
+
],
|
| 151 |
+
"metadata": {
|
| 152 |
+
"kernelspec": {
|
| 153 |
+
"display_name": "Python 3 (ipykernel)",
|
| 154 |
+
"language": "python",
|
| 155 |
+
"name": "python3"
|
| 156 |
+
},
|
| 157 |
+
"language_info": {
|
| 158 |
+
"codemirror_mode": {
|
| 159 |
+
"name": "ipython",
|
| 160 |
+
"version": 3
|
| 161 |
+
},
|
| 162 |
+
"file_extension": ".py",
|
| 163 |
+
"mimetype": "text/x-python",
|
| 164 |
+
"name": "python",
|
| 165 |
+
"nbconvert_exporter": "python",
|
| 166 |
+
"pygments_lexer": "ipython3",
|
| 167 |
+
"version": "3.10.12"
|
| 168 |
+
}
|
| 169 |
+
},
|
| 170 |
+
"nbformat": 4,
|
| 171 |
+
"nbformat_minor": 5
|
| 172 |
+
}
|
lm-evaluation-harness/examples/visualize-zeno.ipynb
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Visualizing Results in Zeno\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Benchmarking your models is the first step towards making sure your model performs well.\n",
|
| 10 |
+
"However, looking at the data behind the benchmark, slicing the data into subsets, and comparing models on individual instances can help you even more in evaluating and quantifying the behavior of your AI system.\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"All of this can be done in [Zeno](https://zenoml.com)!\n",
|
| 13 |
+
"Zeno is super easy to use with the eval harness, let's explore how you can easily upload and visualize your eval results.\n"
|
| 14 |
+
]
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"cell_type": "code",
|
| 18 |
+
"execution_count": null,
|
| 19 |
+
"metadata": {},
|
| 20 |
+
"outputs": [],
|
| 21 |
+
"source": [
|
| 22 |
+
"# Install this project if you did not already do that. This is all that needs to be installed for you to be able to visualize your data in Zeno!\n",
|
| 23 |
+
"!pip install -e ..\n",
|
| 24 |
+
"!pip install -e ..[zeno]"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "markdown",
|
| 29 |
+
"metadata": {},
|
| 30 |
+
"source": [
|
| 31 |
+
"# Run the Eval Harness\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"To visualize the results, run the eval harness with the `log_samples` and `output_path` flags. We expect `output_path` to contain multiple folders that represent individual model names. You can thus run your evaluation on any number of tasks and models and upload all of the results as projects on Zeno.\n"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "code",
|
| 38 |
+
"execution_count": null,
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"outputs": [],
|
| 41 |
+
"source": [
|
| 42 |
+
"!lm_eval \\\n",
|
| 43 |
+
" --model hf \\\n",
|
| 44 |
+
" --model_args pretrained=EleutherAI/gpt-neo-2.7B \\\n",
|
| 45 |
+
" --tasks hellaswag,wikitext \\\n",
|
| 46 |
+
" --batch_size 8 \\\n",
|
| 47 |
+
" --device mps \\\n",
|
| 48 |
+
" --log_samples \\\n",
|
| 49 |
+
" --output_path output/gpt-neo-2.7B \\\n",
|
| 50 |
+
" --limit 10"
|
| 51 |
+
]
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "markdown",
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"source": [
|
| 57 |
+
"# Set your API Key\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"This is so you can be authenticated with Zeno.\n",
|
| 60 |
+
"If you don't already have a Zeno account, first create an account on [Zeno Hub](https://hub.zenoml.com).\n",
|
| 61 |
+
"After logging in to Zeno Hub, generate your API key by clicking on your profile at the bottom left to navigate to your account page.\n"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "code",
|
| 66 |
+
"execution_count": null,
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"outputs": [],
|
| 69 |
+
"source": [
|
| 70 |
+
"%env ZENO_API_KEY=YOUR_API_KEY"
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"cell_type": "markdown",
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"source": [
|
| 77 |
+
"# Visualize Eval Results\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"You can now use the `zeno_visualize` script to upload the results to Zeno.\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"This will use all subfolders in `data_path` as different models and upload all tasks within these model folders to Zeno. If you run the eval harness on multiple tasks, the `project_name` will be used as a prefix and one project will be created per task.\n"
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "code",
|
| 86 |
+
"execution_count": null,
|
| 87 |
+
"metadata": {},
|
| 88 |
+
"outputs": [],
|
| 89 |
+
"source": [
|
| 90 |
+
"!python ../scripts/zeno_visualize.py --data_path output --project_name \"Zeno Upload Test\""
|
| 91 |
+
]
|
| 92 |
+
}
|
| 93 |
+
],
|
| 94 |
+
"metadata": {
|
| 95 |
+
"kernelspec": {
|
| 96 |
+
"display_name": "zeno_projects",
|
| 97 |
+
"language": "python",
|
| 98 |
+
"name": "python3"
|
| 99 |
+
},
|
| 100 |
+
"language_info": {
|
| 101 |
+
"codemirror_mode": {
|
| 102 |
+
"name": "ipython",
|
| 103 |
+
"version": 3
|
| 104 |
+
},
|
| 105 |
+
"file_extension": ".py",
|
| 106 |
+
"mimetype": "text/x-python",
|
| 107 |
+
"name": "python",
|
| 108 |
+
"nbconvert_exporter": "python",
|
| 109 |
+
"pygments_lexer": "ipython3",
|
| 110 |
+
"version": "3.10.11"
|
| 111 |
+
}
|
| 112 |
+
},
|
| 113 |
+
"nbformat": 4,
|
| 114 |
+
"nbformat_minor": 2
|
| 115 |
+
}
|
lm-evaluation-harness/lm_eval/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
__version__ = "0.4.9.1"
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Lazy-load .evaluator module to improve CLI startup
|
| 9 |
+
def __getattr__(name):
|
| 10 |
+
if name == "evaluate":
|
| 11 |
+
from .evaluator import evaluate
|
| 12 |
+
|
| 13 |
+
return evaluate
|
| 14 |
+
elif name == "simple_evaluate":
|
| 15 |
+
from .evaluator import simple_evaluate
|
| 16 |
+
|
| 17 |
+
return simple_evaluate
|
| 18 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
__all__ = ["evaluate", "simple_evaluate", "__version__"]
|
lm-evaluation-harness/lm_eval/__main__.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from functools import partial
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def try_parse_json(value: str) -> Union[str, dict, None]:
|
| 12 |
+
if value is None:
|
| 13 |
+
return None
|
| 14 |
+
try:
|
| 15 |
+
return json.loads(value)
|
| 16 |
+
except json.JSONDecodeError:
|
| 17 |
+
if "{" in value:
|
| 18 |
+
raise argparse.ArgumentTypeError(
|
| 19 |
+
f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
|
| 20 |
+
)
|
| 21 |
+
return value
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _int_or_none_list_arg_type(
|
| 25 |
+
min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
|
| 26 |
+
):
|
| 27 |
+
def parse_value(item):
|
| 28 |
+
item = item.strip().lower()
|
| 29 |
+
if item == "none":
|
| 30 |
+
return None
|
| 31 |
+
try:
|
| 32 |
+
return int(item)
|
| 33 |
+
except ValueError:
|
| 34 |
+
raise argparse.ArgumentTypeError(f"{item} is not an integer or None")
|
| 35 |
+
|
| 36 |
+
items = [parse_value(v) for v in value.split(split_char)]
|
| 37 |
+
num_items = len(items)
|
| 38 |
+
|
| 39 |
+
if num_items == 1:
|
| 40 |
+
# Makes downstream handling the same for single and multiple values
|
| 41 |
+
items = items * max_len
|
| 42 |
+
elif num_items < min_len or num_items > max_len:
|
| 43 |
+
raise argparse.ArgumentTypeError(
|
| 44 |
+
f"Argument requires {max_len} integers or None, separated by '{split_char}'"
|
| 45 |
+
)
|
| 46 |
+
elif num_items != max_len:
|
| 47 |
+
logging.warning(
|
| 48 |
+
f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
|
| 49 |
+
"Missing values will be filled with defaults."
|
| 50 |
+
)
|
| 51 |
+
default_items = [parse_value(v) for v in defaults.split(split_char)]
|
| 52 |
+
items.extend(
|
| 53 |
+
default_items[num_items:]
|
| 54 |
+
) # extend items list with missing defaults
|
| 55 |
+
|
| 56 |
+
return items
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def check_argument_types(parser: argparse.ArgumentParser):
|
| 60 |
+
"""
|
| 61 |
+
Check to make sure all CLI args are typed, raises error if not
|
| 62 |
+
"""
|
| 63 |
+
for action in parser._actions:
|
| 64 |
+
if action.dest != "help" and not action.const:
|
| 65 |
+
if action.type is None:
|
| 66 |
+
raise ValueError(
|
| 67 |
+
f"Argument '{action.dest}' doesn't have a type specified."
|
| 68 |
+
)
|
| 69 |
+
else:
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 74 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
|
| 77 |
+
)
|
| 78 |
+
parser.add_argument(
|
| 79 |
+
"--tasks",
|
| 80 |
+
"-t",
|
| 81 |
+
default=None,
|
| 82 |
+
type=str,
|
| 83 |
+
metavar="task1,task2",
|
| 84 |
+
help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
|
| 85 |
+
)
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--model_args",
|
| 88 |
+
"-a",
|
| 89 |
+
default="",
|
| 90 |
+
type=try_parse_json,
|
| 91 |
+
help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""",
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--num_fewshot",
|
| 95 |
+
"-f",
|
| 96 |
+
type=int,
|
| 97 |
+
default=None,
|
| 98 |
+
metavar="N",
|
| 99 |
+
help="Number of examples in few-shot context",
|
| 100 |
+
)
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--batch_size",
|
| 103 |
+
"-b",
|
| 104 |
+
type=str,
|
| 105 |
+
default=1,
|
| 106 |
+
metavar="auto|auto:N|N",
|
| 107 |
+
help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--max_batch_size",
|
| 111 |
+
type=int,
|
| 112 |
+
default=None,
|
| 113 |
+
metavar="N",
|
| 114 |
+
help="Maximal batch size to try with --batch_size auto.",
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--device",
|
| 118 |
+
type=str,
|
| 119 |
+
default=None,
|
| 120 |
+
help="Device to use (e.g. cuda, cuda:0, cpu).",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--output_path",
|
| 124 |
+
"-o",
|
| 125 |
+
default=None,
|
| 126 |
+
type=str,
|
| 127 |
+
metavar="DIR|DIR/file.json",
|
| 128 |
+
help="Path where result metrics will be saved. Can be either a directory or a .json file. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--limit",
|
| 132 |
+
"-L",
|
| 133 |
+
type=float,
|
| 134 |
+
default=None,
|
| 135 |
+
metavar="N|0<N<1",
|
| 136 |
+
help="Limit the number of examples per task. "
|
| 137 |
+
"If <1, limit is a percentage of the total number of examples.",
|
| 138 |
+
)
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--samples",
|
| 141 |
+
"-E",
|
| 142 |
+
default=None,
|
| 143 |
+
type=str,
|
| 144 |
+
metavar="/path/to/json",
|
| 145 |
+
help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}',
|
| 146 |
+
)
|
| 147 |
+
parser.add_argument(
|
| 148 |
+
"--use_cache",
|
| 149 |
+
"-c",
|
| 150 |
+
type=str,
|
| 151 |
+
default=None,
|
| 152 |
+
metavar="DIR",
|
| 153 |
+
help="A path to a sqlite db file for caching model responses. `None` if not caching.",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--cache_requests",
|
| 157 |
+
type=str,
|
| 158 |
+
default=None,
|
| 159 |
+
choices=["true", "refresh", "delete"],
|
| 160 |
+
help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
|
| 161 |
+
)
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--check_integrity",
|
| 164 |
+
action="store_true",
|
| 165 |
+
help="Whether to run the relevant part of the test suite for the tasks.",
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--write_out",
|
| 169 |
+
"-w",
|
| 170 |
+
action="store_true",
|
| 171 |
+
default=False,
|
| 172 |
+
help="Prints the prompt for the first few documents.",
|
| 173 |
+
)
|
| 174 |
+
parser.add_argument(
|
| 175 |
+
"--log_samples",
|
| 176 |
+
"-s",
|
| 177 |
+
action="store_true",
|
| 178 |
+
default=False,
|
| 179 |
+
help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--system_instruction",
|
| 183 |
+
type=str,
|
| 184 |
+
default=None,
|
| 185 |
+
help="System instruction to be used in the prompt",
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--apply_chat_template",
|
| 189 |
+
type=str,
|
| 190 |
+
nargs="?",
|
| 191 |
+
const=True,
|
| 192 |
+
default=False,
|
| 193 |
+
help=(
|
| 194 |
+
"If True, apply chat template to the prompt. "
|
| 195 |
+
"Providing `--apply_chat_template` without an argument will apply the default chat template to the prompt. "
|
| 196 |
+
"To apply a specific template from the available list of templates, provide the template name as an argument. "
|
| 197 |
+
"E.g. `--apply_chat_template template_name`"
|
| 198 |
+
),
|
| 199 |
+
)
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--fewshot_as_multiturn",
|
| 202 |
+
action="store_true",
|
| 203 |
+
default=False,
|
| 204 |
+
help="If True, uses the fewshot as a multi-turn conversation",
|
| 205 |
+
)
|
| 206 |
+
parser.add_argument(
|
| 207 |
+
"--show_config",
|
| 208 |
+
action="store_true",
|
| 209 |
+
default=False,
|
| 210 |
+
help="If True, shows the the full config of all tasks at the end of the evaluation.",
|
| 211 |
+
)
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--include_path",
|
| 214 |
+
type=str,
|
| 215 |
+
default=None,
|
| 216 |
+
metavar="DIR",
|
| 217 |
+
help="Additional path to include if there are external tasks to include.",
|
| 218 |
+
)
|
| 219 |
+
parser.add_argument(
|
| 220 |
+
"--gen_kwargs",
|
| 221 |
+
type=try_parse_json,
|
| 222 |
+
default=None,
|
| 223 |
+
help=(
|
| 224 |
+
"Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
|
| 225 |
+
""" e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
|
| 226 |
+
),
|
| 227 |
+
)
|
| 228 |
+
parser.add_argument(
|
| 229 |
+
"--verbosity",
|
| 230 |
+
"-v",
|
| 231 |
+
type=str.upper,
|
| 232 |
+
default=None,
|
| 233 |
+
metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
|
| 234 |
+
help="(Deprecated) Controls logging verbosity level. Use the `LOGLEVEL` environment variable instead. Set to DEBUG for detailed output when testing or adding new task configurations.",
|
| 235 |
+
)
|
| 236 |
+
parser.add_argument(
|
| 237 |
+
"--wandb_args",
|
| 238 |
+
type=str,
|
| 239 |
+
default="",
|
| 240 |
+
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
|
| 241 |
+
)
|
| 242 |
+
parser.add_argument(
|
| 243 |
+
"--wandb_config_args",
|
| 244 |
+
type=str,
|
| 245 |
+
default="",
|
| 246 |
+
help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3",
|
| 247 |
+
)
|
| 248 |
+
parser.add_argument(
|
| 249 |
+
"--hf_hub_log_args",
|
| 250 |
+
type=str,
|
| 251 |
+
default="",
|
| 252 |
+
help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
|
| 253 |
+
)
|
| 254 |
+
parser.add_argument(
|
| 255 |
+
"--predict_only",
|
| 256 |
+
"-x",
|
| 257 |
+
action="store_true",
|
| 258 |
+
default=False,
|
| 259 |
+
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
|
| 260 |
+
)
|
| 261 |
+
default_seed_string = "0,1234,1234,1234"
|
| 262 |
+
parser.add_argument(
|
| 263 |
+
"--seed",
|
| 264 |
+
type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
|
| 265 |
+
default=default_seed_string, # for backward compatibility
|
| 266 |
+
help=(
|
| 267 |
+
"Set seed for python's random, numpy, torch, and fewshot sampling.\n"
|
| 268 |
+
"Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
|
| 269 |
+
"respectively, or a single integer to set the same seed for all four.\n"
|
| 270 |
+
f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
|
| 271 |
+
"(for backward compatibility).\n"
|
| 272 |
+
"E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
|
| 273 |
+
"Here numpy's seed is not set since the second value is `None`.\n"
|
| 274 |
+
"E.g, `--seed 42` sets all four seeds to 42."
|
| 275 |
+
),
|
| 276 |
+
)
|
| 277 |
+
parser.add_argument(
|
| 278 |
+
"--trust_remote_code",
|
| 279 |
+
action="store_true",
|
| 280 |
+
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
|
| 281 |
+
)
|
| 282 |
+
parser.add_argument(
|
| 283 |
+
"--confirm_run_unsafe_code",
|
| 284 |
+
action="store_true",
|
| 285 |
+
help="Confirm that you understand the risks of running unsafe code for tasks that require it",
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--metadata",
|
| 289 |
+
type=json.loads,
|
| 290 |
+
default=None,
|
| 291 |
+
help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
|
| 292 |
+
)
|
| 293 |
+
return parser
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def parse_eval_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
| 297 |
+
check_argument_types(parser)
|
| 298 |
+
return parser.parse_args()
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
|
| 302 |
+
if not args:
|
| 303 |
+
# we allow for args to be passed externally, else we parse them ourselves
|
| 304 |
+
parser = setup_parser()
|
| 305 |
+
args = parse_eval_args(parser)
|
| 306 |
+
|
| 307 |
+
# defer loading `lm_eval` submodules for faster CLI load
|
| 308 |
+
from lm_eval import evaluator, utils
|
| 309 |
+
from lm_eval.evaluator import request_caching_arg_to_dict
|
| 310 |
+
from lm_eval.loggers import EvaluationTracker, WandbLogger
|
| 311 |
+
from lm_eval.tasks import TaskManager
|
| 312 |
+
from lm_eval.utils import (
|
| 313 |
+
handle_non_serializable,
|
| 314 |
+
make_table,
|
| 315 |
+
simple_parse_args_string,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
if args.wandb_args:
|
| 319 |
+
wandb_args_dict = simple_parse_args_string(args.wandb_args)
|
| 320 |
+
wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args)
|
| 321 |
+
wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict)
|
| 322 |
+
|
| 323 |
+
utils.setup_logging(args.verbosity)
|
| 324 |
+
eval_logger = logging.getLogger(__name__)
|
| 325 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 326 |
+
|
| 327 |
+
# update the evaluation tracker args with the output path and the HF token
|
| 328 |
+
if args.output_path:
|
| 329 |
+
args.hf_hub_log_args += f",output_path={args.output_path}"
|
| 330 |
+
if os.environ.get("HF_TOKEN", None):
|
| 331 |
+
args.hf_hub_log_args += f",token={os.environ.get('HF_TOKEN')}"
|
| 332 |
+
evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args)
|
| 333 |
+
evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
|
| 334 |
+
|
| 335 |
+
if args.predict_only:
|
| 336 |
+
args.log_samples = True
|
| 337 |
+
if (args.log_samples or args.predict_only) and not args.output_path:
|
| 338 |
+
raise ValueError(
|
| 339 |
+
"Specify --output_path if providing --log_samples or --predict_only"
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
if args.fewshot_as_multiturn and args.apply_chat_template is False:
|
| 343 |
+
raise ValueError(
|
| 344 |
+
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
if args.include_path is not None:
|
| 348 |
+
eval_logger.info(f"Including path: {args.include_path}")
|
| 349 |
+
metadata = (
|
| 350 |
+
simple_parse_args_string(args.model_args)
|
| 351 |
+
if isinstance(args.model_args, str)
|
| 352 |
+
else args.model_args
|
| 353 |
+
if isinstance(args.model_args, dict)
|
| 354 |
+
else {}
|
| 355 |
+
) | (
|
| 356 |
+
args.metadata
|
| 357 |
+
if isinstance(args.metadata, dict)
|
| 358 |
+
else simple_parse_args_string(args.metadata)
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
task_manager = TaskManager(include_path=args.include_path, metadata=metadata)
|
| 362 |
+
|
| 363 |
+
if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
|
| 364 |
+
eval_logger.warning(
|
| 365 |
+
"Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
if args.limit:
|
| 369 |
+
eval_logger.warning(
|
| 370 |
+
" --limit SHOULD ONLY BE USED FOR TESTING."
|
| 371 |
+
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
|
| 372 |
+
)
|
| 373 |
+
if args.samples:
|
| 374 |
+
assert args.limit is None, (
|
| 375 |
+
"If --samples is not None, then --limit must be None."
|
| 376 |
+
)
|
| 377 |
+
if (samples := Path(args.samples)).is_file():
|
| 378 |
+
args.samples = json.loads(samples.read_text())
|
| 379 |
+
else:
|
| 380 |
+
args.samples = json.loads(args.samples)
|
| 381 |
+
|
| 382 |
+
if args.tasks is None:
|
| 383 |
+
eval_logger.error("Need to specify task to evaluate.")
|
| 384 |
+
sys.exit()
|
| 385 |
+
elif args.tasks == "list":
|
| 386 |
+
print(task_manager.list_all_tasks())
|
| 387 |
+
sys.exit()
|
| 388 |
+
elif args.tasks == "list_groups":
|
| 389 |
+
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
|
| 390 |
+
sys.exit()
|
| 391 |
+
elif args.tasks == "list_tags":
|
| 392 |
+
print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
|
| 393 |
+
sys.exit()
|
| 394 |
+
elif args.tasks == "list_subtasks":
|
| 395 |
+
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
|
| 396 |
+
sys.exit()
|
| 397 |
+
else:
|
| 398 |
+
if os.path.isdir(args.tasks):
|
| 399 |
+
import glob
|
| 400 |
+
|
| 401 |
+
task_names = []
|
| 402 |
+
yaml_path = os.path.join(args.tasks, "*.yaml")
|
| 403 |
+
for yaml_file in glob.glob(yaml_path):
|
| 404 |
+
config = utils.load_yaml_config(yaml_file)
|
| 405 |
+
task_names.append(config)
|
| 406 |
+
else:
|
| 407 |
+
task_list = args.tasks.split(",")
|
| 408 |
+
task_names = task_manager.match_tasks(task_list)
|
| 409 |
+
for task in [task for task in task_list if task not in task_names]:
|
| 410 |
+
if os.path.isfile(task):
|
| 411 |
+
config = utils.load_yaml_config(task)
|
| 412 |
+
task_names.append(config)
|
| 413 |
+
task_missing = [
|
| 414 |
+
task for task in task_list if task not in task_names and "*" not in task
|
| 415 |
+
] # we don't want errors if a wildcard ("*") task name was used
|
| 416 |
+
|
| 417 |
+
if task_missing:
|
| 418 |
+
missing = ", ".join(task_missing)
|
| 419 |
+
eval_logger.error(
|
| 420 |
+
f"Tasks were not found: {missing}\n"
|
| 421 |
+
f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
|
| 422 |
+
)
|
| 423 |
+
raise ValueError(
|
| 424 |
+
f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues."
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
|
| 428 |
+
if args.trust_remote_code:
|
| 429 |
+
eval_logger.info(
|
| 430 |
+
"Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
|
| 431 |
+
)
|
| 432 |
+
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
|
| 433 |
+
# because it's already been determined based on the prior env var before launching our
|
| 434 |
+
# script--`datasets` gets imported by lm_eval internally before these lines can update the env.
|
| 435 |
+
import datasets
|
| 436 |
+
from packaging.version import parse as vparse
|
| 437 |
+
|
| 438 |
+
if vparse(datasets.__version__) < vparse("4.0.0"):
|
| 439 |
+
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
|
| 440 |
+
|
| 441 |
+
if isinstance(args.model_args, dict):
|
| 442 |
+
args.model_args["trust_remote_code"] = True
|
| 443 |
+
else:
|
| 444 |
+
args.model_args = args.model_args + ",trust_remote_code=True"
|
| 445 |
+
(
|
| 446 |
+
eval_logger.info(f"Selected Tasks: {task_names}")
|
| 447 |
+
if eval_logger.getEffectiveLevel() >= logging.INFO
|
| 448 |
+
else print(f"Selected Tasks: {task_names}")
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
request_caching_args = request_caching_arg_to_dict(
|
| 452 |
+
cache_requests=args.cache_requests
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
results = evaluator.simple_evaluate(
|
| 456 |
+
model=args.model,
|
| 457 |
+
model_args=args.model_args,
|
| 458 |
+
tasks=task_names,
|
| 459 |
+
num_fewshot=args.num_fewshot,
|
| 460 |
+
batch_size=args.batch_size,
|
| 461 |
+
max_batch_size=args.max_batch_size,
|
| 462 |
+
device=args.device,
|
| 463 |
+
use_cache=args.use_cache,
|
| 464 |
+
limit=args.limit,
|
| 465 |
+
samples=args.samples,
|
| 466 |
+
check_integrity=args.check_integrity,
|
| 467 |
+
write_out=args.write_out,
|
| 468 |
+
log_samples=args.log_samples,
|
| 469 |
+
evaluation_tracker=evaluation_tracker,
|
| 470 |
+
system_instruction=args.system_instruction,
|
| 471 |
+
apply_chat_template=args.apply_chat_template,
|
| 472 |
+
fewshot_as_multiturn=args.fewshot_as_multiturn,
|
| 473 |
+
gen_kwargs=args.gen_kwargs,
|
| 474 |
+
task_manager=task_manager,
|
| 475 |
+
predict_only=args.predict_only,
|
| 476 |
+
random_seed=args.seed[0],
|
| 477 |
+
numpy_random_seed=args.seed[1],
|
| 478 |
+
torch_random_seed=args.seed[2],
|
| 479 |
+
fewshot_random_seed=args.seed[3],
|
| 480 |
+
confirm_run_unsafe_code=args.confirm_run_unsafe_code,
|
| 481 |
+
metadata=metadata,
|
| 482 |
+
**request_caching_args,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
if results is not None:
|
| 486 |
+
if args.log_samples:
|
| 487 |
+
samples = results.pop("samples")
|
| 488 |
+
dumped = json.dumps(
|
| 489 |
+
results, indent=2, default=handle_non_serializable, ensure_ascii=False
|
| 490 |
+
)
|
| 491 |
+
if args.show_config:
|
| 492 |
+
print(dumped)
|
| 493 |
+
|
| 494 |
+
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
|
| 495 |
+
|
| 496 |
+
# Add W&B logging
|
| 497 |
+
if args.wandb_args:
|
| 498 |
+
try:
|
| 499 |
+
wandb_logger.post_init(results)
|
| 500 |
+
wandb_logger.log_eval_result()
|
| 501 |
+
if args.log_samples:
|
| 502 |
+
wandb_logger.log_eval_samples(samples)
|
| 503 |
+
except Exception as e:
|
| 504 |
+
eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
|
| 505 |
+
|
| 506 |
+
evaluation_tracker.save_results_aggregated(
|
| 507 |
+
results=results, samples=samples if args.log_samples else None
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if args.log_samples:
|
| 511 |
+
for task_name, config in results["configs"].items():
|
| 512 |
+
evaluation_tracker.save_results_samples(
|
| 513 |
+
task_name=task_name, samples=samples[task_name]
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
if (
|
| 517 |
+
evaluation_tracker.push_results_to_hub
|
| 518 |
+
or evaluation_tracker.push_samples_to_hub
|
| 519 |
+
):
|
| 520 |
+
evaluation_tracker.recreate_metadata_card()
|
| 521 |
+
|
| 522 |
+
print(
|
| 523 |
+
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
|
| 524 |
+
f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
|
| 525 |
+
)
|
| 526 |
+
print(make_table(results))
|
| 527 |
+
if "groups" in results:
|
| 528 |
+
print(make_table(results, "groups"))
|
| 529 |
+
|
| 530 |
+
if args.wandb_args:
|
| 531 |
+
# Tear down wandb run once all the logging is done.
|
| 532 |
+
wandb_logger.run.finish()
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
if __name__ == "__main__":
|
| 536 |
+
cli_evaluate()
|
lm-evaluation-harness/lm_eval/api/__init__.py
ADDED
|
File without changes
|
lm-evaluation-harness/lm_eval/api/filter.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Callable, Iterable, List, Union
|
| 4 |
+
|
| 5 |
+
from lm_eval.api.instance import Instance
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Filter(ABC):
|
| 9 |
+
"""
|
| 10 |
+
Filter classes operate on a per-task level.
|
| 11 |
+
They take all model outputs (`instance.resps` for all `task.instances`)
|
| 12 |
+
across all instances of a task, and perform operations.
|
| 13 |
+
In a single run, one can configure any number of separate filters or lists of filters.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, **kwargs) -> None:
|
| 18 |
+
"""
|
| 19 |
+
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
|
| 24 |
+
"""
|
| 25 |
+
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
|
| 26 |
+
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
|
| 27 |
+
if pass in [<inst.resps for instance 0>, <inst.resps for instance 1>] should return
|
| 28 |
+
[<filtered resps for instance 0>, <filtered resps for instance 1>]
|
| 29 |
+
"""
|
| 30 |
+
return resps
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class FilterEnsemble:
|
| 35 |
+
"""
|
| 36 |
+
FilterEnsemble creates a pipeline applying multiple filters.
|
| 37 |
+
Its intended usage is to stack multiple post-processing steps in order.
|
| 38 |
+
`task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
|
| 39 |
+
pipeline separately.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
name: str
|
| 43 |
+
filters: List[Callable[[], Filter]]
|
| 44 |
+
|
| 45 |
+
def apply(self, instances: List[Instance]) -> None:
|
| 46 |
+
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
|
| 47 |
+
resps, docs = list(resps), list(docs)
|
| 48 |
+
|
| 49 |
+
for f in self.filters:
|
| 50 |
+
# apply filters in sequence
|
| 51 |
+
resps = f().apply(resps, docs)
|
| 52 |
+
|
| 53 |
+
# add the end results after filtering to filtered_requests of their respective source instances.
|
| 54 |
+
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
|
| 55 |
+
for inst, resp in zip(instances, resps):
|
| 56 |
+
inst.filtered_resps[self.name] = resp
|
lm-evaluation-harness/lm_eval/api/group.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from dataclasses import asdict, dataclass
|
| 3 |
+
from inspect import getsource
|
| 4 |
+
from typing import Any, Callable, List, Optional, Union
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class AggMetricConfig(dict):
|
| 9 |
+
metric: Optional[str] = None
|
| 10 |
+
aggregation: Optional[str] = "mean"
|
| 11 |
+
weight_by_size: Optional[str] = False
|
| 12 |
+
# list of filter names which should be incorporated into the aggregated metric.
|
| 13 |
+
filter_list: Optional[Union[str, list]] = "none"
|
| 14 |
+
|
| 15 |
+
def __post_init__(self):
|
| 16 |
+
if self.aggregation != "mean" and not callable(self.aggregation):
|
| 17 |
+
raise ValueError(
|
| 18 |
+
f"Currently, 'mean' is the only pre-defined aggregation across groups' subtasks. Got '{self.aggregation}'."
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
if isinstance(self.filter_list, str):
|
| 22 |
+
self.filter_list = [self.filter_list]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class GroupConfig(dict):
|
| 27 |
+
group: Optional[str] = None
|
| 28 |
+
group_alias: Optional[str] = None
|
| 29 |
+
task: Optional[Union[str, list]] = None
|
| 30 |
+
aggregate_metric_list: Optional[
|
| 31 |
+
Union[List[AggMetricConfig], AggMetricConfig, dict]
|
| 32 |
+
] = None
|
| 33 |
+
metadata: Optional[dict] = (
|
| 34 |
+
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, item):
|
| 38 |
+
return getattr(self, item)
|
| 39 |
+
|
| 40 |
+
def __setitem__(self, item, value):
|
| 41 |
+
return setattr(self, item, value)
|
| 42 |
+
|
| 43 |
+
def __post_init__(self):
|
| 44 |
+
if self.aggregate_metric_list is not None:
|
| 45 |
+
if isinstance(self.aggregate_metric_list, dict):
|
| 46 |
+
self.aggregate_metric_list = [self.aggregate_metric_list]
|
| 47 |
+
|
| 48 |
+
self.aggregate_metric_list = [
|
| 49 |
+
AggMetricConfig(**item) if isinstance(item, dict) else item
|
| 50 |
+
for item in self.aggregate_metric_list
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
def to_dict(self, keep_callable: bool = False) -> dict:
|
| 54 |
+
"""dumps the current config as a dictionary object, as a printable format.
|
| 55 |
+
null fields will not be printed.
|
| 56 |
+
Used for dumping results alongside full task configuration
|
| 57 |
+
|
| 58 |
+
:return: dict
|
| 59 |
+
A printable dictionary version of the TaskConfig object.
|
| 60 |
+
|
| 61 |
+
# TODO: should any default value in the TaskConfig not be printed?
|
| 62 |
+
"""
|
| 63 |
+
cfg_dict = asdict(self)
|
| 64 |
+
# remove values that are `None`
|
| 65 |
+
for k, v in list(cfg_dict.items()):
|
| 66 |
+
if callable(v):
|
| 67 |
+
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
|
| 68 |
+
return cfg_dict
|
| 69 |
+
|
| 70 |
+
def serialize_function(
|
| 71 |
+
self, value: Union[Callable, str], keep_callable=False
|
| 72 |
+
) -> Union[Callable, str]:
|
| 73 |
+
"""Serializes a given function or string.
|
| 74 |
+
|
| 75 |
+
If 'keep_callable' is True, the original callable is returned.
|
| 76 |
+
Otherwise, attempts to return the source code of the callable using 'getsource'.
|
| 77 |
+
"""
|
| 78 |
+
if keep_callable:
|
| 79 |
+
return value
|
| 80 |
+
else:
|
| 81 |
+
try:
|
| 82 |
+
return getsource(value)
|
| 83 |
+
except (TypeError, OSError):
|
| 84 |
+
return str(value)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class ConfigurableGroup(abc.ABC):
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
config: Optional[dict] = None,
|
| 91 |
+
) -> None:
|
| 92 |
+
self._config = GroupConfig(**config)
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def group(self):
|
| 96 |
+
return self._config.group
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def group_alias(self):
|
| 100 |
+
return self._config.group_alias
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def version(self):
|
| 104 |
+
return self._config.version
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def config(self):
|
| 108 |
+
return self._config.to_dict()
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def group_name(self) -> Any:
|
| 112 |
+
return self._config.group
|
| 113 |
+
|
| 114 |
+
def __repr__(self):
|
| 115 |
+
return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})"
|
lm-evaluation-harness/lm_eval/api/instance.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Literal, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
OutputType = Literal[
|
| 6 |
+
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class Instance:
|
| 12 |
+
request_type: OutputType
|
| 13 |
+
doc: dict
|
| 14 |
+
arguments: tuple
|
| 15 |
+
idx: int
|
| 16 |
+
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
|
| 17 |
+
default_factory=lambda: (None, None, None)
|
| 18 |
+
)
|
| 19 |
+
resps: list = field(default_factory=list)
|
| 20 |
+
filtered_resps: dict = field(default_factory=dict)
|
| 21 |
+
|
| 22 |
+
# initialized after init
|
| 23 |
+
task_name: Optional[str] = None
|
| 24 |
+
doc_id: Optional[int] = None
|
| 25 |
+
repeats: Optional[int] = None
|
| 26 |
+
|
| 27 |
+
def __post_init__(self) -> None:
|
| 28 |
+
# unpack metadata field
|
| 29 |
+
self.task_name, self.doc_id, self.repeats = self.metadata
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def args(self):
|
| 33 |
+
"""
|
| 34 |
+
Returns (string,) where `string` is the string to calculate loglikelihood over
|
| 35 |
+
"""
|
| 36 |
+
return (
|
| 37 |
+
self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
|
| 38 |
+
)
|
lm-evaluation-harness/lm_eval/api/metrics.py
ADDED
|
@@ -0,0 +1,629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import re
|
| 6 |
+
import string
|
| 7 |
+
from collections.abc import Iterable
|
| 8 |
+
from typing import Callable, List, Optional, Sequence, TypeVar
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import sacrebleu
|
| 12 |
+
|
| 13 |
+
from lm_eval.api.registry import register_aggregation, register_metric
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
T = TypeVar("T")
|
| 17 |
+
|
| 18 |
+
eval_logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Register Aggregations First
|
| 22 |
+
@register_aggregation("bypass")
|
| 23 |
+
def bypass_agg(arr):
|
| 24 |
+
return 999
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@register_aggregation("nanmean")
|
| 28 |
+
def nanmean(arr):
|
| 29 |
+
if len(arr) == 0 or all(np.isnan(arr)):
|
| 30 |
+
return np.nan
|
| 31 |
+
return np.nanmean(arr)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@register_aggregation("mean")
|
| 35 |
+
def mean(arr):
|
| 36 |
+
return sum(arr) / len(arr)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@register_aggregation("median")
|
| 40 |
+
def median(arr):
|
| 41 |
+
return arr[len(arr) // 2]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Certain metrics must be calculated across all documents in a benchmark.
|
| 45 |
+
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
|
| 46 |
+
@register_aggregation("perplexity")
|
| 47 |
+
def perplexity(items):
|
| 48 |
+
return math.exp(-mean(items))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@register_aggregation("weighted_perplexity")
|
| 52 |
+
def weighted_perplexity(items):
|
| 53 |
+
return math.exp(-weighted_mean(items))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@register_aggregation("bits_per_byte")
|
| 57 |
+
def bits_per_byte(items):
|
| 58 |
+
return -weighted_mean(items) / math.log(2)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@register_aggregation("f1")
|
| 62 |
+
def f1_score(items):
|
| 63 |
+
from sklearn.metrics import f1_score
|
| 64 |
+
|
| 65 |
+
unzipped_list = list(zip(*items))
|
| 66 |
+
golds = unzipped_list[0]
|
| 67 |
+
preds = unzipped_list[1]
|
| 68 |
+
fscore = f1_score(golds, preds)
|
| 69 |
+
|
| 70 |
+
return np.max(fscore)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@register_aggregation("matthews_corrcoef")
|
| 74 |
+
def matthews_corrcoef(items):
|
| 75 |
+
from sklearn.metrics import matthews_corrcoef
|
| 76 |
+
|
| 77 |
+
unzipped_list = list(zip(*items))
|
| 78 |
+
golds = unzipped_list[0]
|
| 79 |
+
preds = unzipped_list[1]
|
| 80 |
+
return matthews_corrcoef(golds, preds)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@register_aggregation("bleu")
|
| 84 |
+
def bleu(items):
|
| 85 |
+
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
|
| 86 |
+
for evaluating a generated sentence to a reference sentence. It counts matching
|
| 87 |
+
n-grams in the candidate translation to n-grams in the reference text, where
|
| 88 |
+
1-gram or unigram would be each token and a bigram comparison would be each
|
| 89 |
+
word pair. The comparison is made regardless of word order
|
| 90 |
+
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
|
| 91 |
+
Paper: https://www.aclweb.org/anthology/P02-1040/
|
| 92 |
+
|
| 93 |
+
Higher is better
|
| 94 |
+
"""
|
| 95 |
+
refs = list(zip(*items))[0]
|
| 96 |
+
preds = list(zip(*items))[1]
|
| 97 |
+
refs, preds = _sacreformat(refs, preds)
|
| 98 |
+
return sacrebleu.corpus_bleu(preds, refs).score
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@register_aggregation("chrf")
|
| 102 |
+
def chrf(items):
|
| 103 |
+
"""chrF++ is a tool for automatic evaluation of machine translation output
|
| 104 |
+
based on character n-gram precision and recall enhanced with word n-grams.
|
| 105 |
+
Source: https://github.com/m-popovic/chrF
|
| 106 |
+
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
|
| 107 |
+
|
| 108 |
+
Higher is better # TODO I think
|
| 109 |
+
"""
|
| 110 |
+
refs = list(zip(*items))[0]
|
| 111 |
+
preds = list(zip(*items))[1]
|
| 112 |
+
refs, preds = _sacreformat(refs, preds)
|
| 113 |
+
return sacrebleu.corpus_chrf(preds, refs).score
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@register_aggregation("ter")
|
| 117 |
+
def ter(items):
|
| 118 |
+
"""Translation Error Rate is an error metric for machine translation that
|
| 119 |
+
measures the number of edits required to change a system output into one
|
| 120 |
+
of the references
|
| 121 |
+
Source: http://www.cs.umd.edu/~snover/tercom/
|
| 122 |
+
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
|
| 123 |
+
|
| 124 |
+
Lower is better
|
| 125 |
+
"""
|
| 126 |
+
refs = list(zip(*items))[0]
|
| 127 |
+
preds = list(zip(*items))[1]
|
| 128 |
+
refs, preds = _sacreformat(refs, preds)
|
| 129 |
+
return sacrebleu.corpus_ter(preds, refs).score
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@register_aggregation("brier_score")
|
| 133 |
+
def brier_score(items): # This is a passthrough function
|
| 134 |
+
gold, predictions = list(zip(*items))
|
| 135 |
+
bs, num_class = np.array(predictions).shape
|
| 136 |
+
|
| 137 |
+
gold = list(gold)
|
| 138 |
+
gold_one_hot = np.eye(num_class)[gold]
|
| 139 |
+
return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1))
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@register_metric(
|
| 143 |
+
metric="brier_score",
|
| 144 |
+
higher_is_better=False,
|
| 145 |
+
output_type=["multiple_choice"],
|
| 146 |
+
aggregation="brier_score",
|
| 147 |
+
)
|
| 148 |
+
def brier_score_fn(items): # This is a passthrough function
|
| 149 |
+
return items
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@register_metric(
|
| 153 |
+
metric="acc",
|
| 154 |
+
higher_is_better=True,
|
| 155 |
+
output_type=["loglikelihood", "multiple_choice"],
|
| 156 |
+
aggregation="mean",
|
| 157 |
+
)
|
| 158 |
+
def acc_fn(items): # This is a passthrough function
|
| 159 |
+
return items
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@register_metric(
|
| 163 |
+
metric="acc_norm",
|
| 164 |
+
higher_is_better=True,
|
| 165 |
+
output_type=["loglikelihood", "multiple_choice"],
|
| 166 |
+
aggregation="mean",
|
| 167 |
+
)
|
| 168 |
+
def acc_norm_fn(items): # This is a passthrough function
|
| 169 |
+
return items
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@register_metric(
|
| 173 |
+
metric="acc_mutual_info",
|
| 174 |
+
higher_is_better=True,
|
| 175 |
+
output_type="multiple_choice",
|
| 176 |
+
aggregation="mean",
|
| 177 |
+
)
|
| 178 |
+
def acc_mutual_info_fn(items): # This is a passthrough function
|
| 179 |
+
return items
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
### the code used in the `exact_match_hf_evaluate` function is ported from
|
| 183 |
+
### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py
|
| 184 |
+
### which is under the apache license.
|
| 185 |
+
|
| 186 |
+
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
| 187 |
+
|
| 188 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 189 |
+
# you may not use this file except in compliance with the License.
|
| 190 |
+
# You may obtain a copy of the License at
|
| 191 |
+
|
| 192 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 196 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 197 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 198 |
+
# See the License for the specific language governing permissions and
|
| 199 |
+
# limitations under the License.
|
| 200 |
+
def exact_match_hf_evaluate(
|
| 201 |
+
predictions,
|
| 202 |
+
references,
|
| 203 |
+
regexes_to_ignore=None,
|
| 204 |
+
ignore_case=False,
|
| 205 |
+
ignore_punctuation=False,
|
| 206 |
+
ignore_numbers=False,
|
| 207 |
+
):
|
| 208 |
+
if regexes_to_ignore is not None:
|
| 209 |
+
for s in regexes_to_ignore:
|
| 210 |
+
predictions = np.array([re.sub(s, "", x) for x in predictions])
|
| 211 |
+
references = np.array([re.sub(s, "", x) for x in references])
|
| 212 |
+
else:
|
| 213 |
+
predictions = np.asarray(predictions)
|
| 214 |
+
references = np.asarray(references)
|
| 215 |
+
|
| 216 |
+
if ignore_case:
|
| 217 |
+
predictions = np.char.lower(predictions)
|
| 218 |
+
references = np.char.lower(references)
|
| 219 |
+
|
| 220 |
+
if ignore_punctuation:
|
| 221 |
+
repl_table = string.punctuation.maketrans("", "", string.punctuation)
|
| 222 |
+
predictions = np.char.translate(predictions, table=repl_table)
|
| 223 |
+
references = np.char.translate(references, table=repl_table)
|
| 224 |
+
|
| 225 |
+
if ignore_numbers:
|
| 226 |
+
repl_table = string.digits.maketrans("", "", string.digits)
|
| 227 |
+
predictions = np.char.translate(predictions, table=repl_table)
|
| 228 |
+
references = np.char.translate(references, table=repl_table)
|
| 229 |
+
|
| 230 |
+
score_list = predictions == references
|
| 231 |
+
|
| 232 |
+
return {"exact_match": np.mean(score_list)}
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
###
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@register_metric(
|
| 239 |
+
metric="exact_match",
|
| 240 |
+
higher_is_better=True,
|
| 241 |
+
output_type="generate_until",
|
| 242 |
+
aggregation="mean",
|
| 243 |
+
)
|
| 244 |
+
def exact_match_fn(**kwargs):
|
| 245 |
+
return exact_match_hf_evaluate(**kwargs)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@register_metric(
|
| 249 |
+
metric="perplexity",
|
| 250 |
+
higher_is_better=False,
|
| 251 |
+
output_type="loglikelihood",
|
| 252 |
+
aggregation="perplexity",
|
| 253 |
+
)
|
| 254 |
+
def perplexity_fn(items): # This is a passthrough function
|
| 255 |
+
return items
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
@register_metric(
|
| 259 |
+
metric="word_perplexity",
|
| 260 |
+
higher_is_better=False,
|
| 261 |
+
output_type="loglikelihood_rolling",
|
| 262 |
+
aggregation="weighted_perplexity",
|
| 263 |
+
)
|
| 264 |
+
def word_perplexity_fn(items): # This is a passthrough function
|
| 265 |
+
return items
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@register_metric(
|
| 269 |
+
metric="byte_perplexity",
|
| 270 |
+
higher_is_better=False,
|
| 271 |
+
output_type="loglikelihood_rolling",
|
| 272 |
+
aggregation="weighted_perplexity",
|
| 273 |
+
)
|
| 274 |
+
def byte_perplexity_fn(items): # This is a passthrough function
|
| 275 |
+
return items
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
@register_metric(
|
| 279 |
+
metric="bits_per_byte",
|
| 280 |
+
higher_is_better=False,
|
| 281 |
+
output_type="loglikelihood_rolling",
|
| 282 |
+
aggregation="bits_per_byte",
|
| 283 |
+
)
|
| 284 |
+
def bits_per_byte_fn(items): # This is a passthrough function
|
| 285 |
+
return items
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def pop_stddev(arr):
|
| 289 |
+
mu = mean(arr)
|
| 290 |
+
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def sample_stddev(arr: Sequence[T]) -> float:
|
| 294 |
+
mu = mean(arr)
|
| 295 |
+
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def mean_stderr(arr):
|
| 299 |
+
return sample_stddev(arr) / math.sqrt(len(arr))
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
@register_metric(
|
| 303 |
+
metric="bypass",
|
| 304 |
+
higher_is_better=True,
|
| 305 |
+
output_type=["loglikelihood", "multiple_choice", "generate_until"],
|
| 306 |
+
aggregation="bypass",
|
| 307 |
+
)
|
| 308 |
+
def bypass(items):
|
| 309 |
+
return None
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
@register_metric(
|
| 313 |
+
metric="mcc",
|
| 314 |
+
higher_is_better=True,
|
| 315 |
+
output_type="multiple_choice",
|
| 316 |
+
aggregation="matthews_corrcoef",
|
| 317 |
+
)
|
| 318 |
+
def mcc_fn(items): # This is a passthrough function
|
| 319 |
+
return items
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@register_metric(
|
| 323 |
+
metric="f1",
|
| 324 |
+
higher_is_better=True,
|
| 325 |
+
output_type="multiple_choice",
|
| 326 |
+
aggregation="f1",
|
| 327 |
+
)
|
| 328 |
+
def f1_fn(items): # This is a passthrough function
|
| 329 |
+
return items
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@register_metric(
|
| 333 |
+
metric="bleu",
|
| 334 |
+
higher_is_better=True,
|
| 335 |
+
output_type="generate_until",
|
| 336 |
+
aggregation="bleu",
|
| 337 |
+
)
|
| 338 |
+
def bleu_fn(items): # This is a passthrough function
|
| 339 |
+
return items
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
@register_metric(
|
| 343 |
+
metric="chrf",
|
| 344 |
+
higher_is_better=True,
|
| 345 |
+
output_type="generate_until",
|
| 346 |
+
aggregation="chrf",
|
| 347 |
+
)
|
| 348 |
+
def chrf_fn(items): # This is a passthrough function
|
| 349 |
+
return items
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
@register_metric(
|
| 353 |
+
metric="ter",
|
| 354 |
+
higher_is_better=True,
|
| 355 |
+
output_type="generate_until",
|
| 356 |
+
aggregation="ter",
|
| 357 |
+
)
|
| 358 |
+
def ter_fn(items): # This is a passthrough function
|
| 359 |
+
return items
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
@register_metric(
|
| 363 |
+
metric="acc_all",
|
| 364 |
+
higher_is_better=True,
|
| 365 |
+
output_type="loglikelihood",
|
| 366 |
+
aggregation="mean",
|
| 367 |
+
)
|
| 368 |
+
def acc_all(items):
|
| 369 |
+
# Only count as correct if all answers are labeled correctly for each question
|
| 370 |
+
question_scoring_dict = {}
|
| 371 |
+
preds = list(zip(*items))[0]
|
| 372 |
+
docs = list(zip(*items))[1]
|
| 373 |
+
|
| 374 |
+
for doc, pred in zip(docs, preds):
|
| 375 |
+
paragraph_id = doc["idx"]["paragraph"]
|
| 376 |
+
question_id = doc["idx"]["question"]
|
| 377 |
+
if (paragraph_id, question_id) not in question_scoring_dict:
|
| 378 |
+
question_scoring_dict[(paragraph_id, question_id)] = []
|
| 379 |
+
|
| 380 |
+
gold_label = doc["label"] == 1
|
| 381 |
+
|
| 382 |
+
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
|
| 383 |
+
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
|
| 384 |
+
return acc
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def acc_all_stderr(items):
|
| 388 |
+
# Only count as correct if all answers are labeled correctly for each question
|
| 389 |
+
question_scoring_dict = {}
|
| 390 |
+
preds = list(zip(*items))[0]
|
| 391 |
+
docs = list(zip(*items))[1]
|
| 392 |
+
|
| 393 |
+
for doc, pred in zip(docs, preds):
|
| 394 |
+
question_id = doc["idx"]["question"]
|
| 395 |
+
if question_id not in question_scoring_dict:
|
| 396 |
+
question_scoring_dict[question_id] = []
|
| 397 |
+
|
| 398 |
+
gold_label = doc["label"] == 1
|
| 399 |
+
question_scoring_dict[question_id].append(gold_label == pred)
|
| 400 |
+
|
| 401 |
+
acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
|
| 402 |
+
return acc
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
| 406 |
+
"""Compute max metric between prediction and each ground truth."""
|
| 407 |
+
scores_for_ground_truths = []
|
| 408 |
+
for ground_truth in ground_truths:
|
| 409 |
+
score = metric_fn(prediction, ground_truth)
|
| 410 |
+
scores_for_ground_truths.append(score)
|
| 411 |
+
return max(scores_for_ground_truths)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def weighted_mean(items):
|
| 415 |
+
a, b = zip(*items)
|
| 416 |
+
return sum(a) / sum(b)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def is_non_str_iterable(obj):
|
| 420 |
+
return isinstance(obj, Iterable) and not isinstance(obj, str)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def _sacreformat(refs, preds):
|
| 424 |
+
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
|
| 425 |
+
# Sacrebleu expects (List[str], List[List[str])
|
| 426 |
+
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
|
| 427 |
+
|
| 428 |
+
# Note [ref1_stream] is the first reference for each pred.
|
| 429 |
+
# So lists are size N and (M, N) for N preds and M possible refs for each pred
|
| 430 |
+
# This is a different order of dimensions that I would expect
|
| 431 |
+
|
| 432 |
+
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
|
| 433 |
+
# Must become List[List[str]] with the inner list corresponding to preds
|
| 434 |
+
if not is_non_str_iterable(refs):
|
| 435 |
+
refs = list(refs)
|
| 436 |
+
if not is_non_str_iterable(refs[0]):
|
| 437 |
+
refs = [[ref] for ref in refs]
|
| 438 |
+
refs = list(zip(*refs))
|
| 439 |
+
# Note the number of refs in each ref list much match the number of preds
|
| 440 |
+
|
| 441 |
+
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
|
| 442 |
+
if not is_non_str_iterable(preds):
|
| 443 |
+
preds = list(preds)
|
| 444 |
+
if is_non_str_iterable(preds[0]):
|
| 445 |
+
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
|
| 446 |
+
preds = [pred[0] for pred in preds]
|
| 447 |
+
|
| 448 |
+
return refs, preds
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
# stderr stuff
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
class _bootstrap_internal:
|
| 455 |
+
"""
|
| 456 |
+
Pool worker: `(i, xs)` → `n` bootstrap replicates
|
| 457 |
+
of `f(xs)`using a RNG seeded with `i`.
|
| 458 |
+
"""
|
| 459 |
+
|
| 460 |
+
def __init__(self, f: Callable[[Sequence[T]], float], n: int) -> None:
|
| 461 |
+
self.f = f
|
| 462 |
+
self.n = n
|
| 463 |
+
|
| 464 |
+
def __call__(self, v: tuple[int, Sequence[T]]) -> list[float]:
|
| 465 |
+
i, xs = v
|
| 466 |
+
rnd = random.Random()
|
| 467 |
+
rnd.seed(i)
|
| 468 |
+
res = []
|
| 469 |
+
for _ in range(self.n):
|
| 470 |
+
res.append(self.f(rnd.choices(xs, k=len(xs))))
|
| 471 |
+
return res
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def _bootstrap_internal_no_mp(
|
| 475 |
+
f: Callable[[Sequence[T]], float], xs: Sequence[T], iters: int
|
| 476 |
+
) -> list[float]:
|
| 477 |
+
"""
|
| 478 |
+
Single-process fallback: compute `iters` bootstrap replicates
|
| 479 |
+
of statistic`f(xs)`, chunked (≤ 1000 draws).
|
| 480 |
+
"""
|
| 481 |
+
res = []
|
| 482 |
+
chunk_size = min(1000, iters)
|
| 483 |
+
from tqdm import tqdm
|
| 484 |
+
|
| 485 |
+
print(f"bootstrapping for stddev: {f.__name__}")
|
| 486 |
+
|
| 487 |
+
# A single loop replaces the multiprocessing pool.
|
| 488 |
+
for i in tqdm(range(iters // chunk_size)):
|
| 489 |
+
rnd = random.Random(i)
|
| 490 |
+
for _ in range(chunk_size):
|
| 491 |
+
res.append(f(rnd.choices(xs, k=len(xs))))
|
| 492 |
+
|
| 493 |
+
return res
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def bootstrap_stderr(
|
| 497 |
+
f: Callable[[Sequence[T]], float], xs: Sequence[T], iters: int
|
| 498 |
+
) -> float:
|
| 499 |
+
"""
|
| 500 |
+
Bootstrap estimate of the standard error of statistic `f(xs)`
|
| 501 |
+
using up to `iters` resamples, chunked (≤ 1000 draws)
|
| 502 |
+
|
| 503 |
+
Executes in parallel unless the env-var `DISABLE_MULTIPROC` is set;
|
| 504 |
+
"""
|
| 505 |
+
if not os.getenv("DISABLE_MULTIPROC"):
|
| 506 |
+
import multiprocessing as mp
|
| 507 |
+
|
| 508 |
+
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
|
| 509 |
+
# equivalent to stderr calculated without Bessel's correction in the stddev.
|
| 510 |
+
# Unfortunately, I haven't been able to figure out what the right correction is
|
| 511 |
+
# to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
|
| 512 |
+
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
|
| 513 |
+
# Thankfully, shouldn't matter because our samples are pretty big usually anyways
|
| 514 |
+
res = []
|
| 515 |
+
chunk_size = min(1000, iters)
|
| 516 |
+
from tqdm import tqdm
|
| 517 |
+
|
| 518 |
+
print("bootstrapping for stddev:", f.__name__)
|
| 519 |
+
with mp.Pool(mp.cpu_count()) as pool:
|
| 520 |
+
for bootstrap in tqdm(
|
| 521 |
+
pool.imap(
|
| 522 |
+
_bootstrap_internal(f, chunk_size),
|
| 523 |
+
[(i, xs) for i in range(iters // chunk_size)],
|
| 524 |
+
),
|
| 525 |
+
total=iters // chunk_size,
|
| 526 |
+
):
|
| 527 |
+
# sample w replacement
|
| 528 |
+
res.extend(bootstrap)
|
| 529 |
+
else:
|
| 530 |
+
res = _bootstrap_internal_no_mp(f, xs, iters)
|
| 531 |
+
|
| 532 |
+
return sample_stddev(res)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def stderr_for_metric(
|
| 536 |
+
metric: Callable[[Sequence[T]], float], bootstrap_iters: int
|
| 537 |
+
) -> Optional[Callable[[Sequence[T]], float]]:
|
| 538 |
+
"""
|
| 539 |
+
Return a function that estimates the standard error of `metric(xs)`.
|
| 540 |
+
|
| 541 |
+
* If `bootstrap_iters > 0` and the metric is in the pre-approved
|
| 542 |
+
bootstrappable list, use `bootstrap_stderr` with that many draws.
|
| 543 |
+
* If the metric has a closed-form SE (e.g. `mean`, `acc_all`), use it.
|
| 544 |
+
* Otherwise, return `None`.
|
| 545 |
+
"""
|
| 546 |
+
|
| 547 |
+
if bootstrap_iters <= 0:
|
| 548 |
+
# return no function (don't compute stderr) if bootstrap iters = 0
|
| 549 |
+
return None
|
| 550 |
+
|
| 551 |
+
bootstrappable = [
|
| 552 |
+
median,
|
| 553 |
+
matthews_corrcoef,
|
| 554 |
+
f1_score,
|
| 555 |
+
perplexity,
|
| 556 |
+
bleu,
|
| 557 |
+
chrf,
|
| 558 |
+
ter,
|
| 559 |
+
nanmean,
|
| 560 |
+
]
|
| 561 |
+
|
| 562 |
+
if metric in bootstrappable:
|
| 563 |
+
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
|
| 564 |
+
|
| 565 |
+
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
|
| 566 |
+
|
| 567 |
+
return stderr.get(metric, None)
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
|
| 571 |
+
# Used to aggregate bootstrapped stderrs across subtasks in a group,
|
| 572 |
+
# when we are weighting by the size of each subtask.
|
| 573 |
+
#
|
| 574 |
+
|
| 575 |
+
assert len(stderrs) == len(sizes)
|
| 576 |
+
|
| 577 |
+
# formula source: https://en.wikipedia.org/wiki/Pooled_variance
|
| 578 |
+
# and: https://stats.stackexchange.com/a/4841331
|
| 579 |
+
# this empirically seems to match running `stderr_for_metric` on all instances
|
| 580 |
+
# from the subtasks concatenated with each other.
|
| 581 |
+
pooled_sample_var = (
|
| 582 |
+
sum([(size - 1) * stderr**2 * size for size, stderr in zip(sizes, stderrs)])
|
| 583 |
+
) / (sum(sizes) - len(sizes))
|
| 584 |
+
|
| 585 |
+
return np.sqrt(pooled_sample_var / sum(sizes))
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
|
| 589 |
+
assert metrics is not None, (
|
| 590 |
+
"Need to pass a list of each subtask's metric for this stderr aggregation"
|
| 591 |
+
)
|
| 592 |
+
assert len(stderrs) == len(sizes) and len(sizes) == len(metrics)
|
| 593 |
+
|
| 594 |
+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation.
|
| 595 |
+
# This formula depends on sample means.
|
| 596 |
+
# removed because it seems to give erroneously huge stderrs for groupings of tasks
|
| 597 |
+
# and does not seem to match up with bootstrap-calculated stderrs for groups.
|
| 598 |
+
|
| 599 |
+
### don't use this unless a statistician has told you it's the right thing to do ###
|
| 600 |
+
|
| 601 |
+
# accumulators: we'll aggregate pairwise N - 1 times
|
| 602 |
+
variance = stderrs[0] ** 2
|
| 603 |
+
curr_size = sizes[0]
|
| 604 |
+
curr_score = metrics[0]
|
| 605 |
+
|
| 606 |
+
for stderr, size, score in zip(stderrs[1:], sizes[1:], metrics[1:]):
|
| 607 |
+
curr_score = ((curr_score * curr_size) + (score * size)) / (
|
| 608 |
+
curr_size + size
|
| 609 |
+
) # NOTE: this assumes our aggregation fn is "mean"
|
| 610 |
+
|
| 611 |
+
variance = ((curr_size - 1) * variance + (size - 1) * (stderr**2)) / (
|
| 612 |
+
curr_size + size - 1
|
| 613 |
+
) + curr_size * size / ((curr_size + size) * (curr_size + size - 1)) * (
|
| 614 |
+
curr_score - score
|
| 615 |
+
) ** 2
|
| 616 |
+
|
| 617 |
+
return np.sqrt(variance)
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
|
| 621 |
+
# A helper function that is used to aggregate
|
| 622 |
+
# subtask scores cross-task.
|
| 623 |
+
# TODO: does not hold for non-mean aggregations
|
| 624 |
+
if not weight_by_size:
|
| 625 |
+
sizes = [1] * len(sizes)
|
| 626 |
+
|
| 627 |
+
assert len(metrics) == len(sizes)
|
| 628 |
+
|
| 629 |
+
return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes)
|
lm-evaluation-harness/lm_eval/api/model.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import hashlib
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, TypeVar, Union
|
| 7 |
+
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from lm_eval import utils
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from sqlitedict import SqliteDict
|
| 15 |
+
|
| 16 |
+
from lm_eval.api.instance import Instance
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
eval_logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
T = TypeVar("T", bound="LM")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class LM(abc.ABC):
|
| 25 |
+
def __init__(self) -> None:
|
| 26 |
+
"""Defines the interface that should be implemented by all LM subclasses.
|
| 27 |
+
LMs are assumed to take text (strings) as input and yield strings as output
|
| 28 |
+
(inputs/outputs should be tokenization-agnostic.)
|
| 29 |
+
|
| 30 |
+
"""
|
| 31 |
+
# set rank and world size to a single process, by default.
|
| 32 |
+
self._rank = 0
|
| 33 |
+
self._world_size = 1
|
| 34 |
+
self.cache_hook: "CacheHook" = CacheHook(None)
|
| 35 |
+
|
| 36 |
+
@abc.abstractmethod
|
| 37 |
+
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
|
| 38 |
+
"""Compute log-likelihood of generating a continuation from a context.
|
| 39 |
+
Downstream tasks should attempt to use loglikelihood instead of other
|
| 40 |
+
LM calls whenever possible.
|
| 41 |
+
|
| 42 |
+
:param requests: list[Instance]
|
| 43 |
+
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
|
| 44 |
+
`context: str`
|
| 45 |
+
Context string. Implementations of LM must be able to handle an
|
| 46 |
+
empty context string.
|
| 47 |
+
`continuation: str`
|
| 48 |
+
The continuation over which log likelihood will be calculated. If
|
| 49 |
+
there is a word boundary, the space should be in the continuation.
|
| 50 |
+
For example, context="hello" continuation=" world" is correct.
|
| 51 |
+
|
| 52 |
+
:return: list[tuple[float, bool]]
|
| 53 |
+
A list of pairs (logprob, isgreedy)
|
| 54 |
+
`logprob: float`
|
| 55 |
+
The log probability of `continuation`.
|
| 56 |
+
`isgreedy`:
|
| 57 |
+
Whether `continuation` would be generated by greedy sampling from `context`.
|
| 58 |
+
"""
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
@abc.abstractmethod
|
| 62 |
+
def loglikelihood_rolling(self, requests) -> list[float]:
|
| 63 |
+
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
| 64 |
+
- We will use the full max context length of the model.
|
| 65 |
+
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
|
| 66 |
+
the max context length.
|
| 67 |
+
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
|
| 68 |
+
which may simply concatenate multiple documents together.
|
| 69 |
+
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
|
| 70 |
+
multiple chunks, the last input will still a full-sized context.
|
| 71 |
+
Example:
|
| 72 |
+
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
|
| 73 |
+
Prefix: BOS/EOS
|
| 74 |
+
Max context length: 4
|
| 75 |
+
Resulting input/prediction pairs:
|
| 76 |
+
|
| 77 |
+
INPUT: BOS 0 1 2
|
| 78 |
+
PRED: 0 1 2 3
|
| 79 |
+
|
| 80 |
+
INPUT: 3 4 5 6
|
| 81 |
+
PRED: 4 5 6 7
|
| 82 |
+
|
| 83 |
+
INPUT: 5 6 7 8
|
| 84 |
+
PRED: 8 9
|
| 85 |
+
|
| 86 |
+
Observe that:
|
| 87 |
+
1. Each token is predicted exactly once
|
| 88 |
+
2. For the last pair, we provide the full context, but only score the last two tokens
|
| 89 |
+
|
| 90 |
+
:param requests: list[Instance]
|
| 91 |
+
A list of Instance objects with property `args` which returns a tuple (context,).
|
| 92 |
+
string: str
|
| 93 |
+
String for which we are computing overall loglikelihood
|
| 94 |
+
:return: list[tuple[float]]
|
| 95 |
+
A list of tuples (logprob,)
|
| 96 |
+
logprob: float
|
| 97 |
+
The log probability of `context` conditioned on the BOS/EOS token.
|
| 98 |
+
Can also be overridden for custom cases by `prefix_token_id`.
|
| 99 |
+
"""
|
| 100 |
+
pass
|
| 101 |
+
|
| 102 |
+
# TODO: Add an optional max length
|
| 103 |
+
@abc.abstractmethod
|
| 104 |
+
def generate_until(self, requests) -> list[str]:
|
| 105 |
+
"""Generate greedily until a stopping sequence
|
| 106 |
+
|
| 107 |
+
:param requests: list[Instance]
|
| 108 |
+
A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
|
| 109 |
+
context: str
|
| 110 |
+
Context string
|
| 111 |
+
gen_kwargs: dict
|
| 112 |
+
A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
|
| 113 |
+
:return: list[str]
|
| 114 |
+
A list of model generated continuations.
|
| 115 |
+
continuation: str
|
| 116 |
+
The generated continuation.
|
| 117 |
+
"""
|
| 118 |
+
pass
|
| 119 |
+
|
| 120 |
+
def apply_chat_template(
|
| 121 |
+
self, chat_history: list[dict[str, str]], add_generation_prompt=True
|
| 122 |
+
) -> str:
|
| 123 |
+
"""
|
| 124 |
+
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
|
| 125 |
+
|
| 126 |
+
:param chat_history: list[dict[str, str]]
|
| 127 |
+
A list of dictionaries with keys 'role' and 'content'.
|
| 128 |
+
Values are strings representing the role name and the content of the message, respectively.
|
| 129 |
+
:param add_generation_prompt: bool
|
| 130 |
+
Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message.
|
| 131 |
+
:return: str
|
| 132 |
+
A string representing the chat history in a format that can be used as input to the LM.
|
| 133 |
+
"""
|
| 134 |
+
raise NotImplementedError(
|
| 135 |
+
"To use this model with chat templates, please implement the 'apply_chat_template' method for your model type."
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
@classmethod
|
| 139 |
+
def create_from_arg_string(
|
| 140 |
+
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
|
| 141 |
+
) -> T:
|
| 142 |
+
"""
|
| 143 |
+
Creates an instance of the LM class using the given argument string and additional config.
|
| 144 |
+
|
| 145 |
+
Parameters:
|
| 146 |
+
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
|
| 147 |
+
- additional_config: Optional dictionary containing additional configuration parameters.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
- Instance of the LM class.
|
| 151 |
+
"""
|
| 152 |
+
additional_config = {} if additional_config is None else additional_config
|
| 153 |
+
args = utils.simple_parse_args_string(arg_string)
|
| 154 |
+
args2 = {k: v for k, v in additional_config.items() if v is not None}
|
| 155 |
+
return cls(**args, **args2)
|
| 156 |
+
|
| 157 |
+
@classmethod
|
| 158 |
+
def create_from_arg_obj(
|
| 159 |
+
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
|
| 160 |
+
) -> T:
|
| 161 |
+
"""
|
| 162 |
+
Creates an instance of the LM class using the given arg_obj
|
| 163 |
+
|
| 164 |
+
Parameters:
|
| 165 |
+
- arg_obj: A dict containing arguments in the format key1=value1,key2=value2.
|
| 166 |
+
- additional_config: Optional dictionary containing additional configuration parameters.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
- Instance of the LM class.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
additional_config = additional_config or {} | {
|
| 173 |
+
k: v for k, v in additional_config.items() if v is not None
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
return cls(**arg_dict, **additional_config)
|
| 177 |
+
|
| 178 |
+
@property
|
| 179 |
+
def rank(self):
|
| 180 |
+
# used in the case of parallelism. Hardcoded to
|
| 181 |
+
# ensure no errors arise using API models which do
|
| 182 |
+
# not support multi-device parallelism nor expect it.
|
| 183 |
+
return self._rank
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
def world_size(self):
|
| 187 |
+
# used in the case of parallelism. Hardcoded to
|
| 188 |
+
# ensure no errors arise using API models which do
|
| 189 |
+
# not support multi-device parallelism nor expect it.
|
| 190 |
+
return self._world_size
|
| 191 |
+
|
| 192 |
+
@property
|
| 193 |
+
def tokenizer_name(self) -> str:
|
| 194 |
+
"""Must be defined for LM subclasses which implement Chat Templating.
|
| 195 |
+
Should return the name of the tokenizer or chat template used.
|
| 196 |
+
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
|
| 197 |
+
"""
|
| 198 |
+
raise NotImplementedError(
|
| 199 |
+
"To use this model with chat templates, please implement the 'tokenizer_name' property."
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
|
| 203 |
+
"""Returns the chat template structure for user/assistant messages if a template is provided.
|
| 204 |
+
This method is intended to be overridden in a subclass to define a specific chat template format.
|
| 205 |
+
For models that do not support chat templates, this method returns None by default.
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
return ""
|
| 209 |
+
|
| 210 |
+
def set_cache_hook(self, cache_hook: "CacheHook") -> None:
|
| 211 |
+
self.cache_hook = cache_hook
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
### SQLite-based caching of LM responses
|
| 215 |
+
def hash_args(attr: str, args: Iterable[Any]) -> str:
|
| 216 |
+
dat = json.dumps([attr] + list(args))
|
| 217 |
+
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class CacheHook:
|
| 221 |
+
def __init__(self, cachinglm: Optional["CachingLM"]) -> None:
|
| 222 |
+
if cachinglm is None:
|
| 223 |
+
self.dbdict: Optional["SqliteDict"] = None
|
| 224 |
+
return
|
| 225 |
+
|
| 226 |
+
self.dbdict = cachinglm.dbdict
|
| 227 |
+
|
| 228 |
+
def add_partial(self, attr: str, req: Iterable[Any], res: Any) -> None:
|
| 229 |
+
if self.dbdict is None:
|
| 230 |
+
return
|
| 231 |
+
hsh = hash_args(attr, req)
|
| 232 |
+
self.dbdict[hsh] = res
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class CachingLM:
|
| 236 |
+
def __init__(self, lm: LM, cache_db: str) -> None:
|
| 237 |
+
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
|
| 238 |
+
|
| 239 |
+
:param lm: LM
|
| 240 |
+
Underlying LM
|
| 241 |
+
:param cache_db: str
|
| 242 |
+
Path to cache db
|
| 243 |
+
"""
|
| 244 |
+
from sqlitedict import SqliteDict
|
| 245 |
+
|
| 246 |
+
self.lm: LM = lm
|
| 247 |
+
self.cache_db: str = cache_db
|
| 248 |
+
if os.path.dirname(cache_db):
|
| 249 |
+
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
|
| 250 |
+
self.dbdict = SqliteDict(cache_db, autocommit=True)
|
| 251 |
+
|
| 252 |
+
# add hook to lm
|
| 253 |
+
lm.set_cache_hook(self.get_cache_hook())
|
| 254 |
+
|
| 255 |
+
def __getattr__(self, attr: str) -> Any:
|
| 256 |
+
lm_attr = getattr(self.lm, attr)
|
| 257 |
+
if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]:
|
| 258 |
+
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
|
| 259 |
+
return lm_attr
|
| 260 |
+
|
| 261 |
+
def _fn(requests: list["Instance"]) -> list["Instance"]:
|
| 262 |
+
res = []
|
| 263 |
+
remaining_reqs = []
|
| 264 |
+
warned = False
|
| 265 |
+
# figure out which ones are cached and which ones are new
|
| 266 |
+
eval_logger.info(
|
| 267 |
+
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
|
| 268 |
+
)
|
| 269 |
+
for req in tqdm(requests, desc="Checking cached requests"):
|
| 270 |
+
hsh = hash_args(attr, req.args)
|
| 271 |
+
if attr == "generate_until" and req.args[1].get("do_sample", False):
|
| 272 |
+
# when we are doing non-greedy generation, don't use the cache
|
| 273 |
+
# (else every "randomly sampled" generation would be identical for repeats > 1).
|
| 274 |
+
if not warned:
|
| 275 |
+
eval_logger.warning(
|
| 276 |
+
f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
|
| 277 |
+
)
|
| 278 |
+
warned = True
|
| 279 |
+
res.append(None)
|
| 280 |
+
remaining_reqs.append(req)
|
| 281 |
+
elif hsh in self.dbdict:
|
| 282 |
+
ob = self.dbdict[hsh]
|
| 283 |
+
|
| 284 |
+
assert ob is not None
|
| 285 |
+
|
| 286 |
+
res.append(ob)
|
| 287 |
+
else:
|
| 288 |
+
res.append(None)
|
| 289 |
+
remaining_reqs.append(req)
|
| 290 |
+
eval_logger.info(
|
| 291 |
+
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
|
| 292 |
+
)
|
| 293 |
+
if remaining_reqs:
|
| 294 |
+
# actually run the LM on the requests that do not have cached results
|
| 295 |
+
rem_res = getattr(self.lm, attr)(remaining_reqs)
|
| 296 |
+
else:
|
| 297 |
+
rem_res = []
|
| 298 |
+
|
| 299 |
+
# stick the new ones back into the list and also cache any of the new ones
|
| 300 |
+
resptr = 0
|
| 301 |
+
for req, r in zip(remaining_reqs, rem_res):
|
| 302 |
+
while res[resptr] is not None:
|
| 303 |
+
resptr += 1
|
| 304 |
+
|
| 305 |
+
res[resptr] = r
|
| 306 |
+
|
| 307 |
+
# caching
|
| 308 |
+
hsh = hash_args(attr, req.args)
|
| 309 |
+
self.dbdict[hsh] = r
|
| 310 |
+
self.dbdict.commit()
|
| 311 |
+
|
| 312 |
+
return res
|
| 313 |
+
|
| 314 |
+
return _fn
|
| 315 |
+
|
| 316 |
+
def get_cache_hook(self) -> "CacheHook":
|
| 317 |
+
return CacheHook(self)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class TemplateLM(LM):
|
| 321 |
+
"""
|
| 322 |
+
A class acting as intermediary between the LM base class
|
| 323 |
+
and boilerplate often included in other LM subclasses.
|
| 324 |
+
"""
|
| 325 |
+
|
| 326 |
+
tokenizer = None
|
| 327 |
+
|
| 328 |
+
@property
|
| 329 |
+
@abc.abstractmethod
|
| 330 |
+
def eot_token_id(self):
|
| 331 |
+
pass
|
| 332 |
+
|
| 333 |
+
@property
|
| 334 |
+
def prefix_token_id(self):
|
| 335 |
+
# it is used as prefix for loglikelihood
|
| 336 |
+
return self.eot_token_id
|
| 337 |
+
|
| 338 |
+
@abc.abstractmethod
|
| 339 |
+
def tok_encode(self, string: str, **kwargs) -> list[int]:
|
| 340 |
+
"""
|
| 341 |
+
Tokenize a string using the model's tokenizer and return a list of token IDs.
|
| 342 |
+
"""
|
| 343 |
+
pass
|
| 344 |
+
|
| 345 |
+
@abc.abstractmethod
|
| 346 |
+
def _loglikelihood_tokens(
|
| 347 |
+
self, requests: list["Instance"], **kwargs
|
| 348 |
+
) -> list[tuple[float, bool]]:
|
| 349 |
+
pass
|
| 350 |
+
|
| 351 |
+
def _encode_pair(
|
| 352 |
+
self, context: str, continuation: str
|
| 353 |
+
) -> tuple[list[int], list[int]]:
|
| 354 |
+
import transformers
|
| 355 |
+
|
| 356 |
+
n_spaces = len(context) - len(context.rstrip())
|
| 357 |
+
if n_spaces > 0:
|
| 358 |
+
continuation = context[-n_spaces:] + continuation
|
| 359 |
+
context = context[:-n_spaces]
|
| 360 |
+
|
| 361 |
+
model_class = getattr(self, "AUTO_MODEL_CLASS", None)
|
| 362 |
+
|
| 363 |
+
if model_class == transformers.AutoModelForSeq2SeqLM:
|
| 364 |
+
context_enc = self.tok_encode(context)
|
| 365 |
+
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
|
| 366 |
+
else:
|
| 367 |
+
whole_enc = self.tok_encode(context + continuation)
|
| 368 |
+
context_enc = self.tok_encode(context)
|
| 369 |
+
|
| 370 |
+
context_enc_len = len(context_enc)
|
| 371 |
+
continuation_enc = whole_enc[context_enc_len:]
|
| 372 |
+
|
| 373 |
+
return context_enc, continuation_enc
|
| 374 |
+
|
| 375 |
+
def loglikelihood(
|
| 376 |
+
self, requests: list["Instance"], disable_tqdm: bool = False
|
| 377 |
+
) -> list[tuple[float, bool]]:
|
| 378 |
+
new_reqs = []
|
| 379 |
+
for context, continuation in [req.args for req in requests]:
|
| 380 |
+
if context == "":
|
| 381 |
+
# BOS or EOS as context
|
| 382 |
+
context_enc, continuation_enc = (
|
| 383 |
+
[self.prefix_token_id],
|
| 384 |
+
self.tok_encode(continuation),
|
| 385 |
+
)
|
| 386 |
+
else:
|
| 387 |
+
context_enc, continuation_enc = self._encode_pair(context, continuation)
|
| 388 |
+
|
| 389 |
+
new_reqs.append(((context, continuation), context_enc, continuation_enc))
|
| 390 |
+
|
| 391 |
+
return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
|
| 392 |
+
|
| 393 |
+
@abc.abstractmethod
|
| 394 |
+
def loglikelihood_rolling(
|
| 395 |
+
self, requests, disable_tqdm: bool = False
|
| 396 |
+
) -> list[float]:
|
| 397 |
+
pass
|
| 398 |
+
|
| 399 |
+
@abc.abstractmethod
|
| 400 |
+
def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
|
| 401 |
+
pass
|
| 402 |
+
|
| 403 |
+
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
|
| 404 |
+
"""
|
| 405 |
+
Set and get the appropriate chat template for the model.
|
| 406 |
+
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
|
| 407 |
+
|
| 408 |
+
The template selection logic is adapted from the Transformers library's `apply_chat_template`
|
| 409 |
+
method in the Tokenizer class. The original implementation can be found at:
|
| 410 |
+
https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687
|
| 411 |
+
|
| 412 |
+
This method ensures that the right template is chosen based on the following:
|
| 413 |
+
0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string.
|
| 414 |
+
1. If the model's tokenizer has multiple templates:
|
| 415 |
+
a. Use the specified template if it exists in the dictionary.
|
| 416 |
+
b. Use the default template from the list if no specific template is provided.
|
| 417 |
+
c. Raise an error if no default template exists and no specific template is provided.
|
| 418 |
+
2. If the model's tokenizer has a single template or no template:
|
| 419 |
+
a. Use the tokenizer's chat template if available.
|
| 420 |
+
b. Fall back to the default chat template if no tokenizer chat template exists.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
chat_template (Union[bool, str]): Specifies the chat template to use.
|
| 424 |
+
- If False or None, no template is applied.
|
| 425 |
+
- If True, the default or only available template is used.
|
| 426 |
+
- If a string, the template with the matching name is used.
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
Optional[str]: The selected chat template, or None if no template is applied.
|
| 430 |
+
"""
|
| 431 |
+
if self.tokenizer is None:
|
| 432 |
+
return ""
|
| 433 |
+
|
| 434 |
+
if chat_template is False or chat_template is None:
|
| 435 |
+
eval_logger.warning(
|
| 436 |
+
"model.chat_template was called with the chat_template set to False or None. "
|
| 437 |
+
"Therefore no chat template will be applied. Make sure this is an intended behavior."
|
| 438 |
+
)
|
| 439 |
+
return None
|
| 440 |
+
|
| 441 |
+
# Convert boolean chat_template to None to ensure compatibility with the adapted logic
|
| 442 |
+
if isinstance(chat_template, bool):
|
| 443 |
+
chat_template = None
|
| 444 |
+
using_default_template = False
|
| 445 |
+
|
| 446 |
+
# First, handle the cases when the model has a dict of multiple templates
|
| 447 |
+
try:
|
| 448 |
+
template = (
|
| 449 |
+
self.tokenizer.chat_template or self.tokenizer.default_chat_template
|
| 450 |
+
)
|
| 451 |
+
except AttributeError:
|
| 452 |
+
return None
|
| 453 |
+
|
| 454 |
+
if isinstance(template, dict):
|
| 455 |
+
using_default_dict = self.tokenizer.chat_template is None
|
| 456 |
+
|
| 457 |
+
if chat_template is not None:
|
| 458 |
+
if chat_template in template:
|
| 459 |
+
selected_template = template[chat_template]
|
| 460 |
+
if using_default_dict:
|
| 461 |
+
using_default_template = True
|
| 462 |
+
else:
|
| 463 |
+
raise ValueError(
|
| 464 |
+
f"The specified chat template '{chat_template}' is not available. "
|
| 465 |
+
f"Available template names are {sorted(template.keys())}."
|
| 466 |
+
)
|
| 467 |
+
else:
|
| 468 |
+
# If user didn't pass a chat template, use the default template from the dict
|
| 469 |
+
if "default" in template:
|
| 470 |
+
selected_template = template["default"]
|
| 471 |
+
using_default_template = True
|
| 472 |
+
else:
|
| 473 |
+
raise ValueError(
|
| 474 |
+
"This model has multiple chat templates with no default specified! Please either pass a chat "
|
| 475 |
+
"template or the name of the template you wish to use to the `chat_template` argument. Available "
|
| 476 |
+
f"template names are {sorted(template.keys())}."
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
# Cases when the model has a single template or no template
|
| 480 |
+
else:
|
| 481 |
+
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
|
| 482 |
+
if isinstance(chat_template, str):
|
| 483 |
+
eval_logger.warning(
|
| 484 |
+
"Chat template name provided, but the tokenizer's chat template is not a dictionary. "
|
| 485 |
+
"Using the tokenizer's chat template or the default template instead."
|
| 486 |
+
)
|
| 487 |
+
if self.tokenizer.chat_template is not None:
|
| 488 |
+
selected_template = self.tokenizer.chat_template
|
| 489 |
+
else:
|
| 490 |
+
selected_template = self.tokenizer.default_chat_template
|
| 491 |
+
using_default_template = True
|
| 492 |
+
|
| 493 |
+
if using_default_template:
|
| 494 |
+
eval_logger.warning(
|
| 495 |
+
"No chat template is set for this tokenizer, falling back to a default class-level template. This is "
|
| 496 |
+
"very error-prone, because models are often trained with templates different from the class default! "
|
| 497 |
+
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
|
| 498 |
+
"point any code depending on them will stop working. We recommend setting a valid chat template before "
|
| 499 |
+
"then to ensure that this model continues working without issues."
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
return selected_template
|
lm-evaluation-harness/lm_eval/api/registry.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Callable, Dict, Union
|
| 3 |
+
|
| 4 |
+
import evaluate as hf_evaluate
|
| 5 |
+
|
| 6 |
+
from lm_eval.api.model import LM
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
eval_logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
MODEL_REGISTRY = {}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def register_model(*names):
|
| 15 |
+
# either pass a list or a single alias.
|
| 16 |
+
# function receives them as a tuple of strings
|
| 17 |
+
|
| 18 |
+
def decorate(cls):
|
| 19 |
+
for name in names:
|
| 20 |
+
assert issubclass(cls, LM), (
|
| 21 |
+
f"Model '{name}' ({cls.__name__}) must extend LM class"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
assert name not in MODEL_REGISTRY, (
|
| 25 |
+
f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
MODEL_REGISTRY[name] = cls
|
| 29 |
+
return cls
|
| 30 |
+
|
| 31 |
+
return decorate
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_model(model_name):
|
| 35 |
+
try:
|
| 36 |
+
return MODEL_REGISTRY[model_name]
|
| 37 |
+
except KeyError:
|
| 38 |
+
raise ValueError(
|
| 39 |
+
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
TASK_REGISTRY = {}
|
| 44 |
+
GROUP_REGISTRY = {}
|
| 45 |
+
ALL_TASKS = set()
|
| 46 |
+
func2task_index = {}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def register_task(name):
|
| 50 |
+
def decorate(fn):
|
| 51 |
+
assert name not in TASK_REGISTRY, (
|
| 52 |
+
f"task named '{name}' conflicts with existing registered task!"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
TASK_REGISTRY[name] = fn
|
| 56 |
+
ALL_TASKS.add(name)
|
| 57 |
+
func2task_index[fn.__name__] = name
|
| 58 |
+
return fn
|
| 59 |
+
|
| 60 |
+
return decorate
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def register_group(name):
|
| 64 |
+
def decorate(fn):
|
| 65 |
+
func_name = func2task_index[fn.__name__]
|
| 66 |
+
if name in GROUP_REGISTRY:
|
| 67 |
+
GROUP_REGISTRY[name].append(func_name)
|
| 68 |
+
else:
|
| 69 |
+
GROUP_REGISTRY[name] = [func_name]
|
| 70 |
+
ALL_TASKS.add(name)
|
| 71 |
+
return fn
|
| 72 |
+
|
| 73 |
+
return decorate
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
OUTPUT_TYPE_REGISTRY = {}
|
| 77 |
+
METRIC_REGISTRY = {}
|
| 78 |
+
METRIC_AGGREGATION_REGISTRY = {}
|
| 79 |
+
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
|
| 80 |
+
HIGHER_IS_BETTER_REGISTRY = {}
|
| 81 |
+
FILTER_REGISTRY = {}
|
| 82 |
+
|
| 83 |
+
DEFAULT_METRIC_REGISTRY = {
|
| 84 |
+
"loglikelihood": [
|
| 85 |
+
"perplexity",
|
| 86 |
+
"acc",
|
| 87 |
+
],
|
| 88 |
+
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
|
| 89 |
+
"multiple_choice": ["acc", "acc_norm"],
|
| 90 |
+
"generate_until": ["exact_match"],
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def register_metric(**args):
|
| 95 |
+
# TODO: do we want to enforce a certain interface to registered metrics?
|
| 96 |
+
def decorate(fn):
|
| 97 |
+
assert "metric" in args
|
| 98 |
+
name = args["metric"]
|
| 99 |
+
|
| 100 |
+
for key, registry in [
|
| 101 |
+
("metric", METRIC_REGISTRY),
|
| 102 |
+
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
|
| 103 |
+
("aggregation", METRIC_AGGREGATION_REGISTRY),
|
| 104 |
+
]:
|
| 105 |
+
if key in args:
|
| 106 |
+
value = args[key]
|
| 107 |
+
assert value not in registry, (
|
| 108 |
+
f"{key} named '{value}' conflicts with existing registered {key}!"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if key == "metric":
|
| 112 |
+
registry[name] = fn
|
| 113 |
+
elif key == "aggregation":
|
| 114 |
+
registry[name] = AGGREGATION_REGISTRY[value]
|
| 115 |
+
else:
|
| 116 |
+
registry[name] = value
|
| 117 |
+
|
| 118 |
+
return fn
|
| 119 |
+
|
| 120 |
+
return decorate
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
|
| 124 |
+
if not hf_evaluate_metric:
|
| 125 |
+
if name in METRIC_REGISTRY:
|
| 126 |
+
return METRIC_REGISTRY[name]
|
| 127 |
+
else:
|
| 128 |
+
eval_logger.warning(
|
| 129 |
+
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
metric_object = hf_evaluate.load(name)
|
| 134 |
+
return metric_object.compute
|
| 135 |
+
except Exception:
|
| 136 |
+
eval_logger.error(
|
| 137 |
+
f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def register_aggregation(name: str):
|
| 142 |
+
def decorate(fn):
|
| 143 |
+
assert name not in AGGREGATION_REGISTRY, (
|
| 144 |
+
f"aggregation named '{name}' conflicts with existing registered aggregation!"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
AGGREGATION_REGISTRY[name] = fn
|
| 148 |
+
return fn
|
| 149 |
+
|
| 150 |
+
return decorate
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
|
| 154 |
+
try:
|
| 155 |
+
return AGGREGATION_REGISTRY[name]
|
| 156 |
+
except KeyError:
|
| 157 |
+
eval_logger.warning(f"{name} not a registered aggregation metric!")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
|
| 161 |
+
try:
|
| 162 |
+
return METRIC_AGGREGATION_REGISTRY[name]
|
| 163 |
+
except KeyError:
|
| 164 |
+
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def is_higher_better(metric_name) -> bool:
|
| 168 |
+
try:
|
| 169 |
+
return HIGHER_IS_BETTER_REGISTRY[metric_name]
|
| 170 |
+
except KeyError:
|
| 171 |
+
eval_logger.warning(
|
| 172 |
+
f"higher_is_better not specified for metric '{metric_name}'!"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def register_filter(name):
|
| 177 |
+
def decorate(cls):
|
| 178 |
+
if name in FILTER_REGISTRY:
|
| 179 |
+
eval_logger.info(
|
| 180 |
+
f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}"
|
| 181 |
+
)
|
| 182 |
+
FILTER_REGISTRY[name] = cls
|
| 183 |
+
return cls
|
| 184 |
+
|
| 185 |
+
return decorate
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_filter(filter_name: Union[str, Callable]) -> Callable:
|
| 189 |
+
try:
|
| 190 |
+
return FILTER_REGISTRY[filter_name]
|
| 191 |
+
except KeyError as e:
|
| 192 |
+
if callable(filter_name):
|
| 193 |
+
return filter_name
|
| 194 |
+
else:
|
| 195 |
+
eval_logger.warning(f"filter `{filter_name}` is not registered!")
|
| 196 |
+
raise e
|
lm-evaluation-harness/lm_eval/api/samplers.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import warnings
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import TYPE_CHECKING, Iterable, Optional, Union
|
| 5 |
+
|
| 6 |
+
import datasets
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from random import Random
|
| 11 |
+
|
| 12 |
+
from lm_eval.api.task import ConfigurableTask, Task
|
| 13 |
+
|
| 14 |
+
eval_logger = logging.getLogger("lm-eval")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ContextSampler:
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
docs: list[dict],
|
| 21 |
+
task: Union["Task", "ConfigurableTask"],
|
| 22 |
+
fewshot_indices: Optional[Iterable] = None,
|
| 23 |
+
rnd: Optional["Random"] = None,
|
| 24 |
+
) -> None:
|
| 25 |
+
self.rnd = rnd
|
| 26 |
+
if not self.rnd:
|
| 27 |
+
raise ValueError(
|
| 28 |
+
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
self.task = task
|
| 32 |
+
self.config = task._config
|
| 33 |
+
|
| 34 |
+
self.target_delimiter = self.config.target_delimiter
|
| 35 |
+
self.fewshot_delimiter = self.config.fewshot_delimiter
|
| 36 |
+
|
| 37 |
+
if (
|
| 38 |
+
self.config.fewshot_config is not None
|
| 39 |
+
and self.config.fewshot_config.get("doc_to_text", None) is not None
|
| 40 |
+
):
|
| 41 |
+
self.doc_to_text = partial(
|
| 42 |
+
self.task.doc_to_text,
|
| 43 |
+
doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
|
| 44 |
+
)
|
| 45 |
+
else:
|
| 46 |
+
self.doc_to_text = self.task.doc_to_text
|
| 47 |
+
|
| 48 |
+
if (
|
| 49 |
+
self.config.fewshot_config is not None
|
| 50 |
+
and self.config.fewshot_config.get("doc_to_target", None) is not None
|
| 51 |
+
):
|
| 52 |
+
self.doc_to_target = partial(
|
| 53 |
+
self.task.doc_to_target,
|
| 54 |
+
doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
self.doc_to_target = self.task.doc_to_target
|
| 58 |
+
|
| 59 |
+
if (
|
| 60 |
+
self.config.fewshot_config is not None
|
| 61 |
+
and self.config.fewshot_config.get("doc_to_choice", None) is not None
|
| 62 |
+
):
|
| 63 |
+
self.doc_to_choice = partial(
|
| 64 |
+
self.task.doc_to_choice,
|
| 65 |
+
doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
self.doc_to_choice = self.task.doc_to_choice
|
| 69 |
+
|
| 70 |
+
self.docs = docs # HF dataset split, provided by task._fewshot_docs()
|
| 71 |
+
if fewshot_indices: # subset few-shot docs from
|
| 72 |
+
if not isinstance(self.docs, datasets.Dataset):
|
| 73 |
+
raise ValueError(
|
| 74 |
+
"Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
|
| 75 |
+
)
|
| 76 |
+
self.docs = self.docs.select(fewshot_indices)
|
| 77 |
+
|
| 78 |
+
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
|
| 79 |
+
# draw an extra fewshot sample if using same split as evaluating on
|
| 80 |
+
prefix = gen_prefix + " " if gen_prefix else ""
|
| 81 |
+
n_samples = (
|
| 82 |
+
num_fewshot + 1
|
| 83 |
+
if self.config.fewshot_split == self.config.test_split
|
| 84 |
+
else num_fewshot
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# draw `n_samples` docs from fewshot_docs
|
| 88 |
+
fewshotex = self.sample(n_samples)
|
| 89 |
+
|
| 90 |
+
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
| 91 |
+
# TODO: should we just stop people from using fewshot from same split as evaluating?
|
| 92 |
+
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
|
| 93 |
+
|
| 94 |
+
labeled_examples = ""
|
| 95 |
+
for doc in selected_docs:
|
| 96 |
+
doc_content = self.doc_to_text(doc)
|
| 97 |
+
doc_target = self.doc_to_target(doc)
|
| 98 |
+
if self.config.doc_to_choice is None or isinstance(doc_content, str):
|
| 99 |
+
labeled_examples += doc_content
|
| 100 |
+
else:
|
| 101 |
+
labeled_examples += self.doc_to_choice(doc)[doc_content]
|
| 102 |
+
|
| 103 |
+
if doc_target != "":
|
| 104 |
+
if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
|
| 105 |
+
# TODO: add logger warn once here.
|
| 106 |
+
warnings.warn(
|
| 107 |
+
"Both target_delimiter and target start with a space. This may cause issues.",
|
| 108 |
+
Warning,
|
| 109 |
+
stacklevel=2,
|
| 110 |
+
)
|
| 111 |
+
labeled_examples += self.target_delimiter
|
| 112 |
+
labeled_examples += prefix
|
| 113 |
+
labeled_examples += (
|
| 114 |
+
str(doc_target[0])
|
| 115 |
+
if isinstance(doc_target, list)
|
| 116 |
+
else doc_target
|
| 117 |
+
if self.config.doc_to_choice is None or isinstance(doc_target, str)
|
| 118 |
+
else str(self.doc_to_choice(doc)[doc_target])
|
| 119 |
+
)
|
| 120 |
+
labeled_examples += self.fewshot_delimiter
|
| 121 |
+
|
| 122 |
+
return labeled_examples
|
| 123 |
+
|
| 124 |
+
def get_chat_context(
|
| 125 |
+
self,
|
| 126 |
+
doc: dict,
|
| 127 |
+
num_fewshot: int,
|
| 128 |
+
fewshot_as_multiturn: bool = False,
|
| 129 |
+
gen_prefix: Optional[str] = None,
|
| 130 |
+
):
|
| 131 |
+
# TODO: Do we need any other delimiter
|
| 132 |
+
prefix = gen_prefix + " " if gen_prefix else ""
|
| 133 |
+
chat_history = []
|
| 134 |
+
# draw an extra fewshot sample if using same split as evaluating on
|
| 135 |
+
n_samples = (
|
| 136 |
+
num_fewshot + 1
|
| 137 |
+
if self.config.fewshot_split == self.config.test_split
|
| 138 |
+
else num_fewshot
|
| 139 |
+
)
|
| 140 |
+
# draw `n_samples` docs from fewshot_docs
|
| 141 |
+
fewshotex = self.sample(n_samples)
|
| 142 |
+
|
| 143 |
+
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
| 144 |
+
# TODO: should we just stop people from using fewshot from same split as evaluating?
|
| 145 |
+
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
|
| 146 |
+
|
| 147 |
+
if fewshot_as_multiturn:
|
| 148 |
+
for doc in selected_docs:
|
| 149 |
+
doc_content = self.doc_to_text(doc)
|
| 150 |
+
doc_target = self.doc_to_target(doc)
|
| 151 |
+
chat_history.append(
|
| 152 |
+
{
|
| 153 |
+
"role": "user",
|
| 154 |
+
"content": doc_content
|
| 155 |
+
if self.config.doc_to_choice is None
|
| 156 |
+
or isinstance(doc_content, str)
|
| 157 |
+
else self.doc_to_choice(doc)[doc_content],
|
| 158 |
+
}
|
| 159 |
+
)
|
| 160 |
+
chat_history.append(
|
| 161 |
+
{
|
| 162 |
+
"role": "assistant",
|
| 163 |
+
"content": prefix + str(doc_target[0])
|
| 164 |
+
if isinstance(doc_target, list)
|
| 165 |
+
else prefix + doc_target
|
| 166 |
+
if self.config.doc_to_choice is None
|
| 167 |
+
or isinstance(doc_target, str)
|
| 168 |
+
else prefix + str(self.doc_to_choice(doc)[doc_target]),
|
| 169 |
+
}
|
| 170 |
+
)
|
| 171 |
+
else:
|
| 172 |
+
# get fewshot context as one user turn
|
| 173 |
+
chat_history.append(
|
| 174 |
+
{
|
| 175 |
+
"role": "user",
|
| 176 |
+
"content": self.get_context(
|
| 177 |
+
doc, num_fewshot, gen_prefix=gen_prefix
|
| 178 |
+
),
|
| 179 |
+
}
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
return chat_history
|
| 183 |
+
|
| 184 |
+
def sample(self, n: int):
|
| 185 |
+
"""
|
| 186 |
+
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
return self.rnd.sample(self.docs, n)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class FirstNSampler(ContextSampler):
|
| 193 |
+
def sample(self, n: int) -> None:
|
| 194 |
+
"""
|
| 195 |
+
Draw the first `n` samples in order from the specified split.
|
| 196 |
+
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
|
| 197 |
+
"""
|
| 198 |
+
assert n <= len(self.docs), (
|
| 199 |
+
f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
|
| 200 |
+
)
|
| 201 |
+
return self.docs[:n]
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class BalancedSampler(ContextSampler):
|
| 205 |
+
def sample(self, n: int) -> None:
|
| 206 |
+
"""
|
| 207 |
+
TODO: this should return approximately class-balanced samples from our fewshot examples.
|
| 208 |
+
TODO: what order should they be in? maybe random?
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
pass
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class ManualSampler(ContextSampler):
|
| 215 |
+
def sample(self, n: int) -> None:
|
| 216 |
+
""" """
|
| 217 |
+
pass
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
SAMPLER_REGISTRY = {
|
| 221 |
+
"default": ContextSampler,
|
| 222 |
+
"first_n": FirstNSampler,
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_sampler(name: str):
|
| 227 |
+
try:
|
| 228 |
+
return SAMPLER_REGISTRY[name]
|
| 229 |
+
except KeyError:
|
| 230 |
+
raise ValueError(
|
| 231 |
+
f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
|
| 232 |
+
)
|
lm-evaluation-harness/lm_eval/api/task.py
ADDED
|
@@ -0,0 +1,1885 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import ast
|
| 3 |
+
import logging
|
| 4 |
+
import random
|
| 5 |
+
import re
|
| 6 |
+
from collections.abc import Callable
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from dataclasses import asdict, dataclass
|
| 9 |
+
from inspect import getsource
|
| 10 |
+
from typing import (
|
| 11 |
+
Any,
|
| 12 |
+
Dict,
|
| 13 |
+
Iterable,
|
| 14 |
+
Iterator,
|
| 15 |
+
List,
|
| 16 |
+
Literal,
|
| 17 |
+
Mapping,
|
| 18 |
+
Optional,
|
| 19 |
+
Tuple,
|
| 20 |
+
Union,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
import datasets
|
| 24 |
+
import numpy as np
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
|
| 27 |
+
from lm_eval import utils
|
| 28 |
+
from lm_eval.api import samplers
|
| 29 |
+
from lm_eval.api.instance import Instance, OutputType
|
| 30 |
+
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
|
| 31 |
+
from lm_eval.api.registry import (
|
| 32 |
+
AGGREGATION_REGISTRY,
|
| 33 |
+
DEFAULT_METRIC_REGISTRY,
|
| 34 |
+
get_aggregation,
|
| 35 |
+
get_metric,
|
| 36 |
+
get_metric_aggregation,
|
| 37 |
+
is_higher_better,
|
| 38 |
+
)
|
| 39 |
+
from lm_eval.caching.cache import load_from_cache, save_to_cache
|
| 40 |
+
from lm_eval.filters import build_filter_ensemble
|
| 41 |
+
from lm_eval.prompts import get_prompt
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
ALL_OUTPUT_TYPES = [
|
| 45 |
+
"loglikelihood",
|
| 46 |
+
"multiple_choice",
|
| 47 |
+
"loglikelihood_rolling",
|
| 48 |
+
"generate_until",
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
eval_logger = logging.getLogger(__name__)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class TaskConfig(dict):
|
| 56 |
+
# task naming/registry
|
| 57 |
+
task: Optional[str] = None
|
| 58 |
+
task_alias: Optional[str] = None
|
| 59 |
+
tag: Optional[Union[str, list]] = None
|
| 60 |
+
# HF dataset options.
|
| 61 |
+
# which dataset to use,
|
| 62 |
+
# and what splits for what purpose
|
| 63 |
+
custom_dataset: Optional[Callable] = None
|
| 64 |
+
dataset_path: Optional[str] = None
|
| 65 |
+
dataset_name: Optional[str] = None
|
| 66 |
+
dataset_kwargs: Optional[dict] = None
|
| 67 |
+
training_split: Optional[str] = None
|
| 68 |
+
validation_split: Optional[str] = None
|
| 69 |
+
test_split: Optional[str] = None
|
| 70 |
+
fewshot_split: Optional[str] = (
|
| 71 |
+
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
|
| 72 |
+
)
|
| 73 |
+
# formatting / prompting options.
|
| 74 |
+
# see docs/advanced_task_guide.md for more info
|
| 75 |
+
process_docs: Optional[Callable] = None
|
| 76 |
+
doc_to_text: Optional[Union[Callable, str]] = None
|
| 77 |
+
doc_to_target: Optional[Union[Callable, str]] = None
|
| 78 |
+
doc_to_image: Union[Callable, str] = None
|
| 79 |
+
doc_to_audio: Union[Callable, str] = None
|
| 80 |
+
unsafe_code: bool = False
|
| 81 |
+
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
|
| 82 |
+
process_results: Optional[Union[Callable, str]] = None
|
| 83 |
+
use_prompt: Optional[str] = None
|
| 84 |
+
description: str = ""
|
| 85 |
+
target_delimiter: str = " "
|
| 86 |
+
fewshot_delimiter: str = "\n\n"
|
| 87 |
+
fewshot_config: Optional[dict] = None
|
| 88 |
+
# runtime configuration options
|
| 89 |
+
num_fewshot: Optional[int] = None
|
| 90 |
+
# scoring options
|
| 91 |
+
metric_list: Optional[list] = None
|
| 92 |
+
output_type: OutputType = "generate_until"
|
| 93 |
+
generation_kwargs: Optional[dict] = None
|
| 94 |
+
repeats: int = 1
|
| 95 |
+
filter_list: Optional[Union[str, list]] = None
|
| 96 |
+
should_decontaminate: bool = False
|
| 97 |
+
doc_to_decontamination_query: Optional[str] = None
|
| 98 |
+
gen_prefix: Optional[str] = None
|
| 99 |
+
metadata: Optional[dict] = (
|
| 100 |
+
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def __post_init__(self) -> None:
|
| 104 |
+
if self.generation_kwargs is not None:
|
| 105 |
+
if self.output_type != "generate_until":
|
| 106 |
+
eval_logger.warning(
|
| 107 |
+
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if "temperature" in self.generation_kwargs:
|
| 111 |
+
self.generation_kwargs["temperature"] = float(
|
| 112 |
+
self.generation_kwargs["temperature"]
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if "until" not in self.generation_kwargs:
|
| 116 |
+
eval_logger.warning(
|
| 117 |
+
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
|
| 118 |
+
)
|
| 119 |
+
self.generation_kwargs["until"] = [self.fewshot_delimiter]
|
| 120 |
+
else:
|
| 121 |
+
if self.output_type == "generate_until":
|
| 122 |
+
# ensure that we greedily generate in absence of explicit arguments otherwise
|
| 123 |
+
self.generation_kwargs = {
|
| 124 |
+
"until": (
|
| 125 |
+
None
|
| 126 |
+
if self.fewshot_delimiter is None
|
| 127 |
+
else [self.fewshot_delimiter]
|
| 128 |
+
),
|
| 129 |
+
"do_sample": False,
|
| 130 |
+
"temperature": 0,
|
| 131 |
+
}
|
| 132 |
+
eval_logger.warning(
|
| 133 |
+
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def __getitem__(self, item):
|
| 137 |
+
return getattr(self, item)
|
| 138 |
+
|
| 139 |
+
def __setitem__(self, item, value):
|
| 140 |
+
return setattr(self, item, value)
|
| 141 |
+
|
| 142 |
+
def to_dict(self, keep_callable: bool = False) -> dict:
|
| 143 |
+
"""dumps the current config as a dictionary object, as a printable format.
|
| 144 |
+
null fields will not be printed.
|
| 145 |
+
Used for dumping results alongside full task configuration
|
| 146 |
+
|
| 147 |
+
:return: dict
|
| 148 |
+
A printable dictionary version of the TaskConfig object.
|
| 149 |
+
|
| 150 |
+
# TODO: should any default value in the TaskConfig not be printed?
|
| 151 |
+
"""
|
| 152 |
+
cfg_dict = asdict(self)
|
| 153 |
+
# remove values that are `None`
|
| 154 |
+
for k, v in list(cfg_dict.items()):
|
| 155 |
+
if v is None:
|
| 156 |
+
cfg_dict.pop(k)
|
| 157 |
+
elif k == "metric_list":
|
| 158 |
+
for metric_dict in v:
|
| 159 |
+
for metric_key, metric_value in metric_dict.items():
|
| 160 |
+
if callable(metric_value):
|
| 161 |
+
metric_dict[metric_key] = self.serialize_function(
|
| 162 |
+
metric_value, keep_callable=keep_callable
|
| 163 |
+
)
|
| 164 |
+
cfg_dict[k] = v
|
| 165 |
+
elif callable(v):
|
| 166 |
+
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
|
| 167 |
+
return cfg_dict
|
| 168 |
+
|
| 169 |
+
def serialize_function(
|
| 170 |
+
self, value: Union[Callable, str], keep_callable=False
|
| 171 |
+
) -> Union[Callable, str]:
|
| 172 |
+
"""Serializes a given function or string.
|
| 173 |
+
|
| 174 |
+
If 'keep_callable' is True, the original callable is returned.
|
| 175 |
+
Otherwise, attempts to return the source code of the callable using 'getsource'.
|
| 176 |
+
"""
|
| 177 |
+
if keep_callable:
|
| 178 |
+
return value
|
| 179 |
+
else:
|
| 180 |
+
try:
|
| 181 |
+
return getsource(value)
|
| 182 |
+
except (TypeError, OSError):
|
| 183 |
+
return str(value)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class Task(abc.ABC):
|
| 187 |
+
"""A task represents an entire benchmark including its dataset, problems,
|
| 188 |
+
answers, and evaluation methods. See BoolQ for a simple example implementation
|
| 189 |
+
|
| 190 |
+
A `doc` can be any python object which represents one instance of evaluation.
|
| 191 |
+
This is usually a dictionary e.g.
|
| 192 |
+
{"question": ..., "answer": ...} or
|
| 193 |
+
{"question": ..., question, answer)
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
VERSION: Optional[Union[int, str]] = None
|
| 197 |
+
|
| 198 |
+
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
|
| 199 |
+
# or a path to a custom `datasets` loading script.
|
| 200 |
+
DATASET_PATH: Optional[str] = None
|
| 201 |
+
|
| 202 |
+
# The name of a subset within `DATASET_PATH`.
|
| 203 |
+
DATASET_NAME: Optional[str] = None
|
| 204 |
+
|
| 205 |
+
OUTPUT_TYPE: Optional[OutputType] = None
|
| 206 |
+
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
data_dir: Optional[str] = None,
|
| 210 |
+
cache_dir: Optional[str] = None,
|
| 211 |
+
download_mode: Optional[datasets.DownloadMode] = None,
|
| 212 |
+
config: Optional[Mapping] = None, # Union[dict, TaskConfig]
|
| 213 |
+
) -> None:
|
| 214 |
+
"""
|
| 215 |
+
:param data_dir: str
|
| 216 |
+
Stores the path to a local folder containing the `Task`'s data files.
|
| 217 |
+
Use this to specify the path to manually downloaded data (usually when
|
| 218 |
+
the dataset is not publicly accessible).
|
| 219 |
+
:param cache_dir: str
|
| 220 |
+
The directory to read/write the `Task` dataset. This follows the
|
| 221 |
+
HuggingFace `datasets` API with the default cache directory located at:
|
| 222 |
+
`~/.cache/huggingface/datasets`
|
| 223 |
+
NOTE: You can change the cache location globally for a given process
|
| 224 |
+
to another directory:
|
| 225 |
+
`export HF_DATASETS_CACHE="/path/to/another/directory"`
|
| 226 |
+
:param download_mode: datasets.DownloadMode
|
| 227 |
+
How to treat pre-existing `Task` downloads and data.
|
| 228 |
+
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
|
| 229 |
+
Reuse download and reuse dataset.
|
| 230 |
+
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
|
| 231 |
+
Reuse download with fresh dataset.
|
| 232 |
+
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
|
| 233 |
+
Fresh download and fresh dataset.
|
| 234 |
+
"""
|
| 235 |
+
self.download(data_dir, cache_dir, download_mode)
|
| 236 |
+
self._training_docs: Optional[list] = None
|
| 237 |
+
self._fewshot_docs: Optional[list] = None
|
| 238 |
+
self._instances: Optional[List[Instance]] = None
|
| 239 |
+
|
| 240 |
+
self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
|
| 241 |
+
|
| 242 |
+
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
|
| 243 |
+
self.fewshot_rnd: Optional[random.Random] = (
|
| 244 |
+
None # purposely induce errors in case of improper usage
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
def download(
|
| 248 |
+
self,
|
| 249 |
+
data_dir: Optional[str] = None,
|
| 250 |
+
cache_dir: Optional[str] = None,
|
| 251 |
+
download_mode=None,
|
| 252 |
+
) -> None:
|
| 253 |
+
"""Downloads and returns the task dataset.
|
| 254 |
+
Override this method to download the dataset from a custom API.
|
| 255 |
+
|
| 256 |
+
:param data_dir: str
|
| 257 |
+
Stores the path to a local folder containing the `Task`'s data files.
|
| 258 |
+
Use this to specify the path to manually downloaded data (usually when
|
| 259 |
+
the dataset is not publicly accessible).
|
| 260 |
+
:param cache_dir: str
|
| 261 |
+
The directory to read/write the `Task` dataset. This follows the
|
| 262 |
+
HuggingFace `datasets` API with the default cache directory located at:
|
| 263 |
+
`~/.cache/huggingface/datasets`
|
| 264 |
+
NOTE: You can change the cache location globally for a given process
|
| 265 |
+
by setting the shell environment variable, `HF_DATASETS_CACHE`,
|
| 266 |
+
to another directory:
|
| 267 |
+
`export HF_DATASETS_CACHE="/path/to/another/directory"`
|
| 268 |
+
:param download_mode: datasets.DownloadMode
|
| 269 |
+
How to treat pre-existing `Task` downloads and data.
|
| 270 |
+
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
|
| 271 |
+
Reuse download and reuse dataset.
|
| 272 |
+
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
|
| 273 |
+
Reuse download with fresh dataset.
|
| 274 |
+
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
|
| 275 |
+
Fresh download and fresh dataset.
|
| 276 |
+
"""
|
| 277 |
+
self.dataset = datasets.load_dataset(
|
| 278 |
+
path=self.DATASET_PATH,
|
| 279 |
+
name=self.DATASET_NAME,
|
| 280 |
+
data_dir=data_dir,
|
| 281 |
+
cache_dir=cache_dir,
|
| 282 |
+
download_mode=download_mode,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
@property
|
| 286 |
+
def config(self) -> TaskConfig:
|
| 287 |
+
"""Returns the TaskConfig associated with this class."""
|
| 288 |
+
return self._config
|
| 289 |
+
|
| 290 |
+
@abc.abstractmethod
|
| 291 |
+
def has_training_docs(self):
|
| 292 |
+
"""Whether the task has a training set"""
|
| 293 |
+
pass
|
| 294 |
+
|
| 295 |
+
@abc.abstractmethod
|
| 296 |
+
def has_validation_docs(self):
|
| 297 |
+
"""Whether the task has a validation set"""
|
| 298 |
+
pass
|
| 299 |
+
|
| 300 |
+
@abc.abstractmethod
|
| 301 |
+
def has_test_docs(self):
|
| 302 |
+
"""Whether the task has a test set"""
|
| 303 |
+
pass
|
| 304 |
+
|
| 305 |
+
def training_docs(self) -> Iterable:
|
| 306 |
+
"""
|
| 307 |
+
:return: Iterable[obj]
|
| 308 |
+
A iterable of any object, that doc_to_text can handle
|
| 309 |
+
"""
|
| 310 |
+
return []
|
| 311 |
+
|
| 312 |
+
def validation_docs(self) -> Iterable:
|
| 313 |
+
"""
|
| 314 |
+
:return: Iterable[obj]
|
| 315 |
+
A iterable of any object, that doc_to_text can handle
|
| 316 |
+
"""
|
| 317 |
+
return []
|
| 318 |
+
|
| 319 |
+
def test_docs(self) -> Iterable:
|
| 320 |
+
"""
|
| 321 |
+
:return: Iterable[obj]
|
| 322 |
+
A iterable of any object, that doc_to_text can handle
|
| 323 |
+
"""
|
| 324 |
+
return []
|
| 325 |
+
|
| 326 |
+
def fewshot_docs(self) -> Iterable:
|
| 327 |
+
"""
|
| 328 |
+
:return: Iterable[obj]
|
| 329 |
+
A iterable of any object, that doc_to_text can handle
|
| 330 |
+
"""
|
| 331 |
+
if self.has_training_docs():
|
| 332 |
+
return self.training_docs()
|
| 333 |
+
elif self.has_validation_docs():
|
| 334 |
+
return self.validation_docs()
|
| 335 |
+
else:
|
| 336 |
+
if self.config.get("num_fewshot", 0) > 0:
|
| 337 |
+
eval_logger.warning(
|
| 338 |
+
f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
|
| 339 |
+
", using test_docs as fewshot_docs but this is not recommended."
|
| 340 |
+
)
|
| 341 |
+
return self.test_docs()
|
| 342 |
+
|
| 343 |
+
def _process_doc(self, doc: dict) -> dict:
|
| 344 |
+
"""
|
| 345 |
+
Override this to process (detokenize, strip, replace, etc.) individual
|
| 346 |
+
documents. This can be used in a map over documents of a data split.
|
| 347 |
+
E.g. `map(self._process_doc, self.dataset["validation"])`
|
| 348 |
+
|
| 349 |
+
:return: dict
|
| 350 |
+
The processed version of the specified `doc`.
|
| 351 |
+
"""
|
| 352 |
+
return doc
|
| 353 |
+
|
| 354 |
+
@property
|
| 355 |
+
def instances(self) -> List[Instance]:
|
| 356 |
+
"""After calling `task.build_all_requests()`, tasks
|
| 357 |
+
maintain a list of the dataset instances which will be evaluated.
|
| 358 |
+
"""
|
| 359 |
+
return self._instances
|
| 360 |
+
|
| 361 |
+
def fewshot_examples(self, k, rnd):
|
| 362 |
+
if self._training_docs is None:
|
| 363 |
+
self._training_docs = list(self.training_docs())
|
| 364 |
+
|
| 365 |
+
return rnd.sample(self._training_docs, k)
|
| 366 |
+
|
| 367 |
+
def doc_to_decontamination_query(self, doc):
|
| 368 |
+
raise NotImplementedError(
|
| 369 |
+
"Override doc_to_decontamination_query with document specific decontamination query."
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
@abc.abstractmethod
|
| 373 |
+
def doc_to_text(self, doc):
|
| 374 |
+
pass
|
| 375 |
+
|
| 376 |
+
@abc.abstractmethod
|
| 377 |
+
def doc_to_target(self, doc):
|
| 378 |
+
pass
|
| 379 |
+
|
| 380 |
+
# not an abstractmethod because not every language-only task has to implement this
|
| 381 |
+
def doc_to_image(self, doc):
|
| 382 |
+
raise NotImplementedError
|
| 383 |
+
|
| 384 |
+
def doc_to_audio(self, doc):
|
| 385 |
+
raise NotImplementedError
|
| 386 |
+
|
| 387 |
+
def doc_to_prefix(self, doc):
|
| 388 |
+
return ""
|
| 389 |
+
|
| 390 |
+
def build_all_requests(
|
| 391 |
+
self,
|
| 392 |
+
*,
|
| 393 |
+
limit: Union[int, None] = None,
|
| 394 |
+
samples: Optional[List[int]] = None,
|
| 395 |
+
rank: int = 0,
|
| 396 |
+
world_size: int = 1,
|
| 397 |
+
cache_requests: bool = False,
|
| 398 |
+
rewrite_requests_cache: bool = False,
|
| 399 |
+
system_instruction: Optional[str] = None,
|
| 400 |
+
apply_chat_template: bool = False,
|
| 401 |
+
fewshot_as_multiturn: bool = False,
|
| 402 |
+
chat_template: Optional[Callable] = None,
|
| 403 |
+
tokenizer_name: str = "",
|
| 404 |
+
) -> None:
|
| 405 |
+
"""Build a set of Instances for a task, and store them in task.instances"""
|
| 406 |
+
|
| 407 |
+
# used with caching
|
| 408 |
+
og_limit = limit
|
| 409 |
+
|
| 410 |
+
cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}"
|
| 411 |
+
cache_key += "-chat_template" if apply_chat_template else ""
|
| 412 |
+
cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else ""
|
| 413 |
+
cache_key += (
|
| 414 |
+
f"-system_prompt_hash{utils.hash_string(system_instruction)}"
|
| 415 |
+
if system_instruction is not None
|
| 416 |
+
else ""
|
| 417 |
+
)
|
| 418 |
+
cache_key += f"-tokenizer{tokenizer_name}"
|
| 419 |
+
|
| 420 |
+
cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests)
|
| 421 |
+
|
| 422 |
+
if cache_requests and cached_instances and not rewrite_requests_cache:
|
| 423 |
+
cached_instances = cached_instances[:limit]
|
| 424 |
+
|
| 425 |
+
flattened_instances = [
|
| 426 |
+
instance
|
| 427 |
+
for instance_group in cached_instances
|
| 428 |
+
for instance in instance_group
|
| 429 |
+
]
|
| 430 |
+
|
| 431 |
+
self._instances = flattened_instances
|
| 432 |
+
return
|
| 433 |
+
|
| 434 |
+
eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
|
| 435 |
+
|
| 436 |
+
instances = []
|
| 437 |
+
|
| 438 |
+
# process all documents when caching is specified for simplicity
|
| 439 |
+
if (
|
| 440 |
+
cache_requests
|
| 441 |
+
and (not cached_instances or rewrite_requests_cache)
|
| 442 |
+
and limit is not None
|
| 443 |
+
):
|
| 444 |
+
limit = None
|
| 445 |
+
|
| 446 |
+
doc_id_docs = list(
|
| 447 |
+
self.doc_iterator(
|
| 448 |
+
rank=rank, limit=limit, samples=samples, world_size=world_size
|
| 449 |
+
)
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
num_docs = len(doc_id_docs)
|
| 453 |
+
|
| 454 |
+
for doc_id, doc in tqdm(
|
| 455 |
+
doc_id_docs,
|
| 456 |
+
total=num_docs,
|
| 457 |
+
):
|
| 458 |
+
# sample fewshot context #TODO: need to offset doc_id by rank now!
|
| 459 |
+
fewshot_ctx = self.fewshot_context(
|
| 460 |
+
doc,
|
| 461 |
+
num_fewshot=0
|
| 462 |
+
if self.config.num_fewshot is None
|
| 463 |
+
else self.config.num_fewshot,
|
| 464 |
+
system_instruction=system_instruction,
|
| 465 |
+
apply_chat_template=apply_chat_template,
|
| 466 |
+
fewshot_as_multiturn=fewshot_as_multiturn,
|
| 467 |
+
chat_template=chat_template,
|
| 468 |
+
gen_prefix=self.doc_to_prefix(doc),
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
|
| 472 |
+
inst = self.construct_requests(
|
| 473 |
+
doc=doc,
|
| 474 |
+
ctx=fewshot_ctx,
|
| 475 |
+
metadata=(self.config["task"], doc_id, self.config.repeats),
|
| 476 |
+
apply_chat_template=apply_chat_template,
|
| 477 |
+
chat_template=chat_template,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
if not isinstance(inst, list):
|
| 481 |
+
inst = [inst]
|
| 482 |
+
|
| 483 |
+
instances.append(inst)
|
| 484 |
+
|
| 485 |
+
# now flatten, this is to allow slicing to work with pickles
|
| 486 |
+
|
| 487 |
+
sliced_instances = instances[:og_limit]
|
| 488 |
+
|
| 489 |
+
flattened_instances = [
|
| 490 |
+
instance
|
| 491 |
+
for instance_group in sliced_instances
|
| 492 |
+
for instance in instance_group
|
| 493 |
+
]
|
| 494 |
+
|
| 495 |
+
self._instances = flattened_instances
|
| 496 |
+
|
| 497 |
+
if len(self._instances) == 0:
|
| 498 |
+
raise ValueError("task.build_requests() did not find any docs!")
|
| 499 |
+
|
| 500 |
+
if cache_requests and (not cached_instances or rewrite_requests_cache):
|
| 501 |
+
save_to_cache(file_name=cache_key, obj=instances)
|
| 502 |
+
|
| 503 |
+
@abc.abstractmethod
|
| 504 |
+
def construct_requests(self, doc, ctx, **kwargs):
|
| 505 |
+
"""Uses RequestFactory to construct Requests and returns an iterable of
|
| 506 |
+
Requests which will be sent to the LM.
|
| 507 |
+
|
| 508 |
+
:param doc:
|
| 509 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 510 |
+
:param ctx: str
|
| 511 |
+
The context string, generated by fewshot_context. This includes the natural
|
| 512 |
+
language description, as well as the few shot examples, and the question
|
| 513 |
+
part of the document for `doc`.
|
| 514 |
+
:param doc_idx: int
|
| 515 |
+
The index of a document within `self.test_docs()` or `self.validation_docs()`,
|
| 516 |
+
whichever is the main split used.
|
| 517 |
+
:param repeats: int
|
| 518 |
+
TODO: update this docstring
|
| 519 |
+
The number of times each instance in a dataset is inferred on. Defaults to 1,
|
| 520 |
+
can be increased for techniques like majority voting.
|
| 521 |
+
"""
|
| 522 |
+
pass
|
| 523 |
+
|
| 524 |
+
@abc.abstractmethod
|
| 525 |
+
def process_results(self, doc, results):
|
| 526 |
+
"""Take a single document and the LM results and evaluates, returning a
|
| 527 |
+
dict where keys are the names of submetrics and values are the values of
|
| 528 |
+
the metric for that one document
|
| 529 |
+
|
| 530 |
+
:param doc:
|
| 531 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 532 |
+
:param results:
|
| 533 |
+
The results of the requests created in construct_requests.
|
| 534 |
+
"""
|
| 535 |
+
pass
|
| 536 |
+
|
| 537 |
+
@abc.abstractmethod
|
| 538 |
+
def aggregation(self):
|
| 539 |
+
"""
|
| 540 |
+
:returns: {str: [metric_score] -> float}
|
| 541 |
+
A dictionary where keys are the names of submetrics and values are
|
| 542 |
+
functions that aggregate a list of metric scores
|
| 543 |
+
"""
|
| 544 |
+
pass
|
| 545 |
+
|
| 546 |
+
@abc.abstractmethod
|
| 547 |
+
def higher_is_better(self):
|
| 548 |
+
"""
|
| 549 |
+
:returns: {str: bool}
|
| 550 |
+
A dictionary where keys are the names of submetrics and values are
|
| 551 |
+
whether a higher value of the submetric is better
|
| 552 |
+
"""
|
| 553 |
+
pass
|
| 554 |
+
|
| 555 |
+
def get_config(self, key: str) -> Any:
|
| 556 |
+
return getattr(self._config, key, None)
|
| 557 |
+
|
| 558 |
+
@classmethod
|
| 559 |
+
def count_bytes(cls, doc):
|
| 560 |
+
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
|
| 561 |
+
return len(doc.encode("utf-8"))
|
| 562 |
+
|
| 563 |
+
@classmethod
|
| 564 |
+
def count_words(cls, doc):
|
| 565 |
+
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
|
| 566 |
+
return len(re.split(r"\s+", doc))
|
| 567 |
+
|
| 568 |
+
@utils.positional_deprecated
|
| 569 |
+
def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs):
|
| 570 |
+
"""Returns a fewshot context string that is made up of a prepended description
|
| 571 |
+
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
|
| 572 |
+
|
| 573 |
+
:param doc: str
|
| 574 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 575 |
+
:param num_fewshot: int
|
| 576 |
+
The number of fewshot examples to provide in the returned context string.
|
| 577 |
+
:param rnd: random.Random
|
| 578 |
+
The pseudo-random number generator used to randomly sample examples.
|
| 579 |
+
WARNING: This is currently a required arg although it's optionalized with a default `None`.
|
| 580 |
+
:param description: str
|
| 581 |
+
The task's description that will be prepended to the fewshot examples.
|
| 582 |
+
:returns: str
|
| 583 |
+
The fewshot context.
|
| 584 |
+
"""
|
| 585 |
+
if rnd is None:
|
| 586 |
+
if self.fewshot_rnd is not None:
|
| 587 |
+
rnd = self.fewshot_rnd
|
| 588 |
+
else:
|
| 589 |
+
raise ValueError(
|
| 590 |
+
"A `random.Random` generator argument must be provided to `rnd`"
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
description = description if description else ""
|
| 594 |
+
|
| 595 |
+
if num_fewshot == 0:
|
| 596 |
+
labeled_examples = ""
|
| 597 |
+
else:
|
| 598 |
+
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
|
| 599 |
+
if self.has_training_docs():
|
| 600 |
+
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
|
| 601 |
+
else:
|
| 602 |
+
if self._fewshot_docs is None:
|
| 603 |
+
self._fewshot_docs = list(
|
| 604 |
+
self.validation_docs()
|
| 605 |
+
if self.has_validation_docs()
|
| 606 |
+
else self.test_docs()
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
|
| 610 |
+
|
| 611 |
+
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
| 612 |
+
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
|
| 613 |
+
|
| 614 |
+
labeled_examples = (
|
| 615 |
+
"\n\n".join(
|
| 616 |
+
[
|
| 617 |
+
self.doc_to_text(doc) + self.doc_to_target(doc)
|
| 618 |
+
for doc in fewshotex
|
| 619 |
+
]
|
| 620 |
+
)
|
| 621 |
+
+ "\n\n"
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
example = self.doc_to_text(doc)
|
| 625 |
+
return description + labeled_examples + example
|
| 626 |
+
|
| 627 |
+
def apply_filters(self) -> Optional[List[Instance]]:
|
| 628 |
+
"""Iterates over FilterEnsembles and applies them to instances"""
|
| 629 |
+
if hasattr(self, "_filters"):
|
| 630 |
+
for f in self._filters:
|
| 631 |
+
f.apply(self._instances)
|
| 632 |
+
else:
|
| 633 |
+
eval_logger.warning("No filter defined, passing through instances")
|
| 634 |
+
return self._instances
|
| 635 |
+
|
| 636 |
+
def dump_config(self) -> dict:
|
| 637 |
+
"""Returns the config as a dictionary."""
|
| 638 |
+
# TODO: this should only return the overrides applied to a non-YAML task's configuration.
|
| 639 |
+
# (num_fewshot)
|
| 640 |
+
return self.config.to_dict()
|
| 641 |
+
|
| 642 |
+
def set_config(self, key: str, value: Any, update: bool = False) -> None:
|
| 643 |
+
"""Set or update the configuration for a given key."""
|
| 644 |
+
if key is None:
|
| 645 |
+
raise ValueError("Key must be provided.")
|
| 646 |
+
|
| 647 |
+
if update:
|
| 648 |
+
current_value = getattr(self._config, key, {})
|
| 649 |
+
if not isinstance(current_value, dict):
|
| 650 |
+
raise TypeError(
|
| 651 |
+
f"Expected a dict for key '{key}', got {type(current_value).__name__} instead."
|
| 652 |
+
)
|
| 653 |
+
current_value.update(value)
|
| 654 |
+
else:
|
| 655 |
+
setattr(self._config, key, value)
|
| 656 |
+
|
| 657 |
+
def override_metric(self, metric_name: str) -> None:
|
| 658 |
+
"""
|
| 659 |
+
Override the default metrics used for evaluation with custom metrics.
|
| 660 |
+
|
| 661 |
+
Parameters:
|
| 662 |
+
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
|
| 663 |
+
"""
|
| 664 |
+
(
|
| 665 |
+
self._metric_fn_list,
|
| 666 |
+
self._aggregation_list,
|
| 667 |
+
self._metric_fn_kwargs,
|
| 668 |
+
self._higher_is_better,
|
| 669 |
+
) = ({}, {}, {}, {})
|
| 670 |
+
self._metric_fn_list[metric_name] = get_metric(metric_name)
|
| 671 |
+
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
|
| 672 |
+
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
| 673 |
+
self._metric_fn_kwargs[metric_name] = {}
|
| 674 |
+
if not isinstance(self, ConfigurableTask):
|
| 675 |
+
self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
|
| 676 |
+
self.aggregation = lambda: {
|
| 677 |
+
metric_name: get_metric_aggregation(metric_name)
|
| 678 |
+
}
|
| 679 |
+
setattr(self._config, "metric_list", [{"metric": metric_name}])
|
| 680 |
+
setattr(self._config, "process_results", None)
|
| 681 |
+
|
| 682 |
+
def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
|
| 683 |
+
self.fewshot_rnd = random.Random(seed)
|
| 684 |
+
if hasattr(self, "sampler"):
|
| 685 |
+
self.sampler.rnd = self.fewshot_rnd
|
| 686 |
+
|
| 687 |
+
@property
|
| 688 |
+
def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
|
| 689 |
+
if self.has_test_docs():
|
| 690 |
+
return self.test_docs()
|
| 691 |
+
elif self.has_validation_docs():
|
| 692 |
+
return self.validation_docs()
|
| 693 |
+
else:
|
| 694 |
+
raise ValueError(
|
| 695 |
+
f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
def doc_iterator(
|
| 699 |
+
self,
|
| 700 |
+
*,
|
| 701 |
+
rank: int = 0,
|
| 702 |
+
limit: Union[int, None] = None,
|
| 703 |
+
world_size: int = 1,
|
| 704 |
+
samples: Optional[List[int]] = None,
|
| 705 |
+
) -> Iterator[Tuple[int, Any]]:
|
| 706 |
+
if samples:
|
| 707 |
+
n = len(self.eval_docs)
|
| 708 |
+
assert all([e < n for e in samples]), (
|
| 709 |
+
f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}."
|
| 710 |
+
)
|
| 711 |
+
eval_logger.info(
|
| 712 |
+
f"{self.config.task}: Evaluating on {len(samples)} examples"
|
| 713 |
+
)
|
| 714 |
+
doc_iterator = utils.create_iterator(
|
| 715 |
+
enumerate(x for i, x in enumerate(self.eval_docs) if i in samples),
|
| 716 |
+
rank=int(rank),
|
| 717 |
+
limit=None, # limit does not matter here since we are selecting samples directly
|
| 718 |
+
world_size=int(world_size),
|
| 719 |
+
)
|
| 720 |
+
else:
|
| 721 |
+
limit = int(limit) if limit else None
|
| 722 |
+
doc_iterator = utils.create_iterator(
|
| 723 |
+
enumerate(self.eval_docs),
|
| 724 |
+
rank=int(rank),
|
| 725 |
+
limit=limit,
|
| 726 |
+
world_size=int(world_size),
|
| 727 |
+
)
|
| 728 |
+
return doc_iterator
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
class ConfigurableTask(Task):
|
| 732 |
+
VERSION = "Yaml"
|
| 733 |
+
OUTPUT_TYPE = None
|
| 734 |
+
CONFIG = None
|
| 735 |
+
|
| 736 |
+
def __init__(
|
| 737 |
+
self,
|
| 738 |
+
data_dir=None,
|
| 739 |
+
cache_dir=None,
|
| 740 |
+
download_mode=None,
|
| 741 |
+
config: Optional[dict] = None,
|
| 742 |
+
) -> None: # TODO no super() call here
|
| 743 |
+
# Get pre-configured attributes
|
| 744 |
+
self._config = self.CONFIG
|
| 745 |
+
|
| 746 |
+
# Use new configurations if there was no preconfiguration
|
| 747 |
+
if self.config is None:
|
| 748 |
+
self._config = TaskConfig(**config)
|
| 749 |
+
# Overwrite configs
|
| 750 |
+
else:
|
| 751 |
+
if config is not None:
|
| 752 |
+
self._config.__dict__.update(config)
|
| 753 |
+
|
| 754 |
+
if self.config is None:
|
| 755 |
+
raise ValueError(
|
| 756 |
+
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
if isinstance(self.config.metadata, dict):
|
| 760 |
+
if "version" in self.config.metadata:
|
| 761 |
+
self.VERSION = self.config.metadata["version"]
|
| 762 |
+
|
| 763 |
+
if self.config.output_type is not None:
|
| 764 |
+
if self.config.output_type not in ALL_OUTPUT_TYPES:
|
| 765 |
+
raise ValueError(
|
| 766 |
+
f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'"
|
| 767 |
+
)
|
| 768 |
+
self.OUTPUT_TYPE = self.config.output_type
|
| 769 |
+
|
| 770 |
+
if self.config.doc_to_image is not None:
|
| 771 |
+
# mark the task as requiring multimodality.
|
| 772 |
+
self.MULTIMODAL = True
|
| 773 |
+
|
| 774 |
+
if self.config.doc_to_audio:
|
| 775 |
+
# mark the task as requiring multimodality.
|
| 776 |
+
self.MULTIMODAL = True
|
| 777 |
+
|
| 778 |
+
if self.config.unsafe_code is not False:
|
| 779 |
+
self.UNSAFE_CODE = True
|
| 780 |
+
|
| 781 |
+
if self.config.dataset_path is not None:
|
| 782 |
+
self.DATASET_PATH = self.config.dataset_path
|
| 783 |
+
|
| 784 |
+
if self.config.dataset_name is not None:
|
| 785 |
+
self.DATASET_NAME = self.config.dataset_name
|
| 786 |
+
|
| 787 |
+
self._metric_fn_list = {}
|
| 788 |
+
self._metric_fn_kwargs = {}
|
| 789 |
+
self._aggregation_list = {}
|
| 790 |
+
self._higher_is_better = {}
|
| 791 |
+
|
| 792 |
+
if self.config.metric_list is None:
|
| 793 |
+
# TODO: handle this in TaskConfig.__post_init__ ?
|
| 794 |
+
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
|
| 795 |
+
|
| 796 |
+
for metric_name in _metric_list:
|
| 797 |
+
self._metric_fn_list[metric_name] = get_metric(metric_name)
|
| 798 |
+
self._metric_fn_kwargs[metric_name] = {}
|
| 799 |
+
self._aggregation_list[metric_name] = get_metric_aggregation(
|
| 800 |
+
metric_name
|
| 801 |
+
)
|
| 802 |
+
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
| 803 |
+
else:
|
| 804 |
+
for metric_config in self.config.metric_list:
|
| 805 |
+
if "metric" not in metric_config:
|
| 806 |
+
raise ValueError(
|
| 807 |
+
"'metric' key not provided for an entry in 'metric_list', must be specified!"
|
| 808 |
+
)
|
| 809 |
+
metric_name = metric_config["metric"]
|
| 810 |
+
kwargs = {
|
| 811 |
+
key: metric_config[key]
|
| 812 |
+
for key in metric_config
|
| 813 |
+
if key
|
| 814 |
+
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
|
| 815 |
+
}
|
| 816 |
+
hf_evaluate_metric = (
|
| 817 |
+
"hf_evaluate" in metric_config
|
| 818 |
+
and metric_config["hf_evaluate"] is True
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
if self.config.process_results is not None:
|
| 822 |
+
self._metric_fn_list[metric_name] = None
|
| 823 |
+
self._metric_fn_kwargs[metric_name] = {}
|
| 824 |
+
elif callable(metric_name):
|
| 825 |
+
metric_fn = metric_name.__call__
|
| 826 |
+
metric_name = metric_name.__name__
|
| 827 |
+
self._metric_fn_list[metric_name] = metric_fn
|
| 828 |
+
self._metric_fn_kwargs[metric_name] = kwargs
|
| 829 |
+
else:
|
| 830 |
+
self._metric_fn_list[metric_name] = get_metric(
|
| 831 |
+
metric_name, hf_evaluate_metric
|
| 832 |
+
)
|
| 833 |
+
self._metric_fn_kwargs[metric_name] = kwargs
|
| 834 |
+
|
| 835 |
+
if "aggregation" in metric_config:
|
| 836 |
+
agg_name = metric_config["aggregation"]
|
| 837 |
+
if isinstance(agg_name, str):
|
| 838 |
+
self._aggregation_list[metric_name] = get_aggregation(agg_name)
|
| 839 |
+
elif callable(agg_name): # noqa: E721
|
| 840 |
+
self._aggregation_list[metric_name] = metric_config[
|
| 841 |
+
"aggregation"
|
| 842 |
+
]
|
| 843 |
+
else:
|
| 844 |
+
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
|
| 845 |
+
metric_agg = get_metric_aggregation(metric_name)
|
| 846 |
+
eval_logger.warning(
|
| 847 |
+
f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
|
| 848 |
+
f"using default "
|
| 849 |
+
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
|
| 850 |
+
)
|
| 851 |
+
self._aggregation_list[metric_name] = metric_agg
|
| 852 |
+
|
| 853 |
+
if "higher_is_better" in metric_config:
|
| 854 |
+
self._higher_is_better[metric_name] = metric_config[
|
| 855 |
+
"higher_is_better"
|
| 856 |
+
]
|
| 857 |
+
else:
|
| 858 |
+
eval_logger.warning(
|
| 859 |
+
f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
|
| 860 |
+
f"using default "
|
| 861 |
+
f"higher_is_better={is_higher_better(metric_name)}"
|
| 862 |
+
)
|
| 863 |
+
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
| 864 |
+
|
| 865 |
+
self.download(self.config.dataset_kwargs)
|
| 866 |
+
self._training_docs = None
|
| 867 |
+
self._fewshot_docs = None
|
| 868 |
+
|
| 869 |
+
if self.config.filter_list is not None:
|
| 870 |
+
self._filters = []
|
| 871 |
+
for filter_config in self.config.filter_list:
|
| 872 |
+
filter_name = filter_config["name"]
|
| 873 |
+
filter_functions = filter_config["filter"]
|
| 874 |
+
components = []
|
| 875 |
+
for function in filter_functions:
|
| 876 |
+
kwargs = {
|
| 877 |
+
key: function[key] for key in function if key != "function"
|
| 878 |
+
}
|
| 879 |
+
components.append([function["function"], kwargs])
|
| 880 |
+
filter_pipeline = build_filter_ensemble(filter_name, components)
|
| 881 |
+
self._filters.append(filter_pipeline)
|
| 882 |
+
else:
|
| 883 |
+
# TODO: handle repeats in a more general way rather than just discarding
|
| 884 |
+
eval_logger.debug(
|
| 885 |
+
"No custom filters defined. Using default 'take_first' filter for handling repeats."
|
| 886 |
+
)
|
| 887 |
+
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
|
| 888 |
+
|
| 889 |
+
if self.config.use_prompt is not None:
|
| 890 |
+
eval_logger.info(f"loading prompt {self.config.use_prompt}")
|
| 891 |
+
self.prompt = get_prompt(
|
| 892 |
+
self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
|
| 893 |
+
)
|
| 894 |
+
else:
|
| 895 |
+
self.prompt = None
|
| 896 |
+
|
| 897 |
+
if self.fewshot_docs() is not None:
|
| 898 |
+
self.fewshot_rnd = (
|
| 899 |
+
random.Random()
|
| 900 |
+
) # setting with no seed, to be overridden at a later time
|
| 901 |
+
config_sampler: Union[str, Callable] = (
|
| 902 |
+
self.config.fewshot_config.get("sampler", "default")
|
| 903 |
+
if self.config.fewshot_config
|
| 904 |
+
else "default"
|
| 905 |
+
)
|
| 906 |
+
if isinstance(config_sampler, str):
|
| 907 |
+
self.sampler = samplers.get_sampler(config_sampler)(
|
| 908 |
+
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
|
| 909 |
+
)
|
| 910 |
+
elif callable(config_sampler) and issubclass(
|
| 911 |
+
config_sampler, samplers.ContextSampler
|
| 912 |
+
):
|
| 913 |
+
self.sampler = config_sampler(
|
| 914 |
+
docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd
|
| 915 |
+
)
|
| 916 |
+
else:
|
| 917 |
+
raise TypeError(
|
| 918 |
+
f"fewshot_config.sampler should be a string or callable of ContextSampler type, "
|
| 919 |
+
f"not {type(config_sampler)}"
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
self.task_docs = self.eval_docs
|
| 923 |
+
|
| 924 |
+
# Test One Doc
|
| 925 |
+
self.features = list(self.task_docs.features.keys())
|
| 926 |
+
self.multiple_input = 0
|
| 927 |
+
self.multiple_target = 0
|
| 928 |
+
test_doc = self.task_docs[0]
|
| 929 |
+
test_text = self.doc_to_text(test_doc)
|
| 930 |
+
test_target = self.doc_to_target(test_doc)
|
| 931 |
+
|
| 932 |
+
if self.config.doc_to_choice is not None:
|
| 933 |
+
test_choice = self.doc_to_choice(test_doc)
|
| 934 |
+
if not isinstance(test_choice, list):
|
| 935 |
+
eval_logger.error("doc_to_choice must return list")
|
| 936 |
+
else:
|
| 937 |
+
num_choice = len(test_choice)
|
| 938 |
+
|
| 939 |
+
if isinstance(test_text, int):
|
| 940 |
+
eval_logger.debug(
|
| 941 |
+
"doc_to_text returned an int. Assuming multiple inputs."
|
| 942 |
+
)
|
| 943 |
+
self.multiple_input = num_choice
|
| 944 |
+
else:
|
| 945 |
+
test_choice = None
|
| 946 |
+
|
| 947 |
+
if isinstance(test_target, list):
|
| 948 |
+
eval_logger.debug(
|
| 949 |
+
"doc_to_target returned a list. Assuming multiple targets."
|
| 950 |
+
)
|
| 951 |
+
self.multiple_target = len(test_target)
|
| 952 |
+
else:
|
| 953 |
+
if (isinstance(test_target, int)) and (test_choice is not None):
|
| 954 |
+
test_target = test_choice[test_target]
|
| 955 |
+
else:
|
| 956 |
+
test_target = str(test_target)
|
| 957 |
+
|
| 958 |
+
if test_choice is not None:
|
| 959 |
+
check_choices = test_choice
|
| 960 |
+
else:
|
| 961 |
+
check_choices = [test_target]
|
| 962 |
+
if self.config.doc_to_choice is not None:
|
| 963 |
+
for choice in check_choices:
|
| 964 |
+
choice_has_whitespace = True if choice[0].isspace() else False
|
| 965 |
+
delimiter_has_whitespace = (
|
| 966 |
+
True
|
| 967 |
+
if self.config.target_delimiter.rstrip()
|
| 968 |
+
!= self.config.target_delimiter
|
| 969 |
+
else False
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
if delimiter_has_whitespace and choice_has_whitespace:
|
| 973 |
+
eval_logger.debug(
|
| 974 |
+
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
|
| 975 |
+
)
|
| 976 |
+
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
|
| 977 |
+
eval_logger.debug(
|
| 978 |
+
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
def download(
|
| 982 |
+
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
|
| 983 |
+
) -> None:
|
| 984 |
+
from packaging.version import parse as vparse
|
| 985 |
+
|
| 986 |
+
if dataset_kwargs and vparse(datasets.__version__) >= vparse("4.0.0"):
|
| 987 |
+
dataset_kwargs.pop("trust_remote_code", None)
|
| 988 |
+
if isinstance(self.config.custom_dataset, Callable):
|
| 989 |
+
eval_logger.warning(
|
| 990 |
+
f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
|
| 991 |
+
+ "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme."
|
| 992 |
+
)
|
| 993 |
+
self.dataset = self.config.custom_dataset(
|
| 994 |
+
**(self.config.metadata or {}), **(self.config.dataset_kwargs or {})
|
| 995 |
+
)
|
| 996 |
+
else:
|
| 997 |
+
self.dataset = datasets.load_dataset(
|
| 998 |
+
path=self.DATASET_PATH,
|
| 999 |
+
name=self.DATASET_NAME,
|
| 1000 |
+
**dataset_kwargs if dataset_kwargs is not None else {},
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
def has_training_docs(self) -> bool:
|
| 1004 |
+
if self.config.training_split is not None:
|
| 1005 |
+
return True
|
| 1006 |
+
else:
|
| 1007 |
+
return False
|
| 1008 |
+
|
| 1009 |
+
def has_validation_docs(self) -> bool:
|
| 1010 |
+
if self.config.validation_split is not None:
|
| 1011 |
+
return True
|
| 1012 |
+
else:
|
| 1013 |
+
return False
|
| 1014 |
+
|
| 1015 |
+
def has_test_docs(self) -> bool:
|
| 1016 |
+
if self.config.test_split is not None:
|
| 1017 |
+
return True
|
| 1018 |
+
else:
|
| 1019 |
+
return False
|
| 1020 |
+
|
| 1021 |
+
def training_docs(self) -> datasets.Dataset:
|
| 1022 |
+
if self.has_training_docs():
|
| 1023 |
+
if self.config.process_docs is not None:
|
| 1024 |
+
return self.config.process_docs(
|
| 1025 |
+
self.dataset[self.config.training_split]
|
| 1026 |
+
)
|
| 1027 |
+
return self.dataset[self.config.training_split]
|
| 1028 |
+
|
| 1029 |
+
def validation_docs(self) -> datasets.Dataset:
|
| 1030 |
+
if self.has_validation_docs():
|
| 1031 |
+
if self.config.process_docs is not None:
|
| 1032 |
+
return self.config.process_docs(
|
| 1033 |
+
self.dataset[self.config.validation_split]
|
| 1034 |
+
)
|
| 1035 |
+
return self.dataset[self.config.validation_split]
|
| 1036 |
+
|
| 1037 |
+
def test_docs(self) -> datasets.Dataset:
|
| 1038 |
+
if self.has_test_docs():
|
| 1039 |
+
if self.config.process_docs is not None:
|
| 1040 |
+
return self.config.process_docs(self.dataset[self.config.test_split])
|
| 1041 |
+
return self.dataset[self.config.test_split]
|
| 1042 |
+
|
| 1043 |
+
def fewshot_docs(self):
|
| 1044 |
+
if self.config.fewshot_split is not None:
|
| 1045 |
+
if self.config.process_docs is not None:
|
| 1046 |
+
return self.config.process_docs(self.dataset[self.config.fewshot_split])
|
| 1047 |
+
return self.dataset[self.config.fewshot_split]
|
| 1048 |
+
elif (
|
| 1049 |
+
self.config.fewshot_config is not None
|
| 1050 |
+
and self.config.fewshot_config.get("samples", None) is not None
|
| 1051 |
+
):
|
| 1052 |
+
if isinstance(self.config.fewshot_config["samples"], list):
|
| 1053 |
+
return self.config.fewshot_config["samples"]
|
| 1054 |
+
elif callable(self.config.fewshot_config["samples"]):
|
| 1055 |
+
return self.config.fewshot_config["samples"]()
|
| 1056 |
+
else:
|
| 1057 |
+
raise Exception(
|
| 1058 |
+
"`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list."
|
| 1059 |
+
)
|
| 1060 |
+
else:
|
| 1061 |
+
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
|
| 1062 |
+
eval_logger.warning(
|
| 1063 |
+
f"[Task: {self.config.task}] "
|
| 1064 |
+
"num_fewshot > 0 but fewshot_split is None. "
|
| 1065 |
+
"using preconfigured rule."
|
| 1066 |
+
)
|
| 1067 |
+
return super().fewshot_docs()
|
| 1068 |
+
|
| 1069 |
+
@staticmethod
|
| 1070 |
+
def append_target_question(
|
| 1071 |
+
labeled_examples: List[Dict[str, str]],
|
| 1072 |
+
question: str,
|
| 1073 |
+
fewshot_as_multiturn: bool = False,
|
| 1074 |
+
gen_prefix: Optional[str] = None,
|
| 1075 |
+
) -> None:
|
| 1076 |
+
"""Adds a target question to the labeled examples list.
|
| 1077 |
+
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
|
| 1078 |
+
Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant.
|
| 1079 |
+
"""
|
| 1080 |
+
if not fewshot_as_multiturn:
|
| 1081 |
+
# if no messages or last message is system, append as new user entry
|
| 1082 |
+
if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system":
|
| 1083 |
+
labeled_examples.append({"role": "user", "content": question})
|
| 1084 |
+
# if last message is user, append to it to avoid two user messages in a row
|
| 1085 |
+
else:
|
| 1086 |
+
labeled_examples[-1]["content"] += question
|
| 1087 |
+
else:
|
| 1088 |
+
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
|
| 1089 |
+
labeled_examples.append({"role": "user", "content": question})
|
| 1090 |
+
if gen_prefix:
|
| 1091 |
+
labeled_examples.append({"role": "assistant", "content": gen_prefix})
|
| 1092 |
+
|
| 1093 |
+
@utils.positional_deprecated
|
| 1094 |
+
def fewshot_context(
|
| 1095 |
+
self,
|
| 1096 |
+
doc: dict,
|
| 1097 |
+
num_fewshot: int,
|
| 1098 |
+
system_instruction: Optional[str] = None,
|
| 1099 |
+
apply_chat_template: bool = False,
|
| 1100 |
+
fewshot_as_multiturn: bool = False,
|
| 1101 |
+
chat_template: Optional[Callable] = None,
|
| 1102 |
+
gen_prefix: Optional[str] = None,
|
| 1103 |
+
) -> Union[str, List[str]]:
|
| 1104 |
+
"""Returns a fewshot context string that is made up of a prepended description
|
| 1105 |
+
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
|
| 1106 |
+
|
| 1107 |
+
:param doc: str
|
| 1108 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 1109 |
+
:param num_fewshot: int
|
| 1110 |
+
The number of fewshot examples to provide in the returned context string.
|
| 1111 |
+
:param system_instruction: str
|
| 1112 |
+
System instruction to be applied to the prompt.
|
| 1113 |
+
:param apply_chat_template: bool
|
| 1114 |
+
Whether to apply the chat template to the fewshot context.
|
| 1115 |
+
:param fewshot_as_multiturn: bool
|
| 1116 |
+
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
|
| 1117 |
+
:param chat_template:
|
| 1118 |
+
callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
|
| 1119 |
+
:param gen_prefix:
|
| 1120 |
+
String to append after the <|assistant|> token.
|
| 1121 |
+
:returns: str
|
| 1122 |
+
The fewshot context.
|
| 1123 |
+
"""
|
| 1124 |
+
if apply_chat_template:
|
| 1125 |
+
labeled_examples = []
|
| 1126 |
+
else:
|
| 1127 |
+
labeled_examples = ""
|
| 1128 |
+
|
| 1129 |
+
# get task description
|
| 1130 |
+
if description := self.config.description:
|
| 1131 |
+
description = utils.apply_template(self.config.description, doc)
|
| 1132 |
+
|
| 1133 |
+
# create system prompt based on the provided system instruction and description
|
| 1134 |
+
if system_instruction is not None and description:
|
| 1135 |
+
system_prompt = (
|
| 1136 |
+
f"{system_instruction}{self.sampler.fewshot_delimiter}{description}"
|
| 1137 |
+
)
|
| 1138 |
+
elif system_instruction is not None:
|
| 1139 |
+
system_prompt = system_instruction
|
| 1140 |
+
elif description:
|
| 1141 |
+
system_prompt = description
|
| 1142 |
+
else:
|
| 1143 |
+
system_prompt = ""
|
| 1144 |
+
|
| 1145 |
+
# add system prompt if specified
|
| 1146 |
+
if system_prompt:
|
| 1147 |
+
if apply_chat_template:
|
| 1148 |
+
labeled_examples.append({"role": "system", "content": system_prompt})
|
| 1149 |
+
else:
|
| 1150 |
+
labeled_examples = system_prompt
|
| 1151 |
+
# if few-shot - append examples after the system prompt
|
| 1152 |
+
if num_fewshot > 0:
|
| 1153 |
+
if apply_chat_template:
|
| 1154 |
+
labeled_examples.extend(
|
| 1155 |
+
self.sampler.get_chat_context(
|
| 1156 |
+
doc,
|
| 1157 |
+
num_fewshot,
|
| 1158 |
+
fewshot_as_multiturn,
|
| 1159 |
+
gen_prefix=gen_prefix,
|
| 1160 |
+
)
|
| 1161 |
+
)
|
| 1162 |
+
else:
|
| 1163 |
+
labeled_examples += self.sampler.get_context(
|
| 1164 |
+
doc, num_fewshot, gen_prefix=gen_prefix
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
example = self.doc_to_text(doc)
|
| 1168 |
+
if apply_chat_template:
|
| 1169 |
+
if self.multiple_input:
|
| 1170 |
+
# TODO: append prefill?
|
| 1171 |
+
if not labeled_examples:
|
| 1172 |
+
return ""
|
| 1173 |
+
return chat_template(labeled_examples)
|
| 1174 |
+
if isinstance(example, str):
|
| 1175 |
+
self.append_target_question(
|
| 1176 |
+
labeled_examples,
|
| 1177 |
+
example,
|
| 1178 |
+
fewshot_as_multiturn,
|
| 1179 |
+
gen_prefix=gen_prefix,
|
| 1180 |
+
)
|
| 1181 |
+
# for loglikelihood create a list of questions with appended choices
|
| 1182 |
+
elif isinstance(example, list):
|
| 1183 |
+
labeled_examples_list = []
|
| 1184 |
+
# copy chat history for each example and append the answer
|
| 1185 |
+
for ex in example:
|
| 1186 |
+
chat = deepcopy(labeled_examples)
|
| 1187 |
+
self.append_target_question(
|
| 1188 |
+
chat,
|
| 1189 |
+
ex,
|
| 1190 |
+
fewshot_as_multiturn,
|
| 1191 |
+
gen_prefix=gen_prefix,
|
| 1192 |
+
)
|
| 1193 |
+
# TODO: append prefill?
|
| 1194 |
+
labeled_examples_list.append(
|
| 1195 |
+
chat_template(
|
| 1196 |
+
chat,
|
| 1197 |
+
add_generation_prompt=False if gen_prefix else True,
|
| 1198 |
+
)
|
| 1199 |
+
)
|
| 1200 |
+
return labeled_examples_list
|
| 1201 |
+
# if example is an integer, append the choice or convert to string
|
| 1202 |
+
elif isinstance(example, int):
|
| 1203 |
+
if self.config.doc_to_choice is not None:
|
| 1204 |
+
choices = self.doc_to_choice(doc)
|
| 1205 |
+
self.append_target_question(
|
| 1206 |
+
labeled_examples,
|
| 1207 |
+
choices[example],
|
| 1208 |
+
fewshot_as_multiturn,
|
| 1209 |
+
gen_prefix=gen_prefix,
|
| 1210 |
+
)
|
| 1211 |
+
else:
|
| 1212 |
+
self.append_target_question(
|
| 1213 |
+
labeled_examples,
|
| 1214 |
+
str(example),
|
| 1215 |
+
fewshot_as_multiturn,
|
| 1216 |
+
gen_prefix=gen_prefix,
|
| 1217 |
+
)
|
| 1218 |
+
# return lm.apply_chat_template(labeled_examples)
|
| 1219 |
+
return chat_template(
|
| 1220 |
+
labeled_examples,
|
| 1221 |
+
add_generation_prompt=False if gen_prefix else True,
|
| 1222 |
+
)
|
| 1223 |
+
else:
|
| 1224 |
+
prefix = (
|
| 1225 |
+
self.config.target_delimiter + gen_prefix
|
| 1226 |
+
if gen_prefix is not None
|
| 1227 |
+
else ""
|
| 1228 |
+
)
|
| 1229 |
+
if self.multiple_input:
|
| 1230 |
+
return labeled_examples
|
| 1231 |
+
if isinstance(example, str):
|
| 1232 |
+
return labeled_examples + example + prefix
|
| 1233 |
+
elif isinstance(example, list):
|
| 1234 |
+
return [labeled_examples + ex + prefix for ex in example]
|
| 1235 |
+
elif isinstance(example, int):
|
| 1236 |
+
if self.config.doc_to_choice is not None:
|
| 1237 |
+
choices = self.doc_to_choice(doc)
|
| 1238 |
+
return labeled_examples + choices[example] + prefix
|
| 1239 |
+
else:
|
| 1240 |
+
return labeled_examples + str(example) + prefix
|
| 1241 |
+
|
| 1242 |
+
def apply_filters(self) -> Optional[List[Instance]]:
|
| 1243 |
+
"""Iterates over FilterEnsembles and applies them to instances"""
|
| 1244 |
+
if hasattr(self, "_filters"):
|
| 1245 |
+
for f in self._filters:
|
| 1246 |
+
f.apply(self._instances)
|
| 1247 |
+
else:
|
| 1248 |
+
eval_logger.warning("No filter defined, passing through instances")
|
| 1249 |
+
return self._instances
|
| 1250 |
+
|
| 1251 |
+
def should_decontaminate(self):
|
| 1252 |
+
return self.config.should_decontaminate
|
| 1253 |
+
|
| 1254 |
+
def doc_to_decontamination_query(self, doc: dict):
|
| 1255 |
+
if self.config.should_decontaminate:
|
| 1256 |
+
if self.config.doc_to_decontamination_query is None:
|
| 1257 |
+
return self.doc_to_text(doc)
|
| 1258 |
+
else:
|
| 1259 |
+
doc_to_decontamination_query = self.config.doc_to_decontamination_query
|
| 1260 |
+
if doc_to_decontamination_query in self.features:
|
| 1261 |
+
return doc[doc_to_decontamination_query]
|
| 1262 |
+
elif callable(doc_to_decontamination_query):
|
| 1263 |
+
return doc_to_decontamination_query(doc)
|
| 1264 |
+
else:
|
| 1265 |
+
return ast.literal_eval(
|
| 1266 |
+
utils.apply_template(
|
| 1267 |
+
self.config.doc_to_decontamination_query, doc
|
| 1268 |
+
)
|
| 1269 |
+
)
|
| 1270 |
+
|
| 1271 |
+
def _process_doc(self, doc: dict) -> dict:
|
| 1272 |
+
"""
|
| 1273 |
+
Override this to process (detokenize, strip, replace, etc.) individual
|
| 1274 |
+
documents. This can be used in a map over documents of a data split.
|
| 1275 |
+
E.g. `map(self._process_doc, self.dataset["validation"])`
|
| 1276 |
+
|
| 1277 |
+
:return: dict
|
| 1278 |
+
The processed version of the specified `doc`.
|
| 1279 |
+
"""
|
| 1280 |
+
return doc
|
| 1281 |
+
|
| 1282 |
+
def doc_to_text(self, doc, doc_to_text=None):
|
| 1283 |
+
if self.prompt is not None:
|
| 1284 |
+
doc_to_text = self.prompt
|
| 1285 |
+
elif doc_to_text is not None:
|
| 1286 |
+
doc_to_text = doc_to_text
|
| 1287 |
+
else:
|
| 1288 |
+
doc_to_text = self.config.doc_to_text
|
| 1289 |
+
|
| 1290 |
+
if isinstance(doc_to_text, int):
|
| 1291 |
+
return doc_to_text
|
| 1292 |
+
elif isinstance(doc_to_text, str):
|
| 1293 |
+
if doc_to_text in self.features:
|
| 1294 |
+
# if self.config.doc_to_choice is not None:
|
| 1295 |
+
# return self.doc_to_choice(doc)[doc[doc_to_text]]
|
| 1296 |
+
# else:
|
| 1297 |
+
return doc[doc_to_text]
|
| 1298 |
+
else:
|
| 1299 |
+
text_string = utils.apply_template(doc_to_text, doc)
|
| 1300 |
+
if text_string.isdigit() and self._config.doc_to_choice is not None:
|
| 1301 |
+
return ast.literal_eval(text_string)
|
| 1302 |
+
else:
|
| 1303 |
+
return text_string
|
| 1304 |
+
elif callable(doc_to_text):
|
| 1305 |
+
return doc_to_text(doc)
|
| 1306 |
+
# Used when applying a Promptsource template
|
| 1307 |
+
elif hasattr(doc_to_text, "apply"):
|
| 1308 |
+
applied_prompt = doc_to_text.apply(doc)
|
| 1309 |
+
if len(applied_prompt) == 2:
|
| 1310 |
+
return applied_prompt[0]
|
| 1311 |
+
else:
|
| 1312 |
+
eval_logger.warning("Applied prompt returns empty string")
|
| 1313 |
+
return self.config.fewshot_delimiter
|
| 1314 |
+
else:
|
| 1315 |
+
print(type(doc_to_text))
|
| 1316 |
+
raise TypeError
|
| 1317 |
+
|
| 1318 |
+
def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
|
| 1319 |
+
if self.prompt is not None:
|
| 1320 |
+
doc_to_target = self.prompt
|
| 1321 |
+
elif doc_to_target is not None:
|
| 1322 |
+
doc_to_target = doc_to_target
|
| 1323 |
+
else:
|
| 1324 |
+
doc_to_target = self.config.doc_to_target
|
| 1325 |
+
|
| 1326 |
+
if isinstance(doc_to_target, int):
|
| 1327 |
+
return doc_to_target
|
| 1328 |
+
elif isinstance(doc_to_target, str):
|
| 1329 |
+
if doc_to_target in self.features:
|
| 1330 |
+
# if self.config.doc_to_choice is not None:
|
| 1331 |
+
# return self.doc_to_choice(doc)[doc[doc_to_target]]
|
| 1332 |
+
# else:
|
| 1333 |
+
return doc[doc_to_target]
|
| 1334 |
+
else:
|
| 1335 |
+
target_string = utils.apply_template(doc_to_target, doc)
|
| 1336 |
+
if target_string.isdigit() and self._config.doc_to_choice is not None:
|
| 1337 |
+
return ast.literal_eval(target_string)
|
| 1338 |
+
elif (
|
| 1339 |
+
len(target_string) >= 2
|
| 1340 |
+
and (target_string[0] == "[")
|
| 1341 |
+
and (target_string[-1] == "]")
|
| 1342 |
+
):
|
| 1343 |
+
try:
|
| 1344 |
+
return ast.literal_eval(target_string)
|
| 1345 |
+
except (SyntaxError, ValueError):
|
| 1346 |
+
return target_string
|
| 1347 |
+
else:
|
| 1348 |
+
return target_string
|
| 1349 |
+
elif isinstance(doc_to_target, list):
|
| 1350 |
+
return doc_to_target
|
| 1351 |
+
elif callable(doc_to_target):
|
| 1352 |
+
return doc_to_target(doc)
|
| 1353 |
+
# Used when applying a Promptsource template
|
| 1354 |
+
elif hasattr(doc_to_target, "apply"):
|
| 1355 |
+
applied_prompt = doc_to_target.apply(doc)
|
| 1356 |
+
if len(applied_prompt) == 2:
|
| 1357 |
+
return applied_prompt[1]
|
| 1358 |
+
else:
|
| 1359 |
+
eval_logger.warning("Applied prompt returns empty string")
|
| 1360 |
+
return self.config.fewshot_delimiter
|
| 1361 |
+
else:
|
| 1362 |
+
raise TypeError
|
| 1363 |
+
|
| 1364 |
+
def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
|
| 1365 |
+
if self.prompt is not None:
|
| 1366 |
+
doc_to_choice = self.prompt
|
| 1367 |
+
elif doc_to_choice is not None:
|
| 1368 |
+
doc_to_choice = doc_to_choice
|
| 1369 |
+
elif self.config.doc_to_choice is None:
|
| 1370 |
+
eval_logger.error("doc_to_choice was called but not set in config")
|
| 1371 |
+
else:
|
| 1372 |
+
doc_to_choice = self.config.doc_to_choice
|
| 1373 |
+
|
| 1374 |
+
if isinstance(doc_to_choice, str):
|
| 1375 |
+
if doc_to_choice in self.features:
|
| 1376 |
+
return doc[doc_to_choice]
|
| 1377 |
+
else:
|
| 1378 |
+
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
|
| 1379 |
+
elif isinstance(doc_to_choice, list):
|
| 1380 |
+
return doc_to_choice
|
| 1381 |
+
elif isinstance(doc_to_choice, dict):
|
| 1382 |
+
return list(doc_to_choice.values())
|
| 1383 |
+
elif callable(doc_to_choice):
|
| 1384 |
+
return doc_to_choice(doc)
|
| 1385 |
+
elif hasattr(doc_to_choice, "get_answer_choices_list"):
|
| 1386 |
+
return doc_to_choice.get_answer_choices_list(doc)
|
| 1387 |
+
else:
|
| 1388 |
+
raise TypeError
|
| 1389 |
+
|
| 1390 |
+
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
|
| 1391 |
+
if doc_to_image is not None:
|
| 1392 |
+
doc_to_image = doc_to_image
|
| 1393 |
+
elif self.config.doc_to_image is not None:
|
| 1394 |
+
doc_to_image = self.config.doc_to_image
|
| 1395 |
+
else:
|
| 1396 |
+
return None
|
| 1397 |
+
|
| 1398 |
+
if isinstance(doc_to_image, list):
|
| 1399 |
+
image_feature = [
|
| 1400 |
+
self.doc_to_image(doc, feature) for feature in doc_to_image
|
| 1401 |
+
]
|
| 1402 |
+
return [feature for feature in image_feature if feature is not None]
|
| 1403 |
+
elif isinstance(doc_to_image, str):
|
| 1404 |
+
if doc_to_image in self.features:
|
| 1405 |
+
return doc[doc_to_image]
|
| 1406 |
+
else:
|
| 1407 |
+
return ast.literal_eval(utils.apply_template(doc_to_image, doc))
|
| 1408 |
+
elif callable(doc_to_image):
|
| 1409 |
+
return doc_to_image(doc)
|
| 1410 |
+
else:
|
| 1411 |
+
return None
|
| 1412 |
+
|
| 1413 |
+
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]:
|
| 1414 |
+
if doc_to_audio is not None:
|
| 1415 |
+
doc_to_audio = doc_to_audio
|
| 1416 |
+
elif self.config.doc_to_audio is not None:
|
| 1417 |
+
doc_to_audio = self.config.doc_to_audio
|
| 1418 |
+
else:
|
| 1419 |
+
return None
|
| 1420 |
+
|
| 1421 |
+
if isinstance(doc_to_audio, list):
|
| 1422 |
+
audio_feature = [
|
| 1423 |
+
self.doc_to_audio(doc, feature) for feature in doc_to_audio
|
| 1424 |
+
]
|
| 1425 |
+
return [feature for feature in audio_feature if feature is not None]
|
| 1426 |
+
elif isinstance(doc_to_audio, str):
|
| 1427 |
+
if doc_to_audio in self.features:
|
| 1428 |
+
return doc[doc_to_audio]
|
| 1429 |
+
else:
|
| 1430 |
+
return ast.literal_eval(utils.apply_template(doc_to_audio, doc))
|
| 1431 |
+
elif callable(doc_to_audio):
|
| 1432 |
+
return doc_to_audio(doc)
|
| 1433 |
+
else:
|
| 1434 |
+
return None
|
| 1435 |
+
|
| 1436 |
+
def doc_to_prefix(self, doc):
|
| 1437 |
+
if (gen_prefix := self.config.gen_prefix) is not None:
|
| 1438 |
+
if gen_prefix in self.features:
|
| 1439 |
+
return doc[gen_prefix]
|
| 1440 |
+
else:
|
| 1441 |
+
return utils.apply_template(gen_prefix, doc)
|
| 1442 |
+
return None
|
| 1443 |
+
|
| 1444 |
+
def construct_requests(
|
| 1445 |
+
self, doc: dict, ctx: str, **kwargs
|
| 1446 |
+
) -> Union[List[Instance], Instance]:
|
| 1447 |
+
apply_chat_template = kwargs.pop("apply_chat_template", False)
|
| 1448 |
+
chat_template: Callable | None = kwargs.pop("chat_template", None)
|
| 1449 |
+
|
| 1450 |
+
aux_arguments = None
|
| 1451 |
+
|
| 1452 |
+
if self.OUTPUT_TYPE == "loglikelihood":
|
| 1453 |
+
arguments = (ctx, self.doc_to_target(doc))
|
| 1454 |
+
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
|
| 1455 |
+
arguments = (self.doc_to_target(doc),)
|
| 1456 |
+
elif self.OUTPUT_TYPE == "multiple_choice":
|
| 1457 |
+
choices = self.doc_to_choice(doc)
|
| 1458 |
+
target_delimiter = self.config.target_delimiter
|
| 1459 |
+
if apply_chat_template:
|
| 1460 |
+
target_delimiter = ""
|
| 1461 |
+
if self.multiple_input:
|
| 1462 |
+
# If there are multiple inputs, choices are placed in the ctx
|
| 1463 |
+
# apply chat_template to choices if apply_chat_template
|
| 1464 |
+
cont = self.doc_to_target(doc)
|
| 1465 |
+
|
| 1466 |
+
arguments = [
|
| 1467 |
+
(
|
| 1468 |
+
ctx
|
| 1469 |
+
+ (
|
| 1470 |
+
chat_template([{"role": "user", "content": choice}])
|
| 1471 |
+
if apply_chat_template
|
| 1472 |
+
else choice
|
| 1473 |
+
),
|
| 1474 |
+
f"{target_delimiter}{cont}",
|
| 1475 |
+
)
|
| 1476 |
+
for choice in choices
|
| 1477 |
+
]
|
| 1478 |
+
else:
|
| 1479 |
+
# Otherwise they are placed in the continuation
|
| 1480 |
+
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
|
| 1481 |
+
|
| 1482 |
+
# TODO: we should raise a warning telling users this will at most ~2x runtime.
|
| 1483 |
+
if "acc_mutual_info" in self._metric_fn_list.keys():
|
| 1484 |
+
# if we are calculating multiple choice accuracy
|
| 1485 |
+
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
|
| 1486 |
+
|
| 1487 |
+
# here mutual info refers to calculating
|
| 1488 |
+
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
|
| 1489 |
+
# in other words normalizing by subtracting the unconditional logprob of each choice.
|
| 1490 |
+
# TODO: should these be strided? will have to modify the processing in process_results if so
|
| 1491 |
+
aux_arguments = [
|
| 1492 |
+
("", f"{target_delimiter}{choice}") for choice in choices
|
| 1493 |
+
]
|
| 1494 |
+
|
| 1495 |
+
arguments.extend(aux_arguments)
|
| 1496 |
+
|
| 1497 |
+
elif self.OUTPUT_TYPE == "generate_until":
|
| 1498 |
+
arguments = (ctx, deepcopy(self.config.generation_kwargs))
|
| 1499 |
+
|
| 1500 |
+
multimodal_arg = {}
|
| 1501 |
+
if (
|
| 1502 |
+
self.config.doc_to_image
|
| 1503 |
+
): # TODO: ensure that non-multimodal tasks aren't getting visual args
|
| 1504 |
+
multimodal_arg = {
|
| 1505 |
+
**multimodal_arg,
|
| 1506 |
+
**{"visual": self.doc_to_image(doc)},
|
| 1507 |
+
}
|
| 1508 |
+
|
| 1509 |
+
if (
|
| 1510 |
+
self.config.doc_to_audio
|
| 1511 |
+
): # TODO: ensure that non-multimodal tasks aren't getting audio args
|
| 1512 |
+
multimodal_arg = {
|
| 1513 |
+
**multimodal_arg,
|
| 1514 |
+
**{"audio": self.doc_to_audio(doc)},
|
| 1515 |
+
}
|
| 1516 |
+
|
| 1517 |
+
if bool(multimodal_arg):
|
| 1518 |
+
if isinstance(arguments, list):
|
| 1519 |
+
arguments = [arg + (multimodal_arg,) for arg in arguments]
|
| 1520 |
+
else:
|
| 1521 |
+
arguments = arguments + (multimodal_arg,)
|
| 1522 |
+
|
| 1523 |
+
if self.OUTPUT_TYPE == "multiple_choice":
|
| 1524 |
+
request_list = [
|
| 1525 |
+
Instance(
|
| 1526 |
+
request_type="loglikelihood",
|
| 1527 |
+
doc=doc,
|
| 1528 |
+
arguments=arg,
|
| 1529 |
+
idx=i,
|
| 1530 |
+
**kwargs,
|
| 1531 |
+
)
|
| 1532 |
+
for i, arg in enumerate(arguments)
|
| 1533 |
+
]
|
| 1534 |
+
|
| 1535 |
+
return request_list
|
| 1536 |
+
|
| 1537 |
+
return Instance(
|
| 1538 |
+
request_type=self.OUTPUT_TYPE,
|
| 1539 |
+
doc=doc,
|
| 1540 |
+
arguments=arguments,
|
| 1541 |
+
idx=0,
|
| 1542 |
+
**kwargs,
|
| 1543 |
+
)
|
| 1544 |
+
|
| 1545 |
+
def process_results(self, doc, results):
|
| 1546 |
+
if callable(self.config.process_results):
|
| 1547 |
+
return self.config.process_results(doc, results)
|
| 1548 |
+
|
| 1549 |
+
result_dict = {}
|
| 1550 |
+
use_metric = list(self._metric_fn_list.keys())
|
| 1551 |
+
if self.OUTPUT_TYPE == "loglikelihood":
|
| 1552 |
+
results = results[0]
|
| 1553 |
+
ll, is_greedy = results
|
| 1554 |
+
return {
|
| 1555 |
+
**({"perplexity": ll} if "perplexity" in use_metric else {}),
|
| 1556 |
+
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
|
| 1557 |
+
}
|
| 1558 |
+
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
|
| 1559 |
+
(loglikelihood,) = results
|
| 1560 |
+
_words = self.count_words(self.doc_to_target(doc))
|
| 1561 |
+
_bytes = self.count_bytes(self.doc_to_target(doc))
|
| 1562 |
+
return {
|
| 1563 |
+
**(
|
| 1564 |
+
{"word_perplexity": (loglikelihood, _words)}
|
| 1565 |
+
if "word_perplexity" in use_metric
|
| 1566 |
+
else {}
|
| 1567 |
+
),
|
| 1568 |
+
**(
|
| 1569 |
+
{"byte_perplexity": (loglikelihood, _bytes)}
|
| 1570 |
+
if "byte_perplexity" in use_metric
|
| 1571 |
+
else {}
|
| 1572 |
+
),
|
| 1573 |
+
**(
|
| 1574 |
+
{"bits_per_byte": (loglikelihood, _bytes)}
|
| 1575 |
+
if "bits_per_byte" in use_metric
|
| 1576 |
+
else {}
|
| 1577 |
+
),
|
| 1578 |
+
}
|
| 1579 |
+
elif self.OUTPUT_TYPE == "multiple_choice":
|
| 1580 |
+
lls, is_greedy = zip(*results)
|
| 1581 |
+
|
| 1582 |
+
# retrieve choices in List[str] form, to compute choice lengths, etc.
|
| 1583 |
+
choices = self.doc_to_choice(doc)
|
| 1584 |
+
completion_len = np.array([float(len(i)) for i in choices])
|
| 1585 |
+
|
| 1586 |
+
if (
|
| 1587 |
+
2 * len(choices) == len(lls)
|
| 1588 |
+
and "acc_mutual_info" in self._metric_fn_list.keys()
|
| 1589 |
+
):
|
| 1590 |
+
# then we are doing mutual info.
|
| 1591 |
+
# this stores the "dryrun" / unconditional answer loglikelihoods
|
| 1592 |
+
# as we extend the args list with unconditional ("", continuation) pairs
|
| 1593 |
+
lls_unconditional = lls[len(choices) :]
|
| 1594 |
+
if len(lls_unconditional) != len(choices):
|
| 1595 |
+
raise ValueError
|
| 1596 |
+
# and this stores our "regular" conditional loglikelihoods
|
| 1597 |
+
lls = lls[: len(choices)]
|
| 1598 |
+
|
| 1599 |
+
pred = np.argmax(lls)
|
| 1600 |
+
pred_norm = np.argmax(lls / completion_len)
|
| 1601 |
+
|
| 1602 |
+
if self.multiple_input:
|
| 1603 |
+
gold = self.doc_to_text(doc)
|
| 1604 |
+
else:
|
| 1605 |
+
gold = self.doc_to_target(doc)
|
| 1606 |
+
|
| 1607 |
+
gold_index_error = False
|
| 1608 |
+
if isinstance(gold, list):
|
| 1609 |
+
gold = [i if i < len(choices) else -100 for i in gold]
|
| 1610 |
+
if -100 in gold:
|
| 1611 |
+
gold_index_error = True
|
| 1612 |
+
else:
|
| 1613 |
+
if isinstance(gold, int):
|
| 1614 |
+
gold = gold if gold < len(choices) else -100
|
| 1615 |
+
elif isinstance(gold, str):
|
| 1616 |
+
gold = choices.index(gold) if gold in choices else -100
|
| 1617 |
+
|
| 1618 |
+
if gold == -100:
|
| 1619 |
+
gold_index_error = True
|
| 1620 |
+
|
| 1621 |
+
if gold_index_error:
|
| 1622 |
+
eval_logger.warning(
|
| 1623 |
+
f"Label index was not in within range of available choices,"
|
| 1624 |
+
f"Sample:\n\n{doc}\n\n"
|
| 1625 |
+
)
|
| 1626 |
+
|
| 1627 |
+
if self.multiple_target:
|
| 1628 |
+
acc = 1.0 if pred in gold else 0.0
|
| 1629 |
+
acc_norm = 1.0 if pred_norm in gold else 0.0
|
| 1630 |
+
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
|
| 1631 |
+
else:
|
| 1632 |
+
acc = 1.0 if pred == gold else 0.0
|
| 1633 |
+
acc_norm = 1.0 if pred_norm == gold else 0.0
|
| 1634 |
+
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
|
| 1635 |
+
exact_match = int(is_greedy[gold]) if gold != -100 else 0
|
| 1636 |
+
|
| 1637 |
+
prob_norm = utils.softmax(lls)
|
| 1638 |
+
|
| 1639 |
+
# TODO use keyword arguments to the metric?
|
| 1640 |
+
# gold, pred, norm stuff, the original lls,
|
| 1641 |
+
result_dict = {
|
| 1642 |
+
**({"acc": acc} if "acc" in use_metric else {}),
|
| 1643 |
+
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
|
| 1644 |
+
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
|
| 1645 |
+
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
|
| 1646 |
+
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
|
| 1647 |
+
**(
|
| 1648 |
+
{"brier_score": (gold, prob_norm)}
|
| 1649 |
+
if "brier_score" in use_metric
|
| 1650 |
+
else {}
|
| 1651 |
+
),
|
| 1652 |
+
}
|
| 1653 |
+
|
| 1654 |
+
if "acc_mutual_info" in use_metric:
|
| 1655 |
+
lls_mutual_info = [
|
| 1656 |
+
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
|
| 1657 |
+
]
|
| 1658 |
+
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
|
| 1659 |
+
result_dict["acc_mutual_info"] = acc_mutual_info
|
| 1660 |
+
|
| 1661 |
+
elif self.OUTPUT_TYPE == "generate_until":
|
| 1662 |
+
gold = self.doc_to_target(doc)
|
| 1663 |
+
result = results[0]
|
| 1664 |
+
if self.config.doc_to_choice is not None:
|
| 1665 |
+
# If you set doc_to_choice,
|
| 1666 |
+
# it assumes that doc_to_target returns a number.
|
| 1667 |
+
choices = self.doc_to_choice(doc)
|
| 1668 |
+
gold = choices[gold]
|
| 1669 |
+
# we expect multiple_targets to be a list.
|
| 1670 |
+
elif self.multiple_target:
|
| 1671 |
+
gold = list(gold)
|
| 1672 |
+
# TODO: handle this better
|
| 1673 |
+
elif type(gold) is not type(result) and not (
|
| 1674 |
+
"bypass" in self._metric_fn_list.keys() or isinstance(result, list)
|
| 1675 |
+
):
|
| 1676 |
+
# cast gold to the same type as result
|
| 1677 |
+
gold = type(result)(gold)
|
| 1678 |
+
|
| 1679 |
+
for metric in self._metric_fn_list.keys():
|
| 1680 |
+
if self.multiple_target:
|
| 1681 |
+
# in the case where we have multiple targets,
|
| 1682 |
+
# return true if any are true
|
| 1683 |
+
# TODO: this may break for multipLe_target, non zero-or-1 metrics
|
| 1684 |
+
scores = []
|
| 1685 |
+
if not isinstance(gold, list):
|
| 1686 |
+
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
|
| 1687 |
+
# print(gold)
|
| 1688 |
+
gold = [gold]
|
| 1689 |
+
if metric == "exact_match":
|
| 1690 |
+
result = [result for _ in range(len(gold))]
|
| 1691 |
+
scores = self._metric_fn_list[metric](
|
| 1692 |
+
references=gold,
|
| 1693 |
+
predictions=result,
|
| 1694 |
+
**self._metric_fn_kwargs[metric],
|
| 1695 |
+
)[metric]
|
| 1696 |
+
result_score = 1.0 if scores > 0.0 else 0.0
|
| 1697 |
+
else:
|
| 1698 |
+
for gold_option in gold:
|
| 1699 |
+
try:
|
| 1700 |
+
result_score = self._metric_fn_list[metric](
|
| 1701 |
+
references=[gold_option],
|
| 1702 |
+
predictions=[result],
|
| 1703 |
+
**self._metric_fn_kwargs[metric],
|
| 1704 |
+
)
|
| 1705 |
+
except (
|
| 1706 |
+
TypeError
|
| 1707 |
+
): # TODO: this is hacky and I don't want to do it
|
| 1708 |
+
result_score = self._metric_fn_list[metric](
|
| 1709 |
+
[gold_option, result]
|
| 1710 |
+
)
|
| 1711 |
+
if isinstance(result_score, dict):
|
| 1712 |
+
# TODO: this handles the case where HF evaluate returns a dict.
|
| 1713 |
+
result_score = result_score[metric]
|
| 1714 |
+
scores.append(result_score)
|
| 1715 |
+
if any(scores):
|
| 1716 |
+
result_score = 1.0
|
| 1717 |
+
else:
|
| 1718 |
+
result_score = 0.0
|
| 1719 |
+
else:
|
| 1720 |
+
try:
|
| 1721 |
+
result_score = self._metric_fn_list[metric](
|
| 1722 |
+
references=[gold],
|
| 1723 |
+
predictions=[result],
|
| 1724 |
+
**self._metric_fn_kwargs[metric],
|
| 1725 |
+
)
|
| 1726 |
+
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
|
| 1727 |
+
result_score = self._metric_fn_list[metric]([gold, result])
|
| 1728 |
+
if isinstance(result_score, dict):
|
| 1729 |
+
# TODO: this handles the case where HF evaluate returns a dict.
|
| 1730 |
+
# This allows for multiple metrics to be returned from the same function
|
| 1731 |
+
for k, v in result_score.items():
|
| 1732 |
+
result_dict[k] = v
|
| 1733 |
+
else:
|
| 1734 |
+
result_dict[metric] = result_score
|
| 1735 |
+
else:
|
| 1736 |
+
raise ValueError(
|
| 1737 |
+
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
|
| 1738 |
+
"'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
|
| 1739 |
+
)
|
| 1740 |
+
|
| 1741 |
+
return result_dict
|
| 1742 |
+
|
| 1743 |
+
def aggregation(self) -> dict:
|
| 1744 |
+
return self._aggregation_list
|
| 1745 |
+
|
| 1746 |
+
def higher_is_better(self) -> dict:
|
| 1747 |
+
return self._higher_is_better
|
| 1748 |
+
|
| 1749 |
+
def get_config(self, key: str) -> Any:
|
| 1750 |
+
return getattr(self._config, key, None)
|
| 1751 |
+
|
| 1752 |
+
@property
|
| 1753 |
+
def task_name(self) -> Any:
|
| 1754 |
+
return getattr(self.config, "task", None)
|
| 1755 |
+
|
| 1756 |
+
def __repr__(self):
|
| 1757 |
+
return (
|
| 1758 |
+
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
|
| 1759 |
+
f"output_type={self.OUTPUT_TYPE},"
|
| 1760 |
+
f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
|
| 1761 |
+
f"num_samples={len(self.eval_docs)})"
|
| 1762 |
+
)
|
| 1763 |
+
|
| 1764 |
+
|
| 1765 |
+
class MultipleChoiceTask(Task):
|
| 1766 |
+
OUTPUT_TYPE = "loglikelihood"
|
| 1767 |
+
|
| 1768 |
+
def doc_to_target(self, doc: dict) -> str:
|
| 1769 |
+
return " " + doc["choices"][doc["gold"]]
|
| 1770 |
+
|
| 1771 |
+
def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
|
| 1772 |
+
# TODO: add mutual info here?
|
| 1773 |
+
return [
|
| 1774 |
+
Instance(
|
| 1775 |
+
request_type="loglikelihood",
|
| 1776 |
+
doc=doc,
|
| 1777 |
+
arguments=(ctx, " {}".format(choice)),
|
| 1778 |
+
idx=i,
|
| 1779 |
+
**kwargs,
|
| 1780 |
+
)
|
| 1781 |
+
for i, choice in enumerate(doc["choices"])
|
| 1782 |
+
]
|
| 1783 |
+
|
| 1784 |
+
def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
|
| 1785 |
+
results = [
|
| 1786 |
+
res[0] for res in results
|
| 1787 |
+
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
|
| 1788 |
+
gold = doc["gold"]
|
| 1789 |
+
|
| 1790 |
+
acc = 1.0 if np.argmax(results) == gold else 0.0
|
| 1791 |
+
completion_len = np.array([float(len(i)) for i in doc["choices"]])
|
| 1792 |
+
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
|
| 1793 |
+
|
| 1794 |
+
return {
|
| 1795 |
+
"acc": acc,
|
| 1796 |
+
"acc_norm": acc_norm,
|
| 1797 |
+
}
|
| 1798 |
+
|
| 1799 |
+
def higher_is_better(self) -> dict:
|
| 1800 |
+
return {
|
| 1801 |
+
"acc": True,
|
| 1802 |
+
"acc_norm": True,
|
| 1803 |
+
}
|
| 1804 |
+
|
| 1805 |
+
def aggregation(self) -> dict:
|
| 1806 |
+
return {
|
| 1807 |
+
"acc": mean,
|
| 1808 |
+
"acc_norm": mean,
|
| 1809 |
+
}
|
| 1810 |
+
|
| 1811 |
+
|
| 1812 |
+
class PerplexityTask(Task):
|
| 1813 |
+
OUTPUT_TYPE = "loglikelihood_rolling"
|
| 1814 |
+
|
| 1815 |
+
def has_training_docs(self) -> bool:
|
| 1816 |
+
return False
|
| 1817 |
+
|
| 1818 |
+
def fewshot_examples(self, k: int, rnd) -> List:
|
| 1819 |
+
if k != 0:
|
| 1820 |
+
raise ValueError(
|
| 1821 |
+
"The number of fewshot examples must be 0 for perplexity tasks."
|
| 1822 |
+
)
|
| 1823 |
+
return []
|
| 1824 |
+
|
| 1825 |
+
def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
|
| 1826 |
+
if num_fewshot != 0:
|
| 1827 |
+
raise ValueError(
|
| 1828 |
+
"The number of fewshot examples must be 0 for perplexity tasks."
|
| 1829 |
+
)
|
| 1830 |
+
|
| 1831 |
+
return ""
|
| 1832 |
+
|
| 1833 |
+
def higher_is_better(self) -> dict:
|
| 1834 |
+
return {
|
| 1835 |
+
"word_perplexity": False,
|
| 1836 |
+
"byte_perplexity": False,
|
| 1837 |
+
"bits_per_byte": False,
|
| 1838 |
+
}
|
| 1839 |
+
|
| 1840 |
+
def doc_to_decontamination_query(self, doc):
|
| 1841 |
+
return doc
|
| 1842 |
+
|
| 1843 |
+
def doc_to_text(self, doc) -> str:
|
| 1844 |
+
return ""
|
| 1845 |
+
|
| 1846 |
+
def doc_to_target(self, doc):
|
| 1847 |
+
return doc
|
| 1848 |
+
|
| 1849 |
+
def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
|
| 1850 |
+
if bool(ctx):
|
| 1851 |
+
raise ValueError
|
| 1852 |
+
|
| 1853 |
+
return Instance(
|
| 1854 |
+
request_type=self.OUTPUT_TYPE,
|
| 1855 |
+
doc=doc,
|
| 1856 |
+
arguments=(self.doc_to_target(doc),),
|
| 1857 |
+
idx=0,
|
| 1858 |
+
**kwargs,
|
| 1859 |
+
)
|
| 1860 |
+
|
| 1861 |
+
def process_results(self, doc: dict, results: Tuple[float]) -> dict:
|
| 1862 |
+
(loglikelihood,) = results
|
| 1863 |
+
words = self.count_words(self.doc_to_target(doc))
|
| 1864 |
+
bytes_ = self.count_bytes(self.doc_to_target(doc))
|
| 1865 |
+
return {
|
| 1866 |
+
"word_perplexity": (loglikelihood, words),
|
| 1867 |
+
"byte_perplexity": (loglikelihood, bytes_),
|
| 1868 |
+
"bits_per_byte": (loglikelihood, bytes_),
|
| 1869 |
+
}
|
| 1870 |
+
|
| 1871 |
+
def aggregation(self) -> dict:
|
| 1872 |
+
return {
|
| 1873 |
+
"word_perplexity": weighted_perplexity,
|
| 1874 |
+
"byte_perplexity": weighted_perplexity,
|
| 1875 |
+
"bits_per_byte": bits_per_byte,
|
| 1876 |
+
}
|
| 1877 |
+
|
| 1878 |
+
@classmethod
|
| 1879 |
+
def count_bytes(cls, doc) -> int:
|
| 1880 |
+
return len(doc.encode("utf-8"))
|
| 1881 |
+
|
| 1882 |
+
@classmethod
|
| 1883 |
+
def count_words(cls, doc) -> int:
|
| 1884 |
+
"""Downstream tasks with custom word boundaries should override this!"""
|
| 1885 |
+
return len(re.split(r"\s+", doc))
|
lm-evaluation-harness/lm_eval/caching/__init__.py
ADDED
|
File without changes
|
lm-evaluation-harness/lm_eval/caching/cache.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import dill
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
eval_logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
|
| 12 |
+
|
| 13 |
+
OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache"
|
| 17 |
+
|
| 18 |
+
# This should be sufficient for uniqueness
|
| 19 |
+
HASH_INPUT = "EleutherAI-lm-evaluation-harness"
|
| 20 |
+
|
| 21 |
+
HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()
|
| 22 |
+
|
| 23 |
+
FILE_SUFFIX = f".{HASH_PREFIX}.pickle"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_from_cache(file_name: str, cache: bool = False):
|
| 27 |
+
if not cache:
|
| 28 |
+
return
|
| 29 |
+
try:
|
| 30 |
+
path = f"{PATH}/{file_name}{FILE_SUFFIX}"
|
| 31 |
+
|
| 32 |
+
with open(path, "rb") as file:
|
| 33 |
+
cached_task_dict = dill.loads(file.read())
|
| 34 |
+
return cached_task_dict
|
| 35 |
+
|
| 36 |
+
except Exception:
|
| 37 |
+
eval_logger.debug(f"{file_name} is not cached, generating...")
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def save_to_cache(file_name, obj):
|
| 42 |
+
if not os.path.exists(PATH):
|
| 43 |
+
os.mkdir(PATH)
|
| 44 |
+
|
| 45 |
+
file_path = f"{PATH}/{file_name}{FILE_SUFFIX}"
|
| 46 |
+
|
| 47 |
+
eval_logger.debug(f"Saving {file_path} to cache...")
|
| 48 |
+
with open(file_path, "wb") as file:
|
| 49 |
+
file.write(dill.dumps(obj))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# NOTE the "key" param is to allow for flexibility
|
| 53 |
+
def delete_cache(key: str = ""):
|
| 54 |
+
files = os.listdir(PATH)
|
| 55 |
+
|
| 56 |
+
for file in files:
|
| 57 |
+
if file.startswith(key) and file.endswith(FILE_SUFFIX):
|
| 58 |
+
file_path = f"{PATH}/{file}"
|
| 59 |
+
os.unlink(file_path)
|
lm-evaluation-harness/lm_eval/decontamination/__init__.py
ADDED
|
File without changes
|
lm-evaluation-harness/lm_eval/decontamination/archiver.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import io
|
| 3 |
+
import json
|
| 4 |
+
import mmap
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import jsonlines
|
| 10 |
+
import tqdm
|
| 11 |
+
import zstandard
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def json_serial(obj: Any) -> str:
|
| 15 |
+
"""JSON serializer for objects not serializable by default json code"""
|
| 16 |
+
|
| 17 |
+
if isinstance(obj, (datetime.datetime,)):
|
| 18 |
+
return obj.isoformat()
|
| 19 |
+
raise TypeError("Type %s not serializable" % type(obj))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Modified version of lm_dataformat Archive for single file.
|
| 23 |
+
class Archive:
|
| 24 |
+
def __init__(self, file_path: str, compression_level: int = 3) -> None:
|
| 25 |
+
self.file_path = file_path
|
| 26 |
+
dir_name = os.path.dirname(file_path)
|
| 27 |
+
if dir_name:
|
| 28 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 29 |
+
self.fh = open(self.file_path, "wb")
|
| 30 |
+
self.cctx = zstandard.ZstdCompressor(level=compression_level)
|
| 31 |
+
self.compressor = self.cctx.stream_writer(self.fh)
|
| 32 |
+
|
| 33 |
+
def add_data(self, data, meta=None) -> None:
|
| 34 |
+
if meta is None:
|
| 35 |
+
meta = {}
|
| 36 |
+
self.compressor.write(
|
| 37 |
+
json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
|
| 38 |
+
"UTF-8"
|
| 39 |
+
)
|
| 40 |
+
+ b"\n"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def commit(self) -> None:
|
| 44 |
+
self.compressor.flush(zstandard.FLUSH_FRAME)
|
| 45 |
+
self.fh.flush()
|
| 46 |
+
self.fh.close()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
|
| 50 |
+
class Reader:
|
| 51 |
+
def __init__(self) -> None:
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
def read(
|
| 55 |
+
self,
|
| 56 |
+
file,
|
| 57 |
+
get_meta: bool = False,
|
| 58 |
+
autojoin_paragraphs: bool = True,
|
| 59 |
+
para_joiner: str = "\n\n",
|
| 60 |
+
):
|
| 61 |
+
with open(file, "rb") as fh:
|
| 62 |
+
self.fh = fh
|
| 63 |
+
cctx = zstandard.ZstdDecompressor()
|
| 64 |
+
reader = io.BufferedReader(cctx.stream_reader(fh))
|
| 65 |
+
rdr = jsonlines.Reader(reader)
|
| 66 |
+
for ob in rdr:
|
| 67 |
+
# naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
|
| 68 |
+
if isinstance(ob, str):
|
| 69 |
+
assert not get_meta
|
| 70 |
+
yield ob
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
text = ob["text"]
|
| 74 |
+
|
| 75 |
+
if autojoin_paragraphs and isinstance(text, list):
|
| 76 |
+
text = para_joiner.join(text)
|
| 77 |
+
|
| 78 |
+
if get_meta:
|
| 79 |
+
yield text, (ob["meta"] if "meta" in ob else {})
|
| 80 |
+
else:
|
| 81 |
+
yield text
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class TextArchive:
|
| 85 |
+
def __init__(self, file_path, mode: str = "rb+") -> None:
|
| 86 |
+
self.file_path = file_path
|
| 87 |
+
dir_name = os.path.dirname(file_path)
|
| 88 |
+
if dir_name:
|
| 89 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 90 |
+
|
| 91 |
+
if not os.path.exists(file_path):
|
| 92 |
+
Path(file_path).touch()
|
| 93 |
+
|
| 94 |
+
self.fh = open(self.file_path, mode)
|
| 95 |
+
|
| 96 |
+
def add_data(self, data) -> None:
|
| 97 |
+
self.fh.write(data.encode("UTF-8") + b"\n")
|
| 98 |
+
|
| 99 |
+
def commit(self) -> None:
|
| 100 |
+
self.fh.flush()
|
| 101 |
+
self.fh.close()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class TextReader:
|
| 105 |
+
def __init__(self, file_path) -> None:
|
| 106 |
+
self.file_path = file_path
|
| 107 |
+
|
| 108 |
+
# Optimized mmap read with infrequent tqdm updates to maintain speed
|
| 109 |
+
# Tested up to 250MB/s.
|
| 110 |
+
def read_tqdm(self, update_frequency: int = 10000):
|
| 111 |
+
current_file_position = 0
|
| 112 |
+
line_counter = 0
|
| 113 |
+
with (
|
| 114 |
+
open(self.file_path, "r", encoding="utf-8") as fh,
|
| 115 |
+
tqdm.tqdm(
|
| 116 |
+
total=os.path.getsize(self.file_path),
|
| 117 |
+
dynamic_ncols=True,
|
| 118 |
+
unit="byte",
|
| 119 |
+
unit_scale=1,
|
| 120 |
+
) as progress,
|
| 121 |
+
):
|
| 122 |
+
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
|
| 123 |
+
for line in iter(mmap_obj.readline, b""):
|
| 124 |
+
line = line.decode("utf-8")
|
| 125 |
+
line_counter += 1
|
| 126 |
+
if line_counter == update_frequency:
|
| 127 |
+
new_file_pos = mmap_obj.tell()
|
| 128 |
+
bytes_read = new_file_pos - current_file_position
|
| 129 |
+
current_file_position = new_file_pos
|
| 130 |
+
progress.update(bytes_read)
|
| 131 |
+
line_counter = 0
|
| 132 |
+
yield line[:-1]
|
| 133 |
+
|
| 134 |
+
def read_and_tell(self):
|
| 135 |
+
current_file_position = 0
|
| 136 |
+
with open(self.file_path, "r", encoding="utf8") as fh:
|
| 137 |
+
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
|
| 138 |
+
for line in iter(mmap_obj.readline, b""):
|
| 139 |
+
line = line.decode("utf-8")
|
| 140 |
+
new_file_pos = mmap_obj.tell()
|
| 141 |
+
raw_bytes_read = new_file_pos - current_file_position
|
| 142 |
+
current_file_position = new_file_pos
|
| 143 |
+
yield line[:-1], raw_bytes_read
|
| 144 |
+
|
| 145 |
+
def read(self):
|
| 146 |
+
with open(self.file_path, "r", encoding="utf8") as fh:
|
| 147 |
+
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
|
| 148 |
+
for line in iter(mmap_obj.readline, b""):
|
| 149 |
+
line = line.decode("utf-8")
|
| 150 |
+
yield line[:-1]
|
| 151 |
+
|
| 152 |
+
def read_slow(self):
|
| 153 |
+
with open(self.file_path, "r", encoding="utf8") as fh:
|
| 154 |
+
while True:
|
| 155 |
+
line = fh.readline()
|
| 156 |
+
if line == -1 or line == "":
|
| 157 |
+
break
|
| 158 |
+
else:
|
| 159 |
+
yield line[:-1]
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# Optimized for speed. Decompresses the archive in shell before
|
| 163 |
+
# using the mmap'd TextReader.
|
| 164 |
+
class ZStdTextReader:
|
| 165 |
+
def __init__(self, file) -> None:
|
| 166 |
+
self.file = file
|
| 167 |
+
|
| 168 |
+
def read_tqdm(self):
|
| 169 |
+
decompressed_file = self.file[:-4]
|
| 170 |
+
print("Decompressing file, please wait...")
|
| 171 |
+
os.system(f"zstd -d {self.file}") # linux decompress is faster
|
| 172 |
+
reader = TextReader(decompressed_file)
|
| 173 |
+
yield from reader.read_tqdm()
|
| 174 |
+
os.remove(decompressed_file)
|
lm-evaluation-harness/lm_eval/decontamination/decontaminate.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import glob
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import pickle
|
| 6 |
+
import random
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
from .archiver import ZStdTextReader
|
| 10 |
+
from .janitor import Janitor, word_ngrams
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Was used for testing the evaluator decoupled from the full logic below
|
| 14 |
+
def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str):
|
| 15 |
+
simulated_overlap = 0.1
|
| 16 |
+
contaminated = int(len(docs) * simulated_overlap)
|
| 17 |
+
return random.sample(range(len(docs)), contaminated)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Returns a dictionary containing all overlapping documents in each
|
| 21 |
+
# task. In the standard use case, an overlap occurs when any of the 13-grams
|
| 22 |
+
# found in the task document exist in the training set documents.
|
| 23 |
+
#
|
| 24 |
+
# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
|
| 25 |
+
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
|
| 26 |
+
# files. These should exist in the "ngrams_path" provided to this function.
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Algorithm:
|
| 30 |
+
# 1. Build lookups for each dataset {ngram: list(document_ids)}
|
| 31 |
+
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
|
| 32 |
+
# 3. Full scan the 13-grams from the training set against the merged lookup,
|
| 33 |
+
# saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
|
| 34 |
+
# 4. Strip the task_set from the dictionary keys and return
|
| 35 |
+
#
|
| 36 |
+
# We cache the task+set lookups as well as the overlaps.
|
| 37 |
+
def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict:
|
| 38 |
+
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
|
| 39 |
+
|
| 40 |
+
info_dict_path = os.path.join(ngrams_path, "info.json")
|
| 41 |
+
info_dict = json.load(open(info_dict_path, "r", encoding="utf-8"))
|
| 42 |
+
ngrams_n_size = info_dict["ngram_size"]
|
| 43 |
+
|
| 44 |
+
janitor = Janitor()
|
| 45 |
+
|
| 46 |
+
# Build lookup for each dataset first in case we use different task combinations later
|
| 47 |
+
print("Building Lookups...")
|
| 48 |
+
start = time.perf_counter()
|
| 49 |
+
|
| 50 |
+
def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str:
|
| 51 |
+
return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"
|
| 52 |
+
|
| 53 |
+
lookups = {}
|
| 54 |
+
duplicates = {} # (task_name, task_set): set(doc_ids)}
|
| 55 |
+
sets_to_decontaminate = len(docs_by_task_set.keys())
|
| 56 |
+
|
| 57 |
+
for (task_name, task_set), docs in docs_by_task_set.items():
|
| 58 |
+
if not os.path.exists(f"data/{task_name}"):
|
| 59 |
+
os.mkdir(f"data/{task_name}")
|
| 60 |
+
|
| 61 |
+
# Check if we've decontaminated this combination before
|
| 62 |
+
overlaps_dump_path = get_overlaps_dump_path(
|
| 63 |
+
task_name, task_set, ngrams_n_size, limit
|
| 64 |
+
)
|
| 65 |
+
if os.path.exists(overlaps_dump_path):
|
| 66 |
+
duplicates[(task_name, task_set)] = pickle.load(
|
| 67 |
+
open(overlaps_dump_path, "rb")
|
| 68 |
+
)
|
| 69 |
+
sets_to_decontaminate -= 1
|
| 70 |
+
continue
|
| 71 |
+
else:
|
| 72 |
+
duplicates[(task_name, task_set)] = set()
|
| 73 |
+
|
| 74 |
+
# Build/load the task lookup {ngram: set(documents)}.
|
| 75 |
+
task_set_lookup_path = (
|
| 76 |
+
f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
|
| 77 |
+
)
|
| 78 |
+
if os.path.exists(task_set_lookup_path):
|
| 79 |
+
print(f"{task_set_lookup_path} available, loading...")
|
| 80 |
+
lookups[(task_name, task_set)] = pickle.load(
|
| 81 |
+
open(task_set_lookup_path, "rb")
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
print(f"{task_set_lookup_path} not available, building...")
|
| 85 |
+
lookup = collections.defaultdict(set)
|
| 86 |
+
|
| 87 |
+
for doc_id, document in enumerate(docs):
|
| 88 |
+
ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
|
| 89 |
+
for ngram in ngrams:
|
| 90 |
+
lookup[ngram].add(doc_id)
|
| 91 |
+
|
| 92 |
+
pickle.dump(lookup, open(task_set_lookup_path, "wb"))
|
| 93 |
+
lookups[(task_name, task_set)] = lookup
|
| 94 |
+
|
| 95 |
+
elapsed = time.perf_counter() - start
|
| 96 |
+
print(f"Building lookups took {elapsed:0.5f} seconds.")
|
| 97 |
+
|
| 98 |
+
matched_ngrams = []
|
| 99 |
+
|
| 100 |
+
if sets_to_decontaminate > 0:
|
| 101 |
+
print("Merging lookups...")
|
| 102 |
+
start = time.perf_counter()
|
| 103 |
+
merged_lookup = collections.defaultdict(list)
|
| 104 |
+
for (task_name, task_set), lookup in lookups.items():
|
| 105 |
+
for ngram, doc_ids in lookup.items():
|
| 106 |
+
merged_lookup[ngram].append((task_name, task_set, doc_ids))
|
| 107 |
+
|
| 108 |
+
elapsed = time.perf_counter() - start
|
| 109 |
+
print(f"Merging lookups took {elapsed:0.5f} seconds.")
|
| 110 |
+
|
| 111 |
+
print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
|
| 112 |
+
files = glob.glob(os.path.join(ngrams_path, "*.sorted.zst"))
|
| 113 |
+
print(files)
|
| 114 |
+
|
| 115 |
+
for file in files:
|
| 116 |
+
start = time.perf_counter()
|
| 117 |
+
print(f"Scanning {file}")
|
| 118 |
+
reader = ZStdTextReader(file)
|
| 119 |
+
total_ngrams = 0
|
| 120 |
+
unique_ngrams = 0
|
| 121 |
+
matching_unique = 0
|
| 122 |
+
non_matching_unique = 0
|
| 123 |
+
|
| 124 |
+
current_ngram = ""
|
| 125 |
+
for line in reader.read_tqdm(): # Scan training set ngrams file
|
| 126 |
+
total_ngrams += 1
|
| 127 |
+
[ngram, document_id] = line.rsplit(" ", 1)
|
| 128 |
+
if (
|
| 129 |
+
ngram != current_ngram
|
| 130 |
+
): # Only need to match the ngram once in training set
|
| 131 |
+
unique_ngrams += 1
|
| 132 |
+
current_ngram = ngram
|
| 133 |
+
if ngram in merged_lookup:
|
| 134 |
+
matched_ngrams.append(ngram) # For logging
|
| 135 |
+
matching_unique += 1
|
| 136 |
+
for task_name, task_set, doc_ids in merged_lookup[ngram]:
|
| 137 |
+
task_doc_set = duplicates[(task_name, task_set)]
|
| 138 |
+
for doc_id in doc_ids: # Record contamination across all relevant task/set combos
|
| 139 |
+
task_doc_set.add(doc_id)
|
| 140 |
+
del merged_lookup[ngram] # No point matching again
|
| 141 |
+
else:
|
| 142 |
+
non_matching_unique += 1
|
| 143 |
+
|
| 144 |
+
print(f"Total Ngrams: {total_ngrams}")
|
| 145 |
+
print(f"Unique Ngrams: {unique_ngrams}")
|
| 146 |
+
print(f"Unique Matching: {matching_unique}")
|
| 147 |
+
print(f"Unique Non Matching: {non_matching_unique}")
|
| 148 |
+
print("Matched ngrams:")
|
| 149 |
+
for ngram in matched_ngrams:
|
| 150 |
+
print(ngram)
|
| 151 |
+
|
| 152 |
+
elapsed = time.perf_counter() - start
|
| 153 |
+
print(f"Read took {elapsed:0.5f} seconds.")
|
| 154 |
+
print(f"Speed: {(os.path.getsize(file) / 1000000.0) / elapsed}MB/second")
|
| 155 |
+
|
| 156 |
+
print(duplicates)
|
| 157 |
+
|
| 158 |
+
# Dump overlaps separately
|
| 159 |
+
for (task_name, task_set), doc_ids in duplicates.items():
|
| 160 |
+
overlaps_dump_path = get_overlaps_dump_path(
|
| 161 |
+
task_name, task_set, ngrams_n_size, limit
|
| 162 |
+
)
|
| 163 |
+
pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
|
| 164 |
+
|
| 165 |
+
# Strip task set and return
|
| 166 |
+
return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
|
lm-evaluation-harness/lm_eval/decontamination/janitor.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import re
|
| 3 |
+
import string
|
| 4 |
+
import traceback
|
| 5 |
+
from typing import Iterator, List, Sequence, Tuple, TypeVar
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# This is a cpp module.
|
| 9 |
+
# See scripts/clean_training_data/README.md for instructions to compile janitor_util.cpp
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import janitor_util
|
| 13 |
+
|
| 14 |
+
JANITOR_CPP = True
|
| 15 |
+
except Exception:
|
| 16 |
+
print("WARNING: C++ module could not be loaded. Janitor running in python mode")
|
| 17 |
+
traceback.print_exc()
|
| 18 |
+
JANITOR_CPP = False
|
| 19 |
+
|
| 20 |
+
T = TypeVar("T")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Implementation from nltk source
|
| 24 |
+
# https://www.nltk.org/_modules/nltk/util.html
|
| 25 |
+
def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]:
|
| 26 |
+
history = []
|
| 27 |
+
while n > 1:
|
| 28 |
+
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
|
| 29 |
+
try:
|
| 30 |
+
next_item = next(sequence)
|
| 31 |
+
except StopIteration:
|
| 32 |
+
# no more data, terminate the generator
|
| 33 |
+
return
|
| 34 |
+
history.append(next_item)
|
| 35 |
+
n -= 1
|
| 36 |
+
for item in sequence:
|
| 37 |
+
history.append(item)
|
| 38 |
+
yield tuple(history)
|
| 39 |
+
del history[0]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def word_ngrams(s: str, n: int) -> Iterator[str]:
|
| 43 |
+
"""Splits a string into ngram words"""
|
| 44 |
+
tokens = s.split() # not a generator :(
|
| 45 |
+
ngram_seqs = form_ngrams(iter(tokens), n)
|
| 46 |
+
return (" ".join(ngram) for ngram in ngram_seqs)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Does character sequences only - combined faster function to play around with later
|
| 50 |
+
# def word_ngrams_indices_combined(sequence, n):
|
| 51 |
+
# current_word = ""
|
| 52 |
+
# history = []
|
| 53 |
+
# gap = False;
|
| 54 |
+
# start = 0
|
| 55 |
+
# end = 0
|
| 56 |
+
# for character in sequence:
|
| 57 |
+
# if character == " ":
|
| 58 |
+
# if not gap:
|
| 59 |
+
# gap = True
|
| 60 |
+
# history.append(current_word)
|
| 61 |
+
# end += len(current_word) - 1
|
| 62 |
+
# current_word = ""
|
| 63 |
+
# if len(history) == n:
|
| 64 |
+
# yield (tuple(history), start, end)
|
| 65 |
+
# del history[0]
|
| 66 |
+
# start = end + 1
|
| 67 |
+
# end = start
|
| 68 |
+
# else:
|
| 69 |
+
# gap = False
|
| 70 |
+
# current_word += character
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
|
| 74 |
+
def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]:
|
| 75 |
+
"""Splits a string on whitespaces and records the indices of each in the original string.
|
| 76 |
+
@:return generator((word, (start_idx, end_idx)), ...)
|
| 77 |
+
"""
|
| 78 |
+
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]:
|
| 82 |
+
"""Splits a string into pairs of (ngram words, their start/end indices)"""
|
| 83 |
+
tokens_with_indices = split_indices(s)
|
| 84 |
+
|
| 85 |
+
# Generator of ngrams of (word, idx_pairs)
|
| 86 |
+
# (
|
| 87 |
+
# [(word, (start,end)), (word, (start, end))...],
|
| 88 |
+
# [(word, (start, end)), ...],
|
| 89 |
+
# ...
|
| 90 |
+
# )
|
| 91 |
+
ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n)
|
| 92 |
+
|
| 93 |
+
# Generator of pairs of word and index ngrams
|
| 94 |
+
# (
|
| 95 |
+
# ([word, word, ...], [(start,end), (start,end), ...]),
|
| 96 |
+
# ...
|
| 97 |
+
# )
|
| 98 |
+
ngram_indices_pairs = (
|
| 99 |
+
zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
|
| 103 |
+
return (
|
| 104 |
+
(" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
|
| 105 |
+
for ngram_seq, indices in ngram_indices_pairs
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class Janitor:
|
| 110 |
+
# FIXME delete_chars: Should anything else go here? Special chars?
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
ngram_n: int = 13,
|
| 114 |
+
window_to_remove: int = 200,
|
| 115 |
+
too_dirty_cutoff: int = 10,
|
| 116 |
+
minimum_slice_length: int = 200,
|
| 117 |
+
delete_chars: str = string.punctuation,
|
| 118 |
+
) -> None:
|
| 119 |
+
self.ngram_n = ngram_n
|
| 120 |
+
self.window_to_remove = window_to_remove
|
| 121 |
+
self.too_dirty_cutoff = too_dirty_cutoff
|
| 122 |
+
self.minimum_slice_length = minimum_slice_length
|
| 123 |
+
self.delete_chars = delete_chars
|
| 124 |
+
|
| 125 |
+
self.dirt_ngrams = set()
|
| 126 |
+
|
| 127 |
+
# If in python, we'll translate uppercase to lowercase and delete naughty characters.
|
| 128 |
+
# This is fast by python standards
|
| 129 |
+
# https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
|
| 130 |
+
self.translation_table = str.maketrans(
|
| 131 |
+
string.ascii_lowercase + string.ascii_uppercase, # These characters
|
| 132 |
+
string.ascii_lowercase * 2, # Become these characters
|
| 133 |
+
self.delete_chars, # These are deleted
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
##############
|
| 137 |
+
# I/O for saving contamination ngrams
|
| 138 |
+
##############
|
| 139 |
+
|
| 140 |
+
def save_contamination_ngrams(self, filename: str) -> None:
|
| 141 |
+
with open(filename, "wb") as fp:
|
| 142 |
+
pickle.dump(filename, fp)
|
| 143 |
+
|
| 144 |
+
def load_contamination_ngrams(self, filename: str) -> None:
|
| 145 |
+
with open(filename, "rb") as fp:
|
| 146 |
+
self.dirt_ngrams = pickle.load(fp)
|
| 147 |
+
|
| 148 |
+
##############
|
| 149 |
+
# Call these :)
|
| 150 |
+
##############
|
| 151 |
+
|
| 152 |
+
def register_contaminant(self, dirt_string: str) -> None:
|
| 153 |
+
"""Register a string as contamination to be removed, e.g. a test set
|
| 154 |
+
This breaks the dirt_string into ngrams to store for future cleaning"""
|
| 155 |
+
if JANITOR_CPP:
|
| 156 |
+
return self.register_contaminant_cpp(dirt_string)
|
| 157 |
+
else:
|
| 158 |
+
print("WARNING: Janitor running in python mode")
|
| 159 |
+
return self.register_contaminant_python(dirt_string)
|
| 160 |
+
|
| 161 |
+
def clean(self, dirty_string: str) -> List[str]:
|
| 162 |
+
"""Clean a string (e.g. a training set) by removing all ngrams previously
|
| 163 |
+
registered as contaminants. Returns a list of clean chunks, or empty if
|
| 164 |
+
the string was too dirty"""
|
| 165 |
+
if JANITOR_CPP:
|
| 166 |
+
return self.clean_cpp(dirty_string)
|
| 167 |
+
else:
|
| 168 |
+
print("WARNING: Janitor running in python mode")
|
| 169 |
+
return self.clean_python(dirty_string)
|
| 170 |
+
|
| 171 |
+
def _split_chunks(
|
| 172 |
+
self, dirty_string: str, dirty_parts: Sequence[Tuple]
|
| 173 |
+
) -> List[str]:
|
| 174 |
+
clean_chunks = []
|
| 175 |
+
splice_idx = 0
|
| 176 |
+
end = -1
|
| 177 |
+
for i, (ngram, start, end) in enumerate(dirty_parts):
|
| 178 |
+
if i >= self.too_dirty_cutoff:
|
| 179 |
+
return []
|
| 180 |
+
start = max(0, start - self.window_to_remove)
|
| 181 |
+
end = min(len(dirty_string), end + self.window_to_remove)
|
| 182 |
+
|
| 183 |
+
if start - splice_idx > self.minimum_slice_length:
|
| 184 |
+
clean_chunks.append(dirty_string[splice_idx:start])
|
| 185 |
+
splice_idx = end
|
| 186 |
+
|
| 187 |
+
if end < len(dirty_string) - self.minimum_slice_length:
|
| 188 |
+
clean_chunks.append(dirty_string[end + 1 :])
|
| 189 |
+
|
| 190 |
+
return clean_chunks
|
| 191 |
+
|
| 192 |
+
##############
|
| 193 |
+
# Fast C++
|
| 194 |
+
##############
|
| 195 |
+
|
| 196 |
+
def register_contaminant_cpp(self, dirt_string) -> None:
|
| 197 |
+
self.dirt_ngrams.update(
|
| 198 |
+
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def clean_cpp(self, dirty_string: str) -> List[str]:
|
| 202 |
+
contamination_indices = janitor_util.clean_ngram_with_indices(
|
| 203 |
+
dirty_string, self.delete_chars, self.ngram_n
|
| 204 |
+
)
|
| 205 |
+
return self._split_chunks(dirty_string, contamination_indices)
|
| 206 |
+
|
| 207 |
+
##############
|
| 208 |
+
# Slow python
|
| 209 |
+
##############
|
| 210 |
+
|
| 211 |
+
def normalize_string(self, s: str) -> str:
|
| 212 |
+
return s.translate(self.translation_table)
|
| 213 |
+
|
| 214 |
+
def register_contaminant_python(self, dirt_string: str) -> None:
|
| 215 |
+
self.dirt_ngrams.update(
|
| 216 |
+
word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
def clean_python(self, dirty_string: str) -> List[str]:
|
| 220 |
+
contamination_indices = (
|
| 221 |
+
(None, *idx_pair)
|
| 222 |
+
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
|
| 223 |
+
if self.normalize_string(dirty_ngram) in self.dirt_ngrams
|
| 224 |
+
)
|
| 225 |
+
return self._split_chunks(dirty_string, contamination_indices)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
##################################################################
|
| 229 |
+
# Tests
|
| 230 |
+
#################################################################
|
| 231 |
+
|
| 232 |
+
# def print_cpp():
|
| 233 |
+
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
|
| 234 |
+
|
| 235 |
+
# for i in range(1, 10, 2):
|
| 236 |
+
# pprint(janitor_util.clean_ngram(source, string.punctuation, i))
|
| 237 |
+
# for ngram, start, end in \
|
| 238 |
+
# janitor_util.clean_ngram_with_indices(source, string.punctuation, i):
|
| 239 |
+
# print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n"))
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
# def test_cpp():
|
| 243 |
+
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
|
| 244 |
+
# contaminant = "dirty boy. Clean he he"
|
| 245 |
+
|
| 246 |
+
# jan_python = Janitor()
|
| 247 |
+
# jan_cpp = Janitor()
|
| 248 |
+
|
| 249 |
+
# jan_python.register_contaminant_python(contaminant)
|
| 250 |
+
# jan_cpp.register_contaminant(contaminant)
|
| 251 |
+
|
| 252 |
+
# assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams)
|
| 253 |
+
|
| 254 |
+
# assert jan_python.clean_python(source) == jan_cpp.clean(source), \
|
| 255 |
+
# (jan_python.clean_python(source), jan_cpp.clean(source))
|
| 256 |
+
|
| 257 |
+
# print("Passed test, python==cpp")
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# def benchmark():
|
| 261 |
+
# # Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html
|
| 262 |
+
# setup = \
|
| 263 |
+
# """
|
| 264 |
+
# with open("data/enwik8", "r") as f:
|
| 265 |
+
# data = f.read()
|
| 266 |
+
# jan = Janitor(too_dirty_cutoff=1000)
|
| 267 |
+
# jan.register_contaminant('''
|
| 268 |
+
# theories is that there is a connection between "geekdom" and autism.
|
| 269 |
+
# This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled "
|
| 270 |
+
# The [[Geek]] Syndrome", which is a point argued by many in the autism rights
|
| 271 |
+
# movement{{ref|Wired}}. This article, many professionals assert, is just one example of
|
| 272 |
+
# the media's application of mental disease labels to what is actually variant normal behavior
|
| 273 |
+
# &mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual
|
| 274 |
+
# interests, even when they seem unusual to others, are not in themselves signs of autism or
|
| 275 |
+
# Asperger's syndrome. Others assert that it is actually the medical profession which is applying
|
| 276 |
+
# mental disease labels to children who in the past would have simply been accepted as a little
|
| 277 |
+
# different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue.
|
| 278 |
+
# Due to the recent publicity surrounding autism and autis
|
| 279 |
+
# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
|
| 280 |
+
# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
|
| 281 |
+
# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
|
| 282 |
+
# would last, took a cautious approach, preferring to save the revenue rather than investing it in
|
| 283 |
+
# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
|
| 284 |
+
# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
|
| 285 |
+
# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
|
| 286 |
+
# with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M,
|
| 287 |
+
# ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995),
|
| 288 |
+
# ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the
|
| 289 |
+
# Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the
|
| 290 |
+
# [[United Arab Emirates]]. After the Emirates gained independence in 1971,
|
| 291 |
+
# ''')
|
| 292 |
+
# """
|
| 293 |
+
|
| 294 |
+
# n = 1
|
| 295 |
+
# print(f"Timing {n} run on 100 MB")
|
| 296 |
+
# print("Register contaminant")
|
| 297 |
+
# # print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n))
|
| 298 |
+
# print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n))
|
| 299 |
+
|
| 300 |
+
# print("Clean")
|
| 301 |
+
# # print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n))
|
| 302 |
+
# print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n))
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# def test_janitor_general():
|
| 306 |
+
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
|
| 307 |
+
# contaminant = "dirty boy. Clean he he"
|
| 308 |
+
|
| 309 |
+
# jan = Janitor(ngram_n=3)
|
| 310 |
+
# jan.register_contaminant(contaminant)
|
| 311 |
+
# cleaned = " ".join(jan.clean(source))
|
| 312 |
+
# for contam in jan.dirt_ngrams:
|
| 313 |
+
# assert contam not in cleaned, contam
|
| 314 |
+
|
| 315 |
+
# filename = "data/saved_contam"
|
| 316 |
+
# jan.save_contamination_ngrams(filename)
|
| 317 |
+
|
| 318 |
+
# jan = Janitor(ngram_n=3)
|
| 319 |
+
# jan.load_contamination_ngrams(filename)
|
| 320 |
+
# cleaned = " ".join(jan.clean(source))
|
| 321 |
+
# for contam in jan.dirt_ngrams:
|
| 322 |
+
# assert contam not in cleaned, contam
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# if __name__ == "__main__":
|
| 326 |
+
# test()
|
| 327 |
+
# # print_cpp()
|
| 328 |
+
# # test_cpp()
|
| 329 |
+
# # benchmark()
|
lm-evaluation-harness/lm_eval/evaluator.py
ADDED
|
@@ -0,0 +1,787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import time
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from typing import TYPE_CHECKING, List, Optional, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
import lm_eval.api.metrics
|
| 14 |
+
import lm_eval.api.registry
|
| 15 |
+
import lm_eval.api.task
|
| 16 |
+
import lm_eval.models
|
| 17 |
+
from lm_eval.caching.cache import delete_cache
|
| 18 |
+
from lm_eval.evaluator_utils import (
|
| 19 |
+
consolidate_group_results,
|
| 20 |
+
consolidate_results,
|
| 21 |
+
get_sample_size,
|
| 22 |
+
get_subtask_list,
|
| 23 |
+
get_task_list,
|
| 24 |
+
prepare_print_tasks,
|
| 25 |
+
print_writeout,
|
| 26 |
+
run_task_tests,
|
| 27 |
+
)
|
| 28 |
+
from lm_eval.loggers import EvaluationTracker
|
| 29 |
+
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
|
| 30 |
+
from lm_eval.tasks import TaskManager, get_task_dict
|
| 31 |
+
from lm_eval.utils import (
|
| 32 |
+
handle_non_serializable,
|
| 33 |
+
hash_dict_images,
|
| 34 |
+
hash_string,
|
| 35 |
+
positional_deprecated,
|
| 36 |
+
setup_logging,
|
| 37 |
+
simple_parse_args_string,
|
| 38 |
+
wrap_text,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if TYPE_CHECKING:
|
| 43 |
+
from lm_eval.api.model import LM
|
| 44 |
+
from lm_eval.api.task import Task
|
| 45 |
+
|
| 46 |
+
eval_logger = logging.getLogger(__name__)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@positional_deprecated
|
| 50 |
+
def simple_evaluate(
|
| 51 |
+
model,
|
| 52 |
+
model_args: Optional[Union[str, dict]] = None,
|
| 53 |
+
tasks: Optional[List[Union[str, dict, object]]] = None,
|
| 54 |
+
num_fewshot: Optional[int] = None,
|
| 55 |
+
batch_size: Optional[Union[int, str]] = None,
|
| 56 |
+
max_batch_size: Optional[int] = None,
|
| 57 |
+
device: Optional[str] = None,
|
| 58 |
+
use_cache: Optional[str] = None,
|
| 59 |
+
cache_requests: bool = False,
|
| 60 |
+
rewrite_requests_cache: bool = False,
|
| 61 |
+
delete_requests_cache: bool = False,
|
| 62 |
+
limit: Optional[Union[int, float]] = None,
|
| 63 |
+
samples: Optional[dict] = None,
|
| 64 |
+
bootstrap_iters: int = 100000,
|
| 65 |
+
check_integrity: bool = False,
|
| 66 |
+
write_out: bool = False,
|
| 67 |
+
log_samples: bool = True,
|
| 68 |
+
evaluation_tracker: Optional[EvaluationTracker] = None,
|
| 69 |
+
system_instruction: Optional[str] = None,
|
| 70 |
+
apply_chat_template: Union[bool, str] = False,
|
| 71 |
+
fewshot_as_multiturn: bool = False,
|
| 72 |
+
gen_kwargs: Union[str, dict, None] = None,
|
| 73 |
+
task_manager: Optional[TaskManager] = None,
|
| 74 |
+
verbosity=None,
|
| 75 |
+
predict_only: bool = False,
|
| 76 |
+
random_seed: int = 0,
|
| 77 |
+
numpy_random_seed: int = 1234,
|
| 78 |
+
torch_random_seed: int = 1234,
|
| 79 |
+
fewshot_random_seed: int = 1234,
|
| 80 |
+
confirm_run_unsafe_code: bool = False,
|
| 81 |
+
metadata: Optional[dict] = None,
|
| 82 |
+
):
|
| 83 |
+
"""Instantiate and evaluate a model on a list of tasks.
|
| 84 |
+
|
| 85 |
+
:param model: Union[str, LM]
|
| 86 |
+
Name of model or LM object, see lm_eval.models.get_model
|
| 87 |
+
:param model_args: Optional[str, dict]
|
| 88 |
+
String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
|
| 89 |
+
Ignored if `model` argument is a LM object.
|
| 90 |
+
:param tasks: list[Union[str, dict, Task]]
|
| 91 |
+
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
|
| 92 |
+
:param num_fewshot: int
|
| 93 |
+
Number of examples in few-shot context
|
| 94 |
+
:param batch_size: int or str, optional
|
| 95 |
+
Batch size for model
|
| 96 |
+
:param max_batch_size: int, optional
|
| 97 |
+
Maximal batch size to try with automatic batch size detection
|
| 98 |
+
:param device: str, optional
|
| 99 |
+
PyTorch device (e.g. "cpu" or "cuda:0") for running models
|
| 100 |
+
:param use_cache: str, optional
|
| 101 |
+
A path to a sqlite db file for caching model responses. `None` if not caching.
|
| 102 |
+
:param cache_requests: bool, optional
|
| 103 |
+
Speed up evaluation by caching the building of dataset requests. `None` if not caching.
|
| 104 |
+
:param rewrite_requests_cache: bool, optional
|
| 105 |
+
Rewrites all the request cache if set to `True`. `None` if not desired.
|
| 106 |
+
:param delete_requests_cache: bool, optional
|
| 107 |
+
Deletes all the request cache if set to `True`. `None` if not desired.
|
| 108 |
+
:param limit: int or float, optional
|
| 109 |
+
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
|
| 110 |
+
:param samples: dictionary, optional
|
| 111 |
+
Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
|
| 112 |
+
:param bootstrap_iters:
|
| 113 |
+
Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
|
| 114 |
+
:param check_integrity: bool
|
| 115 |
+
Whether to run the relevant part of the test suite for the tasks
|
| 116 |
+
:param write_out: bool
|
| 117 |
+
If True, write out an example document and model input for checking task integrity
|
| 118 |
+
:param log_samples: bool
|
| 119 |
+
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
|
| 120 |
+
:param system_instruction: str
|
| 121 |
+
System instruction to be applied to the prompt
|
| 122 |
+
:param apply_chat_template: Union[bool, str]
|
| 123 |
+
Specifies whether to apply a chat template to the prompt.
|
| 124 |
+
- If set to True, the default chat template is applied.
|
| 125 |
+
- If set to a string, applies the specified chat template by name.
|
| 126 |
+
Defaults to False (no chat template applied).
|
| 127 |
+
:param fewshot_as_multiturn: bool
|
| 128 |
+
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
|
| 129 |
+
:param gen_kwargs: dict or comma-separated string
|
| 130 |
+
Arguments for model generation
|
| 131 |
+
Ignored for all tasks with loglikelihood output_type
|
| 132 |
+
:param verbosity: str
|
| 133 |
+
Verbosity level for logging
|
| 134 |
+
:param predict_only: bool
|
| 135 |
+
If true only model outputs will be generated and returned. Metrics will not be evaluated
|
| 136 |
+
:param random_seed: int
|
| 137 |
+
Random seed for python's random module. If set to None, the seed will not be set.
|
| 138 |
+
:param numpy_random_seed: int
|
| 139 |
+
Random seed for numpy. If set to None, the seed will not be set.
|
| 140 |
+
:param torch_random_seed: int
|
| 141 |
+
Random seed for torch. If set to None, the seed will not be set.
|
| 142 |
+
:param fewshot_random_seed: int
|
| 143 |
+
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
|
| 144 |
+
:param metadata: dict
|
| 145 |
+
Additional metadata to be added to the task manager. Will get passed to the download function of the task.
|
| 146 |
+
return
|
| 147 |
+
Dictionary of results
|
| 148 |
+
"""
|
| 149 |
+
if verbosity is not None:
|
| 150 |
+
setup_logging(verbosity=verbosity)
|
| 151 |
+
start_date = time.time()
|
| 152 |
+
|
| 153 |
+
if limit is not None and samples is not None:
|
| 154 |
+
raise ValueError(
|
| 155 |
+
"Either 'limit' or 'samples' must be None, but both are not None."
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
_NEEDS_CHAT_TEMPLATE = ("inst", "chat")
|
| 159 |
+
if (
|
| 160 |
+
(
|
| 161 |
+
isinstance(model_args, str)
|
| 162 |
+
and any(kw in model_args.lower() for kw in _NEEDS_CHAT_TEMPLATE)
|
| 163 |
+
)
|
| 164 |
+
or (
|
| 165 |
+
isinstance(model_args, dict)
|
| 166 |
+
and any(
|
| 167 |
+
any(kw in str(v).lower() for kw in _NEEDS_CHAT_TEMPLATE)
|
| 168 |
+
for v in model_args.values()
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
) and not apply_chat_template:
|
| 172 |
+
eval_logger.warning(
|
| 173 |
+
wrap_text(
|
| 174 |
+
f"""pretrained={model_args.get("pretrained") if isinstance(model_args, dict) else model_args} appears to be an
|
| 175 |
+
instruct or chat variant but chat template is not applied.
|
| 176 |
+
Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`).""",
|
| 177 |
+
)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if delete_requests_cache:
|
| 181 |
+
eval_logger.info("Deleting requests cache...")
|
| 182 |
+
delete_cache()
|
| 183 |
+
|
| 184 |
+
seed_message = []
|
| 185 |
+
if random_seed is not None:
|
| 186 |
+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
|
| 187 |
+
seed_message.append(f"Setting random seed to {random_seed}")
|
| 188 |
+
random.seed(random_seed)
|
| 189 |
+
|
| 190 |
+
if numpy_random_seed is not None:
|
| 191 |
+
seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
|
| 192 |
+
np.random.seed(numpy_random_seed)
|
| 193 |
+
|
| 194 |
+
if torch_random_seed is not None:
|
| 195 |
+
seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
|
| 196 |
+
torch.manual_seed(torch_random_seed)
|
| 197 |
+
|
| 198 |
+
if fewshot_random_seed is not None:
|
| 199 |
+
seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}")
|
| 200 |
+
|
| 201 |
+
if seed_message:
|
| 202 |
+
eval_logger.info(" | ".join(seed_message))
|
| 203 |
+
|
| 204 |
+
if tasks is None:
|
| 205 |
+
tasks = []
|
| 206 |
+
if len(tasks) == 0:
|
| 207 |
+
raise ValueError(
|
| 208 |
+
"No tasks specified, or no tasks found. Please verify the task names."
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if gen_kwargs is not None:
|
| 212 |
+
if isinstance(gen_kwargs, str):
|
| 213 |
+
gen_kwargs = simple_parse_args_string(gen_kwargs)
|
| 214 |
+
eval_logger.warning(
|
| 215 |
+
f"generation_kwargs: {gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
|
| 216 |
+
"Ensure 'do_sample=True' for non-greedy decoding!"
|
| 217 |
+
)
|
| 218 |
+
if not gen_kwargs:
|
| 219 |
+
gen_kwargs = None
|
| 220 |
+
|
| 221 |
+
if isinstance(model, str):
|
| 222 |
+
if model_args is None:
|
| 223 |
+
eval_logger.warning("model_args not specified. Using defaults.")
|
| 224 |
+
model_args = ""
|
| 225 |
+
|
| 226 |
+
if isinstance(model_args, dict):
|
| 227 |
+
eval_logger.info(
|
| 228 |
+
f"Initializing {model} model, with arguments: {model_args}"
|
| 229 |
+
)
|
| 230 |
+
lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
|
| 231 |
+
model_args,
|
| 232 |
+
{
|
| 233 |
+
"batch_size": batch_size,
|
| 234 |
+
"max_batch_size": max_batch_size,
|
| 235 |
+
"device": device,
|
| 236 |
+
},
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
else:
|
| 240 |
+
eval_logger.info(
|
| 241 |
+
wrap_text(
|
| 242 |
+
f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
|
| 243 |
+
)
|
| 244 |
+
)
|
| 245 |
+
lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
|
| 246 |
+
model_args,
|
| 247 |
+
{
|
| 248 |
+
"batch_size": batch_size,
|
| 249 |
+
"max_batch_size": max_batch_size,
|
| 250 |
+
"device": device,
|
| 251 |
+
},
|
| 252 |
+
)
|
| 253 |
+
else:
|
| 254 |
+
if not isinstance(model, lm_eval.api.model.LM):
|
| 255 |
+
raise TypeError(
|
| 256 |
+
f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
|
| 257 |
+
)
|
| 258 |
+
eval_logger.info("Using pre-initialized model")
|
| 259 |
+
lm = model
|
| 260 |
+
|
| 261 |
+
if use_cache is not None:
|
| 262 |
+
eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
|
| 263 |
+
lm = lm_eval.api.model.CachingLM(
|
| 264 |
+
lm,
|
| 265 |
+
use_cache
|
| 266 |
+
# each rank receives a different cache db.
|
| 267 |
+
# necessary to avoid multiple writes to cache at once
|
| 268 |
+
+ "_rank"
|
| 269 |
+
+ str(lm.rank)
|
| 270 |
+
+ ".db",
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
if task_manager is None:
|
| 274 |
+
metadata = (
|
| 275 |
+
simple_parse_args_string(model_args)
|
| 276 |
+
if isinstance(model_args, str)
|
| 277 |
+
else model_args
|
| 278 |
+
if isinstance(model_args, dict)
|
| 279 |
+
else {}
|
| 280 |
+
) | (metadata or {})
|
| 281 |
+
task_manager = TaskManager(metadata=metadata)
|
| 282 |
+
|
| 283 |
+
task_dict = get_task_dict(
|
| 284 |
+
tasks,
|
| 285 |
+
task_manager,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
|
| 289 |
+
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
|
| 290 |
+
def _adjust_config(task_dict):
|
| 291 |
+
adjusted_task_dict = {}
|
| 292 |
+
for task_name, task_obj in task_dict.items():
|
| 293 |
+
if isinstance(task_obj, dict):
|
| 294 |
+
adjusted_task_dict = {
|
| 295 |
+
**adjusted_task_dict,
|
| 296 |
+
**{task_name: _adjust_config(task_obj)},
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
else:
|
| 300 |
+
if task_obj.get_config("output_type") == "generate_until":
|
| 301 |
+
if gen_kwargs is not None:
|
| 302 |
+
task_obj.set_config(
|
| 303 |
+
key="generation_kwargs", value=gen_kwargs, update=True
|
| 304 |
+
)
|
| 305 |
+
eval_logger.info(
|
| 306 |
+
f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
if predict_only:
|
| 310 |
+
eval_logger.info(
|
| 311 |
+
f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
|
| 312 |
+
)
|
| 313 |
+
# we have to change the class properties post-hoc. This is pretty hacky.
|
| 314 |
+
task_obj.override_metric(metric_name="bypass")
|
| 315 |
+
|
| 316 |
+
# override tasks' fewshot values to the provided num_fewshot arg value
|
| 317 |
+
# except if tasks have it set to 0 manually in their configs--then we should never overwrite that
|
| 318 |
+
if num_fewshot is not None:
|
| 319 |
+
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
|
| 320 |
+
eval_logger.info(
|
| 321 |
+
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
|
| 322 |
+
)
|
| 323 |
+
else:
|
| 324 |
+
eval_logger.warning(
|
| 325 |
+
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
|
| 326 |
+
)
|
| 327 |
+
task_obj.set_config(key="num_fewshot", value=num_fewshot)
|
| 328 |
+
else:
|
| 329 |
+
# if num_fewshot not provided, and the task does not define a default one, default to 0
|
| 330 |
+
if (
|
| 331 |
+
default_num_fewshot := task_obj.get_config("num_fewshot")
|
| 332 |
+
) is None:
|
| 333 |
+
task_obj.set_config(key="num_fewshot", value=0)
|
| 334 |
+
# fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
|
| 335 |
+
task_obj.set_fewshot_seed(seed=fewshot_random_seed)
|
| 336 |
+
|
| 337 |
+
adjusted_task_dict[task_name] = task_obj
|
| 338 |
+
|
| 339 |
+
return adjusted_task_dict
|
| 340 |
+
|
| 341 |
+
task_dict = _adjust_config(task_dict)
|
| 342 |
+
|
| 343 |
+
if check_integrity:
|
| 344 |
+
run_task_tests(task_list=tasks)
|
| 345 |
+
|
| 346 |
+
if evaluation_tracker is not None:
|
| 347 |
+
evaluation_tracker.general_config_tracker.log_experiment_args(
|
| 348 |
+
model_source=model,
|
| 349 |
+
model_args=model_args,
|
| 350 |
+
system_instruction=system_instruction,
|
| 351 |
+
chat_template=lm.chat_template(apply_chat_template)
|
| 352 |
+
if apply_chat_template
|
| 353 |
+
else None,
|
| 354 |
+
fewshot_as_multiturn=fewshot_as_multiturn,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
results = evaluate(
|
| 358 |
+
lm=lm,
|
| 359 |
+
task_dict=task_dict,
|
| 360 |
+
limit=limit,
|
| 361 |
+
samples=samples,
|
| 362 |
+
cache_requests=cache_requests,
|
| 363 |
+
rewrite_requests_cache=rewrite_requests_cache,
|
| 364 |
+
bootstrap_iters=bootstrap_iters,
|
| 365 |
+
write_out=write_out,
|
| 366 |
+
log_samples=True if predict_only else log_samples,
|
| 367 |
+
system_instruction=system_instruction,
|
| 368 |
+
apply_chat_template=apply_chat_template,
|
| 369 |
+
fewshot_as_multiturn=fewshot_as_multiturn,
|
| 370 |
+
verbosity=verbosity,
|
| 371 |
+
confirm_run_unsafe_code=confirm_run_unsafe_code,
|
| 372 |
+
)
|
| 373 |
+
if verbosity is not None:
|
| 374 |
+
setup_logging(verbosity=verbosity)
|
| 375 |
+
|
| 376 |
+
if lm.rank == 0:
|
| 377 |
+
if isinstance(model, str):
|
| 378 |
+
model_name = model
|
| 379 |
+
elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
|
| 380 |
+
model_name = model.config._name_or_path
|
| 381 |
+
else:
|
| 382 |
+
model_name = type(model).__name__
|
| 383 |
+
|
| 384 |
+
# add info about the model and few shot config
|
| 385 |
+
results["config"] = {
|
| 386 |
+
"model": model_name,
|
| 387 |
+
"model_args": model_args,
|
| 388 |
+
}
|
| 389 |
+
# add more detailed model info if available
|
| 390 |
+
if isinstance(lm, lm_eval.models.huggingface.HFLM):
|
| 391 |
+
results["config"].update(lm.get_model_info())
|
| 392 |
+
# add info about execution
|
| 393 |
+
results["config"].update(
|
| 394 |
+
{
|
| 395 |
+
"batch_size": batch_size,
|
| 396 |
+
"batch_sizes": (
|
| 397 |
+
list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
|
| 398 |
+
),
|
| 399 |
+
"device": device,
|
| 400 |
+
"use_cache": use_cache,
|
| 401 |
+
"limit": limit,
|
| 402 |
+
"bootstrap_iters": bootstrap_iters,
|
| 403 |
+
"gen_kwargs": gen_kwargs,
|
| 404 |
+
"random_seed": random_seed,
|
| 405 |
+
"numpy_seed": numpy_random_seed,
|
| 406 |
+
"torch_seed": torch_random_seed,
|
| 407 |
+
"fewshot_seed": fewshot_random_seed,
|
| 408 |
+
}
|
| 409 |
+
)
|
| 410 |
+
results["git_hash"] = get_git_commit_hash()
|
| 411 |
+
results["date"] = start_date
|
| 412 |
+
add_env_info(results) # additional environment info to results
|
| 413 |
+
add_tokenizer_info(results, lm) # additional info about tokenizer
|
| 414 |
+
return results
|
| 415 |
+
else:
|
| 416 |
+
return None
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
@positional_deprecated
|
| 420 |
+
def evaluate(
|
| 421 |
+
lm: "LM",
|
| 422 |
+
task_dict,
|
| 423 |
+
limit: Optional[int] = None,
|
| 424 |
+
samples: Optional[dict] = None,
|
| 425 |
+
cache_requests: bool = False,
|
| 426 |
+
rewrite_requests_cache: bool = False,
|
| 427 |
+
bootstrap_iters: Optional[int] = 100000,
|
| 428 |
+
write_out: bool = False,
|
| 429 |
+
log_samples: bool = True,
|
| 430 |
+
system_instruction: Optional[str] = None,
|
| 431 |
+
apply_chat_template: Union[bool, str] = False,
|
| 432 |
+
fewshot_as_multiturn: bool = False,
|
| 433 |
+
verbosity: str = "INFO",
|
| 434 |
+
confirm_run_unsafe_code: bool = False,
|
| 435 |
+
):
|
| 436 |
+
"""Instantiate and evaluate a model on a list of tasks.
|
| 437 |
+
|
| 438 |
+
:param lm: obj
|
| 439 |
+
Language Model
|
| 440 |
+
:param task_dict: dict[str, Task]
|
| 441 |
+
Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
|
| 442 |
+
:param limit: int, optional
|
| 443 |
+
Limit the number of examples per task (only use this for testing)
|
| 444 |
+
:param samples: dictionary, optional
|
| 445 |
+
Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
|
| 446 |
+
:param cache_requests: bool, optional
|
| 447 |
+
Speed up evaluation by caching the building of dataset requests.
|
| 448 |
+
:param rewrite_requests_cache: bool, optional
|
| 449 |
+
Rewrites all the request cache if set to `True`.
|
| 450 |
+
:param bootstrap_iters:
|
| 451 |
+
Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
|
| 452 |
+
:param write_out: bool
|
| 453 |
+
If True, write out an example document and model input for checking task integrity
|
| 454 |
+
:param log_samples: bool
|
| 455 |
+
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
|
| 456 |
+
:param system_instruction: str
|
| 457 |
+
System instruction to be applied to the prompt
|
| 458 |
+
:param apply_chat_template: Union[bool, str]
|
| 459 |
+
Specifies whether to apply a chat template to the prompt.
|
| 460 |
+
- If set to True, the default chat template is applied.
|
| 461 |
+
- If set to a string, applies the specified chat template by name.
|
| 462 |
+
Defaults to False (no chat template applied).
|
| 463 |
+
:param fewshot_as_multiturn: bool
|
| 464 |
+
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
|
| 465 |
+
:param verbosity: str
|
| 466 |
+
Verbosity level for logging
|
| 467 |
+
:param confirm_run_unsafe_code: bool
|
| 468 |
+
Whether to confirm running tasks marked as unsafe.
|
| 469 |
+
:return
|
| 470 |
+
Dictionary of results
|
| 471 |
+
"""
|
| 472 |
+
|
| 473 |
+
if limit is not None and samples is not None:
|
| 474 |
+
raise ValueError(
|
| 475 |
+
"Either 'limit' or 'samples' must be None, but both are not None."
|
| 476 |
+
)
|
| 477 |
+
if samples is not None:
|
| 478 |
+
eval_logger.info(f"Evaluating examples for tasks {list(samples.keys())}")
|
| 479 |
+
if apply_chat_template:
|
| 480 |
+
eval_logger.warning(
|
| 481 |
+
"Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details."
|
| 482 |
+
)
|
| 483 |
+
# tracks all Instances/requests a model must generate output on.
|
| 484 |
+
requests = defaultdict(list)
|
| 485 |
+
# stores the amount to pad out reqs per req. type so that
|
| 486 |
+
# number of fwd passes per distributed rank is equal
|
| 487 |
+
padding_requests = defaultdict(int)
|
| 488 |
+
|
| 489 |
+
# get lists of group hierarchy and each type of request
|
| 490 |
+
eval_tasks = get_task_list(task_dict)
|
| 491 |
+
if not log_samples:
|
| 492 |
+
if not all(
|
| 493 |
+
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
|
| 494 |
+
for task_output in eval_tasks
|
| 495 |
+
):
|
| 496 |
+
raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
|
| 497 |
+
|
| 498 |
+
# validation checks:
|
| 499 |
+
# 1.are we running multimodal task <-> non-multimodal model class, or vice-versa.
|
| 500 |
+
# 2.are we running code that is marked as unsafe.
|
| 501 |
+
incompatible_tasks = []
|
| 502 |
+
for task_output in eval_tasks:
|
| 503 |
+
task: Task = task_output.task
|
| 504 |
+
|
| 505 |
+
if getattr(task, "MULTIMODAL", False) and not getattr(lm, "MULTIMODAL", False):
|
| 506 |
+
incompatible_tasks.append(task_output.task_name)
|
| 507 |
+
elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code:
|
| 508 |
+
raise ValueError(
|
| 509 |
+
f"Attempted to run task: {task_output.task_name} which is marked as unsafe. Set confirm_run_unsafe_code=True to run this task."
|
| 510 |
+
)
|
| 511 |
+
if len(incompatible_tasks) > 0:
|
| 512 |
+
if not getattr(lm, "MULTIMODAL", False):
|
| 513 |
+
raise ValueError(
|
| 514 |
+
f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
|
| 515 |
+
)
|
| 516 |
+
# end validation check
|
| 517 |
+
|
| 518 |
+
# Cache the limit arg.
|
| 519 |
+
limit_arg = limit
|
| 520 |
+
limits = []
|
| 521 |
+
for task_output in eval_tasks:
|
| 522 |
+
task: Task = task_output.task
|
| 523 |
+
|
| 524 |
+
limit = get_sample_size(task, limit_arg)
|
| 525 |
+
limits.append(limit)
|
| 526 |
+
task.build_all_requests(
|
| 527 |
+
limit=limit,
|
| 528 |
+
samples=samples.get(task_output.task_name, None)
|
| 529 |
+
if samples is not None
|
| 530 |
+
else samples,
|
| 531 |
+
rank=lm.rank,
|
| 532 |
+
world_size=lm.world_size,
|
| 533 |
+
cache_requests=cache_requests,
|
| 534 |
+
rewrite_requests_cache=rewrite_requests_cache,
|
| 535 |
+
system_instruction=system_instruction,
|
| 536 |
+
apply_chat_template=bool(apply_chat_template),
|
| 537 |
+
fewshot_as_multiturn=fewshot_as_multiturn,
|
| 538 |
+
chat_template=getattr(lm, "apply_chat_template")
|
| 539 |
+
if apply_chat_template
|
| 540 |
+
else None,
|
| 541 |
+
tokenizer_name=getattr(lm, "tokenizer_name", "")
|
| 542 |
+
if apply_chat_template
|
| 543 |
+
else "",
|
| 544 |
+
)
|
| 545 |
+
eval_logger.debug(
|
| 546 |
+
f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
|
| 547 |
+
)
|
| 548 |
+
if write_out:
|
| 549 |
+
print_writeout(task)
|
| 550 |
+
# aggregate Instances by LM method requested to get output.
|
| 551 |
+
for instance in task.instances:
|
| 552 |
+
reqtype = instance.request_type
|
| 553 |
+
requests[reqtype].append(instance)
|
| 554 |
+
|
| 555 |
+
if lm.world_size > 1:
|
| 556 |
+
instances_rnk = torch.tensor(len(task._instances), device=lm.device)
|
| 557 |
+
gathered_item = (
|
| 558 |
+
lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
|
| 559 |
+
)
|
| 560 |
+
# "multiple_choice" task types dispatch (several) "loglikelihood" request types
|
| 561 |
+
reqtype = (
|
| 562 |
+
"loglikelihood"
|
| 563 |
+
if task.OUTPUT_TYPE == "multiple_choice"
|
| 564 |
+
else task.OUTPUT_TYPE
|
| 565 |
+
)
|
| 566 |
+
# compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
|
| 567 |
+
numpad = max(gathered_item) - gathered_item[lm.rank]
|
| 568 |
+
# todo: may not account for padding in cases like SquadV2 which has multiple req types
|
| 569 |
+
padding_requests[reqtype] += numpad
|
| 570 |
+
|
| 571 |
+
### Run LM on inputs, get all outputs ###
|
| 572 |
+
# execute each type of request
|
| 573 |
+
for reqtype, reqs in requests.items():
|
| 574 |
+
eval_logger.info(f"Running {reqtype} requests")
|
| 575 |
+
# create `K` copies of each request `req` based off `K = req.repeats`
|
| 576 |
+
cloned_reqs = []
|
| 577 |
+
for req in reqs:
|
| 578 |
+
cloned_reqs.extend([req] * req.repeats)
|
| 579 |
+
|
| 580 |
+
if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
|
| 581 |
+
for _ in range(padding_requests[reqtype]):
|
| 582 |
+
cloned_reqs.extend([req] * req.repeats)
|
| 583 |
+
|
| 584 |
+
# run requests through model
|
| 585 |
+
resps = getattr(lm, reqtype)(cloned_reqs)
|
| 586 |
+
|
| 587 |
+
# put responses from model into a list of length K for each request.
|
| 588 |
+
for x, req in zip(resps, cloned_reqs):
|
| 589 |
+
req.resps.append(x)
|
| 590 |
+
|
| 591 |
+
if lm.world_size > 1:
|
| 592 |
+
lm.accelerator.wait_for_everyone()
|
| 593 |
+
|
| 594 |
+
RANK = lm.rank
|
| 595 |
+
WORLD_SIZE = lm.world_size
|
| 596 |
+
### Postprocess outputs ###
|
| 597 |
+
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
|
| 598 |
+
for task_output, limit in zip(eval_tasks, limits):
|
| 599 |
+
task = task_output.task
|
| 600 |
+
task.apply_filters()
|
| 601 |
+
|
| 602 |
+
### Collect values of metrics on all datapoints ###
|
| 603 |
+
# # unpack results and sort back in order and return control to Task
|
| 604 |
+
# TODO: make it possible to use a different metric per filter
|
| 605 |
+
# Pre-process task.instances to group by doc_id
|
| 606 |
+
instances_by_doc_id = defaultdict(list)
|
| 607 |
+
for instance in task.instances:
|
| 608 |
+
instances_by_doc_id[instance.doc_id].append(instance)
|
| 609 |
+
# Sort instances within each group
|
| 610 |
+
for instances in instances_by_doc_id.values():
|
| 611 |
+
instances.sort(key=lambda x: x.idx)
|
| 612 |
+
# iterate over different filters used
|
| 613 |
+
for filter_key in task.instances[0].filtered_resps.keys():
|
| 614 |
+
indices = (
|
| 615 |
+
samples.get(task_output.task_name, None)
|
| 616 |
+
if samples is not None
|
| 617 |
+
else None
|
| 618 |
+
)
|
| 619 |
+
doc_iterator = task.doc_iterator(
|
| 620 |
+
rank=RANK,
|
| 621 |
+
limit=limit,
|
| 622 |
+
world_size=WORLD_SIZE,
|
| 623 |
+
samples=indices,
|
| 624 |
+
)
|
| 625 |
+
for doc_id, doc in doc_iterator:
|
| 626 |
+
if indices:
|
| 627 |
+
doc_id_true = indices[doc_id]
|
| 628 |
+
else:
|
| 629 |
+
doc_id_true = doc_id
|
| 630 |
+
requests = instances_by_doc_id[doc_id]
|
| 631 |
+
metrics = task.process_results(
|
| 632 |
+
doc, [req.filtered_resps[filter_key] for req in requests]
|
| 633 |
+
)
|
| 634 |
+
if log_samples:
|
| 635 |
+
target = task.doc_to_target(doc)
|
| 636 |
+
example = {
|
| 637 |
+
"doc_id": doc_id_true,
|
| 638 |
+
"doc": doc,
|
| 639 |
+
"target": target,
|
| 640 |
+
"arguments": [req.args for req in requests],
|
| 641 |
+
"resps": [req.resps for req in requests],
|
| 642 |
+
"filtered_resps": [
|
| 643 |
+
req.filtered_resps[filter_key] for req in requests
|
| 644 |
+
],
|
| 645 |
+
"filter": filter_key,
|
| 646 |
+
"metrics": list(metrics.keys()),
|
| 647 |
+
"doc_hash": hash_string(
|
| 648 |
+
json.dumps(
|
| 649 |
+
requests[0].doc,
|
| 650 |
+
indent=2,
|
| 651 |
+
default=handle_non_serializable,
|
| 652 |
+
ensure_ascii=False,
|
| 653 |
+
)
|
| 654 |
+
),
|
| 655 |
+
"prompt_hash": hash_string(requests[0].arguments[0]),
|
| 656 |
+
"target_hash": hash_string(str(target)),
|
| 657 |
+
}
|
| 658 |
+
example.update(metrics)
|
| 659 |
+
task_output.logged_samples.append(example)
|
| 660 |
+
for metric, value in metrics.items():
|
| 661 |
+
task_output.sample_metrics[(metric, filter_key)].append(value)
|
| 662 |
+
|
| 663 |
+
if WORLD_SIZE > 1:
|
| 664 |
+
# if multigpu, then gather data across all ranks to rank 0
|
| 665 |
+
# first gather logged samples across all ranks
|
| 666 |
+
for task_output in eval_tasks:
|
| 667 |
+
if log_samples:
|
| 668 |
+
# for task_name, task_samples in list(samples.items()):
|
| 669 |
+
full_samples = [None] * WORLD_SIZE if RANK == 0 else None
|
| 670 |
+
torch.distributed.gather_object(
|
| 671 |
+
obj=task_output.logged_samples,
|
| 672 |
+
object_gather_list=full_samples,
|
| 673 |
+
dst=0,
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
if RANK == 0:
|
| 677 |
+
task_output.logged_samples = list(
|
| 678 |
+
itertools.chain.from_iterable(full_samples)
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
# then collect metrics across all ranks
|
| 682 |
+
for metrics in task_output.sample_metrics:
|
| 683 |
+
metric_list = [None] * WORLD_SIZE if RANK == 0 else None
|
| 684 |
+
torch.distributed.gather_object(
|
| 685 |
+
obj=task_output.sample_metrics[metrics],
|
| 686 |
+
object_gather_list=metric_list,
|
| 687 |
+
dst=0,
|
| 688 |
+
)
|
| 689 |
+
if RANK == 0:
|
| 690 |
+
task_output.sample_metrics[metrics] = list(
|
| 691 |
+
itertools.chain.from_iterable(metric_list)
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
if RANK == 0:
|
| 695 |
+
### Aggregate results over all datapoints ###
|
| 696 |
+
# aggregate results ; run bootstrap CIs
|
| 697 |
+
for task_output in eval_tasks:
|
| 698 |
+
task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
|
| 699 |
+
(
|
| 700 |
+
results,
|
| 701 |
+
samples,
|
| 702 |
+
configs,
|
| 703 |
+
versions,
|
| 704 |
+
num_fewshot,
|
| 705 |
+
higher_is_better,
|
| 706 |
+
) = consolidate_results(eval_tasks)
|
| 707 |
+
|
| 708 |
+
### Calculate group metrics ###
|
| 709 |
+
if bool(results):
|
| 710 |
+
results, versions, show_group_table, *_ = consolidate_group_results(
|
| 711 |
+
results, versions, task_dict
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
results_agg, group_agg = prepare_print_tasks(task_dict, results)
|
| 715 |
+
subtask_list = get_subtask_list(task_dict)
|
| 716 |
+
|
| 717 |
+
# collect all higher_is_better values for metrics
|
| 718 |
+
# in the group's subtasks.
|
| 719 |
+
# TODO: clean this up ; unify with the below metric_list loop?
|
| 720 |
+
_higher_is_better = {}
|
| 721 |
+
for group, task_list in subtask_list.items():
|
| 722 |
+
if (
|
| 723 |
+
len(task_list) != 0
|
| 724 |
+
): # subtask list will list "task_name": [] for solo tasks
|
| 725 |
+
for task in task_list:
|
| 726 |
+
for m, h in higher_is_better[task].items():
|
| 727 |
+
if m not in _higher_is_better.keys():
|
| 728 |
+
_higher_is_better[m] = h
|
| 729 |
+
|
| 730 |
+
if (
|
| 731 |
+
m in _higher_is_better
|
| 732 |
+
and _higher_is_better[m] is not None
|
| 733 |
+
and _higher_is_better[m] != h
|
| 734 |
+
):
|
| 735 |
+
eval_logger.warning(
|
| 736 |
+
f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
|
| 737 |
+
)
|
| 738 |
+
_higher_is_better[m] = None
|
| 739 |
+
higher_is_better[group] = _higher_is_better
|
| 740 |
+
|
| 741 |
+
results_dict = {
|
| 742 |
+
"results": dict(results_agg.items()),
|
| 743 |
+
**(
|
| 744 |
+
{"groups": dict(group_agg.items())}
|
| 745 |
+
if (bool(group_agg) & show_group_table)
|
| 746 |
+
else {}
|
| 747 |
+
),
|
| 748 |
+
"group_subtasks": dict(reversed(subtask_list.items())),
|
| 749 |
+
"configs": dict(sorted(configs.items())),
|
| 750 |
+
"versions": dict(sorted(versions.items())),
|
| 751 |
+
"n-shot": dict(sorted(num_fewshot.items())),
|
| 752 |
+
"higher_is_better": dict(sorted(higher_is_better.items())),
|
| 753 |
+
"n-samples": {
|
| 754 |
+
task_output.task_name: {
|
| 755 |
+
"original": len(task_output.task.eval_docs),
|
| 756 |
+
"effective": min(
|
| 757 |
+
limit if limit else len(task_output.task.eval_docs),
|
| 758 |
+
len(task_output.task.eval_docs),
|
| 759 |
+
),
|
| 760 |
+
}
|
| 761 |
+
for task_output, limit in zip(eval_tasks, limits)
|
| 762 |
+
},
|
| 763 |
+
}
|
| 764 |
+
if log_samples:
|
| 765 |
+
# default: hash images
|
| 766 |
+
samples = (
|
| 767 |
+
hash_dict_images(samples)
|
| 768 |
+
if os.environ.get("LMEVAL_HASHMM", "1") != "0"
|
| 769 |
+
and (hasattr(lm, "MULTIMODAL"))
|
| 770 |
+
else samples
|
| 771 |
+
)
|
| 772 |
+
results_dict["samples"] = dict(samples)
|
| 773 |
+
|
| 774 |
+
return results_dict
|
| 775 |
+
|
| 776 |
+
else:
|
| 777 |
+
return None
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
def request_caching_arg_to_dict(cache_requests: str) -> dict:
|
| 781 |
+
request_caching_args = {
|
| 782 |
+
"cache_requests": cache_requests in {"true", "refresh"},
|
| 783 |
+
"rewrite_requests_cache": cache_requests == "refresh",
|
| 784 |
+
"delete_requests_cache": cache_requests == "delete",
|
| 785 |
+
}
|
| 786 |
+
|
| 787 |
+
return request_caching_args
|
lm-evaluation-harness/lm_eval/evaluator_utils.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
import math
|
| 4 |
+
import pathlib
|
| 5 |
+
import sys
|
| 6 |
+
from typing import List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
from lm_eval.api.group import ConfigurableGroup
|
| 9 |
+
from lm_eval.api.metrics import (
|
| 10 |
+
aggregate_subtask_metrics,
|
| 11 |
+
mean,
|
| 12 |
+
pooled_sample_stderr,
|
| 13 |
+
stderr_for_metric,
|
| 14 |
+
)
|
| 15 |
+
from lm_eval.api.task import Task
|
| 16 |
+
from lm_eval.utils import positional_deprecated
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
eval_logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TaskOutput:
|
| 23 |
+
"""
|
| 24 |
+
Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task.
|
| 25 |
+
|
| 26 |
+
Attributes:
|
| 27 |
+
task (object): The task object.
|
| 28 |
+
task_name (str): The name of the task.
|
| 29 |
+
task_config (dict): The configuration of the task.
|
| 30 |
+
version (str): The version of the task.
|
| 31 |
+
group_name (str): The name of the task group.
|
| 32 |
+
n_shot (int): The number of shots for the task.
|
| 33 |
+
task_alias (str): The alias of the task.
|
| 34 |
+
group_alias (str): The alias of the task group.
|
| 35 |
+
is_group (bool): Indicates if the task is a group.
|
| 36 |
+
logged_samples (list): The list of logged samples.
|
| 37 |
+
sample_len (int): The length of the samples.
|
| 38 |
+
sample_metrics (defaultdict): The dictionary of samples' metrics.
|
| 39 |
+
agg_metrics (defaultdict): The dictionary of aggregate metrics.
|
| 40 |
+
|
| 41 |
+
Methods:
|
| 42 |
+
from_taskdict(cls, task_name: str, task):
|
| 43 |
+
Creates a TaskOutput instance from a task dictionary.
|
| 44 |
+
|
| 45 |
+
calculate_aggregate_metric(bootstrap_iters=100000) -> None:
|
| 46 |
+
Calculates the aggregate metrics for the task.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
task=None,
|
| 52 |
+
task_name=None,
|
| 53 |
+
task_config=None,
|
| 54 |
+
version=None,
|
| 55 |
+
group_name=None,
|
| 56 |
+
n_shot=None,
|
| 57 |
+
task_alias=None,
|
| 58 |
+
group_alias=None,
|
| 59 |
+
is_group=None,
|
| 60 |
+
):
|
| 61 |
+
self.task = task
|
| 62 |
+
self.task_config = task_config
|
| 63 |
+
self.task_name = task_name
|
| 64 |
+
self.group_name = group_name
|
| 65 |
+
self.version = version
|
| 66 |
+
self.n_shot = n_shot
|
| 67 |
+
self.task_alias = task_alias
|
| 68 |
+
self.group_alias = group_alias
|
| 69 |
+
self.is_group = is_group
|
| 70 |
+
self.logged_samples = []
|
| 71 |
+
self.sample_len = None
|
| 72 |
+
self.sample_metrics = collections.defaultdict(list)
|
| 73 |
+
self.agg_metrics = collections.defaultdict(list)
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def from_taskdict(cls, task_name: str, task):
|
| 77 |
+
if isinstance(task, tuple):
|
| 78 |
+
group_name, task = task
|
| 79 |
+
else:
|
| 80 |
+
group_name = None
|
| 81 |
+
if not task:
|
| 82 |
+
# these gets filtered out in get_task_list
|
| 83 |
+
# once they are added to group hierarchy
|
| 84 |
+
is_group = True
|
| 85 |
+
return cls(
|
| 86 |
+
task=task, task_name=task_name, is_group=is_group, group_name=group_name
|
| 87 |
+
)
|
| 88 |
+
version = task.VERSION
|
| 89 |
+
task_config = dict(task.dump_config())
|
| 90 |
+
if (n_shot := task_config.get("num_fewshot")) == 0:
|
| 91 |
+
n_shot = task_config.get("metadata", {}).get("num_fewshot", 0)
|
| 92 |
+
task_alias = task_config.get("alias")
|
| 93 |
+
group_alias = task_config.get("group_alias")
|
| 94 |
+
return cls(
|
| 95 |
+
task=task,
|
| 96 |
+
task_name=task_name,
|
| 97 |
+
task_config=task_config,
|
| 98 |
+
group_name=group_name,
|
| 99 |
+
version=version,
|
| 100 |
+
n_shot=n_shot,
|
| 101 |
+
task_alias=task_alias,
|
| 102 |
+
group_alias=group_alias,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None:
|
| 106 |
+
for (metric, filter_key), items in self.sample_metrics.items():
|
| 107 |
+
try:
|
| 108 |
+
agg_fn = self.task.aggregation()[metric]
|
| 109 |
+
except KeyError:
|
| 110 |
+
# This is when process results output an arbitrary metric
|
| 111 |
+
# TODO: Handle this better and allow other aggregate functions other than mean.
|
| 112 |
+
agg_fn = mean
|
| 113 |
+
metric_key = f"{metric},{filter_key}"
|
| 114 |
+
self.agg_metrics[metric_key] = agg_fn(items)
|
| 115 |
+
self.sample_len = len(items) # TODO: same sample size for each metric?
|
| 116 |
+
if isinstance(bootstrap_iters, int):
|
| 117 |
+
stderr_fn = stderr_for_metric(
|
| 118 |
+
metric=agg_fn,
|
| 119 |
+
bootstrap_iters=min(bootstrap_iters, 100)
|
| 120 |
+
if metric in ["bleu", "chrf", "ter"]
|
| 121 |
+
else bootstrap_iters,
|
| 122 |
+
)
|
| 123 |
+
self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
|
| 124 |
+
stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
|
| 125 |
+
)
|
| 126 |
+
else:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"Received bootstrap_iters '{bootstrap_iters}' but expected an integer. Set to 0 to turn off stderr calculations."
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def __repr__(self):
|
| 132 |
+
return (
|
| 133 |
+
f"TaskOutput(task_name={self.task_name}, "
|
| 134 |
+
f"group_name={self.group_name}, "
|
| 135 |
+
f"version={self.version}, "
|
| 136 |
+
f"n_shot={self.n_shot}, "
|
| 137 |
+
f"task_alias={self.task_alias}, "
|
| 138 |
+
f"group_alias={self.group_alias})"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_task_list(task_dict: dict) -> List[TaskOutput]:
|
| 143 |
+
outputs = []
|
| 144 |
+
for task_name, task_obj in task_dict.items():
|
| 145 |
+
if isinstance(task_obj, dict):
|
| 146 |
+
_outputs = get_task_list(task_obj)
|
| 147 |
+
outputs.extend(_outputs)
|
| 148 |
+
else:
|
| 149 |
+
task_output = TaskOutput.from_taskdict(task_name, task_obj)
|
| 150 |
+
outputs.append(task_output)
|
| 151 |
+
|
| 152 |
+
return outputs
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def get_subtask_list(task_dict, task_root=None, depth=0):
|
| 156 |
+
subtask_list = {}
|
| 157 |
+
for group_obj, task_obj in task_dict.items():
|
| 158 |
+
if isinstance(group_obj, ConfigurableGroup):
|
| 159 |
+
# group_name = group_obj.group_name
|
| 160 |
+
group_name = group_obj.group_name
|
| 161 |
+
else:
|
| 162 |
+
group_name = group_obj
|
| 163 |
+
if isinstance(task_obj, dict):
|
| 164 |
+
_subtask_list = get_subtask_list(
|
| 165 |
+
task_obj, task_root=group_name, depth=depth + 1
|
| 166 |
+
)
|
| 167 |
+
if task_root:
|
| 168 |
+
subtask_list.setdefault((task_root, depth), []).extend(
|
| 169 |
+
[
|
| 170 |
+
_task
|
| 171 |
+
for (_task, _depth) in _subtask_list.keys()
|
| 172 |
+
if (_depth - 1) == depth
|
| 173 |
+
]
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
subtask_list = {**subtask_list, **_subtask_list}
|
| 177 |
+
else:
|
| 178 |
+
if isinstance(task_obj, ConfigurableGroup):
|
| 179 |
+
# group_or_task_name = task_obj.group_name
|
| 180 |
+
group_or_task_name = task_obj.group_name
|
| 181 |
+
elif isinstance(task_obj, Task):
|
| 182 |
+
# group_or_task_name = task_obj.task_name
|
| 183 |
+
group_or_task_name = task_obj.task_name
|
| 184 |
+
|
| 185 |
+
if task_root is None:
|
| 186 |
+
subtask_list.setdefault((group_or_task_name, depth), [])
|
| 187 |
+
else:
|
| 188 |
+
subtask_list.setdefault((task_root, depth), []).append(
|
| 189 |
+
group_or_task_name
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if depth == 0:
|
| 193 |
+
_subtask_list = {}
|
| 194 |
+
for group_key, task_list in subtask_list.items():
|
| 195 |
+
group_name, depth = group_key
|
| 196 |
+
_subtask_list[group_name] = task_list
|
| 197 |
+
subtask_list = _subtask_list
|
| 198 |
+
|
| 199 |
+
return subtask_list
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def print_writeout(task) -> None:
|
| 203 |
+
for inst in task.instances:
|
| 204 |
+
# print the prompt for the first few documents
|
| 205 |
+
if inst.doc_id < 1:
|
| 206 |
+
eval_logger.info(
|
| 207 |
+
f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\
|
| 208 |
+
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
|
| 209 |
+
)
|
| 210 |
+
eval_logger.info(f"Request: {str(inst)}")
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
|
| 214 |
+
if limit is not None:
|
| 215 |
+
limit = (
|
| 216 |
+
int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit)
|
| 217 |
+
)
|
| 218 |
+
return limit
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def prepare_print_tasks(
|
| 222 |
+
task_dict: dict,
|
| 223 |
+
results: dict,
|
| 224 |
+
task_depth=0,
|
| 225 |
+
group_depth=0,
|
| 226 |
+
) -> Tuple[dict, dict]:
|
| 227 |
+
"""
|
| 228 |
+
@param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
|
| 229 |
+
value is a list of task names.
|
| 230 |
+
@param results: Dictionary containing the results of each task. Each key is a
|
| 231 |
+
group name and its value is a dictionary of task results.
|
| 232 |
+
@param task_depth: The indentation level for printing the task
|
| 233 |
+
hierarchy. Default is 0.
|
| 234 |
+
@param group_depth: The indentation level for printing the group
|
| 235 |
+
hierarchy. Default is 0.
|
| 236 |
+
@return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
|
| 237 |
+
aggregated results for each task, and groups_agg contains aggregated results for each group.
|
| 238 |
+
|
| 239 |
+
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def _sort_task_dict(task_dict):
|
| 243 |
+
"""
|
| 244 |
+
Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name.
|
| 245 |
+
Required so that we end up sorting within each sub-header correctly.
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
return dict(
|
| 249 |
+
sorted(
|
| 250 |
+
task_dict.items(),
|
| 251 |
+
key=lambda item: item[0].group_name
|
| 252 |
+
if isinstance(item[0], ConfigurableGroup)
|
| 253 |
+
else item[0],
|
| 254 |
+
)
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
task_agg = collections.defaultdict(dict)
|
| 258 |
+
group_agg = collections.defaultdict(dict)
|
| 259 |
+
task_dict = _sort_task_dict(task_dict)
|
| 260 |
+
for task_or_group_name, task_or_group_obj in task_dict.items():
|
| 261 |
+
tab_string = " " * task_depth + "- " if task_depth > 0 else ""
|
| 262 |
+
if isinstance(task_or_group_name, ConfigurableGroup):
|
| 263 |
+
# string_name = task_or_group_name.group_name
|
| 264 |
+
name = task_or_group_name.group_name
|
| 265 |
+
from_configurable_group = True
|
| 266 |
+
task_or_group_obj = _sort_task_dict(task_or_group_obj)
|
| 267 |
+
elif isinstance(task_or_group_name, str):
|
| 268 |
+
name = task_or_group_name
|
| 269 |
+
if isinstance(task_or_group_obj, Task):
|
| 270 |
+
# string_name = task_or_group_obj.task_name
|
| 271 |
+
name = task_or_group_obj.task_name
|
| 272 |
+
from_configurable_group = False
|
| 273 |
+
|
| 274 |
+
task_agg[name] = results[name].copy()
|
| 275 |
+
if from_configurable_group:
|
| 276 |
+
if task_or_group_name.group_alias is not None:
|
| 277 |
+
alias = task_or_group_name.group_alias
|
| 278 |
+
else:
|
| 279 |
+
alias = task_or_group_name.group
|
| 280 |
+
else:
|
| 281 |
+
if "alias" in task_agg[name]:
|
| 282 |
+
alias = task_agg[name]["alias"]
|
| 283 |
+
else:
|
| 284 |
+
alias = name
|
| 285 |
+
|
| 286 |
+
task_agg[name]["alias"] = tab_string + alias
|
| 287 |
+
if "samples" in task_agg[name]:
|
| 288 |
+
task_agg[name].pop("samples")
|
| 289 |
+
|
| 290 |
+
if from_configurable_group and (" " not in results[name]):
|
| 291 |
+
group_tab_string = " " * group_depth + "- " if group_depth > 0 else ""
|
| 292 |
+
group_agg[name] = results[name].copy()
|
| 293 |
+
group_agg[name]["alias"] = group_tab_string + alias
|
| 294 |
+
if "samples" in group_agg[name]:
|
| 295 |
+
group_agg[name].pop("samples")
|
| 296 |
+
|
| 297 |
+
if isinstance(task_or_group_obj, dict):
|
| 298 |
+
task_depth += 1
|
| 299 |
+
group_depth += 1
|
| 300 |
+
_task_agg, _group_agg = prepare_print_tasks(
|
| 301 |
+
task_or_group_obj, results, task_depth, group_depth
|
| 302 |
+
)
|
| 303 |
+
task_agg = {
|
| 304 |
+
**task_agg,
|
| 305 |
+
**_task_agg,
|
| 306 |
+
}
|
| 307 |
+
group_agg = {**group_agg, **_group_agg}
|
| 308 |
+
task_depth -= 1
|
| 309 |
+
group_depth -= 1
|
| 310 |
+
return task_agg, group_agg
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def consolidate_results(
|
| 314 |
+
eval_tasks: List[TaskOutput],
|
| 315 |
+
) -> Tuple[dict, dict, dict, dict, dict, dict]:
|
| 316 |
+
"""
|
| 317 |
+
@param eval_tasks: list(TaskOutput).
|
| 318 |
+
@return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot.
|
| 319 |
+
|
| 320 |
+
Consolidates the results of multiple evaluation tasks into a single structure.
|
| 321 |
+
|
| 322 |
+
The method iterates over each evaluation instance and extracts relevant information to create the consolidated
|
| 323 |
+
results structure. The consolidated results structure has the following properties:
|
| 324 |
+
|
| 325 |
+
- results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains
|
| 326 |
+
metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task
|
| 327 |
+
aliases specified in the task configuration.
|
| 328 |
+
- samples: A defaultdict with task names as keys and lists of log samples as values.
|
| 329 |
+
- configs: A defaultdict with task names as keys and task configurations as values.
|
| 330 |
+
- versions: A defaultdict with task names as keys and task versions as values.
|
| 331 |
+
- num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values.
|
| 332 |
+
- higher_is_better: A defaultdict with task names as keys and indicators of whether higher values are better
|
| 333 |
+
for each metric as values.
|
| 334 |
+
|
| 335 |
+
The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple.
|
| 336 |
+
"""
|
| 337 |
+
# stores the final result for each task, for each metric/filter pair.
|
| 338 |
+
results = collections.defaultdict(dict)
|
| 339 |
+
# logs info about each document evaluated.
|
| 340 |
+
samples = collections.defaultdict(list)
|
| 341 |
+
# store num-fewshot value per task
|
| 342 |
+
num_fewshot = collections.defaultdict(int)
|
| 343 |
+
# Tracks the YAML configs of all chosen task
|
| 344 |
+
configs = collections.defaultdict(dict)
|
| 345 |
+
# Tracks each task's version.
|
| 346 |
+
versions = collections.defaultdict(dict)
|
| 347 |
+
# Track `higher_is_better` for each metric
|
| 348 |
+
higher_is_better = collections.defaultdict(dict)
|
| 349 |
+
|
| 350 |
+
for task_output in eval_tasks:
|
| 351 |
+
if "task_alias" in (task_config := task_output.task_config):
|
| 352 |
+
results[task_output.task_name]["alias"] = task_config["task_alias"]
|
| 353 |
+
else:
|
| 354 |
+
results[task_output.task_name]["alias"] = task_output.task_name
|
| 355 |
+
if group_alias := task_output.group_alias:
|
| 356 |
+
if group_alias not in results and (group_name := task_output.group_name):
|
| 357 |
+
results[group_name]["alias"] = group_alias
|
| 358 |
+
num_fewshot[task_output.task_name] = task_output.n_shot
|
| 359 |
+
configs[task_output.task_name] = task_output.task_config
|
| 360 |
+
versions[task_output.task_name] = task_output.version
|
| 361 |
+
samples[task_output.task_name] = task_output.logged_samples
|
| 362 |
+
higher_is_better[task_output.task_name] = task_output.task.higher_is_better()
|
| 363 |
+
for (metric, filter_key), items in task_output.sample_metrics.items():
|
| 364 |
+
metric_key = f"{metric},{filter_key}"
|
| 365 |
+
results[task_output.task_name][metric_key] = task_output.agg_metrics[
|
| 366 |
+
metric_key
|
| 367 |
+
]
|
| 368 |
+
results[task_output.task_name]["samples"] = task_output.sample_len
|
| 369 |
+
results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = (
|
| 370 |
+
task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
|
| 371 |
+
)
|
| 372 |
+
return results, samples, configs, versions, num_fewshot, higher_is_better
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def consolidate_group_results(
|
| 376 |
+
results,
|
| 377 |
+
versions,
|
| 378 |
+
task_dict,
|
| 379 |
+
task_root=None,
|
| 380 |
+
show_group_table=False,
|
| 381 |
+
task_aggregation_list=None,
|
| 382 |
+
) -> Tuple[dict, dict, bool, Union[None,]]:
|
| 383 |
+
"""
|
| 384 |
+
(Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info.
|
| 385 |
+
|
| 386 |
+
@return: a tuple [results, versions, show_group_table, task_aggregation_list] with formats described below:
|
| 387 |
+
|
| 388 |
+
- results: A defaultdict with task names (and, after this function is called, group names of
|
| 389 |
+
groups that perform aggregation) as keys, and dictionaries with "alias" and metric,filter_name pairs as keys.
|
| 390 |
+
- versions: A defaultdict with task names (and, after this function is called, group names of
|
| 391 |
+
groups that perform aggregation) as keys, and float values representing the task or group's version if a version is specified. (defaulting to None).
|
| 392 |
+
- show_group_table: a boolean which is true if there exists a group that requires printing of its aggregated scores in a group table.
|
| 393 |
+
- task_aggregation_list: a defaultdict listing the subtasks to average over to produce a given group's end metric.
|
| 394 |
+
|
| 395 |
+
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
|
| 396 |
+
In the top-level invocation of this function, task_aggregation_list is ignored.
|
| 397 |
+
"""
|
| 398 |
+
if task_root is None:
|
| 399 |
+
task_root = {}
|
| 400 |
+
|
| 401 |
+
if task_aggregation_list is None:
|
| 402 |
+
task_aggregation_list = {}
|
| 403 |
+
|
| 404 |
+
for group_or_task, group_or_task_info in task_dict.items():
|
| 405 |
+
# Convert to string
|
| 406 |
+
if isinstance(group_or_task, ConfigurableGroup):
|
| 407 |
+
group_config = group_or_task.config
|
| 408 |
+
group_or_task = group_or_task.group_name
|
| 409 |
+
else:
|
| 410 |
+
group_config = None
|
| 411 |
+
|
| 412 |
+
if isinstance(group_or_task_info, Task):
|
| 413 |
+
if task_root:
|
| 414 |
+
task_aggregation_list.setdefault(task_root, []).append(
|
| 415 |
+
group_or_task_info.task_name
|
| 416 |
+
)
|
| 417 |
+
else:
|
| 418 |
+
(
|
| 419 |
+
results,
|
| 420 |
+
versions,
|
| 421 |
+
show_group_table,
|
| 422 |
+
_task_aggregation_list,
|
| 423 |
+
) = consolidate_group_results(
|
| 424 |
+
results,
|
| 425 |
+
versions,
|
| 426 |
+
group_or_task_info,
|
| 427 |
+
group_or_task,
|
| 428 |
+
show_group_table,
|
| 429 |
+
task_aggregation_list,
|
| 430 |
+
)
|
| 431 |
+
if task_root:
|
| 432 |
+
task_aggregation_list.setdefault(task_root, []).extend(
|
| 433 |
+
task_aggregation_list.get(group_or_task, [])
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
if (group_config is None) or (
|
| 437 |
+
group_config["aggregate_metric_list"] is None
|
| 438 |
+
):
|
| 439 |
+
results[group_or_task][" "] = " "
|
| 440 |
+
continue
|
| 441 |
+
|
| 442 |
+
if "aggregate_metric_list" in group_config:
|
| 443 |
+
agg_metric_list = group_config["aggregate_metric_list"]
|
| 444 |
+
|
| 445 |
+
show_group_table = show_group_table | bool(
|
| 446 |
+
group_config["aggregate_metric_list"]
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
task_list = _task_aggregation_list[group_or_task]
|
| 450 |
+
|
| 451 |
+
metric_list = list(
|
| 452 |
+
{
|
| 453 |
+
key
|
| 454 |
+
for task in task_list
|
| 455 |
+
for key in results[task].keys()
|
| 456 |
+
if "_stderr" not in key and key not in ["task", "alias", "samples"]
|
| 457 |
+
}
|
| 458 |
+
)
|
| 459 |
+
for metric in metric_list:
|
| 460 |
+
stderr = "_stderr,".join(metric.split(","))
|
| 461 |
+
|
| 462 |
+
# gather metrics, sizes, and stderrs from subtasks
|
| 463 |
+
metrics = [
|
| 464 |
+
results[task][metric]
|
| 465 |
+
for task in task_list
|
| 466 |
+
if metric in results[task]
|
| 467 |
+
] # TODO: copy?
|
| 468 |
+
stderrs = [
|
| 469 |
+
results[task][stderr]
|
| 470 |
+
for task in task_list
|
| 471 |
+
if stderr in results[task]
|
| 472 |
+
]
|
| 473 |
+
sizes = [
|
| 474 |
+
results[task]["samples"]
|
| 475 |
+
for task in task_list
|
| 476 |
+
if metric in results[task]
|
| 477 |
+
]
|
| 478 |
+
|
| 479 |
+
for metric_config in agg_metric_list:
|
| 480 |
+
for filter_name in metric_config["filter_list"]:
|
| 481 |
+
if metric != ",".join([metric_config["metric"], filter_name]):
|
| 482 |
+
continue
|
| 483 |
+
|
| 484 |
+
# compute group's pooled metric and stderr
|
| 485 |
+
if metric_config["aggregation"] == "mean":
|
| 486 |
+
aggregate_fn = aggregate_subtask_metrics
|
| 487 |
+
elif callable(metric_config["aggregation"]):
|
| 488 |
+
aggregate_fn = metric_config["aggregation"]
|
| 489 |
+
else:
|
| 490 |
+
raise ValueError(
|
| 491 |
+
f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'"
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
results[group_or_task][metric] = aggregate_fn(
|
| 495 |
+
metrics,
|
| 496 |
+
sizes,
|
| 497 |
+
metric_config["weight_by_size"],
|
| 498 |
+
)
|
| 499 |
+
# TODO: calculate groups' metrics using arbitrary agg fns
|
| 500 |
+
if "N/A" in stderrs:
|
| 501 |
+
results[group_or_task][stderr] = "N/A"
|
| 502 |
+
else:
|
| 503 |
+
# NOTE: this assumes we are using the mean to aggregate. There are warnings about this elsewhere
|
| 504 |
+
results[group_or_task][stderr] = pooled_sample_stderr(
|
| 505 |
+
stderrs, sizes
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
results[group_or_task]["samples"] = sum(sizes)
|
| 509 |
+
group_metadata = group_config.get("metadata", None)
|
| 510 |
+
if group_metadata is not None:
|
| 511 |
+
versions[group_or_task] = group_metadata.get("version", None)
|
| 512 |
+
# print(results)
|
| 513 |
+
return results, versions, show_group_table, task_aggregation_list
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
@positional_deprecated
|
| 517 |
+
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
|
| 518 |
+
"""
|
| 519 |
+
Search upward in the directory tree to a maximum of three layers
|
| 520 |
+
to find and return the package root (containing the 'tests' folder)
|
| 521 |
+
"""
|
| 522 |
+
cur_path = start_path.resolve()
|
| 523 |
+
max_layers = 3
|
| 524 |
+
for _ in range(max_layers):
|
| 525 |
+
if (cur_path / "tests" / "test_version_stable.py").exists():
|
| 526 |
+
return cur_path
|
| 527 |
+
else:
|
| 528 |
+
cur_path = cur_path.parent.resolve()
|
| 529 |
+
raise FileNotFoundError(
|
| 530 |
+
f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
@positional_deprecated
|
| 535 |
+
def run_task_tests(task_list: List[str]):
|
| 536 |
+
"""
|
| 537 |
+
Find the package root and run the tests for the given tasks
|
| 538 |
+
"""
|
| 539 |
+
import pytest
|
| 540 |
+
|
| 541 |
+
package_root = find_test_root(start_path=pathlib.Path(__file__))
|
| 542 |
+
task_string = " or ".join(task_list)
|
| 543 |
+
args = [
|
| 544 |
+
f"{package_root}/tests/test_version_stable.py",
|
| 545 |
+
f"--rootdir={package_root}",
|
| 546 |
+
"-k",
|
| 547 |
+
f"{task_string}",
|
| 548 |
+
]
|
| 549 |
+
sys.path.append(str(package_root))
|
| 550 |
+
pytest_return_val = pytest.main(args)
|
| 551 |
+
if pytest_return_val:
|
| 552 |
+
raise ValueError(
|
| 553 |
+
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
|
| 554 |
+
)
|
lm-evaluation-harness/lm_eval/filters/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from lm_eval.api.filter import FilterEnsemble
|
| 5 |
+
from lm_eval.api.registry import get_filter
|
| 6 |
+
|
| 7 |
+
from . import custom, extraction, selection, transformation
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def build_filter_ensemble(
|
| 11 |
+
filter_name: str, components: List[List[str]]
|
| 12 |
+
) -> FilterEnsemble:
|
| 13 |
+
"""
|
| 14 |
+
Create a filtering pipeline.
|
| 15 |
+
"""
|
| 16 |
+
filters = []
|
| 17 |
+
for function, kwargs in components:
|
| 18 |
+
if kwargs is None:
|
| 19 |
+
kwargs = {}
|
| 20 |
+
# create a filter given its name in the registry
|
| 21 |
+
f = partial(get_filter(function), **kwargs)
|
| 22 |
+
# add the filter as a pipeline step
|
| 23 |
+
filters.append(f)
|
| 24 |
+
|
| 25 |
+
return FilterEnsemble(name=filter_name, filters=filters)
|
lm-evaluation-harness/lm_eval/filters/custom.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lm_eval.api.filter import Filter
|
| 2 |
+
from lm_eval.api.registry import register_filter
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@register_filter("custom")
|
| 6 |
+
class CustomFilter(Filter):
|
| 7 |
+
"""
|
| 8 |
+
Custom filter that applies a custom, user-defined function to the model responses.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, **kwargs) -> None:
|
| 12 |
+
self.filter_fn = kwargs.pop("filter_fn")
|
| 13 |
+
|
| 14 |
+
super().__init__(**kwargs)
|
| 15 |
+
|
| 16 |
+
def apply(self, resps, docs):
|
| 17 |
+
return self.filter_fn(resps, docs)
|
lm-evaluation-harness/lm_eval/filters/decontamination.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lm_eval.api.filter import Filter
|
| 2 |
+
from lm_eval.api.registry import register_filter
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@register_filter("decontaminate")
|
| 6 |
+
class DecontaminationFilter(Filter):
|
| 7 |
+
"""
|
| 8 |
+
A filter which evaluates
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
name = "track_decontamination"
|
| 12 |
+
|
| 13 |
+
def __init__(self, path) -> None:
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
|
| 17 |
+
should further cache result on a given (task_name, doc_id)
|
| 18 |
+
"""
|
| 19 |
+
self._decontam_results = None
|
| 20 |
+
|
| 21 |
+
def apply(self, resps, docs) -> None:
|
| 22 |
+
"""
|
| 23 |
+
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
|
| 24 |
+
"""
|
| 25 |
+
pass
|
lm-evaluation-harness/lm_eval/filters/extraction.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import sys
|
| 3 |
+
import unicodedata
|
| 4 |
+
|
| 5 |
+
from lm_eval.api.filter import Filter
|
| 6 |
+
from lm_eval.api.registry import register_filter
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@register_filter("regex")
|
| 10 |
+
class RegexFilter(Filter):
|
| 11 |
+
"""A filter that extracts values from text using regex pattern matching.
|
| 12 |
+
|
| 13 |
+
This filter applies a regex pattern to each model response and extracts matched values.
|
| 14 |
+
If no match is found, returns a fallback value. Useful for extracting structured data
|
| 15 |
+
(like numbers) from unstructured model outputs.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
|
| 21 |
+
group_select: int = 0,
|
| 22 |
+
fallback: str = "[invalid]",
|
| 23 |
+
) -> None:
|
| 24 |
+
"""
|
| 25 |
+
pass a string `regex` to run `re.compile(r"regex")` on.
|
| 26 |
+
`fallback` defines the output returned if no matches for the regex are located.
|
| 27 |
+
"""
|
| 28 |
+
self.regex_pattern = regex_pattern
|
| 29 |
+
self.regex = re.compile(regex_pattern)
|
| 30 |
+
self.group_select = group_select
|
| 31 |
+
self.fallback = fallback
|
| 32 |
+
|
| 33 |
+
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
|
| 34 |
+
# here, we assume we have a list, in which each element is
|
| 35 |
+
# a list of model responses for some particular input/target pair.
|
| 36 |
+
# so we process each of these (same input/target response sets)
|
| 37 |
+
# independently (and keep them a list.)
|
| 38 |
+
def filter_set(inst):
|
| 39 |
+
filtered = []
|
| 40 |
+
for resp in inst:
|
| 41 |
+
match = self.regex.findall(resp)
|
| 42 |
+
if match:
|
| 43 |
+
match = match[self.group_select]
|
| 44 |
+
if isinstance(match, tuple):
|
| 45 |
+
match = [m for m in match if m]
|
| 46 |
+
if match:
|
| 47 |
+
match = match[0]
|
| 48 |
+
else:
|
| 49 |
+
match = self.fallback
|
| 50 |
+
match = match.strip()
|
| 51 |
+
else:
|
| 52 |
+
match = self.fallback
|
| 53 |
+
filtered.append(match)
|
| 54 |
+
return filtered
|
| 55 |
+
|
| 56 |
+
filtered_resps = list(map(lambda x: filter_set(x), resps))
|
| 57 |
+
return filtered_resps
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@register_filter("regex_pos")
|
| 61 |
+
class POSFilter(Filter):
|
| 62 |
+
""" """
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
regex_pattern: str = r"\['(.*?)'\]",
|
| 67 |
+
group_select=0,
|
| 68 |
+
fallback=None,
|
| 69 |
+
) -> None:
|
| 70 |
+
"""
|
| 71 |
+
pass a string `regex` to run `re.compile(r"regex")` on.
|
| 72 |
+
`fallback` defines the output returned if no matches for the regex are located.
|
| 73 |
+
"""
|
| 74 |
+
if fallback is None:
|
| 75 |
+
fallback = ["invalid"]
|
| 76 |
+
self.regex_pattern = regex_pattern
|
| 77 |
+
self.regex = re.compile(regex_pattern)
|
| 78 |
+
self.group_select = group_select
|
| 79 |
+
self.fallback = fallback
|
| 80 |
+
|
| 81 |
+
def apply(self, resps, docs):
|
| 82 |
+
def extract_tagged_tokens(text):
|
| 83 |
+
# Extract tagged tokens list from text input using regex
|
| 84 |
+
tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text)
|
| 85 |
+
return [(token, pos) for token, pos in tokens]
|
| 86 |
+
|
| 87 |
+
def extract_pos_tags(result):
|
| 88 |
+
pos_tags = []
|
| 89 |
+
if isinstance(result, str):
|
| 90 |
+
result = extract_tagged_tokens(result)
|
| 91 |
+
pos_tags.extend(pos for _, pos in result)
|
| 92 |
+
return pos_tags if pos_tags else self.fallback
|
| 93 |
+
|
| 94 |
+
def filter_set(inst):
|
| 95 |
+
filtered = []
|
| 96 |
+
for resp in inst:
|
| 97 |
+
match = extract_pos_tags(resp)
|
| 98 |
+
filtered.append(match)
|
| 99 |
+
return filtered
|
| 100 |
+
|
| 101 |
+
filtered_resps = map(lambda x: filter_set(x), resps)
|
| 102 |
+
|
| 103 |
+
return filtered_resps
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@register_filter("remove_whitespace")
|
| 107 |
+
class WhitespaceFilter(Filter):
|
| 108 |
+
"""Filters out leading whitespace from responses."""
|
| 109 |
+
|
| 110 |
+
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
|
| 111 |
+
def filter_set(inst):
|
| 112 |
+
filtered_resp = []
|
| 113 |
+
for resp in inst:
|
| 114 |
+
resp = resp.lstrip()
|
| 115 |
+
filtered_resp.append(resp)
|
| 116 |
+
return filtered_resp
|
| 117 |
+
|
| 118 |
+
filtered_resps = [filter_set(resp) for resp in resps]
|
| 119 |
+
|
| 120 |
+
return filtered_resps
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@register_filter("multi_choice_regex")
|
| 124 |
+
class MultiChoiceRegexFilter(RegexFilter):
|
| 125 |
+
"""
|
| 126 |
+
A filter used to extract a model's answer on multiple choice questions with
|
| 127 |
+
letter answers. assumes each document has a "choices" field
|
| 128 |
+
containing the list of answer choices and that the answer label symbols
|
| 129 |
+
are of the form (A), (B), (C), ... or A, B, C.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
|
| 135 |
+
group_select=0,
|
| 136 |
+
fallback: str = "[invalid]",
|
| 137 |
+
ignore_case=False,
|
| 138 |
+
ignore_punctuation=False,
|
| 139 |
+
regexes_to_ignore=None,
|
| 140 |
+
) -> None:
|
| 141 |
+
"""
|
| 142 |
+
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
|
| 143 |
+
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
|
| 144 |
+
- step 2 : We parse the choice with regex: r's*([A-?])', where ? varies by number of choices.
|
| 145 |
+
group_select: Selects the (group_select)th match from the findall result.
|
| 146 |
+
ignore_case: Ignores the case during step 1 matching
|
| 147 |
+
ignore_punctuation: Remove the punctuation during step 1 matching
|
| 148 |
+
regexes_to_ignore: Remove these regexes during step 1 matching
|
| 149 |
+
"""
|
| 150 |
+
super().__init__(regex_pattern, group_select, fallback)
|
| 151 |
+
self.ignore_case = ignore_case
|
| 152 |
+
self.ignore_punctuation = ignore_punctuation
|
| 153 |
+
self.regexes_to_ignore = regexes_to_ignore
|
| 154 |
+
|
| 155 |
+
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
|
| 156 |
+
# here, we assume we have a list, in which each element is
|
| 157 |
+
# a list of model responses for some particular input/target pair.
|
| 158 |
+
# so we process each of these (same input/target response sets)
|
| 159 |
+
# independently (and keep them a list.)
|
| 160 |
+
|
| 161 |
+
def find_match(regex, resp, convert_dict={}):
|
| 162 |
+
match = regex.findall(resp)
|
| 163 |
+
if match:
|
| 164 |
+
match = match[self.group_select]
|
| 165 |
+
if isinstance(match, tuple):
|
| 166 |
+
match = [m for m in match if m][0]
|
| 167 |
+
match = match.strip()
|
| 168 |
+
if match and match in convert_dict:
|
| 169 |
+
match = convert_dict[match]
|
| 170 |
+
return match
|
| 171 |
+
|
| 172 |
+
punct_tbl = dict.fromkeys(
|
| 173 |
+
i
|
| 174 |
+
for i in range(sys.maxunicode)
|
| 175 |
+
if unicodedata.category(chr(i)).startswith("P")
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def filter_ignores(st):
|
| 179 |
+
if self.regexes_to_ignore is not None:
|
| 180 |
+
for s in self.regexes_to_ignore:
|
| 181 |
+
st = re.sub(s, "", st)
|
| 182 |
+
|
| 183 |
+
if self.ignore_case:
|
| 184 |
+
st = st.lower()
|
| 185 |
+
|
| 186 |
+
if self.ignore_punctuation:
|
| 187 |
+
# https://stackoverflow.com/a/266162
|
| 188 |
+
st = st.translate(punct_tbl)
|
| 189 |
+
return st
|
| 190 |
+
|
| 191 |
+
filtered_resps = []
|
| 192 |
+
|
| 193 |
+
for r, doc in zip(resps, docs):
|
| 194 |
+
fallback_regexes = []
|
| 195 |
+
choice_to_alpha = {}
|
| 196 |
+
next_alpha = "A"
|
| 197 |
+
|
| 198 |
+
without_paren_fallback_regexes = []
|
| 199 |
+
without_paren_to_target = {}
|
| 200 |
+
|
| 201 |
+
choices = doc["choices"]
|
| 202 |
+
for c in choices:
|
| 203 |
+
m = filter_ignores(c.strip())
|
| 204 |
+
fallback_regexes.append(f"{re.escape(m)}")
|
| 205 |
+
choice_to_alpha[m] = f"({next_alpha})"
|
| 206 |
+
|
| 207 |
+
without_paren_fallback_regexes.append(next_alpha)
|
| 208 |
+
without_paren_to_target[next_alpha] = f"({next_alpha})"
|
| 209 |
+
|
| 210 |
+
next_alpha = chr(ord(next_alpha) + 1)
|
| 211 |
+
fallback_regex = re.compile("|".join(fallback_regexes))
|
| 212 |
+
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
|
| 213 |
+
without_paren_fallback_regex = re.compile(
|
| 214 |
+
rf":[\s]*({without_paren_fallback_regex})"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
filtered = []
|
| 218 |
+
for resp in r:
|
| 219 |
+
match = find_match(self.regex, resp)
|
| 220 |
+
if not match:
|
| 221 |
+
match = find_match(
|
| 222 |
+
fallback_regex, filter_ignores(resp), choice_to_alpha
|
| 223 |
+
)
|
| 224 |
+
if not match:
|
| 225 |
+
match = find_match(
|
| 226 |
+
without_paren_fallback_regex, resp, without_paren_to_target
|
| 227 |
+
)
|
| 228 |
+
if not match:
|
| 229 |
+
match = self.fallback
|
| 230 |
+
filtered.append(match)
|
| 231 |
+
filtered_resps.append(filtered)
|
| 232 |
+
|
| 233 |
+
return filtered_resps
|
lm-evaluation-harness/lm_eval/filters/selection.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter
|
| 2 |
+
|
| 3 |
+
from lm_eval.api.filter import Filter
|
| 4 |
+
from lm_eval.api.registry import register_filter
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# TODO: implement "arg_max" filter. either it should take in an arbitrary "scoring"/reward function
|
| 8 |
+
# that takes an input and returns a scalar and then should select the max reward,
|
| 9 |
+
# or should implement different filters for different ways of handling a reward model's inference.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@register_filter("take_first")
|
| 13 |
+
class TakeFirstFilter(Filter):
|
| 14 |
+
def __init__(self) -> None:
|
| 15 |
+
"""
|
| 16 |
+
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def apply(self, resps, docs):
|
| 20 |
+
"""
|
| 21 |
+
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
|
| 22 |
+
"""
|
| 23 |
+
return map(lambda r: r[0], resps)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@register_filter("take_first_k")
|
| 27 |
+
class TakeKFilter(Filter):
|
| 28 |
+
def __init__(self, **kwargs) -> None:
|
| 29 |
+
self.k = kwargs.pop("k")
|
| 30 |
+
|
| 31 |
+
super().__init__(**kwargs)
|
| 32 |
+
|
| 33 |
+
def apply(self, resps, docs):
|
| 34 |
+
# need resp to be subscriptable to check below
|
| 35 |
+
resps = list(resps)
|
| 36 |
+
# check we have at least k responses per doc, else we can't take the first k
|
| 37 |
+
assert len(resps[0]) >= self.k, (
|
| 38 |
+
f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
|
| 39 |
+
)
|
| 40 |
+
return map(lambda r: r[: self.k], resps)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@register_filter("majority_vote")
|
| 44 |
+
class MajorityVoteFilter(Filter):
|
| 45 |
+
def __init__(self) -> None:
|
| 46 |
+
"""
|
| 47 |
+
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def apply(self, resps, docs):
|
| 51 |
+
"""
|
| 52 |
+
Each entry of `resps` is a list of model responses.
|
| 53 |
+
We select the response that occurs most frequently in each entry of `resps`.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def select_majority(resp):
|
| 57 |
+
counts = Counter(resp)
|
| 58 |
+
vote = counts.most_common(1)[0][0]
|
| 59 |
+
return vote
|
| 60 |
+
|
| 61 |
+
return map(lambda r: [select_majority(r)], resps)
|
lm-evaluation-harness/lm_eval/filters/transformation.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
from lm_eval.api.filter import Filter
|
| 4 |
+
from lm_eval.api.registry import register_filter
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@register_filter("lowercase")
|
| 8 |
+
class LowercaseFilter(Filter):
|
| 9 |
+
def __init__(self) -> None:
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
def apply(self, resps, docs):
|
| 13 |
+
def filter_set(inst):
|
| 14 |
+
return [resp.lower() for resp in inst]
|
| 15 |
+
|
| 16 |
+
return [filter_set(resp) for resp in resps]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@register_filter("uppercase")
|
| 20 |
+
class UppercaseFilter(Filter):
|
| 21 |
+
def __init__(self) -> None:
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
def apply(self, resps, docs):
|
| 25 |
+
def filter_set(inst):
|
| 26 |
+
return [resp.upper() for resp in inst]
|
| 27 |
+
|
| 28 |
+
return [filter_set(resp) for resp in resps]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@register_filter("map")
|
| 32 |
+
class MapFilter(Filter):
|
| 33 |
+
def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
|
| 34 |
+
"""
|
| 35 |
+
Initializes the MapFilter with a given mapping dictionary and default value.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
- mapping_dict (dict): A dictionary containing the key-value mappings.
|
| 39 |
+
Default is an empty dictionary.
|
| 40 |
+
- default_value (Any): The value to be returned when a key is not found in the mapping_dict.
|
| 41 |
+
Default is None.
|
| 42 |
+
|
| 43 |
+
Example:
|
| 44 |
+
mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
|
| 45 |
+
"""
|
| 46 |
+
if mapping_dict is None:
|
| 47 |
+
mapping_dict = {}
|
| 48 |
+
assert isinstance(mapping_dict, dict), (
|
| 49 |
+
"Provided mapping_dict is not a dictionary"
|
| 50 |
+
)
|
| 51 |
+
self.mapping_dict = mapping_dict
|
| 52 |
+
self.default_value = default_value
|
| 53 |
+
|
| 54 |
+
def apply(self, resps, docs):
|
| 55 |
+
def filter_set(inst):
|
| 56 |
+
return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
|
| 57 |
+
|
| 58 |
+
return [filter_set(resp) for resp in resps]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@register_filter("format_span")
|
| 62 |
+
class SPANFilter(Filter):
|
| 63 |
+
def __init__(self) -> None:
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
def apply(self, resps, docs):
|
| 67 |
+
def format_ner_text(text):
|
| 68 |
+
label_dict = {
|
| 69 |
+
"person": "PER",
|
| 70 |
+
"location": "LOC",
|
| 71 |
+
"organization": "ORG",
|
| 72 |
+
"counties": "LOC",
|
| 73 |
+
"places": "LOC",
|
| 74 |
+
"people": "PER",
|
| 75 |
+
"persons": "PER",
|
| 76 |
+
"company": "ORG",
|
| 77 |
+
"country": "LOC",
|
| 78 |
+
"continent": "LOC",
|
| 79 |
+
"time": "DATE",
|
| 80 |
+
"date": "DATE",
|
| 81 |
+
"per": "PER",
|
| 82 |
+
"loc": "LOC",
|
| 83 |
+
"org": "ORG",
|
| 84 |
+
}
|
| 85 |
+
text = text.lower()
|
| 86 |
+
for key, value in label_dict.items():
|
| 87 |
+
text = text.replace(key, value)
|
| 88 |
+
|
| 89 |
+
text = "$".join(i for i in text.split("$$"))
|
| 90 |
+
return text.rstrip("$$")
|
| 91 |
+
|
| 92 |
+
def format_named_entities(text):
|
| 93 |
+
"""
|
| 94 |
+
Extract named entities from text and format them as 'label: value $$ label: value'.
|
| 95 |
+
Handles grouped entities (e.g., LOC: kenya, uganda) and excludes 'none' values.
|
| 96 |
+
"""
|
| 97 |
+
# Regular expression to match label: entities pattern
|
| 98 |
+
pattern = r"\b(PER|LOC|ORG|DATE):\s*([^$]+)"
|
| 99 |
+
# Normalize newline characters
|
| 100 |
+
text = text.replace("\n", "$").strip()
|
| 101 |
+
matches = re.findall(pattern, text)
|
| 102 |
+
|
| 103 |
+
formatted_entities = []
|
| 104 |
+
|
| 105 |
+
for label, values in matches:
|
| 106 |
+
# Split multiple entities separated by commas and strip whitespace
|
| 107 |
+
entities = [value.strip() for value in values.split(",")]
|
| 108 |
+
|
| 109 |
+
# Exclude 'none' entities
|
| 110 |
+
for entity in entities:
|
| 111 |
+
if entity.lower() != "none":
|
| 112 |
+
formatted_entities.append(f"{label.lower()}: {entity}")
|
| 113 |
+
|
| 114 |
+
# Join entities with the desired separator
|
| 115 |
+
return " $ ".join(formatted_entities)
|
| 116 |
+
|
| 117 |
+
def filter_set(inst):
|
| 118 |
+
return [
|
| 119 |
+
format_named_entities(format_ner_text(resp.lower())) for resp in inst
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
return [filter_set(resp) for resp in resps]
|
lm-evaluation-harness/lm_eval/loggers/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .evaluation_tracker import EvaluationTracker
|
| 2 |
+
from .wandb_logger import WandbLogger
|
lm-evaluation-harness/lm_eval/loggers/evaluation_tracker.py
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import time
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from dataclasses import asdict, dataclass
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
+
from datasets.utils.metadata import MetadataConfigs
|
| 13 |
+
from huggingface_hub import (
|
| 14 |
+
DatasetCard,
|
| 15 |
+
DatasetCardData,
|
| 16 |
+
HfApi,
|
| 17 |
+
hf_hub_url,
|
| 18 |
+
)
|
| 19 |
+
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
|
| 20 |
+
|
| 21 |
+
from lm_eval.utils import (
|
| 22 |
+
get_file_datetime,
|
| 23 |
+
get_file_task_name,
|
| 24 |
+
get_results_filenames,
|
| 25 |
+
get_sample_results_filenames,
|
| 26 |
+
handle_non_serializable,
|
| 27 |
+
hash_string,
|
| 28 |
+
sanitize_list,
|
| 29 |
+
sanitize_model_name,
|
| 30 |
+
sanitize_task_name,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
eval_logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass(init=False)
|
| 38 |
+
class GeneralConfigTracker:
|
| 39 |
+
"""
|
| 40 |
+
Tracker for the evaluation parameters.
|
| 41 |
+
|
| 42 |
+
Attributes:
|
| 43 |
+
model_source (str): Source of the model (e.g. Hugging Face, GGUF, etc.)
|
| 44 |
+
model_name (str): Name of the model.
|
| 45 |
+
model_name_sanitized (str): Sanitized model name for directory creation.
|
| 46 |
+
start_time (float): Start time of the experiment. Logged at class init.
|
| 47 |
+
end_time (float): Start time of the experiment. Logged when calling [`GeneralConfigTracker.log_end_time`]
|
| 48 |
+
total_evaluation_time_seconds (str): Inferred total evaluation time in seconds (from the start and end times).
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
model_source: str = None
|
| 52 |
+
model_name: str = None
|
| 53 |
+
model_name_sanitized: str = None
|
| 54 |
+
system_instruction: str = None
|
| 55 |
+
system_instruction_sha: str = None
|
| 56 |
+
fewshot_as_multiturn: bool = None
|
| 57 |
+
chat_template: str = None
|
| 58 |
+
chat_template_sha: str = None
|
| 59 |
+
start_time: float = None
|
| 60 |
+
end_time: float = None
|
| 61 |
+
total_evaluation_time_seconds: str = None
|
| 62 |
+
|
| 63 |
+
def __init__(self) -> None:
|
| 64 |
+
"""Starts the evaluation timer."""
|
| 65 |
+
self.start_time = time.perf_counter()
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def _get_model_name(model_args: str) -> str:
|
| 69 |
+
"""Extracts the model name from the model arguments."""
|
| 70 |
+
|
| 71 |
+
def extract_model_name(model_args: str, key: str) -> str:
|
| 72 |
+
"""Extracts the model name from the model arguments using a key."""
|
| 73 |
+
args_after_key = model_args.split(key)[1]
|
| 74 |
+
return args_after_key.split(",")[0]
|
| 75 |
+
|
| 76 |
+
# order does matter, e.g. peft and delta are provided together with pretrained
|
| 77 |
+
prefixes = ["peft=", "delta=", "pretrained=", "model=", "path=", "engine="]
|
| 78 |
+
for prefix in prefixes:
|
| 79 |
+
if prefix in model_args:
|
| 80 |
+
return extract_model_name(model_args, prefix)
|
| 81 |
+
return ""
|
| 82 |
+
|
| 83 |
+
def log_experiment_args(
|
| 84 |
+
self,
|
| 85 |
+
model_source: str,
|
| 86 |
+
model_args: str,
|
| 87 |
+
system_instruction: str,
|
| 88 |
+
chat_template: str,
|
| 89 |
+
fewshot_as_multiturn: bool,
|
| 90 |
+
) -> None:
|
| 91 |
+
"""Logs model parameters and job ID."""
|
| 92 |
+
self.model_source = model_source
|
| 93 |
+
self.model_name = GeneralConfigTracker._get_model_name(model_args)
|
| 94 |
+
self.model_name_sanitized = sanitize_model_name(self.model_name)
|
| 95 |
+
self.system_instruction = system_instruction
|
| 96 |
+
self.system_instruction_sha = (
|
| 97 |
+
hash_string(system_instruction) if system_instruction else None
|
| 98 |
+
)
|
| 99 |
+
self.chat_template = chat_template
|
| 100 |
+
self.chat_template_sha = hash_string(chat_template) if chat_template else None
|
| 101 |
+
self.fewshot_as_multiturn = fewshot_as_multiturn
|
| 102 |
+
|
| 103 |
+
def log_end_time(self) -> None:
|
| 104 |
+
"""Logs the end time of the evaluation and calculates the total evaluation time."""
|
| 105 |
+
self.end_time = time.perf_counter()
|
| 106 |
+
self.total_evaluation_time_seconds = str(self.end_time - self.start_time)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class EvaluationTracker:
|
| 110 |
+
"""
|
| 111 |
+
Keeps track and saves relevant information of the evaluation process.
|
| 112 |
+
Compiles the data from trackers and writes it to files, which can be published to the Hugging Face hub if requested.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
output_path: str = None,
|
| 118 |
+
hub_results_org: str = "",
|
| 119 |
+
hub_repo_name: str = "",
|
| 120 |
+
details_repo_name: str = "",
|
| 121 |
+
results_repo_name: str = "",
|
| 122 |
+
push_results_to_hub: bool = False,
|
| 123 |
+
push_samples_to_hub: bool = False,
|
| 124 |
+
public_repo: bool = False,
|
| 125 |
+
token: str = "",
|
| 126 |
+
leaderboard_url: str = "",
|
| 127 |
+
point_of_contact: str = "",
|
| 128 |
+
gated: bool = False,
|
| 129 |
+
) -> None:
|
| 130 |
+
"""
|
| 131 |
+
Creates all the necessary loggers for evaluation tracking.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
output_path (str): Path to save the results. If not provided, the results won't be saved.
|
| 135 |
+
hub_results_org (str): The Hugging Face organization to push the results to. If not provided, the results will be pushed to the owner of the Hugging Face token.
|
| 136 |
+
hub_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will be pushed to `lm-eval-results`.
|
| 137 |
+
details_repo_name (str): The name of the Hugging Face repository to push the details to. If not provided, the results will be pushed to `lm-eval-results`.
|
| 138 |
+
result_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will not be pushed and will be found in the details_hub_repo.
|
| 139 |
+
push_results_to_hub (bool): Whether to push the results to the Hugging Face hub.
|
| 140 |
+
push_samples_to_hub (bool): Whether to push the samples to the Hugging Face hub.
|
| 141 |
+
public_repo (bool): Whether to push the results to a public or private repository.
|
| 142 |
+
token (str): Token to use when pushing to the Hugging Face hub. This token should have write access to `hub_results_org`.
|
| 143 |
+
leaderboard_url (str): URL to the leaderboard on the Hugging Face hub on the dataset card.
|
| 144 |
+
point_of_contact (str): Contact information on the Hugging Face hub dataset card.
|
| 145 |
+
gated (bool): Whether to gate the repository.
|
| 146 |
+
"""
|
| 147 |
+
self.general_config_tracker = GeneralConfigTracker()
|
| 148 |
+
|
| 149 |
+
self.output_path = output_path
|
| 150 |
+
self.push_results_to_hub = push_results_to_hub
|
| 151 |
+
self.push_samples_to_hub = push_samples_to_hub
|
| 152 |
+
self.public_repo = public_repo
|
| 153 |
+
self.leaderboard_url = leaderboard_url
|
| 154 |
+
self.point_of_contact = point_of_contact
|
| 155 |
+
self.api = HfApi(token=token) if token else None
|
| 156 |
+
self.gated_repo = gated
|
| 157 |
+
|
| 158 |
+
if not self.api and (push_results_to_hub or push_samples_to_hub):
|
| 159 |
+
raise ValueError(
|
| 160 |
+
"Hugging Face token is not defined, but 'push_results_to_hub' or 'push_samples_to_hub' is set to True. "
|
| 161 |
+
"Please provide a valid Hugging Face token by setting the HF_TOKEN environment variable."
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
if (
|
| 165 |
+
self.api
|
| 166 |
+
and hub_results_org == ""
|
| 167 |
+
and (push_results_to_hub or push_samples_to_hub)
|
| 168 |
+
):
|
| 169 |
+
hub_results_org = self.api.whoami()["name"]
|
| 170 |
+
eval_logger.warning(
|
| 171 |
+
f"hub_results_org was not specified. Results will be pushed to '{hub_results_org}'."
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if hub_repo_name == "":
|
| 175 |
+
details_repo_name = (
|
| 176 |
+
details_repo_name if details_repo_name != "" else "lm-eval-results"
|
| 177 |
+
)
|
| 178 |
+
results_repo_name = (
|
| 179 |
+
results_repo_name if results_repo_name != "" else details_repo_name
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
details_repo_name = hub_repo_name
|
| 183 |
+
results_repo_name = hub_repo_name
|
| 184 |
+
eval_logger.warning(
|
| 185 |
+
"hub_repo_name was specified. Both details and results will be pushed to the same repository. Using hub_repo_name is no longer recommended, details_repo_name and results_repo_name should be used instead."
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.details_repo = f"{hub_results_org}/{details_repo_name}"
|
| 189 |
+
self.details_repo_private = f"{hub_results_org}/{details_repo_name}-private"
|
| 190 |
+
self.results_repo = f"{hub_results_org}/{results_repo_name}"
|
| 191 |
+
self.results_repo_private = f"{hub_results_org}/{results_repo_name}-private"
|
| 192 |
+
|
| 193 |
+
def save_results_aggregated(
|
| 194 |
+
self,
|
| 195 |
+
results: dict,
|
| 196 |
+
samples: dict,
|
| 197 |
+
) -> None:
|
| 198 |
+
"""
|
| 199 |
+
Saves the aggregated results and samples to the output path and pushes them to the Hugging Face hub if requested.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
results (dict): The aggregated results to save.
|
| 203 |
+
samples (dict): The samples results to save.
|
| 204 |
+
"""
|
| 205 |
+
self.general_config_tracker.log_end_time()
|
| 206 |
+
|
| 207 |
+
if self.output_path:
|
| 208 |
+
try:
|
| 209 |
+
eval_logger.info("Saving results aggregated")
|
| 210 |
+
|
| 211 |
+
# calculate cumulative hash for each task - only if samples are provided
|
| 212 |
+
task_hashes = {}
|
| 213 |
+
if samples:
|
| 214 |
+
for task_name, task_samples in samples.items():
|
| 215 |
+
sample_hashes = [
|
| 216 |
+
s["doc_hash"] + s["prompt_hash"] + s["target_hash"]
|
| 217 |
+
for s in task_samples
|
| 218 |
+
]
|
| 219 |
+
task_hashes[task_name] = hash_string("".join(sample_hashes))
|
| 220 |
+
|
| 221 |
+
# update initial results dict
|
| 222 |
+
results.update({"task_hashes": task_hashes})
|
| 223 |
+
results.update(asdict(self.general_config_tracker))
|
| 224 |
+
dumped = json.dumps(
|
| 225 |
+
results,
|
| 226 |
+
indent=2,
|
| 227 |
+
default=handle_non_serializable,
|
| 228 |
+
ensure_ascii=False,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
path = Path(self.output_path if self.output_path else Path.cwd())
|
| 232 |
+
self.date_id = datetime.now().isoformat().replace(":", "-")
|
| 233 |
+
if path.suffix == ".json":
|
| 234 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 235 |
+
file_results_aggregated = path.with_name(
|
| 236 |
+
f"{path.stem}_{self.date_id}.json"
|
| 237 |
+
)
|
| 238 |
+
else:
|
| 239 |
+
path = path.joinpath(
|
| 240 |
+
self.general_config_tracker.model_name_sanitized
|
| 241 |
+
)
|
| 242 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 243 |
+
file_results_aggregated = path.joinpath(
|
| 244 |
+
f"results_{self.date_id}.json"
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
file_results_aggregated.open("w", encoding="utf-8").write(dumped)
|
| 248 |
+
|
| 249 |
+
if self.api and self.push_results_to_hub:
|
| 250 |
+
repo_id = (
|
| 251 |
+
self.results_repo
|
| 252 |
+
if self.public_repo
|
| 253 |
+
else self.results_repo_private
|
| 254 |
+
)
|
| 255 |
+
self.api.create_repo(
|
| 256 |
+
repo_id=repo_id,
|
| 257 |
+
repo_type="dataset",
|
| 258 |
+
private=not self.public_repo,
|
| 259 |
+
exist_ok=True,
|
| 260 |
+
)
|
| 261 |
+
self.api.upload_file(
|
| 262 |
+
repo_id=repo_id,
|
| 263 |
+
path_or_fileobj=str(file_results_aggregated),
|
| 264 |
+
path_in_repo=os.path.join(
|
| 265 |
+
self.general_config_tracker.model_name,
|
| 266 |
+
file_results_aggregated.name,
|
| 267 |
+
),
|
| 268 |
+
repo_type="dataset",
|
| 269 |
+
commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}",
|
| 270 |
+
)
|
| 271 |
+
eval_logger.info(
|
| 272 |
+
"Successfully pushed aggregated results to the Hugging Face Hub. "
|
| 273 |
+
f"You can find them at: {repo_id}"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
except Exception as e:
|
| 277 |
+
eval_logger.warning("Could not save results aggregated")
|
| 278 |
+
eval_logger.info(repr(e))
|
| 279 |
+
else:
|
| 280 |
+
eval_logger.info(
|
| 281 |
+
"Output path not provided, skipping saving results aggregated"
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
def save_results_samples(
|
| 285 |
+
self,
|
| 286 |
+
task_name: str,
|
| 287 |
+
samples: dict,
|
| 288 |
+
) -> None:
|
| 289 |
+
"""
|
| 290 |
+
Saves the samples results to the output path and pushes them to the Hugging Face hub if requested.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
task_name (str): The task name to save the samples for.
|
| 294 |
+
samples (dict): The samples results to save.
|
| 295 |
+
"""
|
| 296 |
+
if self.output_path:
|
| 297 |
+
try:
|
| 298 |
+
eval_logger.info(f"Saving per-sample results for: {task_name}")
|
| 299 |
+
|
| 300 |
+
path = Path(self.output_path if self.output_path else Path.cwd())
|
| 301 |
+
if path.suffix == ".json":
|
| 302 |
+
path = path.parent
|
| 303 |
+
else:
|
| 304 |
+
path = path.joinpath(
|
| 305 |
+
self.general_config_tracker.model_name_sanitized
|
| 306 |
+
)
|
| 307 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 308 |
+
|
| 309 |
+
file_results_samples = path.joinpath(
|
| 310 |
+
f"samples_{task_name}_{self.date_id}.jsonl"
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
for sample in samples:
|
| 314 |
+
# we first need to sanitize arguments and resps
|
| 315 |
+
# otherwise we won't be able to load the dataset
|
| 316 |
+
# using the datasets library
|
| 317 |
+
arguments = {}
|
| 318 |
+
for i, arg in enumerate(sample["arguments"]):
|
| 319 |
+
arguments[f"gen_args_{i}"] = {}
|
| 320 |
+
for j, tmp in enumerate(arg):
|
| 321 |
+
arguments[f"gen_args_{i}"][f"arg_{j}"] = tmp
|
| 322 |
+
|
| 323 |
+
sample["resps"] = sanitize_list(sample["resps"])
|
| 324 |
+
sample["filtered_resps"] = sanitize_list(sample["filtered_resps"])
|
| 325 |
+
sample["arguments"] = arguments
|
| 326 |
+
sample["target"] = str(sample["target"])
|
| 327 |
+
|
| 328 |
+
sample_dump = (
|
| 329 |
+
json.dumps(
|
| 330 |
+
sample,
|
| 331 |
+
default=handle_non_serializable,
|
| 332 |
+
ensure_ascii=False,
|
| 333 |
+
)
|
| 334 |
+
+ "\n"
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
with open(file_results_samples, "a", encoding="utf-8") as f:
|
| 338 |
+
f.write(sample_dump)
|
| 339 |
+
|
| 340 |
+
if self.api and self.push_samples_to_hub:
|
| 341 |
+
repo_id = (
|
| 342 |
+
self.details_repo
|
| 343 |
+
if self.public_repo
|
| 344 |
+
else self.details_repo_private
|
| 345 |
+
)
|
| 346 |
+
self.api.create_repo(
|
| 347 |
+
repo_id=repo_id,
|
| 348 |
+
repo_type="dataset",
|
| 349 |
+
private=not self.public_repo,
|
| 350 |
+
exist_ok=True,
|
| 351 |
+
)
|
| 352 |
+
try:
|
| 353 |
+
if self.gated_repo:
|
| 354 |
+
headers = build_hf_headers()
|
| 355 |
+
r = get_session().put(
|
| 356 |
+
url=f"https://huggingface.co/api/datasets/{repo_id}/settings",
|
| 357 |
+
headers=headers,
|
| 358 |
+
json={"gated": "auto"},
|
| 359 |
+
)
|
| 360 |
+
hf_raise_for_status(r)
|
| 361 |
+
except Exception as e:
|
| 362 |
+
eval_logger.warning("Could not gate the repository")
|
| 363 |
+
eval_logger.info(repr(e))
|
| 364 |
+
self.api.upload_folder(
|
| 365 |
+
repo_id=repo_id,
|
| 366 |
+
folder_path=str(path),
|
| 367 |
+
path_in_repo=self.general_config_tracker.model_name_sanitized,
|
| 368 |
+
repo_type="dataset",
|
| 369 |
+
commit_message=f"Adding samples results for {task_name} to {self.general_config_tracker.model_name}",
|
| 370 |
+
)
|
| 371 |
+
eval_logger.info(
|
| 372 |
+
f"Successfully pushed sample results for task: {task_name} to the Hugging Face Hub. "
|
| 373 |
+
f"You can find them at: {repo_id}"
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
except Exception as e:
|
| 377 |
+
eval_logger.warning("Could not save sample results")
|
| 378 |
+
eval_logger.info(repr(e))
|
| 379 |
+
else:
|
| 380 |
+
eval_logger.info("Output path not provided, skipping saving sample results")
|
| 381 |
+
|
| 382 |
+
def recreate_metadata_card(self) -> None:
|
| 383 |
+
"""
|
| 384 |
+
Creates a metadata card for the evaluation results dataset and pushes it to the Hugging Face hub.
|
| 385 |
+
"""
|
| 386 |
+
|
| 387 |
+
eval_logger.info("Recreating metadata card")
|
| 388 |
+
repo_id = self.details_repo if self.public_repo else self.details_repo_private
|
| 389 |
+
|
| 390 |
+
files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset")
|
| 391 |
+
results_files = get_results_filenames(files_in_repo)
|
| 392 |
+
sample_files = get_sample_results_filenames(files_in_repo)
|
| 393 |
+
|
| 394 |
+
# Build a dictionary to store the latest evaluation datetime for:
|
| 395 |
+
# - Each tested model and its aggregated results
|
| 396 |
+
# - Each task and sample results, if existing
|
| 397 |
+
# i.e. {
|
| 398 |
+
# "org__model_name__gsm8k": "2021-09-01T12:00:00",
|
| 399 |
+
# "org__model_name__ifeval": "2021-09-01T12:00:00",
|
| 400 |
+
# "org__model_name__results": "2021-09-01T12:00:00"
|
| 401 |
+
# }
|
| 402 |
+
latest_task_results_datetime = defaultdict(lambda: datetime.min.isoformat())
|
| 403 |
+
|
| 404 |
+
for file_path in sample_files:
|
| 405 |
+
file_path = Path(file_path)
|
| 406 |
+
filename = file_path.name
|
| 407 |
+
model_name = file_path.parent
|
| 408 |
+
task_name = get_file_task_name(filename)
|
| 409 |
+
results_datetime = get_file_datetime(filename)
|
| 410 |
+
task_name_sanitized = sanitize_task_name(task_name)
|
| 411 |
+
# Results and sample results for the same model and task will have the same datetime
|
| 412 |
+
samples_key = f"{model_name}__{task_name_sanitized}"
|
| 413 |
+
results_key = f"{model_name}__results"
|
| 414 |
+
latest_datetime = max(
|
| 415 |
+
latest_task_results_datetime[samples_key],
|
| 416 |
+
results_datetime,
|
| 417 |
+
)
|
| 418 |
+
latest_task_results_datetime[samples_key] = latest_datetime
|
| 419 |
+
latest_task_results_datetime[results_key] = max(
|
| 420 |
+
latest_task_results_datetime[results_key],
|
| 421 |
+
latest_datetime,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# Create metadata card
|
| 425 |
+
card_metadata = MetadataConfigs()
|
| 426 |
+
|
| 427 |
+
# Add the latest aggregated results to the metadata card for easy access
|
| 428 |
+
for file_path in results_files:
|
| 429 |
+
file_path = Path(file_path)
|
| 430 |
+
results_filename = file_path.name
|
| 431 |
+
model_name = file_path.parent
|
| 432 |
+
eval_date = get_file_datetime(results_filename)
|
| 433 |
+
eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date)
|
| 434 |
+
results_filename = Path("**") / Path(results_filename).name
|
| 435 |
+
config_name = f"{model_name}__results"
|
| 436 |
+
sanitized_last_eval_date_results = re.sub(
|
| 437 |
+
r"[^\w\.]", "_", latest_task_results_datetime[config_name]
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
if eval_date_sanitized == sanitized_last_eval_date_results:
|
| 441 |
+
# Ensure that all results files are listed in the metadata card
|
| 442 |
+
current_results = card_metadata.get(config_name, {"data_files": []})
|
| 443 |
+
current_results["data_files"].append(
|
| 444 |
+
{"split": eval_date_sanitized, "path": [str(results_filename)]}
|
| 445 |
+
)
|
| 446 |
+
card_metadata[config_name] = current_results
|
| 447 |
+
# If the results file is the newest, update the "latest" field in the metadata card
|
| 448 |
+
card_metadata[config_name]["data_files"].append(
|
| 449 |
+
{"split": "latest", "path": [str(results_filename)]}
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# Add the tasks details configs
|
| 453 |
+
for file_path in sample_files:
|
| 454 |
+
file_path = Path(file_path)
|
| 455 |
+
filename = file_path.name
|
| 456 |
+
model_name = file_path.parent
|
| 457 |
+
task_name = get_file_task_name(filename)
|
| 458 |
+
eval_date = get_file_datetime(filename)
|
| 459 |
+
task_name_sanitized = sanitize_task_name(task_name)
|
| 460 |
+
eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date)
|
| 461 |
+
results_filename = Path("**") / Path(filename).name
|
| 462 |
+
config_name = f"{model_name}__{task_name_sanitized}"
|
| 463 |
+
sanitized_last_eval_date_results = re.sub(
|
| 464 |
+
r"[^\w\.]", "_", latest_task_results_datetime[config_name]
|
| 465 |
+
)
|
| 466 |
+
if eval_date_sanitized == sanitized_last_eval_date_results:
|
| 467 |
+
# Ensure that all sample results files are listed in the metadata card
|
| 468 |
+
current_details_for_task = card_metadata.get(
|
| 469 |
+
config_name, {"data_files": []}
|
| 470 |
+
)
|
| 471 |
+
current_details_for_task["data_files"].append(
|
| 472 |
+
{"split": eval_date_sanitized, "path": [str(results_filename)]}
|
| 473 |
+
)
|
| 474 |
+
card_metadata[config_name] = current_details_for_task
|
| 475 |
+
# If the samples results file is the newest, update the "latest" field in the metadata card
|
| 476 |
+
card_metadata[config_name]["data_files"].append(
|
| 477 |
+
{"split": "latest", "path": [str(results_filename)]}
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# Get latest results and extract info to update metadata card examples
|
| 481 |
+
latest_datetime = max(latest_task_results_datetime.values())
|
| 482 |
+
latest_model_name = max(
|
| 483 |
+
latest_task_results_datetime, key=lambda k: latest_task_results_datetime[k]
|
| 484 |
+
)
|
| 485 |
+
last_results_file = [
|
| 486 |
+
f for f in results_files if latest_datetime.replace(":", "-") in f
|
| 487 |
+
][0]
|
| 488 |
+
last_results_file_path = hf_hub_url(
|
| 489 |
+
repo_id=repo_id, filename=last_results_file, repo_type="dataset"
|
| 490 |
+
)
|
| 491 |
+
latest_results_file = load_dataset(
|
| 492 |
+
"json", data_files=last_results_file_path, split="train"
|
| 493 |
+
)
|
| 494 |
+
results_dict = latest_results_file["results"][0]
|
| 495 |
+
new_dictionary = {"all": results_dict}
|
| 496 |
+
new_dictionary.update(results_dict)
|
| 497 |
+
results_string = json.dumps(new_dictionary, indent=4)
|
| 498 |
+
|
| 499 |
+
dataset_summary = (
|
| 500 |
+
"Dataset automatically created during the evaluation run of model "
|
| 501 |
+
)
|
| 502 |
+
if self.general_config_tracker.model_source == "hf":
|
| 503 |
+
dataset_summary += f"[{self.general_config_tracker.model_name}](https://huggingface.co/{self.general_config_tracker.model_name})\n"
|
| 504 |
+
else:
|
| 505 |
+
dataset_summary += f"{self.general_config_tracker.model_name}\n"
|
| 506 |
+
dataset_summary += (
|
| 507 |
+
f"The dataset is composed of {len(card_metadata) - 1} configuration(s), each one corresponding to one of the evaluated task.\n\n"
|
| 508 |
+
f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each "
|
| 509 |
+
'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n'
|
| 510 |
+
'An additional configuration "results" store all the aggregated results of the run.\n\n'
|
| 511 |
+
"To load the details from a run, you can for instance do the following:\n"
|
| 512 |
+
)
|
| 513 |
+
if self.general_config_tracker.model_source == "hf":
|
| 514 |
+
dataset_summary += (
|
| 515 |
+
"```python\nfrom datasets import load_dataset\n"
|
| 516 |
+
f'data = load_dataset(\n\t"{repo_id}",\n\tname="{latest_model_name}",\n\tsplit="latest"\n)\n```\n\n'
|
| 517 |
+
)
|
| 518 |
+
dataset_summary += (
|
| 519 |
+
"## Latest results\n\n"
|
| 520 |
+
f"These are the [latest results from run {latest_datetime}]({last_results_file_path.replace('/resolve/', '/blob/')}) "
|
| 521 |
+
"(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. "
|
| 522 |
+
'You find each in the results and the "latest" split for each eval):\n\n'
|
| 523 |
+
f"```python\n{results_string}\n```"
|
| 524 |
+
)
|
| 525 |
+
card_data = DatasetCardData(
|
| 526 |
+
dataset_summary=dataset_summary,
|
| 527 |
+
repo_url=f"https://huggingface.co/{self.general_config_tracker.model_name}",
|
| 528 |
+
pretty_name=f"Evaluation run of {self.general_config_tracker.model_name}",
|
| 529 |
+
leaderboard_url=self.leaderboard_url,
|
| 530 |
+
point_of_contact=self.point_of_contact,
|
| 531 |
+
)
|
| 532 |
+
card_metadata.to_dataset_card_data(card_data)
|
| 533 |
+
card = DatasetCard.from_template(
|
| 534 |
+
card_data,
|
| 535 |
+
pretty_name=card_data.pretty_name,
|
| 536 |
+
)
|
| 537 |
+
card.push_to_hub(repo_id, repo_type="dataset")
|
lm-evaluation-harness/lm_eval/loggers/utils.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import subprocess
|
| 5 |
+
from importlib.metadata import version
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from torch.utils.collect_env import get_pretty_env_info
|
| 11 |
+
from transformers import __version__ as trans_version
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def remove_none_pattern(input_string: str) -> Tuple[str, bool]:
|
| 18 |
+
"""Remove the ',none' substring from the input_string if it exists at the end.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
input_string (str): The input string from which to remove the ',none' substring.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Tuple[str, bool]: A tuple containing the modified input_string with the ',none' substring removed
|
| 25 |
+
and a boolean indicating whether the modification was made (True) or not (False).
|
| 26 |
+
"""
|
| 27 |
+
# Define the pattern to match ',none' at the end of the string
|
| 28 |
+
pattern = re.compile(r",none$")
|
| 29 |
+
|
| 30 |
+
# Use sub() to replace ',none' with an empty string
|
| 31 |
+
result = re.sub(pattern, "", input_string)
|
| 32 |
+
|
| 33 |
+
# check if the input_string changed
|
| 34 |
+
removed = result != input_string
|
| 35 |
+
|
| 36 |
+
return result, removed
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _handle_non_serializable(o: Any) -> Union[int, str, list]:
|
| 40 |
+
"""Handle non-serializable objects by converting them to serializable types.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
o (Any): The object to be handled.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Union[int, str, list]: The converted object. If the object is of type np.int64 or np.int32,
|
| 47 |
+
it will be converted to int. If the object is of type set, it will be converted
|
| 48 |
+
to a list. Otherwise, it will be converted to str.
|
| 49 |
+
"""
|
| 50 |
+
if isinstance(o, np.int64) or isinstance(o, np.int32):
|
| 51 |
+
return int(o)
|
| 52 |
+
elif isinstance(o, set):
|
| 53 |
+
return list(o)
|
| 54 |
+
else:
|
| 55 |
+
return str(o)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_commit_from_path(repo_path: Union[Path, str]) -> Optional[str]:
|
| 59 |
+
try:
|
| 60 |
+
git_folder = Path(repo_path, ".git")
|
| 61 |
+
if git_folder.is_file():
|
| 62 |
+
git_folder = Path(
|
| 63 |
+
git_folder.parent,
|
| 64 |
+
git_folder.read_text(encoding="utf-8").split("\n")[0].split(" ")[-1],
|
| 65 |
+
)
|
| 66 |
+
if Path(git_folder, "HEAD").exists():
|
| 67 |
+
head_name = (
|
| 68 |
+
Path(git_folder, "HEAD")
|
| 69 |
+
.read_text(encoding="utf-8")
|
| 70 |
+
.split("\n")[0]
|
| 71 |
+
.split(" ")[-1]
|
| 72 |
+
)
|
| 73 |
+
head_ref = Path(git_folder, head_name)
|
| 74 |
+
git_hash = head_ref.read_text(encoding="utf-8").replace("\n", "")
|
| 75 |
+
else:
|
| 76 |
+
git_hash = None
|
| 77 |
+
except Exception as err:
|
| 78 |
+
logger.debug(
|
| 79 |
+
f"Failed to retrieve a Git commit hash from path: {str(repo_path)}. Error: {err}"
|
| 80 |
+
)
|
| 81 |
+
return None
|
| 82 |
+
return git_hash
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def get_git_commit_hash():
|
| 86 |
+
"""
|
| 87 |
+
Gets the git commit hash of your current repo (if it exists).
|
| 88 |
+
Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
|
| 89 |
+
"""
|
| 90 |
+
try:
|
| 91 |
+
git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
|
| 92 |
+
git_hash = git_hash.decode()
|
| 93 |
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
| 94 |
+
# FileNotFoundError occurs when git not installed on system
|
| 95 |
+
git_hash = get_commit_from_path(os.getcwd()) # git hash of repo if exists
|
| 96 |
+
return git_hash
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def add_env_info(storage: Dict[str, Any]):
|
| 100 |
+
try:
|
| 101 |
+
pretty_env_info = get_pretty_env_info()
|
| 102 |
+
except Exception as err:
|
| 103 |
+
pretty_env_info = str(err)
|
| 104 |
+
try:
|
| 105 |
+
lm_eval_version = version("lm_eval")
|
| 106 |
+
except Exception as err:
|
| 107 |
+
lm_eval_version = str(err)
|
| 108 |
+
transformers_version = trans_version
|
| 109 |
+
upper_dir_commit = get_commit_from_path(
|
| 110 |
+
Path(os.getcwd(), "..")
|
| 111 |
+
) # git hash of upper repo if exists
|
| 112 |
+
added_info = {
|
| 113 |
+
"pretty_env_info": pretty_env_info,
|
| 114 |
+
"transformers_version": transformers_version,
|
| 115 |
+
"lm_eval_version": lm_eval_version,
|
| 116 |
+
"upper_git_hash": upper_dir_commit, # in case this repo is submodule
|
| 117 |
+
}
|
| 118 |
+
storage.update(added_info)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def add_tokenizer_info(storage: Dict[str, Any], lm):
|
| 122 |
+
if getattr(lm, "tokenizer", False):
|
| 123 |
+
try:
|
| 124 |
+
tokenizer_info = {
|
| 125 |
+
"tokenizer_pad_token": [
|
| 126 |
+
lm.tokenizer.pad_token,
|
| 127 |
+
str(lm.tokenizer.pad_token_id),
|
| 128 |
+
],
|
| 129 |
+
"tokenizer_eos_token": [
|
| 130 |
+
lm.tokenizer.eos_token,
|
| 131 |
+
str(lm.tokenizer.eos_token_id),
|
| 132 |
+
],
|
| 133 |
+
"tokenizer_bos_token": [
|
| 134 |
+
lm.tokenizer.bos_token,
|
| 135 |
+
str(lm.tokenizer.bos_token_id),
|
| 136 |
+
],
|
| 137 |
+
"eot_token_id": getattr(lm, "eot_token_id", None),
|
| 138 |
+
"max_length": getattr(lm, "max_length", None),
|
| 139 |
+
}
|
| 140 |
+
storage.update(tokenizer_info)
|
| 141 |
+
except Exception as err:
|
| 142 |
+
logger.debug(
|
| 143 |
+
f"Logging detailed tokenizer info failed with {err}, skipping..."
|
| 144 |
+
)
|
| 145 |
+
# seems gguf and textsynth do not have tokenizer
|
| 146 |
+
else:
|
| 147 |
+
logger.debug(
|
| 148 |
+
"LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results."
|
| 149 |
+
)
|
lm-evaluation-harness/lm_eval/loggers/wandb_logger.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Any, Dict, List, Literal, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from packaging.version import Version
|
| 9 |
+
|
| 10 |
+
from lm_eval.loggers.utils import _handle_non_serializable, remove_none_pattern
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_wandb_printer() -> Literal["Printer"]:
|
| 17 |
+
"""Returns a wandb printer instance for pretty stdout."""
|
| 18 |
+
from wandb.sdk.lib.printer import new_printer
|
| 19 |
+
|
| 20 |
+
printer = new_printer()
|
| 21 |
+
return printer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class WandbLogger:
|
| 25 |
+
def __init__(self, init_args=None, config_args=None) -> None:
|
| 26 |
+
"""Attaches to wandb logger if already initialized. Otherwise, passes init_args to wandb.init() and config_args to wandb.config.update()
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
init_args Optional[Dict]: Arguments for init configuration.
|
| 30 |
+
config_args Optional[Dict]: Arguments for config
|
| 31 |
+
|
| 32 |
+
Parse and log the results returned from evaluator.simple_evaluate() with:
|
| 33 |
+
wandb_logger.post_init(results)
|
| 34 |
+
wandb_logger.log_eval_result()
|
| 35 |
+
wandb_logger.log_eval_samples(results["samples"])
|
| 36 |
+
"""
|
| 37 |
+
try:
|
| 38 |
+
import wandb
|
| 39 |
+
|
| 40 |
+
assert Version(wandb.__version__) >= Version("0.13.6")
|
| 41 |
+
if Version(wandb.__version__) < Version("0.13.6"):
|
| 42 |
+
wandb.require("report-editing:v0")
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logger.warning(
|
| 45 |
+
"To use the wandb reporting functionality please install wandb>=0.13.6.\n"
|
| 46 |
+
"To install the latest version of wandb run `pip install wandb --upgrade`\n"
|
| 47 |
+
f"{e}"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self.wandb_args: Dict[str, Any] = init_args or {}
|
| 51 |
+
self.wandb_config_args: Dict[str, Any] = config_args or {}
|
| 52 |
+
|
| 53 |
+
# pop the step key from the args to save for all logging calls
|
| 54 |
+
self.step = self.wandb_args.pop("step", None)
|
| 55 |
+
|
| 56 |
+
# initialize a W&B run
|
| 57 |
+
if wandb.run is None:
|
| 58 |
+
self.run = wandb.init(**self.wandb_args)
|
| 59 |
+
if self.wandb_config_args:
|
| 60 |
+
self.run.config.update(self.wandb_config_args)
|
| 61 |
+
else:
|
| 62 |
+
self.run = wandb.run
|
| 63 |
+
|
| 64 |
+
self.printer = get_wandb_printer()
|
| 65 |
+
|
| 66 |
+
def post_init(self, results: Dict[str, Any]) -> None:
|
| 67 |
+
self.results: Dict[str, Any] = copy.deepcopy(results)
|
| 68 |
+
self.task_names: List[str] = list(results.get("results", {}).keys())
|
| 69 |
+
self.group_names: List[str] = list(results.get("groups", {}).keys())
|
| 70 |
+
|
| 71 |
+
def _get_config(self) -> Dict[str, Any]:
|
| 72 |
+
"""Get configuration parameters."""
|
| 73 |
+
self.task_configs = self.results.get("configs", {})
|
| 74 |
+
cli_configs = self.results.get("config", {})
|
| 75 |
+
configs = {
|
| 76 |
+
"task_configs": self.task_configs,
|
| 77 |
+
"cli_configs": cli_configs,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
return configs
|
| 81 |
+
|
| 82 |
+
def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]:
|
| 83 |
+
"""Sanitize the results dictionary."""
|
| 84 |
+
_results = copy.deepcopy(self.results.get("results", dict()))
|
| 85 |
+
|
| 86 |
+
# Remove None from the metric string name
|
| 87 |
+
tmp_results = copy.deepcopy(_results)
|
| 88 |
+
for task_name in self.task_names:
|
| 89 |
+
task_result = tmp_results.get(task_name, dict())
|
| 90 |
+
for metric_name, metric_value in task_result.items():
|
| 91 |
+
_metric_name, removed = remove_none_pattern(metric_name)
|
| 92 |
+
if removed:
|
| 93 |
+
_results[task_name][_metric_name] = metric_value
|
| 94 |
+
_results[task_name].pop(metric_name)
|
| 95 |
+
|
| 96 |
+
# remove string valued keys from the results dict
|
| 97 |
+
wandb_summary = {}
|
| 98 |
+
for task in self.task_names:
|
| 99 |
+
task_result = _results.get(task, dict())
|
| 100 |
+
for metric_name, metric_value in task_result.items():
|
| 101 |
+
if isinstance(metric_value, str):
|
| 102 |
+
wandb_summary[f"{task}/{metric_name}"] = metric_value
|
| 103 |
+
|
| 104 |
+
for summary_metric, summary_value in wandb_summary.items():
|
| 105 |
+
_task, _summary_metric = summary_metric.split("/")
|
| 106 |
+
_results[_task].pop(_summary_metric)
|
| 107 |
+
|
| 108 |
+
tmp_results = copy.deepcopy(_results)
|
| 109 |
+
for task_name, task_results in tmp_results.items():
|
| 110 |
+
for metric_name, metric_value in task_results.items():
|
| 111 |
+
_results[f"{task_name}/{metric_name}"] = metric_value
|
| 112 |
+
_results[task_name].pop(metric_name)
|
| 113 |
+
for task in self.task_names:
|
| 114 |
+
_results.pop(task)
|
| 115 |
+
|
| 116 |
+
return wandb_summary, _results
|
| 117 |
+
|
| 118 |
+
def _log_results_as_table(self) -> None:
|
| 119 |
+
"""Generate and log evaluation results as a table to W&B."""
|
| 120 |
+
columns = [
|
| 121 |
+
"Version",
|
| 122 |
+
"Filter",
|
| 123 |
+
"num_fewshot",
|
| 124 |
+
"Metric",
|
| 125 |
+
"Value",
|
| 126 |
+
"Stderr",
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
def make_table(columns: List[str], key: str = "results"):
|
| 130 |
+
import wandb
|
| 131 |
+
|
| 132 |
+
table = wandb.Table(columns=columns)
|
| 133 |
+
results = copy.deepcopy(self.results)
|
| 134 |
+
|
| 135 |
+
for k, dic in results.get(key).items():
|
| 136 |
+
if k in self.group_names and not key == "groups":
|
| 137 |
+
continue
|
| 138 |
+
version = results.get("versions").get(k)
|
| 139 |
+
if version == "N/A":
|
| 140 |
+
version = None
|
| 141 |
+
n = results.get("n-shot").get(k)
|
| 142 |
+
|
| 143 |
+
for (mf), v in dic.items():
|
| 144 |
+
m, _, f = mf.partition(",")
|
| 145 |
+
if m.endswith("_stderr"):
|
| 146 |
+
continue
|
| 147 |
+
if m == "alias":
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
if m + "_stderr" + "," + f in dic:
|
| 151 |
+
se = dic[m + "_stderr" + "," + f]
|
| 152 |
+
if se != "N/A":
|
| 153 |
+
se = "%.4f" % se
|
| 154 |
+
table.add_data(*[k, version, f, n, m, str(v), str(se)])
|
| 155 |
+
else:
|
| 156 |
+
table.add_data(*[k, version, f, n, m, str(v), ""])
|
| 157 |
+
|
| 158 |
+
return table
|
| 159 |
+
|
| 160 |
+
# log the complete eval result to W&B Table
|
| 161 |
+
table = make_table(["Tasks"] + columns, "results")
|
| 162 |
+
self.run.log({"evaluation/eval_results": table}, step=self.step)
|
| 163 |
+
|
| 164 |
+
if "groups" in self.results.keys():
|
| 165 |
+
table = make_table(["Groups"] + columns, "groups")
|
| 166 |
+
self.run.log({"evaluation/group_eval_results": table}, step=self.step)
|
| 167 |
+
|
| 168 |
+
def _log_results_as_artifact(self) -> None:
|
| 169 |
+
"""Log results as JSON artifact to W&B."""
|
| 170 |
+
import wandb
|
| 171 |
+
|
| 172 |
+
dumped = json.dumps(
|
| 173 |
+
self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False
|
| 174 |
+
)
|
| 175 |
+
artifact = wandb.Artifact("results", type="eval_results")
|
| 176 |
+
with artifact.new_file("results.json", mode="w", encoding="utf-8") as f:
|
| 177 |
+
f.write(dumped)
|
| 178 |
+
self.run.log_artifact(artifact)
|
| 179 |
+
|
| 180 |
+
def log_eval_result(self) -> None:
|
| 181 |
+
"""Log evaluation results to W&B."""
|
| 182 |
+
# Log configs to wandb
|
| 183 |
+
configs = self._get_config()
|
| 184 |
+
self.run.config.update(configs, allow_val_change=self.step is not None)
|
| 185 |
+
|
| 186 |
+
wandb_summary, self.wandb_results = self._sanitize_results_dict()
|
| 187 |
+
# update wandb.run.summary with items that were removed
|
| 188 |
+
self.run.summary.update(wandb_summary)
|
| 189 |
+
# Log the evaluation metrics to wandb
|
| 190 |
+
self.run.log(self.wandb_results, step=self.step)
|
| 191 |
+
# Log the evaluation metrics as W&B Table
|
| 192 |
+
self._log_results_as_table()
|
| 193 |
+
# Log the results dict as json to W&B Artifacts
|
| 194 |
+
self._log_results_as_artifact()
|
| 195 |
+
|
| 196 |
+
def _generate_dataset(
|
| 197 |
+
self, data: List[Dict[str, Any]], config: Dict[str, Any]
|
| 198 |
+
) -> pd.DataFrame:
|
| 199 |
+
"""Generate a dataset from evaluation data.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
data (List[Dict[str, Any]]): The data to generate a dataset for.
|
| 203 |
+
config (Dict[str, Any]): The configuration of the task.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
pd.DataFrame: A dataframe that is ready to be uploaded to W&B.
|
| 207 |
+
"""
|
| 208 |
+
ids = [x["doc_id"] for x in data]
|
| 209 |
+
labels = [x["target"] for x in data]
|
| 210 |
+
instance = [""] * len(ids)
|
| 211 |
+
resps = [""] * len(ids)
|
| 212 |
+
filtered_resps = [""] * len(ids)
|
| 213 |
+
model_outputs = {}
|
| 214 |
+
|
| 215 |
+
metrics_list = config["metric_list"]
|
| 216 |
+
metrics = {}
|
| 217 |
+
for metric in metrics_list:
|
| 218 |
+
metric = metric.get("metric")
|
| 219 |
+
if metric in ["word_perplexity", "byte_perplexity", "bits_per_byte"]:
|
| 220 |
+
metrics[f"{metric}_loglikelihood"] = [x[metric][0] for x in data]
|
| 221 |
+
if metric in ["byte_perplexity", "bits_per_byte"]:
|
| 222 |
+
metrics[f"{metric}_bytes"] = [x[metric][1] for x in data]
|
| 223 |
+
else:
|
| 224 |
+
metrics[f"{metric}_words"] = [x[metric][1] for x in data]
|
| 225 |
+
else:
|
| 226 |
+
metrics[metric] = [x[metric] for x in data]
|
| 227 |
+
|
| 228 |
+
if config["output_type"] == "loglikelihood":
|
| 229 |
+
instance = [x["arguments"][0][0] for x in data]
|
| 230 |
+
labels = [x["arguments"][0][1] for x in data]
|
| 231 |
+
resps = [
|
| 232 |
+
f"log probability of continuation is {x['resps'][0][0][0]} "
|
| 233 |
+
+ "\n\n"
|
| 234 |
+
+ "continuation will {} generated with greedy sampling".format(
|
| 235 |
+
"not be" if not x["resps"][0][0][1] else "be"
|
| 236 |
+
)
|
| 237 |
+
for x in data
|
| 238 |
+
]
|
| 239 |
+
filtered_resps = [
|
| 240 |
+
f"log probability of continuation is {x['filtered_resps'][0][0]} "
|
| 241 |
+
+ "\n\n"
|
| 242 |
+
+ "continuation will {} generated with greedy sampling".format(
|
| 243 |
+
"not be" if not x["filtered_resps"][0][1] else "be"
|
| 244 |
+
)
|
| 245 |
+
for x in data
|
| 246 |
+
]
|
| 247 |
+
elif config["output_type"] == "multiple_choice":
|
| 248 |
+
instance = [x["arguments"][0][0] for x in data]
|
| 249 |
+
choices = [
|
| 250 |
+
"\n".join([f"{idx}. {y[1]}" for idx, y in enumerate(x["arguments"])])
|
| 251 |
+
for x in data
|
| 252 |
+
]
|
| 253 |
+
resps = [np.argmax([n[0][0] for n in x["resps"]]) for x in data]
|
| 254 |
+
filtered_resps = [
|
| 255 |
+
np.argmax([n[0] for n in x["filtered_resps"]]) for x in data
|
| 256 |
+
]
|
| 257 |
+
elif config["output_type"] == "loglikelihood_rolling":
|
| 258 |
+
instance = [x["arguments"][0][0] for x in data]
|
| 259 |
+
resps = [x["resps"][0][0] for x in data]
|
| 260 |
+
filtered_resps = [x["filtered_resps"][0] for x in data]
|
| 261 |
+
elif config["output_type"] == "generate_until":
|
| 262 |
+
instance = [x["arguments"][0][0] for x in data]
|
| 263 |
+
resps = [x["resps"][0][0] for x in data]
|
| 264 |
+
filtered_resps = [x["filtered_resps"][0] for x in data]
|
| 265 |
+
|
| 266 |
+
model_outputs["raw_predictions"] = resps
|
| 267 |
+
model_outputs["filtered_predictions"] = filtered_resps
|
| 268 |
+
|
| 269 |
+
df_data = {
|
| 270 |
+
"id": ids,
|
| 271 |
+
"data": instance,
|
| 272 |
+
}
|
| 273 |
+
if config["output_type"] == "multiple_choice":
|
| 274 |
+
df_data["choices"] = choices
|
| 275 |
+
|
| 276 |
+
tmp_data = {
|
| 277 |
+
"input_len": [len(x) for x in instance],
|
| 278 |
+
"labels": labels,
|
| 279 |
+
"output_type": config["output_type"],
|
| 280 |
+
}
|
| 281 |
+
df_data.update(tmp_data)
|
| 282 |
+
df_data.update(model_outputs)
|
| 283 |
+
df_data.update(metrics)
|
| 284 |
+
|
| 285 |
+
return pd.DataFrame(df_data)
|
| 286 |
+
|
| 287 |
+
def _log_samples_as_artifact(
|
| 288 |
+
self, data: List[Dict[str, Any]], task_name: str
|
| 289 |
+
) -> None:
|
| 290 |
+
import wandb
|
| 291 |
+
|
| 292 |
+
# log the samples as an artifact
|
| 293 |
+
dumped = json.dumps(
|
| 294 |
+
data,
|
| 295 |
+
indent=2,
|
| 296 |
+
default=_handle_non_serializable,
|
| 297 |
+
ensure_ascii=False,
|
| 298 |
+
)
|
| 299 |
+
artifact = wandb.Artifact(f"{task_name}", type="samples_by_task")
|
| 300 |
+
with artifact.new_file(
|
| 301 |
+
f"{task_name}_eval_samples.json", mode="w", encoding="utf-8"
|
| 302 |
+
) as f:
|
| 303 |
+
f.write(dumped)
|
| 304 |
+
self.run.log_artifact(artifact)
|
| 305 |
+
# artifact.wait()
|
| 306 |
+
|
| 307 |
+
def log_eval_samples(self, samples: Dict[str, List[Dict[str, Any]]]) -> None:
|
| 308 |
+
"""Log evaluation samples to W&B.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
samples (Dict[str, List[Dict[str, Any]]]): Evaluation samples for each task.
|
| 312 |
+
"""
|
| 313 |
+
task_names: List[str] = [
|
| 314 |
+
x for x in self.task_names if x not in self.group_names
|
| 315 |
+
]
|
| 316 |
+
|
| 317 |
+
ungrouped_tasks = []
|
| 318 |
+
tasks_by_groups = {}
|
| 319 |
+
|
| 320 |
+
for task_name in task_names:
|
| 321 |
+
group_names = self.task_configs[task_name].get("group", None)
|
| 322 |
+
if group_names:
|
| 323 |
+
if isinstance(group_names, str):
|
| 324 |
+
group_names = [group_names]
|
| 325 |
+
|
| 326 |
+
for group_name in group_names:
|
| 327 |
+
if not tasks_by_groups.get(group_name):
|
| 328 |
+
tasks_by_groups[group_name] = [task_name]
|
| 329 |
+
else:
|
| 330 |
+
tasks_by_groups[group_name].append(task_name)
|
| 331 |
+
else:
|
| 332 |
+
ungrouped_tasks.append(task_name)
|
| 333 |
+
|
| 334 |
+
for task_name in ungrouped_tasks:
|
| 335 |
+
eval_preds = samples[task_name]
|
| 336 |
+
|
| 337 |
+
# log the samples as a W&B Table
|
| 338 |
+
df = self._generate_dataset(eval_preds, self.task_configs.get(task_name))
|
| 339 |
+
self.run.log({f"{task_name}_eval_results": df}, step=self.step)
|
| 340 |
+
|
| 341 |
+
# log the samples as a json file as W&B Artifact
|
| 342 |
+
self._log_samples_as_artifact(eval_preds, task_name)
|
| 343 |
+
|
| 344 |
+
for group, grouped_tasks in tasks_by_groups.items():
|
| 345 |
+
grouped_df = pd.DataFrame()
|
| 346 |
+
for task_name in grouped_tasks:
|
| 347 |
+
eval_preds = samples[task_name]
|
| 348 |
+
df = self._generate_dataset(
|
| 349 |
+
eval_preds, self.task_configs.get(task_name)
|
| 350 |
+
)
|
| 351 |
+
df["group"] = group
|
| 352 |
+
df["task"] = task_name
|
| 353 |
+
grouped_df = pd.concat([grouped_df, df], ignore_index=True)
|
| 354 |
+
|
| 355 |
+
# log the samples as a json file as W&B Artifact
|
| 356 |
+
self._log_samples_as_artifact(eval_preds, task_name)
|
| 357 |
+
|
| 358 |
+
self.run.log({f"{group}_eval_results": grouped_df}, step=self.step)
|
lm-evaluation-harness/lm_eval/models/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import (
|
| 2 |
+
anthropic_llms,
|
| 3 |
+
api_models,
|
| 4 |
+
dummy,
|
| 5 |
+
gguf,
|
| 6 |
+
hf_audiolm,
|
| 7 |
+
hf_steered,
|
| 8 |
+
hf_vlms,
|
| 9 |
+
huggingface,
|
| 10 |
+
ibm_watsonx_ai,
|
| 11 |
+
mamba_lm,
|
| 12 |
+
nemo_lm,
|
| 13 |
+
neuron_optimum,
|
| 14 |
+
openai_completions,
|
| 15 |
+
optimum_ipex,
|
| 16 |
+
optimum_lm,
|
| 17 |
+
sglang_causallms,
|
| 18 |
+
sglang_generate_API,
|
| 19 |
+
textsynth,
|
| 20 |
+
vllm_causallms,
|
| 21 |
+
vllm_vlms,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# TODO: implement __all__
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
# enable hf hub transfer if available
|
| 30 |
+
import hf_transfer # type: ignore # noqa
|
| 31 |
+
import huggingface_hub.constants # type: ignore
|
| 32 |
+
|
| 33 |
+
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
|
| 34 |
+
except ImportError:
|
| 35 |
+
pass
|
lm-evaluation-harness/lm_eval/models/anthropic_llms.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from functools import cached_property
|
| 4 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 5 |
+
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
from lm_eval.api.model import LM
|
| 9 |
+
from lm_eval.api.registry import register_model
|
| 10 |
+
from lm_eval.models.openai_completions import LocalCompletionsAPI
|
| 11 |
+
from lm_eval.models.utils import handle_stop_sequences, retry_on_specific_exceptions
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
eval_logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def anthropic_completion(
|
| 18 |
+
client, #: anthropic.Anthropic,
|
| 19 |
+
model: str,
|
| 20 |
+
prompt: str,
|
| 21 |
+
max_tokens_to_sample: int,
|
| 22 |
+
temperature: float,
|
| 23 |
+
stop: List[str],
|
| 24 |
+
**kwargs: Any,
|
| 25 |
+
) -> str:
|
| 26 |
+
"""Wrapper function around the Anthropic completion API client with exponential back-off
|
| 27 |
+
in case of RateLimitError.
|
| 28 |
+
|
| 29 |
+
params:
|
| 30 |
+
client: anthropic.Anthropic
|
| 31 |
+
Anthropic API client
|
| 32 |
+
model: str
|
| 33 |
+
Anthropic model e.g. 'claude-instant-v1', 'claude-2'
|
| 34 |
+
prompt: str
|
| 35 |
+
Prompt to feed to the model
|
| 36 |
+
max_tokens_to_sample: int
|
| 37 |
+
Maximum number of tokens to sample from the model
|
| 38 |
+
temperature: float
|
| 39 |
+
Sampling temperature
|
| 40 |
+
stop: List[str]
|
| 41 |
+
List of stop sequences
|
| 42 |
+
kwargs: Any
|
| 43 |
+
Additional model_args to pass to the API client
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
import anthropic
|
| 48 |
+
except ModuleNotFoundError as exception:
|
| 49 |
+
raise type(exception)(
|
| 50 |
+
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
|
| 51 |
+
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def _exception_callback(e: Exception, sleep_time: float) -> None:
|
| 55 |
+
eval_logger.warning(
|
| 56 |
+
f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
@retry_on_specific_exceptions(
|
| 60 |
+
on_exceptions=[anthropic.RateLimitError],
|
| 61 |
+
max_retries=None, # retry forever, consider changing
|
| 62 |
+
on_exception_callback=_exception_callback,
|
| 63 |
+
)
|
| 64 |
+
def completion():
|
| 65 |
+
response = client.completions.create(
|
| 66 |
+
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
|
| 67 |
+
model=model,
|
| 68 |
+
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
|
| 69 |
+
# (e.g. gsm8k's ":") may truncate a lot of the input.
|
| 70 |
+
stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
|
| 71 |
+
max_tokens_to_sample=max_tokens_to_sample,
|
| 72 |
+
temperature=temperature,
|
| 73 |
+
**kwargs,
|
| 74 |
+
)
|
| 75 |
+
return response.completion
|
| 76 |
+
|
| 77 |
+
return completion()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def anthropic_chat(
|
| 81 |
+
client, #: anthropic.Anthropic,
|
| 82 |
+
model: str,
|
| 83 |
+
prompt: str,
|
| 84 |
+
max_tokens: int,
|
| 85 |
+
temperature: float,
|
| 86 |
+
stop: List[str],
|
| 87 |
+
**kwargs: Any,
|
| 88 |
+
) -> str:
|
| 89 |
+
"""Wrapper function around the Anthropic completion API client with exponential back-off
|
| 90 |
+
in case of RateLimitError.
|
| 91 |
+
|
| 92 |
+
params:
|
| 93 |
+
client: anthropic.Anthropic
|
| 94 |
+
Anthropic API client
|
| 95 |
+
model: str
|
| 96 |
+
Anthropic model e.g. 'claude-3-opus-20240229', 'claude-3-sonnet-20240229'
|
| 97 |
+
prompt: str
|
| 98 |
+
Prompt to feed to the model
|
| 99 |
+
max_tokens: int
|
| 100 |
+
Maximum number of tokens to sample from the model
|
| 101 |
+
temperature: float
|
| 102 |
+
Sampling temperature
|
| 103 |
+
stop: List[str]
|
| 104 |
+
List of stop sequences
|
| 105 |
+
kwargs: Any
|
| 106 |
+
Additional model_args to pass to the API client
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
import anthropic
|
| 111 |
+
except ModuleNotFoundError as exception:
|
| 112 |
+
raise type(exception)(
|
| 113 |
+
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
|
| 114 |
+
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def _exception_callback(e: Exception, sleep_time: float) -> None:
|
| 118 |
+
eval_logger.warning(
|
| 119 |
+
f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
@retry_on_specific_exceptions(
|
| 123 |
+
on_exceptions=[
|
| 124 |
+
anthropic.RateLimitError,
|
| 125 |
+
anthropic.APIConnectionError,
|
| 126 |
+
anthropic.APIStatusError,
|
| 127 |
+
],
|
| 128 |
+
max_retries=None, # retry forever, consider changing
|
| 129 |
+
on_exception_callback=_exception_callback,
|
| 130 |
+
)
|
| 131 |
+
def messages():
|
| 132 |
+
response = client.messages.create(
|
| 133 |
+
model=model,
|
| 134 |
+
max_tokens=max_tokens,
|
| 135 |
+
temperature=temperature,
|
| 136 |
+
messages=[{"role": "user", "content": f"{prompt}"}],
|
| 137 |
+
**kwargs,
|
| 138 |
+
)
|
| 139 |
+
return response.content[0].text
|
| 140 |
+
|
| 141 |
+
return messages()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@register_model("anthropic-completions")
|
| 145 |
+
class AnthropicLM(LM):
|
| 146 |
+
REQ_CHUNK_SIZE = 20 # TODO: not used
|
| 147 |
+
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
batch_size: int = 1,
|
| 151 |
+
model: str = "claude-2.0",
|
| 152 |
+
max_tokens_to_sample: int = 256,
|
| 153 |
+
temperature: float = 0, # defaults to 1
|
| 154 |
+
**kwargs, # top_p, top_k, etc.
|
| 155 |
+
) -> None:
|
| 156 |
+
"""Anthropic API wrapper.
|
| 157 |
+
|
| 158 |
+
:param model: str
|
| 159 |
+
Anthropic model e.g. 'claude-instant-v1', 'claude-2'
|
| 160 |
+
:param max_tokens_to_sample: int
|
| 161 |
+
Maximum number of tokens to sample from the model
|
| 162 |
+
:param temperature: float
|
| 163 |
+
Sampling temperature
|
| 164 |
+
:param kwargs: Any
|
| 165 |
+
Additional model_args to pass to the API client
|
| 166 |
+
"""
|
| 167 |
+
super().__init__()
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
import anthropic
|
| 171 |
+
except ModuleNotFoundError as exception:
|
| 172 |
+
raise type(exception)(
|
| 173 |
+
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
|
| 174 |
+
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
self.model = model
|
| 178 |
+
# defaults to os.environ.get("ANTHROPIC_API_KEY")
|
| 179 |
+
self.client = anthropic.Anthropic()
|
| 180 |
+
self.temperature = temperature
|
| 181 |
+
self.max_tokens_to_sample = max_tokens_to_sample
|
| 182 |
+
self.tokenizer = self.client.get_tokenizer()
|
| 183 |
+
self.kwargs = kwargs
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
def eot_token_id(self):
|
| 187 |
+
# Not sure but anthropic.HUMAN_PROMPT ?
|
| 188 |
+
raise NotImplementedError("No idea about anthropic tokenization.")
|
| 189 |
+
|
| 190 |
+
@property
|
| 191 |
+
def max_length(self) -> int:
|
| 192 |
+
return 2048
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def max_gen_toks(self) -> int:
|
| 196 |
+
return self.max_tokens_to_sample
|
| 197 |
+
|
| 198 |
+
@property
|
| 199 |
+
def batch_size(self):
|
| 200 |
+
# Isn't used because we override _loglikelihood_tokens
|
| 201 |
+
raise NotImplementedError("No support for logits.")
|
| 202 |
+
|
| 203 |
+
@property
|
| 204 |
+
def device(self):
|
| 205 |
+
# Isn't used because we override _loglikelihood_tokens
|
| 206 |
+
raise NotImplementedError("No support for logits.")
|
| 207 |
+
|
| 208 |
+
def tok_encode(self, string: str) -> List[int]:
|
| 209 |
+
return self.tokenizer.encode(string).ids
|
| 210 |
+
|
| 211 |
+
def tok_decode(self, tokens: List[int]) -> str:
|
| 212 |
+
return self.tokenizer.decode(tokens)
|
| 213 |
+
|
| 214 |
+
def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
|
| 215 |
+
raise NotImplementedError("No support for logits.")
|
| 216 |
+
|
| 217 |
+
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
|
| 218 |
+
try:
|
| 219 |
+
import anthropic
|
| 220 |
+
except ModuleNotFoundError as exception:
|
| 221 |
+
raise type(exception)(
|
| 222 |
+
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
|
| 223 |
+
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
if not requests:
|
| 227 |
+
return []
|
| 228 |
+
|
| 229 |
+
_requests: List[Tuple[str, dict]] = [req.args for req in requests]
|
| 230 |
+
|
| 231 |
+
res = []
|
| 232 |
+
for request in tqdm(_requests, disable=disable_tqdm):
|
| 233 |
+
try:
|
| 234 |
+
inp = request[0]
|
| 235 |
+
request_args = request[1]
|
| 236 |
+
# generation_kwargs
|
| 237 |
+
until = request_args.get("until")
|
| 238 |
+
max_gen_toks = request_args.get("max_gen_toks", self.max_length)
|
| 239 |
+
temperature = request_args.get("temperature", self.temperature)
|
| 240 |
+
response = anthropic_completion(
|
| 241 |
+
client=self.client,
|
| 242 |
+
model=self.model,
|
| 243 |
+
prompt=inp,
|
| 244 |
+
max_tokens_to_sample=max_gen_toks,
|
| 245 |
+
temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
|
| 246 |
+
stop=until, # type: ignore
|
| 247 |
+
**self.kwargs,
|
| 248 |
+
)
|
| 249 |
+
res.append(response)
|
| 250 |
+
|
| 251 |
+
self.cache_hook.add_partial("generate_until", request, response)
|
| 252 |
+
except anthropic.APIConnectionError as e: # type: ignore # noqa: F821
|
| 253 |
+
eval_logger.critical(f"Server unreachable: {e.__cause__}")
|
| 254 |
+
break
|
| 255 |
+
except anthropic.APIStatusError as e: # type: ignore # noqa: F821
|
| 256 |
+
eval_logger.critical(f"API error {e.status_code}: {e.message}")
|
| 257 |
+
break
|
| 258 |
+
|
| 259 |
+
return res
|
| 260 |
+
|
| 261 |
+
def _model_call(self, inps):
|
| 262 |
+
# Isn't used because we override _loglikelihood_tokens
|
| 263 |
+
raise NotImplementedError()
|
| 264 |
+
|
| 265 |
+
def _model_generate(self, context, max_length, eos_token_id):
|
| 266 |
+
# Isn't used because we override generate_until
|
| 267 |
+
raise NotImplementedError()
|
| 268 |
+
|
| 269 |
+
def loglikelihood(self, requests, disable_tqdm: bool = False):
|
| 270 |
+
raise NotImplementedError("No support for logits.")
|
| 271 |
+
|
| 272 |
+
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
|
| 273 |
+
raise NotImplementedError("No support for logits.")
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
@register_model("anthropic-chat", "anthropic-chat-completions")
|
| 277 |
+
class AnthropicChat(LocalCompletionsAPI):
|
| 278 |
+
def __init__(
|
| 279 |
+
self,
|
| 280 |
+
base_url="https://api.anthropic.com/v1/messages",
|
| 281 |
+
tokenizer_backend=None,
|
| 282 |
+
**kwargs,
|
| 283 |
+
):
|
| 284 |
+
super().__init__(
|
| 285 |
+
base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
|
| 286 |
+
)
|
| 287 |
+
eval_logger.warning(
|
| 288 |
+
"Chat completions does not support batching. Defaulting to batch size 1."
|
| 289 |
+
)
|
| 290 |
+
self._batch_size = 1
|
| 291 |
+
self.anthropic_version = "2023-06-01"
|
| 292 |
+
eval_logger.warning(
|
| 293 |
+
f"Using Anthropic Version: {self.anthropic_version}. Confirm the current version here: https://docs.anthropic.com/en/api/versioning"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
@cached_property
|
| 297 |
+
def api_key(self):
|
| 298 |
+
"""Override this property to return the API key for the API request."""
|
| 299 |
+
key = os.environ.get("ANTHROPIC_API_KEY", None)
|
| 300 |
+
if key is None:
|
| 301 |
+
raise ValueError(
|
| 302 |
+
"API key not found. Please set the ANTHROPIC_API_KEY environment variable."
|
| 303 |
+
)
|
| 304 |
+
return key
|
| 305 |
+
|
| 306 |
+
@cached_property
|
| 307 |
+
def header(self):
|
| 308 |
+
return {
|
| 309 |
+
"x-api-key": f"{self.api_key}",
|
| 310 |
+
"anthropic-version": self.anthropic_version,
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
def _create_payload(
|
| 314 |
+
self,
|
| 315 |
+
messages: List[Dict],
|
| 316 |
+
generate=True,
|
| 317 |
+
gen_kwargs: dict = None,
|
| 318 |
+
eos="\n\nHuman:",
|
| 319 |
+
**kwargs,
|
| 320 |
+
) -> dict:
|
| 321 |
+
system = (
|
| 322 |
+
messages[0].get("content") if messages[0].get("role") == "system" else None
|
| 323 |
+
)
|
| 324 |
+
if system:
|
| 325 |
+
messages = messages[1:]
|
| 326 |
+
|
| 327 |
+
cleaned_messages = []
|
| 328 |
+
for msg in messages:
|
| 329 |
+
cleaned_msg = {
|
| 330 |
+
"role": msg["role"],
|
| 331 |
+
"content": [
|
| 332 |
+
{"type": msg["type"], msg["type"]: msg["content"]},
|
| 333 |
+
],
|
| 334 |
+
}
|
| 335 |
+
cleaned_messages.append(cleaned_msg)
|
| 336 |
+
|
| 337 |
+
gen_kwargs.pop("do_sample", False)
|
| 338 |
+
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
|
| 339 |
+
temperature = gen_kwargs.pop("temperature", 0)
|
| 340 |
+
stop = handle_stop_sequences(gen_kwargs.pop("until", ["\n\nHuman:"]), eos=eos)
|
| 341 |
+
if not isinstance(stop, list):
|
| 342 |
+
stop = [stop]
|
| 343 |
+
|
| 344 |
+
# Filter out empty or whitespace-only stop sequences for Anthropic API
|
| 345 |
+
stop = [s for s in stop if s and s.strip()]
|
| 346 |
+
|
| 347 |
+
out = {
|
| 348 |
+
"messages": cleaned_messages,
|
| 349 |
+
"model": self.model,
|
| 350 |
+
"max_tokens": max_tokens,
|
| 351 |
+
"temperature": temperature,
|
| 352 |
+
"stop_sequences": stop,
|
| 353 |
+
**gen_kwargs,
|
| 354 |
+
}
|
| 355 |
+
if system:
|
| 356 |
+
out["system"] = system
|
| 357 |
+
return out
|
| 358 |
+
|
| 359 |
+
def parse_generations(
|
| 360 |
+
self, outputs: Union[Dict, List[Dict]], **kwargs
|
| 361 |
+
) -> List[str]:
|
| 362 |
+
res = []
|
| 363 |
+
if not isinstance(outputs, list):
|
| 364 |
+
outputs = [outputs]
|
| 365 |
+
for out in outputs:
|
| 366 |
+
for choices in out["content"]:
|
| 367 |
+
res.append(choices["text"])
|
| 368 |
+
return res
|
| 369 |
+
|
| 370 |
+
def tok_encode(
|
| 371 |
+
self,
|
| 372 |
+
string: str,
|
| 373 |
+
left_truncate_len=None,
|
| 374 |
+
add_special_tokens=None,
|
| 375 |
+
**kwargs,
|
| 376 |
+
) -> List[str]:
|
| 377 |
+
return [string]
|
| 378 |
+
|
| 379 |
+
def loglikelihood(self, requests, **kwargs):
|
| 380 |
+
raise NotImplementedError(
|
| 381 |
+
"Anthropic Chat Completions API does not support the return of loglikelihood"
|
| 382 |
+
)
|
lm-evaluation-harness/lm_eval/models/api_models.py
ADDED
|
@@ -0,0 +1,810 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import asyncio
|
| 3 |
+
import copy
|
| 4 |
+
import itertools
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
from functools import cached_property
|
| 8 |
+
from typing import (
|
| 9 |
+
TYPE_CHECKING,
|
| 10 |
+
Any,
|
| 11 |
+
Awaitable,
|
| 12 |
+
Callable,
|
| 13 |
+
Dict,
|
| 14 |
+
Iterable,
|
| 15 |
+
List,
|
| 16 |
+
Literal,
|
| 17 |
+
NamedTuple,
|
| 18 |
+
Optional,
|
| 19 |
+
Tuple,
|
| 20 |
+
Union,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
import requests
|
| 26 |
+
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
| 27 |
+
from tenacity import RetryError, retry, stop_after_attempt, wait_exponential
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
from tqdm.asyncio import tqdm_asyncio
|
| 30 |
+
except ModuleNotFoundError:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
import base64
|
| 35 |
+
from importlib.util import find_spec
|
| 36 |
+
from io import BytesIO
|
| 37 |
+
|
| 38 |
+
from lm_eval import utils
|
| 39 |
+
from lm_eval.api.instance import Instance
|
| 40 |
+
from lm_eval.api.model import TemplateLM
|
| 41 |
+
from lm_eval.models.utils import Collator, chunks, configure_pad_token
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if TYPE_CHECKING:
|
| 45 |
+
from PIL import Image
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
eval_logger = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# utility class to keep track of json encoded chats
|
| 54 |
+
class JsonChatStr(NamedTuple):
|
| 55 |
+
prompt: str
|
| 56 |
+
|
| 57 |
+
def encode(self, encoding):
|
| 58 |
+
return self.prompt.encode(encoding)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def create_image_prompt(
|
| 62 |
+
imgs: list["Image.Image"], chat: dict, fmt: str = "PNG"
|
| 63 |
+
) -> dict:
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
Parameters
|
| 67 |
+
----------
|
| 68 |
+
img : list[PIL.Image.Image]
|
| 69 |
+
The list of images to encode to base64
|
| 70 |
+
chat : dict
|
| 71 |
+
fmt : str, optional
|
| 72 |
+
Any format Pillow understands (e.g. "PNG", "JPEG").
|
| 73 |
+
Defaults to "PNG".
|
| 74 |
+
|
| 75 |
+
Returns
|
| 76 |
+
-------
|
| 77 |
+
dict
|
| 78 |
+
"""
|
| 79 |
+
images = []
|
| 80 |
+
for img in imgs:
|
| 81 |
+
buf = BytesIO()
|
| 82 |
+
img.save(buf, format=fmt)
|
| 83 |
+
img_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 84 |
+
img_dict = {
|
| 85 |
+
"type": "image_url",
|
| 86 |
+
"image_url": {"url": f"data:image/png;base64,{img_b64}", "detail": "auto"},
|
| 87 |
+
}
|
| 88 |
+
images.append(img_dict)
|
| 89 |
+
|
| 90 |
+
# chat is in format of list[dict["role": "user"/"system", "content": str, "type": "text"],...]
|
| 91 |
+
# with images, we need "content" to be a list of dicts with "type" and "text"/"image_url"
|
| 92 |
+
# currently we do not support few-shots so only one user message
|
| 93 |
+
# text content also has <image> placeholders, which apparently is not necessary for API class (confirm)
|
| 94 |
+
|
| 95 |
+
if isinstance(chat[-1]["content"], list):
|
| 96 |
+
chat[-1]["content"] = images + chat[-1]["content"]
|
| 97 |
+
else:
|
| 98 |
+
text_content = {"type": "text", "text": chat[-1]["content"]}
|
| 99 |
+
chat[-1]["content"] = images + [text_content]
|
| 100 |
+
chat[-1].pop("type")
|
| 101 |
+
return chat
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class TemplateAPI(TemplateLM):
|
| 105 |
+
MULTIMODAL = True
|
| 106 |
+
|
| 107 |
+
def __init__(
|
| 108 |
+
self,
|
| 109 |
+
model: str = None,
|
| 110 |
+
pretrained: str = None, # `model` takes precedence over `pretrained` when passed.
|
| 111 |
+
base_url: str = None,
|
| 112 |
+
tokenizer: Optional[str] = None,
|
| 113 |
+
# Loglikelihood tasks require a tokenizer to calculate context lengths,
|
| 114 |
+
# however the requests can be sent as a string if the API doesn't support token inputs.
|
| 115 |
+
# use tokenized_requests=False
|
| 116 |
+
tokenizer_backend: Optional[
|
| 117 |
+
Literal["tiktoken", "huggingface", "None", "none"]
|
| 118 |
+
] = "huggingface",
|
| 119 |
+
truncate: bool = False,
|
| 120 |
+
# number of concurrent requests. More useful if not batching
|
| 121 |
+
num_concurrent: int = 1,
|
| 122 |
+
max_retries: int = 3,
|
| 123 |
+
max_gen_toks: int = 256,
|
| 124 |
+
batch_size: Union[str, int] = 1,
|
| 125 |
+
seed: int = 1234,
|
| 126 |
+
max_length: Optional[int] = 2048,
|
| 127 |
+
add_bos_token: bool = False,
|
| 128 |
+
custom_prefix_token_id: int = None,
|
| 129 |
+
# send the requests as tokens or strings
|
| 130 |
+
tokenized_requests: bool = True,
|
| 131 |
+
trust_remote_code: bool = False,
|
| 132 |
+
revision: Optional[str] = "main",
|
| 133 |
+
use_fast_tokenizer: bool = True,
|
| 134 |
+
verify_certificate: bool = True,
|
| 135 |
+
eos_string: str = None,
|
| 136 |
+
# timeout in seconds
|
| 137 |
+
timeout: int = 300,
|
| 138 |
+
header: Optional[Dict[str, str]] = None,
|
| 139 |
+
max_images: int = 1,
|
| 140 |
+
**kwargs,
|
| 141 |
+
) -> None:
|
| 142 |
+
super().__init__()
|
| 143 |
+
missing_packages = [
|
| 144 |
+
pkg
|
| 145 |
+
for pkg in ["aiohttp", "tqdm", "tenacity", "requests"]
|
| 146 |
+
if find_spec(pkg) is None
|
| 147 |
+
]
|
| 148 |
+
if missing_packages:
|
| 149 |
+
raise ModuleNotFoundError(
|
| 150 |
+
f"Attempted to use an API model, but the required packages {missing_packages} are not installed. "
|
| 151 |
+
'Please install these via `pip install lm-eval[api]` or `pip install -e ."[api]"`'
|
| 152 |
+
)
|
| 153 |
+
self.model = model or pretrained
|
| 154 |
+
self.base_url = base_url
|
| 155 |
+
self.tokenizer = tokenizer
|
| 156 |
+
self._header = header
|
| 157 |
+
if not isinstance(batch_size, int) and "auto" in batch_size:
|
| 158 |
+
eval_logger.warning(
|
| 159 |
+
"Automatic batch size is not supported for API models. Defaulting to batch size 1."
|
| 160 |
+
)
|
| 161 |
+
elif int(batch_size) > 1:
|
| 162 |
+
eval_logger.warning(
|
| 163 |
+
"Batch size > 1 detected. Ensure your API supports batched requests with varying total sequence lengths."
|
| 164 |
+
)
|
| 165 |
+
self._batch_size = int(batch_size) if batch_size != "auto" else 1
|
| 166 |
+
self._truncate = truncate
|
| 167 |
+
self._max_gen_toks = int(max_gen_toks)
|
| 168 |
+
self._seed = int(seed)
|
| 169 |
+
# max_length - 1 as we always have 1 token for generation
|
| 170 |
+
eval_logger.info(f"Using max length {max_length} - 1")
|
| 171 |
+
self.max_length = max_length - 1
|
| 172 |
+
if int(num_concurrent) <= 1:
|
| 173 |
+
eval_logger.info(
|
| 174 |
+
"Concurrent requests are disabled. To enable concurrent requests, set `num_concurrent` > 1."
|
| 175 |
+
)
|
| 176 |
+
self._concurrent = int(num_concurrent)
|
| 177 |
+
self.tokenizer_backend = (
|
| 178 |
+
None if tokenizer_backend in ("None", "none") else tokenizer_backend
|
| 179 |
+
)
|
| 180 |
+
self.add_bos_token = add_bos_token
|
| 181 |
+
self.custom_prefix_token_id = custom_prefix_token_id
|
| 182 |
+
self.tokenized_requests = tokenized_requests
|
| 183 |
+
self.max_retries = int(max_retries)
|
| 184 |
+
self.verify_certificate = verify_certificate
|
| 185 |
+
self._eos_string = eos_string
|
| 186 |
+
self.timeout = int(timeout)
|
| 187 |
+
self.max_images = int(max_images)
|
| 188 |
+
|
| 189 |
+
eval_logger.info(f"Using tokenizer {self.tokenizer_backend}")
|
| 190 |
+
if self.tokenizer_backend is None:
|
| 191 |
+
self.tokenizer = None
|
| 192 |
+
self.tokenized_requests = False
|
| 193 |
+
else:
|
| 194 |
+
if self.tokenizer is None:
|
| 195 |
+
if self.tokenizer_backend == "huggingface":
|
| 196 |
+
import transformers
|
| 197 |
+
|
| 198 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 199 |
+
self.tokenizer if self.tokenizer else self.model,
|
| 200 |
+
trust_remote_code=trust_remote_code,
|
| 201 |
+
revision=revision,
|
| 202 |
+
use_fast=use_fast_tokenizer,
|
| 203 |
+
)
|
| 204 |
+
# Not used as the API will handle padding but to mirror the behavior of the HFLM
|
| 205 |
+
self.tokenizer = configure_pad_token(self.tokenizer)
|
| 206 |
+
elif self.tokenizer_backend == "tiktoken":
|
| 207 |
+
try:
|
| 208 |
+
import tiktoken
|
| 209 |
+
|
| 210 |
+
self.tokenizer = tiktoken.encoding_for_model(self.model)
|
| 211 |
+
except ModuleNotFoundError as e:
|
| 212 |
+
raise ModuleNotFoundError(
|
| 213 |
+
"Attempted to use 'openai' LM type, but the package `tiktoken` is not installed. "
|
| 214 |
+
"Please install it via `pip install lm-eval[api]` or `pip install -e .[api]`."
|
| 215 |
+
) from e
|
| 216 |
+
if "openai" not in self.base_url:
|
| 217 |
+
eval_logger.warning(
|
| 218 |
+
f"Passed `base_url={self.base_url}` but using (OpenAI) Tiktoken tokenizer backend. "
|
| 219 |
+
"Pass `tokenizer_backend=huggingface` and provide the HF tokenizer name if your model does not use Tiktoken."
|
| 220 |
+
)
|
| 221 |
+
else:
|
| 222 |
+
import transformers
|
| 223 |
+
|
| 224 |
+
assert isinstance(tokenizer, str), "tokenizer must be a string"
|
| 225 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 226 |
+
tokenizer,
|
| 227 |
+
trust_remote_code=trust_remote_code,
|
| 228 |
+
revision=revision,
|
| 229 |
+
use_fast=use_fast_tokenizer,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
@abc.abstractmethod
|
| 233 |
+
def _create_payload(
|
| 234 |
+
self,
|
| 235 |
+
messages: Union[List[List[int]], List[dict], List[str], str],
|
| 236 |
+
*,
|
| 237 |
+
generate: bool = True,
|
| 238 |
+
gen_kwargs: Optional[dict] = None,
|
| 239 |
+
seed: int = 1234,
|
| 240 |
+
eos: str = None,
|
| 241 |
+
**kwargs,
|
| 242 |
+
) -> dict:
|
| 243 |
+
"""This method is responsible for creating the json payload that will be sent to the API."""
|
| 244 |
+
raise NotImplementedError
|
| 245 |
+
|
| 246 |
+
def create_message(
|
| 247 |
+
self,
|
| 248 |
+
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
|
| 249 |
+
generate=False,
|
| 250 |
+
) -> Union[List[List[int]], List[dict], List[str], str]:
|
| 251 |
+
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
|
| 252 |
+
if isinstance(messages[0], JsonChatStr):
|
| 253 |
+
# for chat completions we need to decode the json string to list[dict,...]
|
| 254 |
+
assert self._batch_size == 1, (
|
| 255 |
+
"non-tokenized chat requests are only supported with batch_size=1"
|
| 256 |
+
)
|
| 257 |
+
# list[dict["role":..., "content":...],...]
|
| 258 |
+
return json.loads(messages[0].prompt)
|
| 259 |
+
|
| 260 |
+
if not self.tokenized_requests:
|
| 261 |
+
# if messages are tokenized:
|
| 262 |
+
if isinstance(messages[0][0], int):
|
| 263 |
+
# assuming decoding is lossless. However, this is only for loglikelihood requests
|
| 264 |
+
# as we need to compute the context length. For generations, we don't need to tokenize.
|
| 265 |
+
messages = self.decode_batch(messages)
|
| 266 |
+
if self._batch_size <= 1:
|
| 267 |
+
# if batch is 1 return str
|
| 268 |
+
return messages[0]
|
| 269 |
+
else:
|
| 270 |
+
# list[str,...]
|
| 271 |
+
return messages
|
| 272 |
+
|
| 273 |
+
# list[list[int], ...]
|
| 274 |
+
return messages
|
| 275 |
+
|
| 276 |
+
@staticmethod
|
| 277 |
+
@abc.abstractmethod
|
| 278 |
+
def parse_logprobs(
|
| 279 |
+
outputs: Union[Any, List[Any]],
|
| 280 |
+
tokens: List[List[int]] = None,
|
| 281 |
+
ctxlen: List[int] = None,
|
| 282 |
+
**kwargs,
|
| 283 |
+
) -> List[Tuple[float, bool]]:
|
| 284 |
+
"""Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
|
| 285 |
+
raise NotImplementedError
|
| 286 |
+
|
| 287 |
+
@staticmethod
|
| 288 |
+
@abc.abstractmethod
|
| 289 |
+
def parse_generations(outputs: Union[Any, List[Any]], **kwargs) -> List[str]:
|
| 290 |
+
"""Method used to parse the generations from the (batched) API response. This method should return a list of str"""
|
| 291 |
+
raise NotImplementedError
|
| 292 |
+
|
| 293 |
+
@cached_property
|
| 294 |
+
def api_key(self) -> str:
|
| 295 |
+
"""Override this property to return the API key for the API request."""
|
| 296 |
+
return ""
|
| 297 |
+
|
| 298 |
+
@cached_property
|
| 299 |
+
def header(self) -> dict:
|
| 300 |
+
"""Override this property to return the headers for the API request."""
|
| 301 |
+
return self._header or {"Authorization": f"Bearer {self.api_key}"}
|
| 302 |
+
|
| 303 |
+
@property
|
| 304 |
+
def tokenizer_name(self) -> str:
|
| 305 |
+
"""Must be defined for LM subclasses which implement Chat Templating.
|
| 306 |
+
Should return the name of the tokenizer or chat template used.
|
| 307 |
+
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
|
| 308 |
+
"""
|
| 309 |
+
return ""
|
| 310 |
+
|
| 311 |
+
def apply_chat_template(
|
| 312 |
+
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
|
| 313 |
+
) -> Union[str, JsonChatStr]:
|
| 314 |
+
"""Applies a chat template to a list of chat history between user and model."""
|
| 315 |
+
if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
|
| 316 |
+
return self.tokenizer.apply_chat_template(
|
| 317 |
+
chat_history,
|
| 318 |
+
tokenize=False,
|
| 319 |
+
add_generation_prompt=add_generation_prompt,
|
| 320 |
+
continue_final_message=not add_generation_prompt,
|
| 321 |
+
)
|
| 322 |
+
else:
|
| 323 |
+
# bit of a hack. We'll load back before sending to the API
|
| 324 |
+
return JsonChatStr(
|
| 325 |
+
json.dumps(
|
| 326 |
+
[{**item, "type": "text"} for item in chat_history],
|
| 327 |
+
ensure_ascii=False,
|
| 328 |
+
)
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
@cached_property
|
| 332 |
+
def eot_token_id(self) -> Optional[int]:
|
| 333 |
+
if self.tokenizer is None:
|
| 334 |
+
return None
|
| 335 |
+
else:
|
| 336 |
+
if self.tokenizer_backend == "huggingface":
|
| 337 |
+
return self.tokenizer.eos_token_id
|
| 338 |
+
elif self.tokenizer_backend == "tiktoken":
|
| 339 |
+
return self.tokenizer.eot_token
|
| 340 |
+
|
| 341 |
+
@cached_property
|
| 342 |
+
def eos_string(self) -> Optional[str]:
|
| 343 |
+
if self._eos_string:
|
| 344 |
+
return self._eos_string
|
| 345 |
+
elif self.tokenizer is not None:
|
| 346 |
+
if self.tokenizer_backend == "huggingface":
|
| 347 |
+
return self.tokenizer.eos_token
|
| 348 |
+
elif self.tokenizer_backend == "tiktoken":
|
| 349 |
+
return self.tokenizer.decode([self.tokenizer.eot_token])
|
| 350 |
+
else:
|
| 351 |
+
eval_logger.warning(
|
| 352 |
+
"Cannot determine EOS string to pass to stop sequence. Manually set by passing `eos_string` to model_args."
|
| 353 |
+
)
|
| 354 |
+
return None
|
| 355 |
+
|
| 356 |
+
@cached_property
|
| 357 |
+
def prefix_token_id(self) -> Optional[int]:
|
| 358 |
+
if self.tokenizer is None:
|
| 359 |
+
return None
|
| 360 |
+
else:
|
| 361 |
+
if self.custom_prefix_token_id is not None:
|
| 362 |
+
return self.custom_prefix_token_id
|
| 363 |
+
if self.tokenizer_backend == "huggingface":
|
| 364 |
+
if self.tokenizer.bos_token_id is not None:
|
| 365 |
+
return self.tokenizer.bos_token_id
|
| 366 |
+
return self.tokenizer.eos_token_id
|
| 367 |
+
else:
|
| 368 |
+
return self.tokenizer.eot_token
|
| 369 |
+
|
| 370 |
+
def tok_encode(
|
| 371 |
+
self,
|
| 372 |
+
string: str,
|
| 373 |
+
left_truncate_len: int = None,
|
| 374 |
+
add_special_tokens: bool = False,
|
| 375 |
+
truncation: bool = False,
|
| 376 |
+
**kwargs,
|
| 377 |
+
) -> Union[List[List[int]], List[int], List[str]]:
|
| 378 |
+
if self.tokenizer_backend is None:
|
| 379 |
+
return [string]
|
| 380 |
+
elif self.tokenizer_backend == "huggingface":
|
| 381 |
+
# by default for CausalLM - false or self.add_bos_token is set
|
| 382 |
+
if not add_special_tokens:
|
| 383 |
+
add_special_tokens = False or self.add_bos_token
|
| 384 |
+
encoding: Union[List[List[int]], List[int]] = self.tokenizer(
|
| 385 |
+
string,
|
| 386 |
+
add_special_tokens=add_special_tokens,
|
| 387 |
+
truncation=truncation,
|
| 388 |
+
return_attention_mask=False,
|
| 389 |
+
).input_ids
|
| 390 |
+
|
| 391 |
+
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
|
| 392 |
+
if left_truncate_len:
|
| 393 |
+
if not isinstance(string, str):
|
| 394 |
+
encoding = [enc[-left_truncate_len:] for enc in encoding]
|
| 395 |
+
else:
|
| 396 |
+
encoding = encoding[-left_truncate_len:]
|
| 397 |
+
|
| 398 |
+
return encoding
|
| 399 |
+
|
| 400 |
+
else:
|
| 401 |
+
try:
|
| 402 |
+
encoding = self.tokenizer.encode(string)
|
| 403 |
+
except Exception:
|
| 404 |
+
encoding = self.tokenizer.encode_batch(string)
|
| 405 |
+
return encoding
|
| 406 |
+
|
| 407 |
+
def decode_batch(self, tokens: List[List[int]]) -> List[str]:
|
| 408 |
+
if self.tokenizer_backend == "huggingface":
|
| 409 |
+
return self.tokenizer.batch_decode(tokens)
|
| 410 |
+
elif self.tokenizer_backend == "tiktoken":
|
| 411 |
+
return self.tokenizer.decode_batch(tokens)
|
| 412 |
+
|
| 413 |
+
def model_call(
|
| 414 |
+
self,
|
| 415 |
+
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
|
| 416 |
+
*,
|
| 417 |
+
generate: bool = True,
|
| 418 |
+
gen_kwargs: Optional[Dict] = None,
|
| 419 |
+
**kwargs,
|
| 420 |
+
) -> Optional[dict]:
|
| 421 |
+
# !!! Copy: shared dict for each request, need new object !!!
|
| 422 |
+
gen_kwargs = copy.deepcopy(gen_kwargs)
|
| 423 |
+
try:
|
| 424 |
+
response = requests.post(
|
| 425 |
+
self.base_url,
|
| 426 |
+
json=self._create_payload(
|
| 427 |
+
self.create_message(messages),
|
| 428 |
+
generate=generate,
|
| 429 |
+
gen_kwargs=gen_kwargs,
|
| 430 |
+
seed=self._seed,
|
| 431 |
+
eos=self.eos_string,
|
| 432 |
+
**kwargs,
|
| 433 |
+
),
|
| 434 |
+
headers=self.header,
|
| 435 |
+
verify=self.verify_certificate,
|
| 436 |
+
)
|
| 437 |
+
if not response.ok:
|
| 438 |
+
eval_logger.warning(
|
| 439 |
+
f"API request failed with error message: {response.text}. Retrying..."
|
| 440 |
+
)
|
| 441 |
+
response.raise_for_status()
|
| 442 |
+
return response.json()
|
| 443 |
+
except RetryError:
|
| 444 |
+
eval_logger.error(
|
| 445 |
+
"API request failed after multiple retries. Please check the API status."
|
| 446 |
+
)
|
| 447 |
+
return None
|
| 448 |
+
|
| 449 |
+
async def amodel_call(
|
| 450 |
+
self,
|
| 451 |
+
session: ClientSession,
|
| 452 |
+
sem: asyncio.Semaphore,
|
| 453 |
+
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
|
| 454 |
+
*,
|
| 455 |
+
generate: bool = True,
|
| 456 |
+
cache_keys: list = None,
|
| 457 |
+
ctxlens: Optional[List[int]] = None,
|
| 458 |
+
gen_kwargs: Optional[Dict] = None,
|
| 459 |
+
**kwargs,
|
| 460 |
+
) -> Union[List[str], List[Tuple[float, bool]], None]:
|
| 461 |
+
# !!! Copy: shared dict for each request, need new object !!!
|
| 462 |
+
gen_kwargs = copy.deepcopy(gen_kwargs)
|
| 463 |
+
payload = self._create_payload(
|
| 464 |
+
self.create_message(messages),
|
| 465 |
+
generate=generate,
|
| 466 |
+
gen_kwargs=gen_kwargs,
|
| 467 |
+
seed=self._seed,
|
| 468 |
+
**kwargs,
|
| 469 |
+
)
|
| 470 |
+
cache_method = "generate_until" if generate else "loglikelihood"
|
| 471 |
+
acquired = await sem.acquire()
|
| 472 |
+
try:
|
| 473 |
+
async with session.post(
|
| 474 |
+
self.base_url,
|
| 475 |
+
json=payload,
|
| 476 |
+
headers=self.header,
|
| 477 |
+
) as response:
|
| 478 |
+
if not response.ok:
|
| 479 |
+
error_text = await response.text()
|
| 480 |
+
eval_logger.warning(
|
| 481 |
+
f"API request failed! Status code: {response.status}, "
|
| 482 |
+
f"Response text: {error_text}. Retrying..."
|
| 483 |
+
)
|
| 484 |
+
# raising exception will retry the request
|
| 485 |
+
response.raise_for_status()
|
| 486 |
+
outputs = await response.json()
|
| 487 |
+
answers = (
|
| 488 |
+
self.parse_generations(
|
| 489 |
+
outputs=outputs,
|
| 490 |
+
)
|
| 491 |
+
if generate
|
| 492 |
+
else self.parse_logprobs(
|
| 493 |
+
outputs=outputs,
|
| 494 |
+
tokens=messages,
|
| 495 |
+
ctxlens=ctxlens,
|
| 496 |
+
)
|
| 497 |
+
)
|
| 498 |
+
if cache_keys:
|
| 499 |
+
for res, cache in zip(answers, cache_keys):
|
| 500 |
+
self.cache_hook.add_partial(cache_method, cache, res)
|
| 501 |
+
return answers
|
| 502 |
+
# If the retries also fail
|
| 503 |
+
except BaseException as e:
|
| 504 |
+
eval_logger.error(f"Exception:{repr(e)}, {outputs}, retrying.")
|
| 505 |
+
raise e
|
| 506 |
+
finally:
|
| 507 |
+
if acquired:
|
| 508 |
+
sem.release()
|
| 509 |
+
|
| 510 |
+
def batch_loglikelihood_requests(
|
| 511 |
+
self, chunks: Iterable[List[LogLikelihoodInputs]]
|
| 512 |
+
) -> Tuple[List[List[int]], List[int], List[Tuple[str, str]]]:
|
| 513 |
+
inputs = []
|
| 514 |
+
ctxlens = []
|
| 515 |
+
cache_keys = []
|
| 516 |
+
for chunk in chunks:
|
| 517 |
+
for cache_key, context_enc, continuation_enc in chunk:
|
| 518 |
+
# max_length - 1 as we always have 1 token for generation
|
| 519 |
+
inp = (context_enc + continuation_enc)[-self.max_length :]
|
| 520 |
+
if len(inp) < len(context_enc + continuation_enc):
|
| 521 |
+
eval_logger.warning(
|
| 522 |
+
f"Context length ({len(context_enc)}) + continuation length ({len(continuation_enc)}) > max_length ({self.max_length}). Left truncating context."
|
| 523 |
+
)
|
| 524 |
+
ctxlen = len(context_enc) - max(
|
| 525 |
+
0, len(context_enc) + len(continuation_enc) - self.max_length
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
inputs.append(inp)
|
| 529 |
+
ctxlens.append(ctxlen)
|
| 530 |
+
cache_keys.append(cache_key)
|
| 531 |
+
return inputs, ctxlens, cache_keys
|
| 532 |
+
|
| 533 |
+
async def get_batched_requests(
|
| 534 |
+
self,
|
| 535 |
+
requests: list,
|
| 536 |
+
cache_keys: list,
|
| 537 |
+
*,
|
| 538 |
+
generate: bool = True,
|
| 539 |
+
ctxlens: List[int] = None,
|
| 540 |
+
**kwargs,
|
| 541 |
+
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]:
|
| 542 |
+
ctxlens = ctxlens if ctxlens else [None] * len(requests)
|
| 543 |
+
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
|
| 544 |
+
sem = asyncio.Semaphore(self._concurrent)
|
| 545 |
+
async with ClientSession(
|
| 546 |
+
connector=conn, timeout=ClientTimeout(total=self.timeout)
|
| 547 |
+
) as session:
|
| 548 |
+
retry_: Callable[..., Awaitable[Any]] = retry(
|
| 549 |
+
stop=stop_after_attempt(self.max_retries),
|
| 550 |
+
wait=wait_exponential(multiplier=0.5, min=1, max=10),
|
| 551 |
+
reraise=True,
|
| 552 |
+
before_sleep=lambda retry_state: eval_logger.info(
|
| 553 |
+
f"Retry attempt {retry_state.attempt_number}"
|
| 554 |
+
),
|
| 555 |
+
)(self.amodel_call)
|
| 556 |
+
# Create tasks for each batch of request
|
| 557 |
+
tasks = [
|
| 558 |
+
asyncio.create_task(
|
| 559 |
+
retry_(
|
| 560 |
+
session=session,
|
| 561 |
+
sem=sem,
|
| 562 |
+
messages=message,
|
| 563 |
+
cache_keys=cache_key,
|
| 564 |
+
generate=generate,
|
| 565 |
+
ctxlens=ctxlen,
|
| 566 |
+
**kwargs,
|
| 567 |
+
)
|
| 568 |
+
)
|
| 569 |
+
for message, cache_key, ctxlen in zip(
|
| 570 |
+
chunks(requests, n=self._batch_size),
|
| 571 |
+
chunks(cache_keys, n=self._batch_size),
|
| 572 |
+
chunks(ctxlens, n=self._batch_size),
|
| 573 |
+
)
|
| 574 |
+
]
|
| 575 |
+
|
| 576 |
+
return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
|
| 577 |
+
|
| 578 |
+
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
|
| 579 |
+
assert self.tokenizer is not None, (
|
| 580 |
+
"Tokenizer is required for loglikelihood tasks to compute context lengths."
|
| 581 |
+
)
|
| 582 |
+
res = []
|
| 583 |
+
|
| 584 |
+
def _collate(req: LogLikelihoodInputs):
|
| 585 |
+
"""Defines the key for the sorted method"""
|
| 586 |
+
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
| 587 |
+
# - time estimates will always be over not underestimates, which is more useful for planning
|
| 588 |
+
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
| 589 |
+
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
| 590 |
+
# automatic adaptive batches much much easier to implement
|
| 591 |
+
# - any OOMs will happen right away rather than near the end
|
| 592 |
+
|
| 593 |
+
toks = req[1] + req[2]
|
| 594 |
+
return -len(toks), tuple(toks)
|
| 595 |
+
|
| 596 |
+
re_ord = Collator(
|
| 597 |
+
requests,
|
| 598 |
+
sort_fn=_collate,
|
| 599 |
+
group_by=None,
|
| 600 |
+
)
|
| 601 |
+
# if concurrent then we'll batch in the async context
|
| 602 |
+
chunked = re_ord.get_batched(n=self._batch_size if self._concurrent <= 1 else 0)
|
| 603 |
+
if self._concurrent <= 1:
|
| 604 |
+
pbar = tqdm(desc="Requesting API", total=len(requests))
|
| 605 |
+
for chunk in chunked:
|
| 606 |
+
inputs, ctxlens, cache_keys = self.batch_loglikelihood_requests([chunk])
|
| 607 |
+
|
| 608 |
+
outputs = retry(
|
| 609 |
+
stop=stop_after_attempt(self.max_retries),
|
| 610 |
+
wait=wait_exponential(multiplier=0.5, min=1, max=10),
|
| 611 |
+
reraise=True,
|
| 612 |
+
)(self.model_call)(messages=inputs, generate=False)
|
| 613 |
+
if isinstance(outputs, dict):
|
| 614 |
+
outputs = [outputs]
|
| 615 |
+
for answer_, cache_key in zip(
|
| 616 |
+
self.parse_logprobs(
|
| 617 |
+
outputs=outputs, tokens=inputs, ctxlens=ctxlens
|
| 618 |
+
),
|
| 619 |
+
cache_keys,
|
| 620 |
+
):
|
| 621 |
+
if answer_ is not None:
|
| 622 |
+
res.append(answer_)
|
| 623 |
+
# cache requests that aren't from a loglikelihood_rolling request
|
| 624 |
+
if cache_key is not None:
|
| 625 |
+
self.cache_hook.add_partial(
|
| 626 |
+
"loglikelihood", cache_key, answer_
|
| 627 |
+
)
|
| 628 |
+
pbar.update(1)
|
| 629 |
+
else:
|
| 630 |
+
inputs, ctxlens, cache_keys = self.batch_loglikelihood_requests(chunked)
|
| 631 |
+
res = itertools.chain.from_iterable(
|
| 632 |
+
asyncio.run(
|
| 633 |
+
self.get_batched_requests(
|
| 634 |
+
inputs, cache_keys, generate=False, ctxlens=ctxlens
|
| 635 |
+
)
|
| 636 |
+
)
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
return re_ord.get_original(res)
|
| 640 |
+
|
| 641 |
+
def generate_until(
|
| 642 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
| 643 |
+
) -> List[str]:
|
| 644 |
+
res = []
|
| 645 |
+
|
| 646 |
+
def _collate_gen(_requests):
|
| 647 |
+
# sort by the length of the non-tokenized contexts
|
| 648 |
+
return -len(_requests[0])
|
| 649 |
+
|
| 650 |
+
# Let the API deal with tokenization
|
| 651 |
+
if len(requests[0].args) > 2:
|
| 652 |
+
assert self.tokenizer is None, (
|
| 653 |
+
"tokenizer is not supported for multimodal requests yet!"
|
| 654 |
+
)
|
| 655 |
+
eval_logger.info(
|
| 656 |
+
f"Using max_images {self.max_images}. Set in the model args."
|
| 657 |
+
)
|
| 658 |
+
requests, all_gen_kwargs, auxiliary_args = zip(
|
| 659 |
+
*(req.args for req in requests)
|
| 660 |
+
)
|
| 661 |
+
requests = tuple(
|
| 662 |
+
JsonChatStr(
|
| 663 |
+
json.dumps(
|
| 664 |
+
create_image_prompt(
|
| 665 |
+
y["visual"][: self.max_images], json.loads(x.prompt)
|
| 666 |
+
)
|
| 667 |
+
)
|
| 668 |
+
)
|
| 669 |
+
for x, y in zip(requests, auxiliary_args)
|
| 670 |
+
)
|
| 671 |
+
else:
|
| 672 |
+
requests, all_gen_kwargs = zip(*(req.args for req in requests))
|
| 673 |
+
if self.tokenized_requests:
|
| 674 |
+
encodings_list = self.tok_encode(
|
| 675 |
+
requests, add_special_tokens=self.add_bos_token
|
| 676 |
+
)
|
| 677 |
+
else:
|
| 678 |
+
encodings_list = [None] * len(requests)
|
| 679 |
+
requests = [
|
| 680 |
+
(a, b, c) for a, b, c in zip(requests, all_gen_kwargs, encodings_list)
|
| 681 |
+
]
|
| 682 |
+
|
| 683 |
+
re_ord = Collator(
|
| 684 |
+
requests,
|
| 685 |
+
sort_fn=_collate_gen,
|
| 686 |
+
group_by="gen_kwargs",
|
| 687 |
+
)
|
| 688 |
+
chunked = re_ord.get_batched(
|
| 689 |
+
n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None
|
| 690 |
+
)
|
| 691 |
+
if not self.tokenized_requests:
|
| 692 |
+
eval_logger.info(
|
| 693 |
+
"Tokenized requests are disabled. Context + generation length is not checked."
|
| 694 |
+
)
|
| 695 |
+
if self._concurrent <= 1:
|
| 696 |
+
pbar = tqdm(desc="Requesting API", total=len(requests))
|
| 697 |
+
for chunk in chunked:
|
| 698 |
+
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
|
| 699 |
+
if self.tokenized_requests:
|
| 700 |
+
max_gen_toks = all_gen_kwargs[0].get(
|
| 701 |
+
"max_gen_toks", self._max_gen_toks
|
| 702 |
+
)
|
| 703 |
+
max_context_len = self.max_length - max_gen_toks
|
| 704 |
+
|
| 705 |
+
encodings_list = [x[-max_context_len:] for x in encodings_list]
|
| 706 |
+
|
| 707 |
+
if any(
|
| 708 |
+
len(x) + max_gen_toks > self.max_length for x in encodings_list
|
| 709 |
+
):
|
| 710 |
+
eval_logger.warning(
|
| 711 |
+
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks: ({max_gen_toks}). They were left truncated."
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
req = encodings_list if self.tokenized_requests else contexts
|
| 715 |
+
outputs = retry(
|
| 716 |
+
stop=stop_after_attempt(self.max_retries),
|
| 717 |
+
wait=wait_exponential(multiplier=0.5, min=1, max=10),
|
| 718 |
+
reraise=True,
|
| 719 |
+
)(self.model_call)(
|
| 720 |
+
messages=req,
|
| 721 |
+
generate=True,
|
| 722 |
+
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
|
| 723 |
+
)
|
| 724 |
+
for generated_text, context in zip(
|
| 725 |
+
self.parse_generations(
|
| 726 |
+
outputs=outputs,
|
| 727 |
+
contexts=contexts,
|
| 728 |
+
),
|
| 729 |
+
contexts,
|
| 730 |
+
):
|
| 731 |
+
if generated_text is not None:
|
| 732 |
+
res.append(generated_text)
|
| 733 |
+
|
| 734 |
+
# partial caching
|
| 735 |
+
if context is not None:
|
| 736 |
+
self.cache_hook.add_partial(
|
| 737 |
+
"generate_until",
|
| 738 |
+
(context, all_gen_kwargs[0]),
|
| 739 |
+
generated_text,
|
| 740 |
+
)
|
| 741 |
+
pbar.update(1)
|
| 742 |
+
else:
|
| 743 |
+
for chunk in chunked:
|
| 744 |
+
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
|
| 745 |
+
if self.tokenized_requests:
|
| 746 |
+
max_gen_toks = all_gen_kwargs[0].get(
|
| 747 |
+
"max_gen_toks", self._max_gen_toks
|
| 748 |
+
)
|
| 749 |
+
max_context_len = self.max_length - max_gen_toks
|
| 750 |
+
|
| 751 |
+
encodings_list = [x[-max_context_len:] for x in encodings_list]
|
| 752 |
+
|
| 753 |
+
if any(
|
| 754 |
+
len(x) + max_gen_toks > self.max_length for x in encodings_list
|
| 755 |
+
):
|
| 756 |
+
eval_logger.warning(
|
| 757 |
+
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks ({max_gen_toks}). They were left truncated."
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
req = encodings_list if self.tokenized_requests else contexts
|
| 761 |
+
results = itertools.chain.from_iterable(
|
| 762 |
+
asyncio.run(
|
| 763 |
+
self.get_batched_requests(
|
| 764 |
+
req,
|
| 765 |
+
cache_keys=[(ctx, all_gen_kwargs[0]) for ctx in contexts],
|
| 766 |
+
generate=True,
|
| 767 |
+
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
|
| 768 |
+
)
|
| 769 |
+
)
|
| 770 |
+
)
|
| 771 |
+
res.extend(results)
|
| 772 |
+
|
| 773 |
+
return re_ord.get_original(res)
|
| 774 |
+
|
| 775 |
+
def loglikelihood_rolling(
|
| 776 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
| 777 |
+
) -> List[float]:
|
| 778 |
+
loglikelihoods = []
|
| 779 |
+
|
| 780 |
+
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
|
| 781 |
+
rolling_token_windows = list(
|
| 782 |
+
map(
|
| 783 |
+
utils.make_disjoint_window,
|
| 784 |
+
utils.get_rolling_token_windows(
|
| 785 |
+
token_list=self.tok_encode(string),
|
| 786 |
+
prefix_token=self.prefix_token_id,
|
| 787 |
+
# max_seq_len - (1 for context)
|
| 788 |
+
max_seq_len=self.max_length - 1,
|
| 789 |
+
context_len=1,
|
| 790 |
+
),
|
| 791 |
+
)
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
|
| 795 |
+
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
|
| 796 |
+
|
| 797 |
+
string_nll = self._loglikelihood_tokens(
|
| 798 |
+
rolling_token_windows,
|
| 799 |
+
disable_tqdm=True,
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
# discard is_greedy
|
| 803 |
+
string_nll = [x[0] for x in string_nll]
|
| 804 |
+
|
| 805 |
+
string_nll = sum(string_nll)
|
| 806 |
+
loglikelihoods.append(string_nll)
|
| 807 |
+
|
| 808 |
+
# cache this loglikelihood_rolling request
|
| 809 |
+
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
|
| 810 |
+
return loglikelihoods
|
lm-evaluation-harness/lm_eval/models/dummy.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
from lm_eval.api.model import LM
|
| 6 |
+
from lm_eval.api.registry import register_model
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@register_model("dummy")
|
| 10 |
+
class DummyLM(LM):
|
| 11 |
+
def __init__(self) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
@classmethod
|
| 15 |
+
def create_from_arg_string(cls, arg_string, additional_config=None):
|
| 16 |
+
return cls()
|
| 17 |
+
|
| 18 |
+
def loglikelihood(self, requests, disable_tqdm: bool = False):
|
| 19 |
+
res = []
|
| 20 |
+
|
| 21 |
+
for _ in tqdm(requests, disable=disable_tqdm):
|
| 22 |
+
res.append((-random.random(), False))
|
| 23 |
+
|
| 24 |
+
return res
|
| 25 |
+
|
| 26 |
+
def generate_until(self, requests, disable_tqdm: bool = False):
|
| 27 |
+
res = []
|
| 28 |
+
|
| 29 |
+
for request in tqdm(requests, disable=disable_tqdm):
|
| 30 |
+
res.append("lol")
|
| 31 |
+
assert request.arguments[0].strip() != ""
|
| 32 |
+
|
| 33 |
+
return res
|
| 34 |
+
|
| 35 |
+
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
|
| 36 |
+
res = []
|
| 37 |
+
|
| 38 |
+
for _ in tqdm(requests, disable=disable_tqdm):
|
| 39 |
+
res.append(-random.random())
|
| 40 |
+
|
| 41 |
+
return res
|
lm-evaluation-harness/lm_eval/models/gguf.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import requests
|
| 5 |
+
from requests.exceptions import RequestException
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
from lm_eval.api.model import LM
|
| 9 |
+
from lm_eval.api.registry import register_model
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_result(logprobs, context_length):
|
| 16 |
+
is_greedy = True
|
| 17 |
+
offsets = logprobs["text_offset"]
|
| 18 |
+
tokens = logprobs["tokens"]
|
| 19 |
+
tokens_logprobs = logprobs["token_logprobs"]
|
| 20 |
+
|
| 21 |
+
idx = 0
|
| 22 |
+
while offsets[idx] < context_length:
|
| 23 |
+
idx += 1
|
| 24 |
+
continuation_logprobs = sum(tokens_logprobs[idx:-1])
|
| 25 |
+
for i in range(idx, len(tokens)):
|
| 26 |
+
token = tokens[i]
|
| 27 |
+
top_tokens = logprobs["top_logprobs"][i]
|
| 28 |
+
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
|
| 29 |
+
if top_token != token:
|
| 30 |
+
is_greedy = False
|
| 31 |
+
break
|
| 32 |
+
|
| 33 |
+
return continuation_logprobs, is_greedy
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@register_model("gguf", "ggml")
|
| 37 |
+
class GGUFLM(LM):
|
| 38 |
+
def __init__(self, base_url=None, max_length=2048, **kwargs):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.base_url = base_url
|
| 41 |
+
assert self.base_url, "must pass `base_url` to use GGUF LM!"
|
| 42 |
+
self.logprobs = 10
|
| 43 |
+
self.temperature = 0.0
|
| 44 |
+
self.max_length = max_length
|
| 45 |
+
|
| 46 |
+
def gguf_completion(
|
| 47 |
+
self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs
|
| 48 |
+
):
|
| 49 |
+
for _ in range(retries):
|
| 50 |
+
try:
|
| 51 |
+
prompt = context
|
| 52 |
+
request = {
|
| 53 |
+
"prompt": prompt,
|
| 54 |
+
"logprobs": self.logprobs,
|
| 55 |
+
"temperature": self.temperature,
|
| 56 |
+
}
|
| 57 |
+
if continuation:
|
| 58 |
+
prompt += continuation
|
| 59 |
+
request.update({"prompt": prompt, "max_tokens": 1, "echo": True})
|
| 60 |
+
if stop is not None:
|
| 61 |
+
request["stop"] = stop
|
| 62 |
+
response = requests.post(
|
| 63 |
+
f"{self.base_url}/v1/completions", json=request
|
| 64 |
+
)
|
| 65 |
+
response.raise_for_status()
|
| 66 |
+
return response.json()
|
| 67 |
+
except RequestException as e:
|
| 68 |
+
logger.error(f"RequestException: {e}")
|
| 69 |
+
time.sleep(delay) # wait before retrying
|
| 70 |
+
else:
|
| 71 |
+
raise RuntimeError(
|
| 72 |
+
f"Failed to get a valid response after {retries} retries."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def loglikelihood(self, requests, disable_tqdm: bool = False):
|
| 76 |
+
if not requests:
|
| 77 |
+
return []
|
| 78 |
+
res = []
|
| 79 |
+
for context, continuation in tqdm(
|
| 80 |
+
[req.args for req in requests], disable=disable_tqdm
|
| 81 |
+
):
|
| 82 |
+
response = self.gguf_completion(context=context, continuation=continuation)
|
| 83 |
+
if response and "choices" in response and response["choices"]:
|
| 84 |
+
choice = response["choices"][0]
|
| 85 |
+
logprobs = choice.get("logprobs")
|
| 86 |
+
if (
|
| 87 |
+
logprobs
|
| 88 |
+
and "token_logprobs" in logprobs
|
| 89 |
+
and logprobs["token_logprobs"]
|
| 90 |
+
):
|
| 91 |
+
logprob, is_greedy = get_result(logprobs, len(context))
|
| 92 |
+
res.append((logprob, is_greedy))
|
| 93 |
+
else:
|
| 94 |
+
logger.warning(
|
| 95 |
+
"Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list."
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
logger.error(
|
| 99 |
+
f"Invalid response for loglikelihood. Response: {response}"
|
| 100 |
+
)
|
| 101 |
+
assert False
|
| 102 |
+
return res
|
| 103 |
+
|
| 104 |
+
def generate_until(self, requests, disable_tqdm: bool = False):
|
| 105 |
+
if not requests:
|
| 106 |
+
return []
|
| 107 |
+
|
| 108 |
+
res = []
|
| 109 |
+
for request in tqdm([req.args for req in requests], disable=disable_tqdm):
|
| 110 |
+
inp = request[0]
|
| 111 |
+
request_args = request[1]
|
| 112 |
+
until = request_args.get("until", ["</s>"])
|
| 113 |
+
response = self.gguf_completion(context=inp, stop=until)
|
| 114 |
+
if response and "choices" in response and response["choices"]:
|
| 115 |
+
choice = response["choices"][0]
|
| 116 |
+
if "text" in choice:
|
| 117 |
+
generated_text = choice["text"].strip()
|
| 118 |
+
res.append(generated_text)
|
| 119 |
+
else:
|
| 120 |
+
logger.error(
|
| 121 |
+
f"Invalid response for greedy_until. Response: {response}"
|
| 122 |
+
)
|
| 123 |
+
res.append(None) # Add default value in case of error
|
| 124 |
+
else:
|
| 125 |
+
logger.error(f"Invalid response for greedy_until. Response: {response}")
|
| 126 |
+
res.append(None) # Add default value in case of error
|
| 127 |
+
return res
|
| 128 |
+
|
| 129 |
+
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
|
| 130 |
+
raise NotImplementedError(
|
| 131 |
+
"loglikelihood_rolling not yet supported for GGUF models"
|
| 132 |
+
)
|