henrycastillo commited on
Commit
9413fee
·
verified ·
1 Parent(s): 47d993b

more lm eval harness

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. lm-evaluation-harness/docs/API_guide.md +211 -0
  3. lm-evaluation-harness/docs/CONTRIBUTING.md +83 -0
  4. lm-evaluation-harness/docs/README.md +11 -0
  5. lm-evaluation-harness/docs/chat-template-readme.md +31 -0
  6. lm-evaluation-harness/docs/decontamination.md +76 -0
  7. lm-evaluation-harness/docs/footguns.md +58 -0
  8. lm-evaluation-harness/docs/img/fewshot_example_gpt3.png +3 -0
  9. lm-evaluation-harness/docs/interface.md +170 -0
  10. lm-evaluation-harness/docs/model_guide.md +192 -0
  11. lm-evaluation-harness/docs/new_task_guide.md +521 -0
  12. lm-evaluation-harness/docs/task_guide.md +335 -0
  13. lm-evaluation-harness/examples/lm-eval-overview.ipynb +1240 -0
  14. lm-evaluation-harness/examples/transformer-lens.py +59 -0
  15. lm-evaluation-harness/examples/visualize-wandb.ipynb +172 -0
  16. lm-evaluation-harness/examples/visualize-zeno.ipynb +115 -0
  17. lm-evaluation-harness/lm_eval/__init__.py +21 -0
  18. lm-evaluation-harness/lm_eval/__main__.py +536 -0
  19. lm-evaluation-harness/lm_eval/api/__init__.py +0 -0
  20. lm-evaluation-harness/lm_eval/api/filter.py +56 -0
  21. lm-evaluation-harness/lm_eval/api/group.py +115 -0
  22. lm-evaluation-harness/lm_eval/api/instance.py +38 -0
  23. lm-evaluation-harness/lm_eval/api/metrics.py +629 -0
  24. lm-evaluation-harness/lm_eval/api/model.py +502 -0
  25. lm-evaluation-harness/lm_eval/api/registry.py +196 -0
  26. lm-evaluation-harness/lm_eval/api/samplers.py +232 -0
  27. lm-evaluation-harness/lm_eval/api/task.py +1885 -0
  28. lm-evaluation-harness/lm_eval/caching/__init__.py +0 -0
  29. lm-evaluation-harness/lm_eval/caching/cache.py +59 -0
  30. lm-evaluation-harness/lm_eval/decontamination/__init__.py +0 -0
  31. lm-evaluation-harness/lm_eval/decontamination/archiver.py +174 -0
  32. lm-evaluation-harness/lm_eval/decontamination/decontaminate.py +166 -0
  33. lm-evaluation-harness/lm_eval/decontamination/janitor.py +329 -0
  34. lm-evaluation-harness/lm_eval/evaluator.py +787 -0
  35. lm-evaluation-harness/lm_eval/evaluator_utils.py +554 -0
  36. lm-evaluation-harness/lm_eval/filters/__init__.py +25 -0
  37. lm-evaluation-harness/lm_eval/filters/custom.py +17 -0
  38. lm-evaluation-harness/lm_eval/filters/decontamination.py +25 -0
  39. lm-evaluation-harness/lm_eval/filters/extraction.py +233 -0
  40. lm-evaluation-harness/lm_eval/filters/selection.py +61 -0
  41. lm-evaluation-harness/lm_eval/filters/transformation.py +122 -0
  42. lm-evaluation-harness/lm_eval/loggers/__init__.py +2 -0
  43. lm-evaluation-harness/lm_eval/loggers/evaluation_tracker.py +537 -0
  44. lm-evaluation-harness/lm_eval/loggers/utils.py +149 -0
  45. lm-evaluation-harness/lm_eval/loggers/wandb_logger.py +358 -0
  46. lm-evaluation-harness/lm_eval/models/__init__.py +35 -0
  47. lm-evaluation-harness/lm_eval/models/anthropic_llms.py +382 -0
  48. lm-evaluation-harness/lm_eval/models/api_models.py +810 -0
  49. lm-evaluation-harness/lm_eval/models/dummy.py +41 -0
  50. lm-evaluation-harness/lm_eval/models/gguf.py +132 -0
.gitattributes CHANGED
@@ -45,3 +45,5 @@ records/012625_BatchSize/ablations.png filter=lfs diff=lfs merge=lfs -text
45
  records/102924_Optimizers/nanogpt_speedrun81w.png filter=lfs diff=lfs merge=lfs -text
46
  records/102924_Optimizers/nanogpt_speedrun82w.png filter=lfs diff=lfs merge=lfs -text
47
  records/110624_ShortcutsTweaks/nanogpt_speedrun111.png filter=lfs diff=lfs merge=lfs -text
 
 
 
45
  records/102924_Optimizers/nanogpt_speedrun81w.png filter=lfs diff=lfs merge=lfs -text
46
  records/102924_Optimizers/nanogpt_speedrun82w.png filter=lfs diff=lfs merge=lfs -text
47
  records/110624_ShortcutsTweaks/nanogpt_speedrun111.png filter=lfs diff=lfs merge=lfs -text
48
+ lm-evaluation-harness/docs/img/fewshot_example_gpt3.png filter=lfs diff=lfs merge=lfs -text
49
+ lm-evaluation-harness/lm_eval/tasks/noreval/noreval.jpg filter=lfs diff=lfs merge=lfs -text
lm-evaluation-harness/docs/API_guide.md ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TemplateAPI Usage Guide
2
+
3
+ The `TemplateAPI` class is a versatile superclass designed to facilitate the integration of various API-based language models into the lm-evaluation-harness framework. This guide will explain how to use and extend the `TemplateAPI` class to implement your own API models. If your API implements the OpenAI API you can use the `local-completions` or the `local-chat-completions` (defined [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/openai_completions.py)) model types, which can also serve as examples of how to effectively subclass this template.
4
+
5
+ ## Overview
6
+
7
+ The `TemplateAPI` class provides a template for creating API-based model implementations. It handles common functionalities such as:
8
+
9
+ - Tokenization (optional)
10
+ - Batch processing
11
+ - Caching
12
+ - Retrying failed requests
13
+ - Parsing API responses
14
+
15
+ To use this class, you typically need to subclass it and implement specific methods for your API.
16
+
17
+ ## Key Methods to Implement
18
+
19
+ When subclassing `TemplateAPI`, you need to implement the following methods:
20
+
21
+ 1. `_create_payload`: Creates the JSON payload for API requests.
22
+ 2. `parse_logprobs`: Parses log probabilities from API responses.
23
+ 3. `parse_generations`: Parses generated text from API responses.
24
+
25
+ Optional Properties:
26
+
27
+ 4. `header`: Returns the headers for the API request.
28
+ 5. `api_key`: Returns the API key for authentication (if required).
29
+
30
+ You may also need to override other methods or properties depending on your API's specific requirements.
31
+
32
+ > [!NOTE]
33
+ > Currently loglikelihood and MCQ based tasks (such as MMLU) are only supported for completion endpoints. Not for chat-completion — those that expect a list of dicts — endpoints! Completion APIs which support instruct tuned models can be evaluated with the `--apply_chat_template` option in order to simultaneously evaluate models using a chat template format while still being able to access the model logits needed for loglikelihood-based tasks.
34
+
35
+ ## TemplateAPI Arguments
36
+
37
+ When initializing a `TemplateAPI` instance or a subclass, you can provide several arguments to customize its behavior. Here's a detailed explanation of some important arguments:
38
+
39
+ - `model` or `pretrained` (str):
40
+ - The name or identifier of the model to use.
41
+ - `model` takes precedence over `pretrained` when both are provided.
42
+
43
+ - `base_url` (str):
44
+ - The base URL for the API endpoint.
45
+
46
+ - `tokenizer` (str, optional):
47
+ - The name or path of the tokenizer to use.
48
+ - If not provided, it defaults to using the same tokenizer name as the model.
49
+
50
+ - `num_concurrent` (int):
51
+ - Number of concurrent requests to make to the API.
52
+ - Useful for APIs that support parallel processing.
53
+ - Default is 1 (sequential processing).
54
+
55
+ - `timeout` (int, optional):
56
+ - Timeout for API requests in seconds.
57
+ - Default is 30.
58
+
59
+ - `tokenized_requests` (bool):
60
+ - Determines whether the input is pre-tokenized. Defaults to `True`.
61
+ - Requests can be sent in either tokenized form (`list[list[int]]`) or as text (`list[str]`, or `str` for batch_size=1).
62
+ - For loglikelihood-based tasks, prompts require tokenization to calculate the context length. If `False` prompts are decoded back to text before being sent to the API.
63
+ - Not as important for `generate_until` tasks.
64
+ - Ignored for chat formatted inputs (list[dict...]) or if tokenizer_backend is None.
65
+
66
+ - `tokenizer_backend` (str, optional):
67
+ - Required for loglikelihood-based or MCQ tasks.
68
+ - Specifies the tokenizer library to use. Options are "tiktoken", "huggingface", or None.
69
+ - Default is "huggingface".
70
+
71
+ - `max_length` (int, optional):
72
+ - Maximum length of input + output.
73
+ - Default is 2048.
74
+
75
+ - `max_retries` (int, optional):
76
+ - Maximum number of retries for failed API requests.
77
+ - Default is 3.
78
+
79
+ - `max_gen_toks` (int, optional):
80
+ - Maximum number of tokens to generate in completion tasks.
81
+ - Default is 256 or set in task yaml.
82
+
83
+ - `batch_size` (int or str, optional):
84
+ - Number of requests to batch together (if the API supports batching).
85
+ - Can be an integer or "auto" (which defaults to 1 for API models).
86
+ - Default is 1.
87
+
88
+ - `seed` (int, optional):
89
+ - Random seed for reproducibility.
90
+ - Default is 1234.
91
+
92
+ - `add_bos_token` (bool, optional):
93
+ - Whether to add the beginning-of-sequence token to inputs (when tokenizing).
94
+ - Default is False.
95
+
96
+ - `custom_prefix_token_id` (int, optional):
97
+ - Custom token ID to use as a prefix for inputs.
98
+ - If not provided, uses the model's default BOS or EOS token (if `add_bos_token` is True).
99
+
100
+ - `verify_certificate` (bool, optional):
101
+ - Whether to validate the certificate of the API endpoint (if HTTPS).
102
+ - Default is True.
103
+
104
+ - `header` (dict, optional):
105
+ - Custom headers for API requests.
106
+ - If not provided, uses `{"Authorization": f"Bearer {self.api_key}"}` by default.
107
+
108
+ Example usage:
109
+
110
+ ```python
111
+ class MyAPIModel(TemplateAPI):
112
+ def __init__(self, **kwargs):
113
+ super().__init__(
114
+ model="my-model",
115
+ base_url="https://api.mymodel.com/v1/completions",
116
+ tokenizer_backend="huggingface",
117
+ num_concurrent=5,
118
+ max_retries=5,
119
+ batch_size=10,
120
+ **kwargs
121
+ )
122
+
123
+ # Implement other required methods...
124
+ ```
125
+
126
+ When subclassing `TemplateAPI`, you can override these arguments in your `__init__` method to set default values specific to your API. You can also add additional (potentially user-specified) arguments as needed for your specific implementation.
127
+
128
+ ## Example Implementation: OpenAI API
129
+
130
+ The `OpenAICompletionsAPI` and `OpenAIChatCompletion` ([here](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/openai_completions.py) classes demonstrate how to implement API models using the `TemplateAPI` class. Here's a breakdown of the key components:
131
+
132
+ ### 1. Subclassing and Initialization
133
+
134
+ ```python
135
+ @register_model("openai-completions")
136
+ class OpenAICompletionsAPI(LocalCompletionsAPI):
137
+ def __init__(
138
+ self,
139
+ base_url="https://api.openai.com/v1/completions",
140
+ tokenizer_backend="tiktoken",
141
+ **kwargs,
142
+ ):
143
+ super().__init__(
144
+ base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
145
+ )
146
+ ```
147
+
148
+ ### 2. Implementing API Key Retrieval
149
+
150
+ ```python
151
+ @cached_property
152
+ def api_key(self):
153
+ key = os.environ.get("OPENAI_API_KEY", None)
154
+ if key is None:
155
+ raise ValueError(
156
+ "API key not found. Please set the OPENAI_API_KEY environment variable."
157
+ )
158
+ return key
159
+ ```
160
+
161
+ ### 3. Creating the Payload
162
+
163
+ ```python
164
+ def _create_payload(
165
+ self,
166
+ messages: Union[List[List[int]], List[dict], List[str], str],
167
+ generate=False,
168
+ gen_kwargs: Optional[dict] = None,
169
+ **kwargs,
170
+ ) -> dict:
171
+ if generate:
172
+ # ... (implementation for generation)
173
+ else:
174
+ # ... (implementation for log likelihood)
175
+ ```
176
+
177
+ ### 4. Parsing API Responses
178
+
179
+ ```python
180
+ @staticmethod
181
+ def parse_logprobs(
182
+ outputs: Union[Dict, List[Dict]],
183
+ tokens: List[List[int]] = None,
184
+ ctxlens: List[int] = None,
185
+ **kwargs,
186
+ ) -> List[Tuple[float, bool]]:
187
+ # ... (implementation)
188
+
189
+ @staticmethod
190
+ def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
191
+ # ... (implementation)
192
+ ```
193
+
194
+ The requests are initiated in the `model_call` or the `amodel_call` methods.
195
+
196
+ ## Implementing Your Own API Model
197
+
198
+ To implement your own API model:
199
+
200
+ 1. Subclass `TemplateAPI` or one of its subclasses (e.g., `LocalCompletionsAPI`).
201
+ 2. Override the `__init__` method if you need to set specific parameters.
202
+ 3. Implement the `_create_payload` and `header` methods to create the appropriate payload for your API.
203
+ 4. Implement the `parse_logprobs` and `parse_generations` methods to parse your API's responses.
204
+ 5. Override the `api_key` property if your API requires authentication.
205
+ 6. Override any other methods as necessary to match your API's behavior.
206
+
207
+ ## Best Practices
208
+
209
+ 1. Use the `@register_model` decorator to register your model with the framework (and import it in `lm_eval/models/__init__.py`!).
210
+ 2. Use environment variables for sensitive information like API keys.
211
+ 3. Properly handle batching and concurrent requests if supported by your API.
lm-evaluation-harness/docs/CONTRIBUTING.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to LM Evaluation Harness
2
+
3
+ Welcome and thank you for your interest in the LM Evaluation Harness! We welcome contributions and feedback and appreciate your time spent with our library, and hope you find it useful!
4
+
5
+ ## Important Resources
6
+
7
+ There are several places information about LM Evaluation Harness is located:
8
+
9
+ - Our [documentation pages](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/docs)
10
+ - We occasionally use [GitHub Milestones](https://github.com/EleutherAI/lm-evaluation-harness/milestones) to track progress toward specific near-term version releases.
11
+ - We maintain a [Project Board](https://github.com/orgs/EleutherAI/projects/25) for tracking current work items and PRs, and for future roadmap items or feature requests.
12
+ - Further discussion and support conversations are located in the #lm-thunderdome channel of the [EleutherAI discord](https://discord.gg/eleutherai).
13
+
14
+ ## Code Style
15
+
16
+ LM Evaluation Harness uses [ruff](https://github.com/astral-sh/ruff) for linting via [pre-commit](https://pre-commit.com/).
17
+
18
+ You can install linters and dev tools via
19
+
20
+ ```pip install lm_eval[dev]``` or ```pip install -e ".[dev]"```
21
+
22
+ Then, run
23
+
24
+ ```pre-commit install```
25
+
26
+ in order to ensure linters and other checks will be run upon committing.
27
+
28
+ ## Testing
29
+
30
+ We use [pytest](https://docs.pytest.org/en/latest/) for running unit tests. All library unit tests can be run via:
31
+
32
+ ```bash
33
+ python -m pytest --showlocals -s -vv -n=auto --ignore=tests/models/test_openvino.py
34
+ ```
35
+
36
+ ## Contributor License Agreement
37
+
38
+ We ask that new contributors agree to a Contributor License Agreement affirming that EleutherAI has the rights to use your contribution to our library.
39
+ First-time pull requests will have a reply added by @CLAassistant containing instructions for how to confirm this, and we require it before merging your PR.
40
+
41
+ ## Contribution Best Practices
42
+
43
+ We recommend a few best practices to make your contributions or reported errors easier to assist with.
44
+
45
+ **For Pull Requests:**
46
+
47
+ - PRs should be titled descriptively, and be opened with a brief description of the scope and intent of the new contribution.
48
+ - New features should have appropriate documentation added alongside them.
49
+ - Aim for code maintainability, and minimize code copying.
50
+ - If opening a task, try to share test results on the task using a publicly-available model, and if any public results are available on the task, compare to them.
51
+
52
+ **For Feature Requests:**
53
+
54
+ - Provide a short paragraph's worth of description. What is the feature you are requesting? What is its motivation, and an example use case of it? How does this differ from what is currently supported?
55
+
56
+ **For Bug Reports**:
57
+
58
+ - Provide a short description of the bug.
59
+ - Provide a *reproducible example*--what is the command you run with our library that results in this error? Have you tried any other steps to resolve it?
60
+ - Provide a *full error traceback* of the error that occurs, if applicable. A one-line error message or small screenshot snippet is unhelpful without the surrounding context.
61
+ - Note what version of the codebase you are using, and any specifics of your environment and setup that may be relevant.
62
+
63
+ **For Requesting New Tasks**:
64
+
65
+ - Provide a 1-2 sentence description of what the task is and what it evaluates.
66
+ - Provide a link to the paper introducing the task.
67
+ - Provide a link to where the dataset can be found.
68
+ - Provide a link to a paper containing results on an open-source model on the task, for use in comparisons and implementation validation.
69
+ - If applicable, link to any codebase that has implemented the task (especially the original publication's codebase, if existent).
70
+
71
+ ## How Can I Get Involved?
72
+
73
+ To quickly get started, we maintain a list of good first issues, which can be found [on our project board](https://github.com/orgs/EleutherAI/projects/25/views/8) or by [filtering GH Issues](https://github.com/EleutherAI/lm-evaluation-harness/issues?q=is%3Aopen+label%3A%22good+first+issue%22+label%3A%22help+wanted%22). These are typically smaller code changes or self-contained features which can be added without extensive familiarity with library internals, and we recommend new contributors consider taking a stab at one of these first if they are feeling uncertain where to begin.
74
+
75
+ There are a number of distinct ways to contribute to LM Evaluation Harness, and all are extremely helpful! A sampling of ways to contribute include:
76
+
77
+ - **Implementing and verifying new evaluation tasks**: Is there a task you'd like to see LM Evaluation Harness support? Consider opening an issue requesting it, or helping add it! Verifying and cross-checking task implementations with their original versions is also a very valuable form of assistance in ensuring standardized evaluation.
78
+ - **Improving documentation** - Improvements to the documentation, or noting pain points / gaps in documentation, are helpful in order for us to improve the user experience of the library and clarity + coverage of documentation.
79
+ - **Testing and devops** - We are very grateful for any assistance in adding tests for the library that can be run for new PRs, and other devops workflows.
80
+ - **Adding new modeling / inference library integrations** - We hope to support a broad range of commonly-used inference libraries popular among the community, and welcome PRs for new integrations, so long as they are documented properly and maintainable.
81
+ - **Proposing or Contributing New Features** - We want LM Evaluation Harness to support a broad range of evaluation usecases. If you have a feature that is not currently supported but desired, feel free to open an issue describing the feature and, if applicable, how you intend to implement it. We would be happy to give feedback on the cleanest way to implement new functionalities and are happy to coordinate with interested contributors via GH discussions or via discord.
82
+
83
+ We hope that this has been helpful, and appreciate your interest in contributing! Further questions can be directed to [our Discord](discord.gg/eleutherai).
lm-evaluation-harness/docs/README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Eval Harness Documentation
2
+
3
+ Welcome to the docs for the LM Evaluation Harness!
4
+
5
+ ## Table of Contents
6
+
7
+ * To learn about the public interface of the library, as well as how to evaluate via the command line or as integrated into an external library, see the [Interface](./interface.md).
8
+ * To learn how to add a new library, API, or model type to the library, as well as a quick explainer on the types of ways to evaluate an LM, see the [Model Guide](./model_guide.md).
9
+ * For an extended description of how to extend the library to new model classes served over an API, see the [API Guide](./API_guide.md).
10
+ * For a crash course on adding new tasks to the library, see our [New Task Guide](./new_task_guide.md).
11
+ * To learn more about pushing the limits of task configuration that the Eval Harness supports, see the [Task Configuration Guide](./task_guide.md).
lm-evaluation-harness/docs/chat-template-readme.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Chat Template Delimiter Handling Update
2
+
3
+ ## Overview
4
+
5
+ This change modifies how delimiters are handled when applying chat templates in the request construction process for likelihood and multiple-choice based tasks. When `apply_chat_template` is set to `True`, the target delimiter is now set to an empty string instead of using the configured delimiter.
6
+
7
+ ## Background
8
+
9
+ By default, the system uses a target delimiter (typically a whitespace " ") between the context and target text when constructing prompts. The full string is constructed as:
10
+
11
+ ```text
12
+ doc_to_text(doc) + target_delimiter + doc_to_target(doc)
13
+ ```
14
+
15
+ While this worked well for base models where we wanted the model to predict a single whitespace followed by the answer, chat models have their own formatting conventions that handle spacing differently.
16
+
17
+ ## The Change
18
+
19
+ - When `apply_chat_template=True`, the target delimiter is now empty ("") instead of the default whitespace
20
+ - This prevents interference between chat template formatting and the default delimiter system
21
+ - Particularly important for multiple choice tasks where the template itself handles spacing
22
+
23
+ ## Example
24
+
25
+ ```text
26
+ # Before (with default delimiter " ")
27
+ <user>Question: What color is the sky?\nAnswer:<assistant> blue
28
+
29
+ # After
30
+ <user>Question: What color is the sky?\nAnswer:<assistant>blue
31
+ ```
lm-evaluation-harness/docs/decontamination.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Decontamination
2
+
3
+ ## Usage
4
+
5
+ The provided directory should contain
6
+ the ngram files and info.json produced in "Pile Ngram Generation" further down.
7
+
8
+ ```bash
9
+ python -m lm_eval \
10
+ --model gpt2 \
11
+ --device 0 \
12
+ --tasks sciq
13
+ ```
14
+
15
+ ## Background
16
+
17
+ Downstream evaluations test model generalization, and are less useful when test set data also exists in the training set, referred to as leakage or contamination.
18
+
19
+ Filtering your training set against the test set is a good first step, however this isn't always possible, as in the case of a new benchmark or one that wasn't considered prior to model training. When training set filtering isn't possible, it is useful to measure the impact of test set leakage by detecting the contaminated test examples and producing a clean version of the benchmark.
20
+
21
+ The basis for our decontamination procedure can be found in Appendix C of "Language Models are Few-Shot Learners". OpenAI defined a test document as contaminated if any N-gram overlap existed with any training document. They used a range of N values between 8 and 13 depending on dataset, while we just used 13 for simplicity.
22
+
23
+ ## Implementation
24
+
25
+ Contamination detection can be found in `lm_eval/decontaminate.py` with supporting code in `lm_eval/decontamination/`.
26
+
27
+ decontaminate.py does the following:
28
+
29
+ 1. Build dictionaries of all ngrams and their corresponding evaluation/document ids.
30
+ 2. Scan through sorted files containing training set n-grams.
31
+ 3. If a match is found, the corresponding evaluation/document combinations are marked as contaminated.
32
+
33
+ `lm_eval/evaluator.py` can then produce a clean version of the benchmark by excluding the results of contaminated documents. For each metric, a clean version will be shown in the results with a "decontaminate" suffix.
34
+
35
+ This is disabled by default for new tasks, to support decontamination on a task override the "should_decontaminate" and "doc_to_decontamination_query" methods. For more details see the [task guide](task_guide.md).
36
+
37
+ ## Pile Ngram Generation
38
+
39
+ The relevant scripts can be found in `scripts/clean_training_data`, which also import from
40
+ `lm_eval/decontamination/`
41
+
42
+ 1. git clone https://github.com/EleutherAI/lm-evaluation-harness.git
43
+ 2. pip install -r requirements.txt
44
+ 3. Download The Pile from [The Eye](https://the-eye.eu/public/AI/pile/train/)
45
+ 4. Place pile files in "pile" directory under "lm-evaluation-harness" (or create a symlink)
46
+ 5. Run generate_13_grams.
47
+
48
+ ```bash
49
+ export PYTHONHASHSEED=0
50
+ python -m scripts/clean_training_data/generate_13_grams \
51
+ -dir path/to/working/directory \
52
+ -n 13 \
53
+ -buckets 500
54
+ ```
55
+
56
+ Took approximately 4 days for us. We had the time to wait, but this could be scaled out by doing partial pile scans on multiple instances of this script and merging the relevant buckets. We fixed PYTHONHASHSEED to ensure reproducibility of bucket hashing in case you need to stop and start.
57
+
58
+ 6. Sort the generated 13-grams.
59
+
60
+ ```bash
61
+ python -m scripts/clean_training_data/sort_13_gram_buckets \
62
+ -dir path/to/working/directory/output
63
+ ```
64
+
65
+ Took approximately 5 days for us. You could speed this up by spreading the files around to different machines and running the sort script before gathering them together.
66
+
67
+ 7. Compress the sorted 13 grams files and place them together with info.json.
68
+
69
+ This step only takes a few hours.
70
+
71
+ ```bash
72
+ python -m scripts/clean_training_data/compress_and_package \
73
+ -dir path/to/working/directory \
74
+ -output path/to/final/directory \
75
+ -procs 8
76
+ ```
lm-evaluation-harness/docs/footguns.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Common Pitfalls and Troubleshooting Guide
2
+
3
+ This document highlights common pitfalls and troubleshooting tips when using this library. We'll continue to add more tips as we discover them.
4
+
5
+ ## YAML Configuration Issues
6
+
7
+ ### Newline Characters in YAML (`\n`)
8
+
9
+ **Problem:** When specifying newline characters in YAML, they may be interpreted incorrectly depending on how you format them.
10
+
11
+ ```yaml
12
+ # ❌ WRONG: Single quotes don't process escape sequences
13
+ generation_kwargs:
14
+ until: ['\n'] # Gets parsed as the literal characters '\' and 'n' i.e "\\n"
15
+
16
+ ```
17
+ ```yaml
18
+ # ✅ RIGHT: Use double quotes for escape sequences
19
+ generation_kwargs:
20
+ until: ["\n"] # Gets parsed as an actual newline character
21
+
22
+ ```
23
+
24
+ **Solutions:**
25
+ - Use double quotes for strings containing escape sequences
26
+ - For multiline content, use YAML's block scalars (`|` or `>`)
27
+ - When generating YAML programmatically, be careful with how template engines handle escape sequences
28
+
29
+ ### Quoting in YAML
30
+
31
+ **When to use different types of quotes:**
32
+
33
+ - **No quotes**: Simple values (numbers, booleans, alphanumeric strings without special characters)
34
+ ```yaml
35
+ simple_value: plain text
36
+ number: 42
37
+
38
+ ```
39
+
40
+ - **Single quotes (')**:
41
+ - Preserves literal values
42
+ - Use when you need special characters to be treated literally
43
+ - Escape single quotes by doubling them: `'It''s working'`
44
+ ```yaml
45
+ literal_string: 'The newline character \n is not processed here'
46
+ path: 'C:\Users\name' # Backslashes preserved
47
+
48
+ ```
49
+
50
+ - **Double quotes (")**:
51
+ - Processes escape sequences like `\n`, `\t`, etc.
52
+ - Use for strings that need special characters interpreted
53
+ - Escape double quotes with backslash: `"He said \"Hello\""`
54
+ ```yaml
55
+ processed_string: "First line\nSecond line" # Creates actual newline
56
+ unicode: "Copyright symbol: \u00A9" # Unicode character
57
+
58
+ ```
lm-evaluation-harness/docs/img/fewshot_example_gpt3.png ADDED

Git LFS Details

  • SHA256: 6af5dc2196248b29260ba443e882725dd6cfc51ef17ad5a4dbab4f8ce6850c75
  • Pointer size: 131 Bytes
  • Size of remote file: 316 kB
lm-evaluation-harness/docs/interface.md ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # User Guide
2
+
3
+ This document details the interface exposed by `lm-eval` and provides details on what flags are available to users.
4
+
5
+ ## Command-line Interface
6
+
7
+ A majority of users run the library by cloning it from Github, installing the package as editable, and running the `python -m lm_eval` script.
8
+
9
+ Equivalently, running the library can be done via the `lm-eval` entrypoint at the command line.
10
+
11
+ This mode supports a number of command-line arguments, the details of which can also be seen via running with `-h` or `--help`:
12
+
13
+ - `--model` : Selects which model type or provider is evaluated. Must be a string corresponding to the name of the model type/provider being used. See [the main README](https://github.com/EleutherAI/lm-evaluation-harness/tree/main#model-apis-and-inference-servers) for a full list of enabled model names and supported libraries or APIs.
14
+
15
+ - `--model_args` : Controls parameters passed to the model constructor. Accepts a string containing comma-separated keyword arguments to the model class of the format `"arg1=val1,arg2=val2,..."`, such as, for example `--model_args pretrained=EleutherAI/pythia-160m,dtype=float32`. For a full list of what keyword arguments, see the initialization of the `lm_eval.api.model.LM` subclass, e.g. [`HFLM`](https://github.com/EleutherAI/lm-evaluation-harness/blob/365fcda9b85bbb6e0572d91976b8daf409164500/lm_eval/models/huggingface.py#L66)
16
+
17
+ - `--tasks` : Determines which tasks or task groups are evaluated. Accepts a comma-separated list of task names or task group names. Must be solely comprised of valid tasks/groups. A list of supported tasks can be viewed with `--tasks list`.
18
+
19
+ - `--num_fewshot` : Sets the number of few-shot examples to place in context. Must be an integer.
20
+
21
+ - `--gen_kwargs` : takes an arg string in same format as `--model_args` and creates a dictionary of keyword arguments. These will be passed to the models for all called `generate_until` (free-form or greedy generation task) tasks, to set options such as the sampling temperature or `top_p` / `top_k`. For a list of what args are supported for each model type, reference the respective library's documentation (for example, the documentation for `transformers.AutoModelForCausalLM.generate()`.) These kwargs will be applied to all `generate_until` tasks called--we do not currently support unique gen_kwargs or batch_size values per task in a single run of the library. To control these on a per-task level, set them in that task's YAML file.
22
+
23
+ - `--batch_size` : Sets the batch size used for evaluation. Can be a positive integer or `"auto"` to automatically select the largest batch size that will fit in memory, speeding up evaluation. One can pass `--batch_size auto:N` to re-select the maximum batch size `N` times during evaluation. This can help accelerate evaluation further, since `lm-eval` sorts documents in descending order of context length.
24
+
25
+ - `--max_batch_size` : Sets the maximum batch size to try to fit in memory, if `--batch_size auto` is passed.
26
+
27
+ - `--device` : Sets which device to place the model onto. Must be a string, for example, `"cuda", "cuda:0", "cpu", "mps"`. Defaults to "cuda", and can be ignored if running multi-GPU or running a non-local model type.
28
+
29
+ - `--output_path` : A string of the form `dir/file.jsonl` or `dir/`. Provides a path where high-level results will be saved, either into the file named or into the directory named. If `--log_samples` is passed as well, then per-document outputs and metrics will be saved into the directory as well.
30
+
31
+ - `--log_samples` : If this flag is passed, then the model's outputs, and the text fed into the model, will be saved at per-document granularity. Must be used with `--output_path`.
32
+
33
+ - `--limit` : Accepts an integer, or a float between 0.0 and 1.0 . If passed, will limit the number of documents to evaluate to the first X documents (if an integer) per task or first X% of documents per task. Useful for debugging, especially on costly API models.
34
+
35
+ - `--use_cache` : Should be a path where a sqlite db file can be written to. Takes a string of format `/path/to/sqlite_cache_` in order to create a cache db at `/path/to/sqlite_cache_rank{i}.db` for each process (0-NUM_GPUS). This allows results of prior runs to be cached, so that there is no need to re-run results in order to re-score or re-run a given (model, task) pair again.
36
+
37
+ - `--cache_requests` : Can be "true", "refresh", or "delete". "true" means that the cache should be used. "refresh" means that you wish to regenerate the cache, which you should run if you change your dataset configuration for a given task. "delete" will delete the cache. Cached files are stored under lm_eval/cache/.cache unless you specify a different path via the environment variable: `LM_HARNESS_CACHE_PATH`. e.g. `LM_HARNESS_CACHE_PATH=~/Documents/cache_for_lm_harness`.
38
+
39
+ - `--check_integrity` : If this flag is used, the library tests for each task selected are run to confirm task integrity.
40
+
41
+ - `--write_out` : Used for diagnostic purposes to observe the format of task documents passed to a model. If this flag is used, then prints the prompt and gold target string for the first document of each task.
42
+
43
+ - `--show_config` : If used, prints the full `lm_eval.api.task.TaskConfig` contents (non-default settings the task YAML file) for each task which was run, at the completion of an evaluation. Useful for when one is modifying a task's configuration YAML locally to transmit the exact configurations used for debugging or for reproducibility purposes.
44
+
45
+ - `--include_path` : Accepts a path to a folder. If passed, then all YAML files containing `lm-eval` compatible task configurations will be added to the task registry as available tasks. Used for when one is writing config files for their own task in a folder other than `lm_eval/tasks/`.
46
+
47
+ - `--system_instruction`: Specifies a system instruction string to prepend to the prompt.
48
+
49
+ - `--apply_chat_template` : This flag specifies whether to apply a chat template to the prompt. It can be used in the following ways:
50
+ - `--apply_chat_template` : When used without an argument, applies the only available chat template to the prompt. For Hugging Face models, if no dedicated chat template exists, the default chat template will be applied.
51
+ - `--apply_chat_template template_name` : If the model has multiple chat templates, apply the specified template to the prompt.
52
+
53
+ For Hugging Face models, the default chat template can be found in the [`default_chat_template`](https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1912) property of the Transformers Tokenizer.
54
+
55
+ - `--fewshot_as_multiturn` : If this flag is on, the Fewshot examples are treated as a multi-turn conversation. Questions are provided as user content and answers are provided as assistant responses. Requires `--num_fewshot` to be set to be greater than 0, and `--apply_chat_template` to be on.
56
+
57
+ - `--predict_only`: Generates the model outputs without computing metrics. Use with `--log_samples` to retrieve decoded results.
58
+
59
+ - `--seed`: Set seed for python's random, numpy and torch. Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, or a single integer to set the same seed for all three. The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility). E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`. E.g, `--seed 42` sets all three seeds to 42.
60
+
61
+ - `--wandb_args`: Tracks logging to Weights and Biases for evaluation runs and includes args passed to `wandb.init`, such as `project` and `job_type`. Full list [here](https://docs.wandb.ai/ref/python/init). e.g., ```--wandb_args project=test-project,name=test-run```. Also allows for the passing of the step to log things at (passed to `wandb.run.log`), e.g., `--wandb_args step=123`.
62
+
63
+ - `--hf_hub_log_args` : Logs evaluation results to Hugging Face Hub. Accepts a string with the arguments separated by commas. Available arguments:
64
+ - `hub_results_org` - organization name on Hugging Face Hub, e.g., `EleutherAI`. If not provided, the results will be pushed to the owner of the Hugging Face token,
65
+ - `hub_repo_name` - repository name on Hugging Face Hub (deprecated, `details_repo_name` and `results_repo_name` should be used instead), e.g., `lm-eval-results`,
66
+ - `details_repo_name` - repository name on Hugging Face Hub to store details, e.g., `lm-eval-results`,
67
+ - `results_repo_name` - repository name on Hugging Face Hub to store results, e.g., `lm-eval-results`,
68
+ - `push_results_to_hub` - whether to push results to Hugging Face Hub, can be `True` or `False`,
69
+ - `push_samples_to_hub` - whether to push samples results to Hugging Face Hub, can be `True` or `False`. Requires `--log_samples` to be set,
70
+ - `public_repo` - whether the repository is public, can be `True` or `False`,
71
+ - `leaderboard_url` - URL to the leaderboard, e.g., `https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard`.
72
+ - `point_of_contact` - Point of contact for the results dataset, e.g., `[email protected]`.
73
+ - `gated` - whether to gate the details dataset, can be `True` or `False`.
74
+
75
+ - `--metadata`: JSON string to pass to TaskConfig. Used for some tasks which require additional metadata to be passed for processing. E.g., `--metadata '{"key": "value"}'`.
76
+
77
+ ## External Library Usage
78
+
79
+ We also support using the library's external API for use within model training loops or other scripts.
80
+
81
+ `lm_eval` supplies two functions for external import and use: `lm_eval.evaluate()` and `lm_eval.simple_evaluate()`.
82
+
83
+ `simple_evaluate()` can be used by simply creating an `lm_eval.api.model.LM` subclass that implements the methods described in the [Model Guide](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/docs/model_guide.md), and wrapping your custom model in that class as follows:
84
+
85
+ ```python
86
+ import lm_eval
87
+ from lm_eval.utils import setup_logging
88
+ ...
89
+ # initialize logging
90
+ setup_logging("DEBUG") # optional, but recommended; or you can set up logging yourself
91
+ my_model = initialize_my_model() # create your model (could be running finetuning with some custom modeling code)
92
+ ...
93
+ # instantiate an LM subclass that takes your initialized model and can run
94
+ # - `Your_LM.loglikelihood()`
95
+ # - `Your_LM.loglikelihood_rolling()`
96
+ # - `Your_LM.generate_until()`
97
+ lm_obj = Your_LM(model=my_model, batch_size=16)
98
+
99
+ # indexes all tasks from the `lm_eval/tasks` subdirectory.
100
+ # Alternatively, you can set `TaskManager(include_path="path/to/my/custom/task/configs")`
101
+ # to include a set of tasks in a separate directory.
102
+ task_manager = lm_eval.tasks.TaskManager()
103
+
104
+ # Setting `task_manager` to the one above is optional and should generally be done
105
+ # if you want to include tasks from paths other than ones in `lm_eval/tasks`.
106
+ # `simple_evaluate` will instantiate its own task_manager if it is set to None here.
107
+ results = lm_eval.simple_evaluate( # call simple_evaluate
108
+ model=lm_obj,
109
+ tasks=["taskname1", "taskname2"],
110
+ num_fewshot=0,
111
+ task_manager=task_manager,
112
+ ...
113
+ )
114
+ ```
115
+
116
+ See the `simple_evaluate()` and `evaluate()` functions in [lm_eval/evaluator.py](../lm_eval/evaluator.py#:~:text=simple_evaluate) for a full description of all arguments available. All keyword arguments to simple_evaluate share the same role as the command-line flags described previously.
117
+
118
+ Additionally, the `evaluate()` function offers the core evaluation functionality provided by the library, but without some of the special handling and simplification + abstraction provided by `simple_evaluate()`.
119
+
120
+ As a brief example usage of `evaluate()`:
121
+
122
+ ```python
123
+ import lm_eval
124
+
125
+ # suppose you've defined a custom lm_eval.api.Task subclass in your own external codebase
126
+ from my_tasks import MyTask1
127
+ ...
128
+
129
+ # create your model (could be running finetuning with some custom modeling code)
130
+ my_model = initialize_my_model()
131
+ ...
132
+
133
+ # instantiate an LM subclass that takes your initialized model and can run
134
+ # - `Your_LM.loglikelihood()`
135
+ # - `Your_LM.loglikelihood_rolling()`
136
+ # - `Your_LM.generate_until()`
137
+ lm_obj = Your_LM(model=my_model, batch_size=16)
138
+
139
+ # optional: the task_manager indexes tasks including ones
140
+ # specified by the user through `include_path`.
141
+ task_manager = lm_eval.tasks.TaskManager(
142
+ include_path="/path/to/custom/yaml"
143
+ )
144
+
145
+ # To get a task dict for `evaluate`
146
+ task_dict = lm_eval.tasks.get_task_dict(
147
+ [
148
+ "mmlu", # A stock task
149
+ "my_custom_task", # A custom task
150
+ {
151
+ "task": ..., # A dict that configures a task
152
+ "doc_to_text": ...,
153
+ },
154
+ MyTask1 # A task object from `lm_eval.task.Task`
155
+ ],
156
+ task_manager # A task manager that allows lm_eval to
157
+ # load the task during evaluation.
158
+ # If none is provided, `get_task_dict`
159
+ # will instantiate one itself, but this
160
+ # only includes the stock tasks so users
161
+ # will need to set this if including
162
+ # custom paths is required.
163
+ )
164
+
165
+ results = evaluate(
166
+ lm=lm_obj,
167
+ task_dict=task_dict,
168
+ ...
169
+ )
170
+ ```
lm-evaluation-harness/docs/model_guide.md ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # New Model Guide
2
+
3
+ This guide may be of special interest to users who are using the library outside of the repository, via installing the library via pypi and calling `lm_eval.evaluator.evaluate()` to evaluate an existing model.
4
+
5
+ In order to properly evaluate a given LM, we require implementation of a wrapper class subclassing the `lm_eval.api.model.LM` class, that defines how the Evaluation Harness should interface with your model. This guide walks through how to write this `LM` subclass via adding it to the library!
6
+
7
+ ## Setup
8
+
9
+ To get started contributing, go ahead and fork the main repo, clone it, create a branch with the name of your model, and install the project requirements in your environment:
10
+
11
+ ```sh
12
+ # After forking...
13
+ git clone https://github.com/<YOUR-USERNAME>/lm-evaluation-harness.git
14
+ cd lm-evaluation-harness
15
+ git checkout -b <model-type>
16
+ pip install -e ".[dev]"
17
+ ```
18
+
19
+ Now, we'll create a new file where we'll be adding our model:
20
+
21
+ ```sh
22
+ touch lm_eval/models/<my_model_filename>.py
23
+ ```
24
+
25
+ **Tip: this filename should not shadow package names! For example, naming your file `anthropic.py` is disallowed since the API's name on pypi is `anthropic`, but naming it `anthropic_llms.py` works with no problems.**
26
+
27
+ ## Interface
28
+
29
+ All models must subclass the `lm_eval.api.model.LM` class.
30
+
31
+ The LM class enforces a common interface via which we can extract responses from a model:
32
+
33
+ ```python
34
+ class MyCustomLM(LM):
35
+ #...
36
+ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
37
+ #...
38
+
39
+
40
+ def loglikelihood_rolling(self, requests: list[Instance]) -> list[tuple[float, bool]]:
41
+ #...
42
+
43
+
44
+ def generate_until(self, requests: list[Instance]) -> list[str]:
45
+ #...
46
+ #...
47
+ ```
48
+
49
+ Where `Instance` is a dataclass defined in [`lm_eval.api.instance`](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/api/instance.py) with property `args` of request-dependent type signature described below.
50
+
51
+ We support three types of requests, consisting of different interactions / measurements with an autoregressive LM.
52
+
53
+ All three request types take as input `requests` of type `list[Instance]` that have a matching `Instance.request_type` to the method name.
54
+
55
+ - `generate_until`
56
+ - Each request contains `Instance.args : Tuple[str, dict]` containing 1. an input string to the LM and 2. a dictionary of keyword arguments used to control generation parameters.
57
+ - Using this input and these generation parameters, text will be sampled from the language model (typically until a maximum output length or specific stopping string sequences--for example, `{"until": ["\n\n", "."], "max_gen_toks": 128}`).
58
+ - The generated output text from the model will then be returned.
59
+
60
+ - `loglikelihood`
61
+ - Each request contains `Instance.args : Tuple[str, str]` containing 1. an input string to the LM and 2. a target string on which the loglikelihood of the LM producing this target, conditioned on the input, will be returned.
62
+ - Each request will have, as result, `(ll, is_greedy): Tuple[float, int]` returned, where `ll` is a floating point number representing the log probability of generating the target string conditioned on the input, and `is_greedy` being either the value `0` or `1`, with it being `1` if and only if the target string *would be generated by greedy sampling from the LM* (that is, if the target string is the *most likely* N-token string to be output by the LM given the input. )
63
+
64
+ - `loglikelihood_rolling`
65
+ - Each request contains `Instance.args : Tuple[str]`, which is an input string to the model whose *entire* loglikelihood, conditioned on purely the EOT token, will be calculated.
66
+ - This is used to evaluate *perplexity* on a data distribution.
67
+ - It should return `(ll,) : Tuple[float]` , a.k.a. solely the *loglikelihood* of producing each piece of text given no starting input.
68
+
69
+ To allow a model to be evaluated on all types of tasks, you will need to implement these three types of measurements (note that `loglikelihood_rolling` is a special case of `loglikelihood`). For a reference implementation, check out `lm_eval/models/huggingface.py` ! Additionally, check out `lm_eval.api.model.TemplateLM` for a class that abstracts away some commonly used functions across LM subclasses, or see if your model would lend itself well to subclassing the `lm_eval.models.huggingface.HFLM` class and overriding just the initialization or a couple methods!
70
+
71
+ **Tip: be careful of indexing in loglikelihood!**
72
+
73
+ LMs take in tokens in position `[0 1 2 ... N]` and output a probability distribution for token position `N+1`. We provide a simplified graphic here, excerpted from `huggingface.py`:
74
+
75
+ ```text
76
+ # how this all works (illustrated on a causal decoder-only setup):
77
+ # CTX CONT
78
+ # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
79
+ # model \ \
80
+ # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
81
+ # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
82
+ ```
83
+
84
+ The final token of the target is not passed into the LM, because we want the LM's predictions *up to but not past* that final target token. For more information, check out https://github.com/EleutherAI/lm-evaluation-harness/issues/942 .
85
+
86
+ ## Registration
87
+
88
+ Congrats on implementing your model! Now it's time to test it out.
89
+
90
+ To make your model usable via the command line interface to `lm-eval` using `python -m lm_eval`, you'll need to tell `lm-eval` what your model's name is.
91
+
92
+ This is done via a *decorator*, `lm_eval.api.registry.register_model`. Using `register_model()`, one can both tell the package what the model's name(s) to be used are when invoking it with `python -m lm_eval --model <name>` and alert `lm-eval` to the model's existence.
93
+
94
+ ```python
95
+ from lm_eval.api.registry import register_model
96
+
97
+ @register_model("<name1>", "<name2>")
98
+ class MyCustomLM(LM):
99
+ ```
100
+
101
+ Using this decorator results in the class being added to an accounting of the usable LM types maintained internally to the library at `lm_eval.api.registry.MODEL_REGISTRY`. See `lm_eval.api.registry` for more detail on what sorts of registries and decorators exist in the library!
102
+
103
+ **Tip: be sure to import your model in `lm_eval/models/__init__.py!`**
104
+
105
+ ## Testing
106
+
107
+ We also recommend that new model contributions be accompanied by short tests of their 3 core functionalities, at minimum. To see an example of such tests, look at https://github.com/EleutherAI/lm-evaluation-harness/blob/35bdecd379c0cefad6897e67db892f4a6026a128/tests/test_ggml.py .
108
+
109
+ ## Chat Templating
110
+
111
+ Many models are fine-tuned with a [Chat Template](https://huggingface.co/docs/transformers/main/en/chat_templating) in order to enable back-and-forth interaction between a "User"'s queries and the model (often called "Assistant")'s responses. It can be desirable to evaluate fine-tuned models on evaluation tasks while wrapped in the conversational format they expect.
112
+
113
+ In order to make your model optionally compatible with a chat format, three additional methods must be implemented:
114
+
115
+ ```python
116
+ class MyCustomLM(LM):
117
+ #...
118
+ @property
119
+ def tokenizer_name(self) -> str:
120
+ """
121
+ Return the name of the model's tokenizer and/or the accompanying chat template.
122
+ The returned string is used to cache requests.
123
+
124
+ Returns:
125
+ str: The name of the model's tokenizer and/or chat template.
126
+ """
127
+
128
+ def chat_template(self, chat_template: Union[bool, str] = False) -> str:
129
+ """
130
+ Get the appropriate chat template for the model based on the `chat_template` argument.
131
+
132
+ This method returns the chat template string to build the prompt from a chat history.
133
+ The chat template is saved in the evaluation results for reproducibility.
134
+ Boolean arguments should be used with models that have only one chat template,
135
+ while string arguments are used with models that have multiple chat templates.
136
+ For the reference implementation, see HFLM class in `lm_eval.models.huggingface`.
137
+
138
+ Args:
139
+ chat_template (Union[bool, str]): Specifies whether to apply a chat template:
140
+ - If False: Do not apply any chat template.
141
+ - If True: Apply the default chat template.
142
+ - If str: Apply the specified chat template by name.
143
+
144
+ Returns:
145
+ str: The selected chat template in Jinja format.
146
+ """
147
+
148
+ def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
149
+ """
150
+ Process a chat history to create a string that can be tokenized and input into the model.
151
+
152
+ Args:
153
+ chat_history (List[Dict[str, str]]): A list of dictionaries representing the chat history,
154
+ where each dictionary has "role" and "content" keys.
155
+
156
+ Returns:
157
+ str: A string representing the chat history that can be tokenized and fed into the model.
158
+ """
159
+ ```
160
+
161
+ - `apply_chat_template`
162
+ - This method performs the bulk of the work required for chat-formatting.
163
+ - As input, a `chat_history: List[Dict[str, str]]` is passed in. This is a transcript of a conversation of a form similar to
164
+
165
+ ```text
166
+ [
167
+ {"system": <user-provided system message such as "You are a helpful math-focused chatbot">},
168
+ {"user": <task example - a few-shot example 'input'>}
169
+ {"assistant": <correct response to the above example>},
170
+ # ... more few-shot examples, potentially
171
+ {"user": <test set query--response on which we will evaluate>},
172
+ ]
173
+ ```
174
+
175
+ which can then be converted into a string input.
176
+ - The output is a string representing this conversation that can be fed into the model.
177
+ - For example, this consists of simply calling `tokenizer.apply_chat_template` for HFLM--see the implementation there for reference.
178
+ - `tokenizer_name`
179
+ - LM Eval Harness supports [caching requests](https://github.com/EleutherAI/lm-evaluation-harness/blob/4902aaaf1f374682f95ac25fe2e13b23faddc91a/lm_eval/__main__.py#L140) that are sent to a model, for faster setup when repeating an already-performed evaluation.
180
+ - However, we don't want to use the cache of chat transcripts rendered using one chat template or system prompt to send to a model with a different template! So, we use this `lm.tokenizer_name` string to distinguish caches for a given model (and chat template) from one another.
181
+ - `chat_template`
182
+ - Chat templates are typically provided as a Jinja template string or a string formatted with str.format to include user and assistant messages in a single prompt. This template string is saved in the evaluation results to ensure reproducibility.
183
+
184
+ If not implemented for a given model type, the flags `--apply_chat_template` , `--fewshot_as_multiturn`, and `--system_instruction` cannot be used.
185
+
186
+ ## Other
187
+
188
+ **Pro tip**: In order to make the Evaluation Harness overestimate total runtimes rather than underestimate it, HuggingFace models come in-built with the ability to provide responses on data points in *descending order by total input length* via `lm_eval.utils.Reorderer`. Take a look at `lm_eval.models.hf_causal.HFLM` to see how this is done, and see if you can implement it in your own model!
189
+
190
+ ## Conclusion
191
+
192
+ After reading this guide, you should be able to add new model APIs or implementations to the Eval Harness library!
lm-evaluation-harness/docs/new_task_guide.md ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # New Task Guide
2
+
3
+ `lm-evaluation-harness` is a framework that strives to support a wide range of zero- and few-shot evaluation tasks on autoregressive language models (LMs).
4
+
5
+ This documentation page provides a walkthrough to get started creating your own task, in `lm-eval` versions v0.4.0 and later.
6
+
7
+ A more interactive tutorial is available as a Jupyter notebook [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/examples/lm-eval-overview.ipynb).
8
+
9
+ ## Setup
10
+
11
+ If you haven't already, go ahead and fork the main repo, clone it, create a branch with the name of your task, and install the project requirements in your environment:
12
+
13
+ ```sh
14
+ # After forking...
15
+ git clone https://github.com/<YOUR-USERNAME>/lm-evaluation-harness.git
16
+ cd lm-evaluation-harness
17
+ git checkout -b <task-name>
18
+ pip install -e ".[dev]"
19
+ ```
20
+
21
+ In this document, we'll walk through the basics of implementing a static benchmark evaluation in two formats: a *generative* task which requires sampling text from a model, such as [`gsm8k`](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/gsm8k/gsm8k.yaml), and a *discriminative*, or *multiple choice*, task where the model picks the most likely of several fixed answer choices, such as [`sciq`](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/sciq/sciq.yaml).
22
+
23
+ ## Creating a YAML file
24
+
25
+ To implement a new standard task, we'll need to write a YAML file which configures our task logic. We start by making a new empty YAML file. This file can have any name, but we recommend placing it in a subfolder of `lm_eval/tasks` titled by the dataset or task's shorthand name: for example,
26
+
27
+ ```sh
28
+ touch lm_eval/tasks/<dataset_name>/<my_new_task_name>.yaml
29
+ ```
30
+
31
+ Or, copy the template subfolder we provide from `templates/new_yaml_task`:
32
+
33
+ ```sh
34
+ cp -r templates/new_yaml_task lm_eval/tasks/
35
+ ```
36
+
37
+ and rename the folders and YAML file(s) as desired.
38
+
39
+ ### Selecting and configuring a dataset
40
+
41
+ All data downloading and management is handled through the HuggingFace (**HF**) [`datasets`](https://github.com/huggingface/datasets) API. So, the first thing you should do is check to see if your task's dataset is already provided in their catalog [here](https://huggingface.co/datasets). If it's not in there, please consider adding it to their Hub to make it accessible to a wider user base by following their [new dataset guide](https://github.com/huggingface/datasets/blob/main/ADD_NEW_DATASET.md)
42
+ .
43
+ > [!TIP]
44
+ > To test your task, we recommend using verbose logging using `export LOGLEVEL = DEBUG` in your shell before running the evaluation script. This will help you debug any issues that may arise.
45
+ Once you have a HuggingFace dataset prepared for your task, we want to assign our new YAML to use this dataset:
46
+
47
+ ```yaml
48
+ dataset_path: ... # the name of the dataset on the HF Hub.
49
+ dataset_name: ... # the dataset configuration to use. Leave `null` if your dataset does not require a config to be passed. See https://huggingface.co/docs/datasets/load_hub#configurations for more info.
50
+ dataset_kwargs: null # any extra keyword arguments that should be passed to the dataset constructor, e.g. `data_dir`.
51
+ ```
52
+
53
+ Next, we'd like to tell our task what the dataset's train, validation, and test splits are named, if they exist:
54
+
55
+ ```yaml
56
+ training_split: <split name of training set, or `null`>
57
+ validation_split: <split name of val. set, or `null`>
58
+ test_split: <split name of test set, or `null`>
59
+ ```
60
+
61
+ Tests will run on the `test_split` if it is available, and otherwise evaluate on the `validation_split`.
62
+
63
+ We can also specify from which split the task should retrieve few-shot examples via:
64
+
65
+ ```yaml
66
+ fewshot_split: <split name to draw fewshot examples from, or `null`>
67
+ ```
68
+
69
+ or by hardcoding them, either using the following in the yaml file:
70
+
71
+ ```yaml
72
+ fewshot_config:
73
+ sampler: first_n
74
+ samples: [
75
+ {<sample 1>},
76
+ {<sample 2>},
77
+ ]
78
+ ```
79
+
80
+ or by adding the function `list_fewshot_samples` in the associated utils.py file:
81
+
82
+ ```python
83
+ def list_fewshot_samples() -> list[dict]:
84
+ return [{<sample 1>}, {<sample 2>}]
85
+ ```
86
+
87
+ See `lm_eval/tasks/minerva_math/minerva_math_algebra.yaml` for an example of the latter, and `lm_eval/tasks/gsm8k/gsm8k-cot.yaml` for an example of the former.
88
+
89
+ In this case, each sample must contain the same fields as the samples in the above sets--for example, if `doc_to_text` expects an `input` field when rendering input prompts, these provided samples must include an `input` key.
90
+
91
+ If neither above options are not set, we will default to train/validation/test sets, in that order.
92
+
93
+ Finally, our dataset may not be already in the exact format we want. Maybe we have to strip whitespace and special characters via a regex from our dataset's "question" field! Or maybe we just want to rename its columns to match a convention we'll be using for our prompts.
94
+
95
+ Let's create a python file in the directory where we're writing our YAML file:
96
+
97
+ ```bash
98
+ touch lm_eval/tasks/<dataset_name>/utils.py
99
+ ```
100
+
101
+ Now, in `utils.py` we'll write a function to process each split of our dataset (the following example is drawn from [the `hellaswag` task](../lm_eval/tasks/hellaswag/utils.py)):
102
+
103
+ ```python
104
+ def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
105
+ def _process_doc(doc):
106
+ ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
107
+ out_doc = {
108
+ "query": preprocess(doc["activity_label"] + ": " + ctx),
109
+ "choices": [preprocess(ending) for ending in doc["endings"]],
110
+ "gold": int(doc["label"]),
111
+ }
112
+ return out_doc
113
+
114
+ return dataset.map(_process_doc)
115
+ ```
116
+
117
+ Now, in our YAML config file we'll use the `!function` constructor, and tell the config where our imported Python function will come from. At runtime, before doing anything else we will preprocess our dataset according to this function!
118
+
119
+ ```yaml
120
+ process_docs: !function utils.process_docs
121
+ ```
122
+
123
+ ### Using Local Datasets
124
+
125
+ To load a local dataset for evaluation, you can specify data files in the `dataset_kwargs` field, such as the following for JSON files:
126
+
127
+ ```yaml
128
+ dataset_path: json
129
+ dataset_name: null
130
+ dataset_kwargs:
131
+ data_files: /path/to/my/json
132
+ ```
133
+
134
+ Or with files already split into separate directories:
135
+
136
+ ```yaml
137
+ dataset_path: arrow
138
+ dataset_kwargs:
139
+ data_files:
140
+ train: /path/to/arrow/train/data-00000-of-00001.arrow
141
+ validation: /path/to/arrow/validation/data-00000-of-00001.arrow
142
+ ```
143
+
144
+ Alternatively, if you have previously downloaded a dataset from huggingface hub (using `save_to_disk()`) and wish to use the local files, you will need to use `data_dir` under `dataset_kwargs` to point to where the directory is.
145
+
146
+ ```yaml
147
+ dataset_path: hellaswag
148
+ dataset_kwargs:
149
+ data_dir: hellaswag_local/
150
+ ```
151
+
152
+ You can also set `dataset_path` as a directory path in your local system. This will assume that there is a loading script with the same name as the directory. [See datasets docs](https://huggingface.co/docs/datasets/loading#local-loading-script).
153
+
154
+ ## Writing a Prompt Template
155
+
156
+ The next thing we need to do is decide what format to use when presenting the data to the LM. This is our **prompt**, where we'll define both an input and output format.
157
+
158
+ To write a prompt, users will use `doc_to_text`, `doc_to_target`, and `doc_to_choice` (Optional when certain conditions are met).
159
+
160
+ `doc_to_text` defines the input string a model will be given while `doc_to_target` and `doc_to_choice` will be used to generate the target text. `doc_to_target` can be either a text string that refers to the target string or an integer that refers to the index of the correct label. When it is set as an index, `doc_to_choice` must also be set with the appropriate list of possible choice strings.
161
+
162
+ ### Basic prompts
163
+
164
+ If a dataset is straightforward enough, users can enter the feature name directly. This assumes that no preprocessing is required. For example in [Swag](https://github.com/EleutherAI/lm-evaluation-harness/blob/1710b42d52d0f327cb0eb3cb1bfbbeca992836ca/lm_eval/tasks/swag/swag.yaml#L10-L11), `doc_to_text` and `doc_to_target` given the name of one of the feature each.
165
+
166
+ ```yaml
167
+ doc_to_text: startphrase
168
+ doc_to_target: label
169
+ ```
170
+
171
+ Hard-coding is also possible as is the case in [SciQ](https://github.com/EleutherAI/lm-evaluation-harness/blob/1710b42d52d0f327cb0eb3cb1bfbbeca992836ca/lm_eval/tasks/sciq/sciq.yaml#L11).
172
+
173
+ ```yaml
174
+ doc_to_target: 3
175
+ ```
176
+
177
+ `doc_to_choice` can be directly given a list of text as option (See [Toxigen](https://github.com/EleutherAI/lm-evaluation-harness/blob/1710b42d52d0f327cb0eb3cb1bfbbeca992836ca/lm_eval/tasks/toxigen/toxigen.yaml#L11))
178
+
179
+ ```yaml
180
+ doc_to_choice: ['No', 'Yes']
181
+ ```
182
+
183
+ if a dataset feature is already a list, you can set the name of the feature as `doc_to_choice` (See [Hellaswag](https://github.com/EleutherAI/lm-evaluation-harness/blob/e0eda4d3ffa10e5f65e0976161cd134bec61983a/lm_eval/tasks/hellaswag/hellaswag.yaml#L13))
184
+
185
+ ```yaml
186
+ doc_to_choice: choices
187
+ ```
188
+
189
+ ### Writing a prompt with Jinja 2
190
+
191
+ We support the [Jinja 2](https://jinja.palletsprojects.com/en/3.1.x/) templating language for writing prompts. In practice, this means you can take your dataset's columns and do many basic string manipulations to place each document into prompted format.
192
+
193
+ Take for example the dataset `super_glue/boolq`. As input, we'd like to use the features `passage` and `question` and string them together so that for a sample line `doc`, the model sees something in the format of:
194
+
195
+ ```text
196
+ doc["passage"]
197
+ Question: doc["question"]?
198
+ Answer:
199
+ ```
200
+
201
+ We do this by [writing](https://github.com/EleutherAI/lm-evaluation-harness/blob/1710b42d52d0f327cb0eb3cb1bfbbeca992836ca/lm_eval/tasks/super_glue/boolq/default.yaml#L9C1-L9C61)
202
+
203
+ ```yaml
204
+ doc_to_text: "{{passage}}\nQuestion: {{question}}?\nAnswer:"
205
+ ```
206
+
207
+ Such that `{{passage}}` will be replaced by `doc["passage"]` and `{{question}}` with `doc["question"]` when rendering the prompt template.
208
+
209
+ Our intended output is for the model to predict a single whitespace, and then the answer to the question. We do this via:
210
+
211
+ ```yaml
212
+ doc_to_target: "{{answer}}"
213
+ ```
214
+
215
+ #### Multiple choice format
216
+
217
+ For tasks which are multiple choice (a fixed, finite set of label words per each document) and evaluated via comparing loglikelihoods of all label words (the `multiple_choice` task output type) we enforce a particular convention on prompt format.
218
+
219
+ > [!WARNING]
220
+ > We add `target_delimiter` between input and target which defaults to " ", such that the full input-output string is `doc_to_text(doc) + target_delimiter + doc_to_target(doc)`. `doc_to_text` and `doc_to_target` should not contain trailing right or left whitespace, respectively. For multiple choice the target will be each choice index concatenated with the delimiter.
221
+
222
+ An annotated example in the case of SciQ is as follows:
223
+
224
+ ```yaml
225
+ doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:" # This is the input portion of the prompt for this doc. It will have " {{choice}}" appended to it as target for each choice in answer_choices.
226
+ doc_to_target: 3 # this contains the index into the answer choice list of the correct answer.
227
+ doc_to_choice: "{{[distractor1, distractor2, distractor3, correct_answer]}}"
228
+ ```
229
+
230
+ Task implementers are thus able to decide what the answer choices should be for a document, and what prompt format to use.
231
+
232
+ The label index can also be sourced from a feature directly. For example in `superglue/boolq`, the label index if defined in the feature `label`. We can set `doc_to_target` as simply `label`. The options or verbalizers can be written in the form of a list `["no", "yes"]` that will correspond to the label index.
233
+
234
+ ```yaml
235
+ doc_to_text: "{{passage}}\nQuestion: {{question}}?\nAnswer:"
236
+ doc_to_target: label
237
+ doc_to_choice: ["no", "yes"]
238
+ ```
239
+
240
+ ### Using Python Functions for Prompts
241
+
242
+ There may be cases where the prompt we want to implement is easier expressed in Python instead of Jinja 2. For this, we can use Python helper functions that are defined in the YAML config. It should be noted that the function script must be in the same directory as the yaml.
243
+
244
+ A good example is WikiText that requires a lot of regex rules to clean the samples.
245
+
246
+ ```python
247
+ def wikitext_detokenizer(doc):
248
+ string = doc["page"]
249
+ # contractions
250
+ string = string.replace("s '", "s'")
251
+ string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
252
+ ...
253
+ string = string.replace(" 's", "'s")
254
+
255
+ return string
256
+ ```
257
+
258
+ We can load this function in `doc_to_target` by using a `!function` operator after `doc_to_target` and followed by `<file name>.<function name>`. In the file [wikitext.yaml](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/wikitext/wikitext.yaml) we write:
259
+
260
+ ```yaml
261
+ doc_to_target: !function preprocess_wikitext.wikitext_detokenizer
262
+ ```
263
+
264
+ ### Importing a Prompt from Promptsource
265
+
266
+ [Promptsource](https://github.com/bigscience-workshop/promptsource/tree/main/promptsource) is a great repository for crowdsourced prompts for many datasets. We can load these prompts easily by using the `use_prompt` argument and filling it with the format `"promptsource:<name of prompt template>"`. To use this, `doc_to_text` and `doc_to_target` should be left undefined. This will fetch the template of the dataset defined in the YAML file.
267
+
268
+ For example, For Super Glue BoolQ, if we want to use the prompt template `GPT-3 Style` we can add this to the YAML file.
269
+
270
+ ```yaml
271
+ use_prompt: "promptsource:GPT-3 Style"
272
+ ```
273
+
274
+ If you would like to run evaluation on all prompt templates, you can simply call it this way.
275
+
276
+ ```yaml
277
+ use_prompt: "promptsource:*"
278
+ ```
279
+
280
+ ### Setting metrics
281
+
282
+ You're almost done! Now we need to choose how to score our task.
283
+
284
+ - *If this is a multiple choice task:* do you just want to check your model's accuracy in choosing the correct answer choice?
285
+ - *If this is a generation task:* do you just want to check how often your model outputs *exactly the ground-truth output string provided*?
286
+
287
+ If the answer to the above is no: you'll need to record what scoring metrics to use! Metrics can be listed in the following format:
288
+
289
+ ```yaml
290
+ metric_list:
291
+ - metric: <name of the metric here>
292
+ aggregation: <name of the aggregation fn here>
293
+ higher_is_better: <true or false>
294
+ - metric: !function script.function
295
+ aggregation: ...
296
+ higher_is_better: ...
297
+ ```
298
+
299
+ `aggregation` and `higher_is_better` can optionally be left out to default to the manually-set defaults if using a natively supported metric, otherwise it must be defined explicitly (for example, when using a custom metric implemented as a function).
300
+
301
+ For a full list of natively supported metrics and aggregation functions see [`docs/task_guide.md`](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/task_guide.md). All metrics supported in [HuggingFace Evaluate](https://github.com/huggingface/evaluate/tree/main/metrics) can also be used, and will be loaded if a given metric name is not one natively supported in `lm-eval` or `hf_evaluate` is set to `true`.
302
+
303
+ ### Optional, More Advanced Setup
304
+
305
+ Some tasks may require more advanced processing logic than is described in this guide.
306
+
307
+ As a heuristic check:
308
+
309
+ - Does your task require generating multiple free-form outputs per input document?
310
+ - Does your task require complex, multi-step post-processing of generated model outputs?
311
+ - Does your task require subsetting documents on the fly based on their content?
312
+ - Do you expect to compute metrics after applying multiple such processing steps on your model outputs?
313
+ - Does your task rely on metrics that need a custom implementation?
314
+
315
+ For more detail on the task system and advanced features, see [`docs/task_guide.md`](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/task_guide.md). If none of the above sounds like they apply to your task, it's time to continue onto checking your task performance!
316
+
317
+ ### Task name + tags (registering a task)
318
+
319
+ To test a task conveniently, it helps to *register* the task--that is, to give it a name and make the `lm-eval` library aware it exists!
320
+
321
+ If you're writing your YAML file inside the `lm_eval/tasks` folder, you just need to give your task a name! You can do this inside your YAML file:
322
+
323
+ ```yaml
324
+ task: <name of the task>
325
+ ```
326
+
327
+ Including a task name is mandatory.
328
+
329
+ It is often also convenient to label your task with several `tag` values, though this field is optional:
330
+
331
+ ```yaml
332
+ tag:
333
+ - tag1
334
+ - tag2
335
+ ```
336
+
337
+ This will add your task to the `tag1` and `tag2` tags, enabling people to know how to categorize your task, and if desired run all tasks in one of these groups at once, your task along with them.
338
+
339
+ If your task is not in the `lm_eval/tasks` folder, you'll need to tell the Eval Harness where to look for YAML files.
340
+
341
+ You can do this via the `--include_path` argument in `__main__.py`. This command will be used to initialize the `TaskManager` object which you can also use for your custom scripts.
342
+
343
+ ```python
344
+ task_manager = TaskManager(args.verbosity, include_path=args.include_path)
345
+ ```
346
+
347
+ Passing `--tasks /path/to/yaml/file` is also accepted.
348
+
349
+ ### Advanced Group Configs
350
+
351
+ While `tag` values are helpful when you want to be able to quickly and conveniently run a set of related tasks via `--tasks my_tag_name`, often, we wish to implement more complex logic. For example, the MMLU benchmark contains 57 *subtasks* that must all be *averaged* together in order to report a final 'MMLU score'.
352
+
353
+ Groupings of tasks might also use particular variants of a task--for example, we might want to default to evaluating a task as 5-shot when called as part of a given grouping, but not have a preference for number of shots when evaluating it as a standalone.
354
+
355
+ We implement this via **groups**, which are distinct from tags. Groups can be implemented via *group config* YAML files, which are laid out similarly but slightly differently to tasks' YAML configs.
356
+
357
+ The most basic form of group can be defined via a YAML config similar to the following:
358
+
359
+ ```yaml
360
+ group: nli_tasks
361
+ task:
362
+ - cb
363
+ - anli_r1
364
+ - rte
365
+ metadata:
366
+ version: 1.0
367
+ ```
368
+
369
+ This will behave almost identically to a `tag` that includes these 3 tasks, but with one key distinction: we'll print the `nli_tasks` group as a row (with no associated metrics) in our table of outputs, and visually show that these 3 tasks appear under its subheader.
370
+
371
+ Now, let's assume we actually want to report an aggregate score for `nli_tasks`. We would instead use a YAML config like the following:
372
+
373
+ ```yaml
374
+ group: nli_tasks
375
+ task:
376
+ - cb
377
+ - anli_r1
378
+ - rte
379
+ aggregate_metric_list:
380
+ - metric: acc
381
+ aggregation: mean
382
+ weight_by_size: true # defaults to `true`. Set this to `false` to do a "macro" average (taking each subtask's average accuracy, and summing those accuracies and dividing by 3)--by default we do a "micro" average (retain all subtasks' per-document accuracies, and take the mean over all documents' accuracies to get our aggregate mean).
383
+ metadata:
384
+ version: 1.0
385
+ ```
386
+
387
+ Similar to our `metric_list` for listing out the metrics we want to calculate for a given task, we use an `aggregate_metric_list` field to specify which metric name to aggregate across subtasks, what aggregation function to use, and whether we should micro- or macro- average these metrics. See [./task_guide.md](./task_guide.md) for a full list of related sub-keys.
388
+
389
+ **[!Tip]: currently, we predominantly only support the aggregation of group metrics that use `mean` (either micro- or macro- averaged) over their subtasks. If you require even more complex aggregation rules, you may want to perform aggregation offline.**
390
+
391
+ Group configs can be fairly complex! We can do various operations, such as defining new subtask(s) inline in our group YAML, overriding an existing task's specific config value, or nesting existing groups within our
392
+
393
+ For example, let's build a config for evaluating MMLU and a few natural language inference tasks. For MMLU, we can write the name for the benchmark as a subtask written under `task`. You can configure the parameters such as `num_fewshot`. If the task being configured is a group such as `mmlu` or `super_glue`, the parameter set will be applied to all of the subtasks.
394
+
395
+ ```yaml
396
+ group: nli_and_mmlu
397
+ task:
398
+ - group: nli_tasks
399
+ task:
400
+ - cb
401
+ - anli_r1
402
+ - rte
403
+ aggregate_metric_list:
404
+ - metric: acc
405
+ aggregation: mean
406
+ higher_is_better: true
407
+ - task: mmlu
408
+ num_fewshot: 2
409
+ ```
410
+
411
+ ### Configuring python classes
412
+
413
+ There can be occasions when yaml-based tasks cannot accommodate how a task is handled. LM-Eval supports the manually implementing tasks as was previously done before `0.4.x`. To register the task, you can simply make a yaml with the name of the task in `task` and the class object in `class` using the `!function` prefix.
414
+
415
+ ```yaml
416
+ task: squadv2
417
+ class: !function task.SQuAD2
418
+ ```
419
+
420
+ This also applies to building group configurations with subtasks that are python classes.
421
+
422
+ ```yaml
423
+ group: scrolls
424
+ task:
425
+ - task: scrolls_qasper
426
+ class: !function task.Qasper
427
+ - task: scrolls_quality
428
+ class: !function task.QuALITY
429
+ - task: scrolls_narrativeqa
430
+ class: !function task.NarrativeQA
431
+ ...
432
+ ```
433
+
434
+ You can also pass a custom argument to your class by accepting `config` in the custom class constructor.
435
+ Here's how to do it:
436
+
437
+ ```yaml
438
+ task: 20_newsgroups
439
+ class: !function task.Unitxt
440
+ recipe: card=cards.20_newsgroups,template=templates.classification.multi_class.title
441
+ ```
442
+
443
+ In this example, `recipe` is the custom argument for the `Unitxt` class.
444
+
445
+ ## Beautifying Table Display
446
+
447
+ To avoid conflict, each task needs to be registered with a unique name. Because of this, slight variations of task are still counted as unique tasks and need to be named uniquely. This could be done by appending an additional naming that may refer to the variation such as in MMLU where the template used to evaluated for flan are differentiated from the default by the prefix `mmlu_flan_*`. Printing the full task names can easily clutter the results table at the end of the evaluation especially when you have a long list of tasks or are using a benchmark that comprises of many tasks. To make it more legible, you can use `task_alias` and `group_alias` to provide an alternative task name and group name that will be printed. For example in `mmlu_abstract_algebra.yaml` we set `task_alias` to `abstract_algebra`. In group configs, a `group_alias` for a group can also be set.
448
+
449
+ ```yaml
450
+ "dataset_name": "abstract_algebra"
451
+ "description": "The following are multiple choice questions (with answers) about abstract\
452
+ \ algebra.\n\n"
453
+ "include": "_default_template_yaml"
454
+ "task": "mmlu_abstract_algebra"
455
+ "task_alias": "abstract_algebra"
456
+ ```
457
+
458
+ ## Checking validity
459
+
460
+ After registering your task, you can now check on your data downloading and verify that the few-shot samples look as intended. Run the following command with your desired args:
461
+
462
+ ```bash
463
+ python -m scripts.write_out \
464
+ --output_base_path <path> \
465
+ --tasks <your-task-name> \
466
+ --sets <train | val | test> \
467
+ --num_fewshot K \
468
+ --num_examples N \
469
+ ```
470
+
471
+ Open the file specified at the `--output_base_path <path>` and ensure it passes
472
+ a simple eye test.
473
+
474
+ ## Versioning
475
+
476
+ One key feature in LM Evaluation Harness is the ability to version tasks and groups--that is, mark them with a specific version number that can be bumped whenever a breaking change is made.
477
+
478
+ This version info can be provided by adding the following to your new task or group config file:
479
+
480
+ ```yaml
481
+ metadata:
482
+ version: 0
483
+ ```
484
+
485
+ Now, whenever a change needs to be made to your task in the future, please increase the version number by 1 so that users can differentiate the different task iterations and versions.
486
+
487
+ If you are incrementing a task's version, please also consider adding a changelog to the task's README.md noting the date, PR number, what version you have updated to, and a one-liner describing the change.
488
+
489
+ for example,
490
+
491
+ - \[Dec 25, 2023\] (PR #999) Version 0.0 -> 1.0: Fixed a bug with answer extraction that led to underestimated performance.
492
+
493
+ ## Checking performance + equivalence
494
+
495
+ It's now time to check models' performance on your task! In the evaluation harness, we intend to support a wide range of evaluation tasks and setups, but prioritize the inclusion of already-proven benchmarks following the precise evaluation setups in the literature where possible.
496
+
497
+ To enable this, we provide a checklist that should be completed when contributing a new task, to enable accurate book-keeping and to ensure that tasks added to the library are well-tested and, where applicable, precedented.
498
+
499
+ ### Task Validity Checklist
500
+
501
+ The checklist is the following:
502
+
503
+ For adding novel benchmarks/datasets to the library:
504
+
505
+ - [ ] Is the task an existing benchmark in the literature?
506
+ - [ ] Have you referenced the original paper that introduced the task?
507
+ - [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
508
+
509
+ If other tasks on this dataset are already supported:
510
+
511
+ - [ ] Is the "Main" variant of this task clearly denoted?
512
+ - [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
513
+ - [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
514
+
515
+ It is recommended to include a filled-out copy of this checklist in the README.md for the subfolder you are creating, if you have created a new subfolder in `lm_eval/tasks`.
516
+
517
+ **Finally, please add a short description of your task(s), along with a link to its subfolder in lm_eval/tasks, to [`lm_eval/tasks/README.md`](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/README.md) so that users can discover your task in the library, and follow the link to your README for more information about the variants supported, their task names, and the original source of the dataset and/or evaluation setup.**
518
+
519
+ ## Submitting your task
520
+
521
+ You're all set! Now push your work and make a pull request to the `main` branch! Thanks for the contribution :). If there are any questions, please leave a message in the `#lm-thunderdome` channel on the EAI discord!
lm-evaluation-harness/docs/task_guide.md ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Task Configuration
2
+
3
+ The `lm-evaluation-harness` is meant to be an extensible and flexible framework within which many different evaluation tasks can be defined. All tasks in the new version of the harness are built around a YAML configuration file format.
4
+
5
+ These YAML configuration files, along with the current codebase commit hash, are intended to be shareable such that providing the YAML config enables another researcher to precisely replicate the evaluation setup used by another, in the case that the prompt or setup differs from standard `lm-eval` task implementations.
6
+
7
+ While adding a standard evaluation task on a new dataset can be occasionally as simple as swapping out a Hugging Face dataset path in an existing file, more specialized evaluation setups also exist. Here we'll provide a crash course on the more advanced logic implementable in YAML form available to users.
8
+
9
+ If your intended task relies on features beyond what is described in this guide, we'd love to hear about it! Feel free to open an issue describing the scenario on Github, create a PR to the project with a proposed implementation, or ask in the `#lm-thunderdome` channel on the EleutherAI discord.
10
+
11
+ ## Configurations
12
+
13
+ Tasks are configured via the `TaskConfig` object. Below, we describe all fields usable within the object, and their role in defining a task.
14
+
15
+ ### Parameters
16
+
17
+ Task naming + registration:
18
+
19
+ - **task** (`str`, defaults to None) — name of the task.
20
+ - **task_alias** (`str`, defaults to None) - Alias of the task name that will be printed in the final table results.
21
+ - **tag** (`str`, *optional*) — name of the task tags(s) a task belongs to. Enables one to run all tasks with a specified tag name at once.
22
+
23
+ Dataset configuration options:
24
+
25
+ - **dataset_path** (`str`) — The name of the dataset as listed by HF in the datasets Hub.
26
+ - **dataset_name** (`str`, *optional*, defaults to None) — The name of what HF calls a “data instance” or sub-task of the benchmark. If your task does not contain any data instances, just leave this to default to None. (If you're familiar with the HF `datasets.load_dataset` function, these are just the first 2 arguments to it.)
27
+ - **dataset_kwargs** (`dict`, *optional*) — Auxiliary arguments that `datasets.load_dataset` accepts. This can be used to specify arguments such as `data_files` or `data_dir` if you want to use local datafiles such as json or csv.
28
+ - **custom_dataset** (`Callable`, *optional) - A function that returns a `dict[str, datasets.Dataset]` (<split_name>, dataset) object. This can be used to load a dataset from a custom source or to preprocess the dataset in a way that is not supported by the `datasets` library. Will have access to `metadata` field if defined (from config and passed to TaskManager), and `model_args` from runtime (if using `evaluate`).
29
+ - **training_split** (`str`, *optional*) — Split in the dataset to use as the training split.
30
+ - **validation_split** (`str`, *optional*) — Split in the dataset to use as the validation split.
31
+ - **test_split** (`str`, *optional*) — Split in the dataset to use as the test split.
32
+ - **fewshot_split** (`str`, *optional*) — Split in the dataset to draw few-shot exemplars from. assert that this not None if num_fewshot > 0.
33
+ - **process_docs** (`Callable`, *optional*) — Optionally define a function to apply to each HF dataset split, to preprocess all documents before being fed into prompt template rendering or other evaluation steps. Can be used to rename dataset columns, or to process documents into a format closer to the expected format expected by a prompt template.
34
+
35
+ Prompting / in-context formatting options:
36
+
37
+ - **use_prompt** (`str`, *optional*) — Name of prompt in promptsource to use. if defined, will overwrite doc_to_text, doc_to_target, and doc_to_choice.
38
+ - **description** (`str`, *optional*) — An optional prepended Jinja2 template or string which will be prepended to the few-shot examples passed into the model, often describing the task or providing instructions to a model, such as `"The following are questions (with answers) about {{subject}}.\n\n"`. No delimiters or spacing are inserted between the description and the first few-shot example.
39
+ - **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into the appropriate input for the model.
40
+ - **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into the answer choice list of the correct answer.
41
+ - **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks.
42
+ - **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples.
43
+ - **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested.
44
+ - **gen_prefix** (`str`, *optional*) — String to append after the <|assistant|> token. For example, if the task is to generate a question, the gen_prefix could be "The answer is: " to prompt the model to generate an answer to the question. If not using a chat template then this string will be appended to the end of the prompt.
45
+
46
+ Runtime configuration options:
47
+
48
+ - **num_fewshot** (`int`, *optional*, defaults to 0) — Number of few-shot examples before the input.
49
+ - **batch_size** (`int`, *optional*, defaults to 1) — Batch size.
50
+
51
+ Scoring details:
52
+
53
+ - **metric_list** (`str`, *optional*, defaults to None) — A list of metrics to use for evaluation. See docs for expected format.
54
+ - **output_type** (`str`, *optional*, defaults to "generate_until") — Selects the type of model output for the given task. Options are `generate_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`.
55
+ - **generation_kwargs** (`dict`, *optional*) — Auxiliary arguments for the `generate` function from HF transformers library. Advanced keyword arguments may not be supported for non-HF LM classes.
56
+ - **repeats** (`int`, *optional*, defaults to 1) — Number of repeated runs through model for each sample. Can be used for cases such as self-consistency.
57
+ - **filter_list** (`Union[str, list]`, *optional*) — List of filters to postprocess model outputs. See below for further detail on the filter API.
58
+ - **should_decontaminate** (`bool`, *optional*, defaults to False) - Whether to decontaminate or not.
59
+ - **doc_to_decontamination_query** (`str`, *optional*) — Query for decontamination if `should_decontaminate` is True. If `should_decontaminate` is True but `doc_to_decontamination_query` is `None`, `doc_to_decontamination_query` will follow `doc_to_text`.
60
+
61
+ Other:
62
+
63
+ - **metadata** (`dict`, *optional*) — An optional field where arbitrary metadata can be passed. Most tasks should include a `version` key in this field that is used to denote the version of the yaml config. Other special metadata keys are: `num_fewshot`, to override the printed `n-shot` table column for a task. Will also be passed to the `custom_dataset` function if defined.
64
+
65
+ ## Filters
66
+
67
+ A key component of the `lm-evaluation-harness` library is the `Filter` object. In a typical evaluation run of the harness, we take the formatted inputs and run them through our LM, with the appropriate output type (greedy or free-form generation, or loglikelihood-based comparative scoring).
68
+
69
+ After getting scores or output text from our LM on each `Instance` or document in the dataset, we then need to feed these responses into a metric or scoring function to return scores to a user.
70
+
71
+ However, certain tasks may require more complex behavior than directly turning over model outputs to a metric function. For example, we may want to post-process our output text by truncating it or extracting a model's answer, we may want to ensemble over multiple "takes" on a different document, et cetera.
72
+
73
+ **Detailed Aside**:
74
+ We do such post-processing by operating on *responses*, which are stored after running an LM on an `Instance` from the task in `Instance.resps`.
75
+
76
+ `resps` is a `List[str]` for each instance, and we pass a `List[List[<expected return type from model>]]` to our filters that is a list of `[instance.resps for instance in instances]`.
77
+
78
+ Our filters, after completing a pipeline, must return a `List[<expected return type from model>]` which we then unpack and store each element of in `Instance.filtered_resps` for the corresponding instance. Thus, we take as input a list of returns from our model for each doc, and must return a return from our model *without it being wrapped in a list* for each doc.
79
+ **End Aside**
80
+
81
+ A full list of supported filter operations can be found in `lm_eval/filters/__init__.py`. Contributions of new filter types are welcome!
82
+
83
+ ### Multiple Filter Pipelines
84
+
85
+ Tasks need not be limited to a single filter pipeline. We enable users to run multiple, distinct, filter pipelines on *the same model outputs* generated in one run on a task.
86
+
87
+ As a case study, let's look at an implementation of solving the Gsm8k math word problem benchmark in `lm_eval/tasks/gsm8k/gsm8k-cot-self-consistency.yaml`. Here, we are emulating the setup used by [Self-Consistency Improves Chain of Thought Prompting](https://arxiv.org/abs/2203.11171), in which evaluation is performed by generating N chain-of-thought outputs from a model via temperature-based sampling, then selecting the answers output by the model at the end of the chains of thought, then majority voting across all those numeric answers.
88
+
89
+ Within our YAML file:
90
+
91
+ ```yaml
92
+ ...
93
+ repeats: 64
94
+ filter_list:
95
+ - name: "score-first"
96
+ filter:
97
+ - function: "regex"
98
+ regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
99
+ - function: "take_first"
100
+ - name: "maj@64"
101
+ filter:
102
+ - function: "regex"
103
+ regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
104
+ - function: "majority_vote"
105
+ - function: "take_first"
106
+ - name: "maj@8"
107
+ filter:
108
+ - function: "take_first_k"
109
+ k: 8
110
+ - function: "regex"
111
+ regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
112
+ - function: "majority_vote"
113
+ - function: "take_first"
114
+ ```
115
+
116
+ We are able to provide multiple different filter pipelines, each with their own name and list of filters to apply in sequence.
117
+
118
+ Our first filter pipeline implements
119
+
120
+ - applying a regex to the model generations (extracting the number within the phrase "The answer is (number)")
121
+ - selecting only the first out of the 64 model answers
122
+
123
+ Then scoring this single answer.
124
+
125
+ ```yaml
126
+ - name: "score-first"
127
+ filter:
128
+ - function: "regex"
129
+ regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
130
+ - function: "take_first"
131
+ ```
132
+
133
+ Our second filter pipeline, "maj@64", does majority voting across all 64 answers via:
134
+
135
+ - applying the same regex to all responses, to get the numerical answer from the model for each of the 64 responses per problem
136
+ - applying majority voting to all responses, which then returns a length-1 `[<majority answer>]` list for each
137
+ - taking the first element of this length-1 list, to then score the sole response `<majority answer>` for each document.
138
+
139
+ ```yaml
140
+ - name: "maj@64"
141
+ filter:
142
+ - function: "regex"
143
+ regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
144
+ - function: "majority_vote"
145
+ - function: "take_first"
146
+ ```
147
+
148
+ Our final filter pipeline, "maj@8", does majority voting across the first 8 of the model's responses per document via:
149
+
150
+ - subsetting the len-64 list of responses `[answer1, answer2, ..., answer64]` to `[answer1, answer2, ..., answer8]` for each document
151
+ - performing the same sequence of filters on these new sets of 8 responses, for each document.
152
+
153
+ ```yaml
154
+ - name: "maj@8"
155
+ filter:
156
+ - function: "take_first_k"
157
+ k: 8
158
+ - function: "regex"
159
+ regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
160
+ - function: "majority_vote"
161
+ - function: "take_first"
162
+ ```
163
+
164
+ Thus, given the 64 responses from our LM on each document, we can report metrics on these responses in these 3 different ways, as defined by our filter pipelines.
165
+
166
+ ### Adding a custom filter
167
+
168
+ Just like adding a custom model with `register_model` decorator one is able to do the same with filters, for example
169
+
170
+ ```python
171
+ from lm_eval.api.filter import Filter
172
+ from lm_eval.api.registry import register_filter
173
+
174
+ @register_filter("new_filter")
175
+ class NewFilter(Filter)
176
+ ...
177
+ ```
178
+
179
+ ## Embedded Python Code
180
+
181
+ Use can use python functions for certain arguments by using the `!function` operator after the argument name followed by `<filename>.<pythonfunctionname>`. This feature can be used for the following arguments:
182
+
183
+ 1. `doc_to_text`
184
+ 2. `doc_to_target`
185
+ 3. `doc_to_choice`
186
+ 4. `aggregation` for a `metric` in `metric_list`
187
+
188
+ ## (No Longer Recommended) Direct `Task` Subclassing
189
+
190
+ The prior implementation method of new tasks was to subclass `Task`. While we intend to migrate all tasks to the new YAML implementation option going forward, it remains possible to subclass the Task class and implement custom logic. For more information, see `docs/task_guide.md` in v0.3.0 of the `lm-evaluation-harness`.
191
+
192
+ ## Including a Base YAML
193
+
194
+ You can base a YAML on another YAML file as a template. This can be handy when you need to just change the prompt for `doc_to_text` but keep the rest the same or change `filters` to compare which is better. Simply use `include` in the YAML file and write the name of the template you want to base from. This assumes that the base template is in the same directory. Otherwise, You will need to define the full path.
195
+
196
+ ```yaml
197
+ include: <YAML filename or with full path>
198
+ ...
199
+ ```
200
+
201
+ You can find an example of how to use this feature at [gsm8k-cot-self-consistency.yaml](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/gsm8k/gsm8k-cot-self-consistency.yaml) where it is based off [gsm8k-cot.yaml](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/gsm8k/gsm8k-cot.yaml)
202
+
203
+ ## Passing Arguments to Metrics
204
+
205
+ Metrics can be defined in the `metric_list` argument when building the YAML config. Multiple metrics can be listed along with any auxiliary arguments. For example, setting the [`exact_match` metric](https://github.com/huggingface/evaluate/tree/main/metrics/exact_match), auxiliary arguments such as `ignore_case`, `ignore_punctuation`, `regexes_to_ignore` can be listed as well. They will be added to the metric function as `kwargs`. Some metrics have predefined values for `aggregation` and `higher_is_better` so listing the metric name only can be sufficient.
206
+
207
+ ```yaml
208
+ metric_list:
209
+ - metric: acc
210
+ - metric: exact_match
211
+ aggregation: mean
212
+ higher_is_better: true
213
+ ignore_case: true
214
+ ignore_punctuation: false
215
+ regexes_to_ignore:
216
+ - ","
217
+ - "\\$"
218
+ ```
219
+
220
+ ### Natively Supported Metrics
221
+
222
+ Here we list all metrics currently supported natively in `lm-eval`:
223
+
224
+ Metrics:
225
+
226
+ - `acc` (accuracy)
227
+ - `acc_norm` (length-normalized accuracy)
228
+ - `acc_mutual_info` (baseline loglikelihood - normalized accuracy)
229
+ - `perplexity`
230
+ - `word_perplexity` (perplexity per word)
231
+ - `byte_perplexity` (perplexity per byte)
232
+ - `bits_per_byte`
233
+ - `matthews_corrcoef` (Matthews correlation coefficient)
234
+ - `f1` (F1 score)
235
+ - `bleu`
236
+ - `chrf`
237
+ - `ter`
238
+
239
+ Aggregation functions:
240
+
241
+ - `mean`
242
+ - `median`
243
+ - `perplexity`
244
+ - `weighted_perplexity`
245
+ - `bits_per_byte`
246
+
247
+ ### Adding a Multiple Choice Metric
248
+
249
+ Adding a multiple choice metric has a few steps. To get it working you need to:
250
+
251
+ 1. register a metric function
252
+ 2. register an aggregation function
253
+ 3. update the `Task` definition to make sure the correct arguments are passed
254
+
255
+ The default metric and aggregation functions are in `lm_eval/api/metrics.py`, and you can add a function there if it's for general use. The metrics are towards the bottom of the file and look like this:
256
+
257
+ ```python
258
+ @register_metric(
259
+ metric="mcc",
260
+ higher_is_better=True,
261
+ output_type="multiple_choice",
262
+ aggregation="matthews_corrcoef",
263
+ )
264
+ def mcc_fn(items): # This is a passthrough function
265
+ return items
266
+ ```
267
+
268
+ Note that many of these are passthrough functions, and for multiple choice (at least) this function is never actually called.
269
+
270
+ Aggregation functions are defined towards the top of the file, here's an example:
271
+
272
+ ```python
273
+ @register_aggregation("matthews_corrcoef")
274
+ def matthews_corrcoef(items):
275
+ unzipped_list = list(zip(*items))
276
+ golds = unzipped_list[0]
277
+ preds = unzipped_list[1]
278
+ return sklearn.metrics.matthews_corrcoef(golds, preds)
279
+ ```
280
+
281
+ This function returns a single numeric value. The input is defined in `Task.process_results` in `lm_eval/api/task.py`. There's a section that looks like this:
282
+
283
+ ```python
284
+ result_dict = {
285
+ **({"acc": acc} if "acc" in use_metric else {}),
286
+ **({"f1": (gold, pred)} if "f1" in use_metric else {}),
287
+ **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
288
+ **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
289
+ **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
290
+ }
291
+ ```
292
+
293
+ The value here determines the input to the aggregation function, though the name used matches the metric function. These metrics all have simple needs and just need the accuracy or gold and predicted values, but immediately below this there are examples of metrics with more complicated needs you can use as reference.
294
+
295
+ ## Good Reference Tasks
296
+
297
+ Contributing a new task can be daunting! Luckily, much of the work has often been done for you in a different, similarly evaluated task. Good examples of task implementations to study include:
298
+
299
+ Multiple choice tasks:
300
+
301
+ - SciQ (`lm_eval/tasks/sciq/sciq.yaml`)
302
+
303
+ Corpus perplexity evaluations:
304
+
305
+ - Wikitext (`lm_eval/tasks/wikitext/wikitext.yaml`)
306
+
307
+ Generative tasks:
308
+
309
+ - GSM8k (`lm_eval/tasks/gsm8k/gsm8k.yaml`)
310
+
311
+ Tasks using complex filtering:
312
+
313
+ - GSM8k with CoT (+ with Self-Consistency): (`lm_eval/tasks/gsm8k/gsm8k-cot.yaml` ; `lm_eval/tasks/gsm8k/gsm8k-cot-self-consistency.yaml`)
314
+
315
+ # Group Configuration
316
+
317
+ When evaluating a language model, it is not unusual to test across a number of tasks that may not be related to one another in order to assess a variety of capabilities. To this end, it may be cumbersome to have to list the set of tasks or add a new group name to each yaml of each individual task.
318
+
319
+ To solve this, we can create a **group** yaml config. This is a config that contains the names of the tasks that should be included in a particular group. The config consists of two main keys: a `group` key which denotes the name of the group (as it would be called from the command line, e.g. `mmlu`) and a `task` key which is where we can list the tasks. The tasks listed in `task` are the task names that have been registered. A good example of a group yaml config can be found at [../lm_eval/tasks/mmlu/default/_mmlu.yaml]. See also the [New Task Guide](./new_task_guide.md) for a more in-depth and tutorial-esque explanation of how to write complex GroupConfigs.
320
+
321
+ ## Configurations
322
+
323
+ Groups are configured via the `GroupConfig` object. Below, we describe all fields usable within the object, and their role in defining a task.
324
+
325
+ ### Parameters
326
+
327
+ - **group** (`str`, defaults to `None`) — name of the group. Used to invoke it from the command line.
328
+ - **group_alias** (`str`, defaults to `None`) - Alternative name for the group that will be printed in the table output.
329
+ - **task** (`Union[str, list]`, defaults to `None`) - List of tasks that constitute the group.
330
+ - **aggregate_metric_list** (`list`, defaults to `None`) - similar to `metric_list` in TaskConfigs, provide a list of configurations for metrics that should be aggregated across subtasks. Leaving empty will result in no aggregation being performed for this group. Keys for each list entry are:
331
+ - `metric: str` - the name of the metric to aggregate over (all subtasks must report a metric holding this name.)
332
+ - `aggregation: str` - what aggregation function to apply to aggregate these per-subtask metrics. **currently, only `mean` is supported.**
333
+ - `weight_by_size: bool = True` whether to perform micro- averaging (`True`) or macro- (`False`) averaging of subtasks' accuracy scores when reporting the group's metric. MMLU, for example, averages over per-document accuracies (the *micro average*), resulting in the same accuracy as if one simply concatenated all 57 subjects into a single dataset and evaluated accuracy on that dataset.
334
+ - `filter_list: Union[str, List[str]] = "none"` - what filter keys one should match on to aggregate results. For example, if trying to aggregate over the `exact_match` metric using `strict-match` filter for `bbh_cot_zeroshot`, then set this to be `filter_list: "strict-match"`.
335
+ - **metadata** (`dict`, *optional*) - As with TaskConfigs, a field where extra config metadata can be passed. set the `num_fewshot` key within this to override the printed n_shot value in a results table for your group, for example.
lm-evaluation-harness/examples/lm-eval-overview.ipynb ADDED
@@ -0,0 +1,1240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "Qw83KAePAhaS"
7
+ },
8
+ "source": [
9
+ "# Releasing LM-Evaluation-Harness v0.4.0"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "markdown",
14
+ "metadata": {
15
+ "id": "Z7k2vq1iAdqr"
16
+ },
17
+ "source": [
18
+ "With the vast amount of work done in the field today, it helps to have a tool that people can use easily to share their results and use to check others to ensure reported numbers are valid. The LM Evaluation Harness is one such tool the community has used extensively. We want to continue to support the community and with that in mind, we’re excited to announce a major update on the LM Evaluation Harness to further our goal for open and accessible AI research."
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {
24
+ "id": "0gDoM0AJAvEc"
25
+ },
26
+ "source": [
27
+ "Our refactor stems from our desires to make the following believed best practices easier to carry out. \n",
28
+ "\n",
29
+ "1. Never copy results from other papers\n",
30
+ "2. Always share your exact prompts\n",
31
+ "3. Always provide model outputs\n",
32
+ "4. Qualitatively review a small batch of outputs before running evaluation jobs at scale\n",
33
+ "\n",
34
+ "We also wanted to make the library a better experience to use and to contribute or design evaluations within. New features in the new release that serve this purpose include:\n",
35
+ "\n",
36
+ "1. Faster Evaluation Runtimes (accelerated data-parallel inference with HF Transformers + Accelerate, and commonly used or faster inference libraries such as vLLM and Llama-CPP)\n",
37
+ "2. Easier addition and sharing of new tasks (YAML-based task config formats, allowing single-file sharing of custom tasks)\n",
38
+ "3. More configurability, for more advanced workflows and easier operation with modifying prompts\n",
39
+ "4. Better logging of data at runtime and post-hoc"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "markdown",
44
+ "metadata": {
45
+ "id": "nnwsOpjda_YW"
46
+ },
47
+ "source": [
48
+ "In this notebook we will be going through a short tutorial on how things work."
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "markdown",
53
+ "metadata": {
54
+ "id": "zAov81vTbL2K"
55
+ },
56
+ "source": [
57
+ "## Install LM-Eval"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 1,
63
+ "metadata": {
64
+ "colab": {
65
+ "base_uri": "https://localhost:8080/"
66
+ },
67
+ "id": "8hiosGzq_qZg",
68
+ "outputId": "6ab73e5e-1f54-417e-a388-07e0d870b132"
69
+ },
70
+ "outputs": [
71
+ {
72
+ "name": "stdout",
73
+ "output_type": "stream",
74
+ "text": [
75
+ "Collecting git+https://github.com/EleutherAI/lm-evaluation-harness.git@big-refactor\n",
76
+ " Cloning https://github.com/EleutherAI/lm-evaluation-harness.git (to revision big-refactor) to /tmp/pip-req-build-tnssql5s\n",
77
+ " Running command git clone --filter=blob:none --quiet https://github.com/EleutherAI/lm-evaluation-harness.git /tmp/pip-req-build-tnssql5s\n",
78
+ " Running command git checkout -b big-refactor --track origin/big-refactor\n",
79
+ " Switched to a new branch 'big-refactor'\n",
80
+ " Branch 'big-refactor' set up to track remote branch 'big-refactor' from 'origin'.\n",
81
+ " Resolved https://github.com/EleutherAI/lm-evaluation-harness.git to commit 42f486ee49b65926a444cb0620870a39a5b4b0a8\n",
82
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
83
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
84
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
85
+ "Collecting accelerate>=0.21.0 (from lm-eval==1.0.0)\n",
86
+ " Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)\n",
87
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m261.4/261.4 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
88
+ "\u001b[?25hCollecting evaluate (from lm-eval==1.0.0)\n",
89
+ " Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)\n",
90
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.1/84.1 kB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
91
+ "\u001b[?25hCollecting datasets>=2.0.0 (from lm-eval==1.0.0)\n",
92
+ " Downloading datasets-2.15.0-py3-none-any.whl (521 kB)\n",
93
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m521.2/521.2 kB\u001b[0m \u001b[31m9.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
94
+ "\u001b[?25hCollecting jsonlines (from lm-eval==1.0.0)\n",
95
+ " Downloading jsonlines-4.0.0-py3-none-any.whl (8.7 kB)\n",
96
+ "Requirement already satisfied: numexpr in /usr/local/lib/python3.10/dist-packages (from lm-eval==1.0.0) (2.8.7)\n",
97
+ "Collecting peft>=0.2.0 (from lm-eval==1.0.0)\n",
98
+ " Downloading peft-0.6.2-py3-none-any.whl (174 kB)\n",
99
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m174.7/174.7 kB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
100
+ "\u001b[?25hCollecting pybind11>=2.6.2 (from lm-eval==1.0.0)\n",
101
+ " Downloading pybind11-2.11.1-py3-none-any.whl (227 kB)\n",
102
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m227.7/227.7 kB\u001b[0m \u001b[31m12.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
103
+ "\u001b[?25hCollecting pytablewriter (from lm-eval==1.0.0)\n",
104
+ " Downloading pytablewriter-1.2.0-py3-none-any.whl (111 kB)\n",
105
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m111.1/111.1 kB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
106
+ "\u001b[?25hCollecting rouge-score>=0.0.4 (from lm-eval==1.0.0)\n",
107
+ " Downloading rouge_score-0.1.2.tar.gz (17 kB)\n",
108
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
109
+ "Collecting sacrebleu>=1.5.0 (from lm-eval==1.0.0)\n",
110
+ " Downloading sacrebleu-2.3.2-py3-none-any.whl (119 kB)\n",
111
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m119.7/119.7 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
112
+ "\u001b[?25hRequirement already satisfied: scikit-learn>=0.24.1 in /usr/local/lib/python3.10/dist-packages (from lm-eval==1.0.0) (1.2.2)\n",
113
+ "Collecting sqlitedict (from lm-eval==1.0.0)\n",
114
+ " Downloading sqlitedict-2.1.0.tar.gz (21 kB)\n",
115
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
116
+ "Requirement already satisfied: torch>=1.8 in /usr/local/lib/python3.10/dist-packages (from lm-eval==1.0.0) (2.1.0+cu118)\n",
117
+ "Collecting tqdm-multiprocess (from lm-eval==1.0.0)\n",
118
+ " Downloading tqdm_multiprocess-0.0.11-py3-none-any.whl (9.8 kB)\n",
119
+ "Requirement already satisfied: transformers>=4.1 in /usr/local/lib/python3.10/dist-packages (from lm-eval==1.0.0) (4.35.2)\n",
120
+ "Collecting zstandard (from lm-eval==1.0.0)\n",
121
+ " Downloading zstandard-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)\n",
122
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.4/5.4 MB\u001b[0m \u001b[31m29.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
123
+ "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.21.0->lm-eval==1.0.0) (1.23.5)\n",
124
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.21.0->lm-eval==1.0.0) (23.2)\n",
125
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.21.0->lm-eval==1.0.0) (5.9.5)\n",
126
+ "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.21.0->lm-eval==1.0.0) (6.0.1)\n",
127
+ "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.21.0->lm-eval==1.0.0) (0.19.4)\n",
128
+ "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->lm-eval==1.0.0) (9.0.0)\n",
129
+ "Collecting pyarrow-hotfix (from datasets>=2.0.0->lm-eval==1.0.0)\n",
130
+ " Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)\n",
131
+ "Collecting dill<0.3.8,>=0.3.0 (from datasets>=2.0.0->lm-eval==1.0.0)\n",
132
+ " Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n",
133
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
134
+ "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->lm-eval==1.0.0) (1.5.3)\n",
135
+ "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->lm-eval==1.0.0) (2.31.0)\n",
136
+ "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->lm-eval==1.0.0) (4.66.1)\n",
137
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->lm-eval==1.0.0) (3.4.1)\n",
138
+ "Collecting multiprocess (from datasets>=2.0.0->lm-eval==1.0.0)\n",
139
+ " Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)\n",
140
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m19.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
141
+ "\u001b[?25hRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->lm-eval==1.0.0) (2023.6.0)\n",
142
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets>=2.0.0->lm-eval==1.0.0) (3.8.6)\n",
143
+ "Collecting responses<0.19 (from evaluate->lm-eval==1.0.0)\n",
144
+ " Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n",
145
+ "Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft>=0.2.0->lm-eval==1.0.0) (0.4.0)\n",
146
+ "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from rouge-score>=0.0.4->lm-eval==1.0.0) (1.4.0)\n",
147
+ "Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (from rouge-score>=0.0.4->lm-eval==1.0.0) (3.8.1)\n",
148
+ "Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from rouge-score>=0.0.4->lm-eval==1.0.0) (1.16.0)\n",
149
+ "Collecting portalocker (from sacrebleu>=1.5.0->lm-eval==1.0.0)\n",
150
+ " Downloading portalocker-2.8.2-py3-none-any.whl (17 kB)\n",
151
+ "Requirement already satisfied: regex in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.5.0->lm-eval==1.0.0) (2023.6.3)\n",
152
+ "Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.5.0->lm-eval==1.0.0) (0.9.0)\n",
153
+ "Collecting colorama (from sacrebleu>=1.5.0->lm-eval==1.0.0)\n",
154
+ " Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)\n",
155
+ "Requirement already satisfied: lxml in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.5.0->lm-eval==1.0.0) (4.9.3)\n",
156
+ "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.24.1->lm-eval==1.0.0) (1.11.3)\n",
157
+ "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.24.1->lm-eval==1.0.0) (1.3.2)\n",
158
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.24.1->lm-eval==1.0.0) (3.2.0)\n",
159
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm-eval==1.0.0) (3.13.1)\n",
160
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm-eval==1.0.0) (4.5.0)\n",
161
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm-eval==1.0.0) (1.12)\n",
162
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm-eval==1.0.0) (3.2.1)\n",
163
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm-eval==1.0.0) (3.1.2)\n",
164
+ "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8->lm-eval==1.0.0) (2.1.0)\n",
165
+ "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.1->lm-eval==1.0.0) (0.15.0)\n",
166
+ "Requirement already satisfied: attrs>=19.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonlines->lm-eval==1.0.0) (23.1.0)\n",
167
+ "Requirement already satisfied: setuptools>=38.3.0 in /usr/local/lib/python3.10/dist-packages (from pytablewriter->lm-eval==1.0.0) (67.7.2)\n",
168
+ "Collecting DataProperty<2,>=1.0.1 (from pytablewriter->lm-eval==1.0.0)\n",
169
+ " Downloading DataProperty-1.0.1-py3-none-any.whl (27 kB)\n",
170
+ "Collecting mbstrdecoder<2,>=1.0.0 (from pytablewriter->lm-eval==1.0.0)\n",
171
+ " Downloading mbstrdecoder-1.1.3-py3-none-any.whl (7.8 kB)\n",
172
+ "Collecting pathvalidate<4,>=2.3.0 (from pytablewriter->lm-eval==1.0.0)\n",
173
+ " Downloading pathvalidate-3.2.0-py3-none-any.whl (23 kB)\n",
174
+ "Collecting tabledata<2,>=1.3.1 (from pytablewriter->lm-eval==1.0.0)\n",
175
+ " Downloading tabledata-1.3.3-py3-none-any.whl (11 kB)\n",
176
+ "Collecting tcolorpy<1,>=0.0.5 (from pytablewriter->lm-eval==1.0.0)\n",
177
+ " Downloading tcolorpy-0.1.4-py3-none-any.whl (7.9 kB)\n",
178
+ "Collecting typepy[datetime]<2,>=1.3.2 (from pytablewriter->lm-eval==1.0.0)\n",
179
+ " Downloading typepy-1.3.2-py3-none-any.whl (31 kB)\n",
180
+ "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->lm-eval==1.0.0) (3.3.2)\n",
181
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->lm-eval==1.0.0) (6.0.4)\n",
182
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->lm-eval==1.0.0) (4.0.3)\n",
183
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->lm-eval==1.0.0) (1.9.2)\n",
184
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->lm-eval==1.0.0) (1.4.0)\n",
185
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.0.0->lm-eval==1.0.0) (1.3.1)\n",
186
+ "Requirement already satisfied: chardet<6,>=3.0.4 in /usr/local/lib/python3.10/dist-packages (from mbstrdecoder<2,>=1.0.0->pytablewriter->lm-eval==1.0.0) (5.2.0)\n",
187
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.0.0->lm-eval==1.0.0) (3.4)\n",
188
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.0.0->lm-eval==1.0.0) (2.0.7)\n",
189
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.0.0->lm-eval==1.0.0) (2023.7.22)\n",
190
+ "Requirement already satisfied: python-dateutil<3.0.0,>=2.8.0 in /usr/local/lib/python3.10/dist-packages (from typepy[datetime]<2,>=1.3.2->pytablewriter->lm-eval==1.0.0) (2.8.2)\n",
191
+ "Requirement already satisfied: pytz>=2018.9 in /usr/local/lib/python3.10/dist-packages (from typepy[datetime]<2,>=1.3.2->pytablewriter->lm-eval==1.0.0) (2023.3.post1)\n",
192
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.8->lm-eval==1.0.0) (2.1.3)\n",
193
+ "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk->rouge-score>=0.0.4->lm-eval==1.0.0) (8.1.7)\n",
194
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.8->lm-eval==1.0.0) (1.3.0)\n",
195
+ "Building wheels for collected packages: lm-eval, rouge-score, sqlitedict\n",
196
+ " Building wheel for lm-eval (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
197
+ " Created wheel for lm-eval: filename=lm_eval-1.0.0-py3-none-any.whl size=994254 sha256=88356155b19f2891981ecef948326ad6ce8ca40a6009378410ec20d0e225995a\n",
198
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-9v6ye7h3/wheels/17/01/26/599c0779e9858a70a73fa8a306699b5b9a868f820c225457b0\n",
199
+ " Building wheel for rouge-score (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
200
+ " Created wheel for rouge-score: filename=rouge_score-0.1.2-py3-none-any.whl size=24933 sha256=6bb0d44e4881972c43ce194e7cb65233d309758cb15f0dec54590d3d2efcfc36\n",
201
+ " Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4\n",
202
+ " Building wheel for sqlitedict (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
203
+ " Created wheel for sqlitedict: filename=sqlitedict-2.1.0-py3-none-any.whl size=16863 sha256=5747f7dd73ddf3d8fbcebf51b5e4f718fabe1e94bccdf16d2f22a2e65ee7fdf4\n",
204
+ " Stored in directory: /root/.cache/pip/wheels/79/d6/e7/304e0e6cb2221022c26d8161f7c23cd4f259a9e41e8bbcfabd\n",
205
+ "Successfully built lm-eval rouge-score sqlitedict\n",
206
+ "Installing collected packages: sqlitedict, zstandard, tcolorpy, pybind11, pyarrow-hotfix, portalocker, pathvalidate, mbstrdecoder, jsonlines, dill, colorama, typepy, tqdm-multiprocess, sacrebleu, rouge-score, responses, multiprocess, accelerate, datasets, DataProperty, tabledata, peft, evaluate, pytablewriter, lm-eval\n",
207
+ "Successfully installed DataProperty-1.0.1 accelerate-0.24.1 colorama-0.4.6 datasets-2.15.0 dill-0.3.7 evaluate-0.4.1 jsonlines-4.0.0 lm-eval-1.0.0 mbstrdecoder-1.1.3 multiprocess-0.70.15 pathvalidate-3.2.0 peft-0.6.2 portalocker-2.8.2 pyarrow-hotfix-0.6 pybind11-2.11.1 pytablewriter-1.2.0 responses-0.18.0 rouge-score-0.1.2 sacrebleu-2.3.2 sqlitedict-2.1.0 tabledata-1.3.3 tcolorpy-0.1.4 tqdm-multiprocess-0.0.11 typepy-1.3.2 zstandard-0.22.0\n"
208
+ ]
209
+ }
210
+ ],
211
+ "source": [
212
+ "# Install LM-Eval\n",
213
+ "!pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": 2,
219
+ "metadata": {
220
+ "colab": {
221
+ "base_uri": "https://localhost:8080/",
222
+ "height": 0,
223
+ "referenced_widgets": [
224
+ "a1d3a8aa016544a78e8821c8f6199e06",
225
+ "f61ed33fad754146bdd2ac9db1ba1c48",
226
+ "bfa0af6aeff344c6845e1080a878e92e",
227
+ "fd1ad9e0367d4004aae853b91c3a7617",
228
+ "6b2d90209ec14230b3d58a74ac9b83bf",
229
+ "a73f357065d34d7baf0453ae4a8d75e2",
230
+ "46f521b73fd943c081c648fd873ebc0a",
231
+ "7c5689bc13684db8a22681f41863dddd",
232
+ "48763b6233374554ae76035c0483066f",
233
+ "4986a21eb560448fa79f4b25cde48951",
234
+ "aed3acd2f2d74003b44079c333a0698e"
235
+ ]
236
+ },
237
+ "id": "uyO5MaKkZyah",
238
+ "outputId": "d46e8096-5086-4e49-967e-ea33d4a2a335"
239
+ },
240
+ "outputs": [
241
+ {
242
+ "data": {
243
+ "application/vnd.jupyter.widget-view+json": {
244
+ "model_id": "a1d3a8aa016544a78e8821c8f6199e06",
245
+ "version_major": 2,
246
+ "version_minor": 0
247
+ },
248
+ "text/plain": [
249
+ "Downloading builder script: 0%| | 0.00/5.67k [00:00<?, ?B/s]"
250
+ ]
251
+ },
252
+ "metadata": {},
253
+ "output_type": "display_data"
254
+ }
255
+ ],
256
+ "source": []
257
+ },
258
+ {
259
+ "cell_type": "markdown",
260
+ "metadata": {
261
+ "id": "8rfUeX6n_wkK"
262
+ },
263
+ "source": [
264
+ "## Create new evaluation tasks with config-based tasks\n",
265
+ "\n",
266
+ "Even within the same task, many works have reported numbers based on different choices of evaluation. Some report on the test sets, validation sets, or even subset of the training sets. Others have specialized prompts and verbalizers. We introduce YAMLs to allow users to easily make different variations. By leveraging the YAML configs to configure evaluations, the refactored LM-Eval takes the methods of the `Task` object and makes them configurable by setting the appropriate attributes in the config file. There, users can set the tasks they want by setting the name of the HF dataset (local tasks are also possible), the dataset splits used, and much more. Key configurations relating to prompting, such as `doc_to_text`, previously implemented as a method of the same name, are now configurable with jinja2 to allow high-level scripting to transform a HF dataset to text string as input to the model.\n",
267
+ "\n"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "markdown",
272
+ "metadata": {
273
+ "id": "HYFUhhfOSJKe"
274
+ },
275
+ "source": [
276
+ "A core-feature to LM-Eval is to configure tasks with YAML configs. With configs, you can fill preset fields to easily set up a task.\n",
277
+ "\n",
278
+ "Here, we write a demo YAML config for a multiple-choice evaluation of BoolQ:"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 3,
284
+ "metadata": {
285
+ "id": "bg3dGROW-V39"
286
+ },
287
+ "outputs": [],
288
+ "source": [
289
+ "YAML_boolq_string = \"\"\"\n",
290
+ "task: demo_boolq\n",
291
+ "dataset_path: super_glue\n",
292
+ "dataset_name: boolq\n",
293
+ "output_type: multiple_choice\n",
294
+ "training_split: train\n",
295
+ "validation_split: validation\n",
296
+ "doc_to_text: \"{{passage}}\\nQuestion: {{question}}?\\nAnswer:\"\n",
297
+ "doc_to_target: label\n",
298
+ "doc_to_choice: [\"no\", \"yes\"]\n",
299
+ "should_decontaminate: true\n",
300
+ "doc_to_decontamination_query: passage\n",
301
+ "metric_list:\n",
302
+ " - metric: acc\n",
303
+ "\"\"\"\n",
304
+ "with open(\"boolq.yaml\", \"w\") as f:\n",
305
+ " f.write(YAML_boolq_string)"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "markdown",
310
+ "metadata": {},
311
+ "source": [
312
+ "And we can now run evaluation on this task, by pointing to the config file we've just created:"
313
+ ]
314
+ },
315
+ {
316
+ "cell_type": "code",
317
+ "execution_count": 4,
318
+ "metadata": {
319
+ "id": "LOUHK7PtQfq4"
320
+ },
321
+ "outputs": [
322
+ {
323
+ "name": "stdout",
324
+ "output_type": "stream",
325
+ "text": [
326
+ "2023-11-29:11:54:55,156 INFO [utils.py:160] NumExpr defaulting to 2 threads.\n",
327
+ "2023-11-29 11:54:55.942051: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
328
+ "2023-11-29 11:54:55.942108: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
329
+ "2023-11-29 11:54:55.942142: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
330
+ "2023-11-29 11:54:57.066802: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
331
+ "2023-11-29:11:55:00,954 INFO [__main__.py:132] Verbosity set to INFO\n",
332
+ "2023-11-29:11:55:11,038 WARNING [__main__.py:138] --limit SHOULD ONLY BE USED FOR TESTING.REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.\n",
333
+ "2023-11-29:11:55:11,038 INFO [__main__.py:143] Including path: ./\n",
334
+ "2023-11-29:11:55:11,046 INFO [__main__.py:205] Selected Tasks: ['demo_boolq']\n",
335
+ "2023-11-29:11:55:11,047 WARNING [evaluator.py:93] generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks.\n",
336
+ "2023-11-29:11:55:11,110 INFO [huggingface.py:120] Using device 'cuda'\n",
337
+ "config.json: 100% 571/571 [00:00<00:00, 2.87MB/s]\n",
338
+ "model.safetensors: 100% 5.68G/5.68G [00:32<00:00, 173MB/s]\n",
339
+ "tokenizer_config.json: 100% 396/396 [00:00<00:00, 2.06MB/s]\n",
340
+ "tokenizer.json: 100% 2.11M/2.11M [00:00<00:00, 11.6MB/s]\n",
341
+ "special_tokens_map.json: 100% 99.0/99.0 [00:00<00:00, 555kB/s]\n",
342
+ "2023-11-29:11:56:18,658 WARNING [task.py:614] [Task: demo_boolq] metric acc is defined, but aggregation is not. using default aggregation=mean\n",
343
+ "2023-11-29:11:56:18,658 WARNING [task.py:626] [Task: demo_boolq] metric acc is defined, but higher_is_better is not. using default higher_is_better=True\n",
344
+ "Downloading builder script: 100% 30.7k/30.7k [00:00<00:00, 59.0MB/s]\n",
345
+ "Downloading metadata: 100% 38.7k/38.7k [00:00<00:00, 651kB/s]\n",
346
+ "Downloading readme: 100% 14.8k/14.8k [00:00<00:00, 37.3MB/s]\n",
347
+ "Downloading data: 100% 4.12M/4.12M [00:00<00:00, 55.1MB/s]\n",
348
+ "Generating train split: 100% 9427/9427 [00:00<00:00, 15630.89 examples/s]\n",
349
+ "Generating validation split: 100% 3270/3270 [00:00<00:00, 20002.56 examples/s]\n",
350
+ "Generating test split: 100% 3245/3245 [00:00<00:00, 20866.19 examples/s]\n",
351
+ "2023-11-29:11:56:22,315 INFO [task.py:355] Building contexts for task on rank 0...\n",
352
+ "2023-11-29:11:56:22,322 INFO [evaluator.py:319] Running loglikelihood requests\n",
353
+ "100% 20/20 [00:04<00:00, 4.37it/s]\n",
354
+ "fatal: not a git repository (or any of the parent directories): .git\n",
355
+ "hf (pretrained=EleutherAI/pythia-2.8b), gen_kwargs: (), limit: 10.0, num_fewshot: None, batch_size: 1\n",
356
+ "| Tasks |Version|Filter|n-shot|Metric|Value| |Stderr|\n",
357
+ "|----------|-------|------|-----:|------|----:|---|-----:|\n",
358
+ "|demo_boolq|Yaml |none | 0|acc | 1|± | 0|\n",
359
+ "\n"
360
+ ]
361
+ }
362
+ ],
363
+ "source": [
364
+ "%env LOGLEVEL=DEBUG\n",
365
+ "!lm_eval \\\n",
366
+ " --model hf \\\n",
367
+ " --model_args pretrained=EleutherAI/pythia-2.8b \\\n",
368
+ " --include_path ./ \\\n",
369
+ " --tasks demo_boolq \\\n",
370
+ " --limit 10"
371
+ ]
372
+ },
373
+ {
374
+ "cell_type": "markdown",
375
+ "metadata": {
376
+ "id": "LOUHK7PtQfq4"
377
+ },
378
+ "source": [
379
+ "Often, tasks are part of a larger group used to measure different capabilities. The dynamism of the field today means new dimensions of evaluation can come about which would mix and match new and older tasks alike. In LM-Eval, We can also group tasks and call that the group name to evaluate on a set of tasks easily. In this instance, let's evaluate the tag `yes_or_no_tasks` which comprise of the tasks `demo_boolq` and `demo_cola`; tasks which are multiple choice tasks with options `yes` and `no` as the name suggests.\n",
380
+ "\n",
381
+ "<!-- making new groups is easier than ever, allowing user to work bottom-up by makiing individual tasks and linking them to a group or Top-Down, making a new group by listing existing tasks.\n",
382
+ "\n",
383
+ "We also show the aggregate across samples besides only showing the aggregation between subtasks. This may come in handy when certain groups want to be aggregated as a single task. -->\n",
384
+ "\n",
385
+ "\n"
386
+ ]
387
+ },
388
+ {
389
+ "cell_type": "code",
390
+ "execution_count": 5,
391
+ "metadata": {
392
+ "id": "fthNg3ywO-kA"
393
+ },
394
+ "outputs": [],
395
+ "source": [
396
+ "YAML_cola_string = \"\"\"\n",
397
+ "tag: yes_or_no_tasks\n",
398
+ "task: demo_cola\n",
399
+ "dataset_path: glue\n",
400
+ "dataset_name: cola\n",
401
+ "output_type: multiple_choice\n",
402
+ "training_split: train\n",
403
+ "validation_split: validation\n",
404
+ "doc_to_text: \"{{sentence}}\\nQuestion: Does this sentence make sense?\\nAnswer:\"\n",
405
+ "doc_to_target: label\n",
406
+ "doc_to_choice: [\"no\", \"yes\"]\n",
407
+ "should_decontaminate: true\n",
408
+ "doc_to_decontamination_query: sentence\n",
409
+ "metric_list:\n",
410
+ " - metric: acc\n",
411
+ "\"\"\"\n",
412
+ "with open(\"cola.yaml\", \"w\") as f:\n",
413
+ " f.write(YAML_cola_string)"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": 6,
419
+ "metadata": {
420
+ "id": "XceRKCuuDtbn"
421
+ },
422
+ "outputs": [
423
+ {
424
+ "name": "stdout",
425
+ "output_type": "stream",
426
+ "text": [
427
+ "2023-11-29:11:56:33,016 INFO [utils.py:160] NumExpr defaulting to 2 threads.\n",
428
+ "2023-11-29 11:56:33.852995: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
429
+ "2023-11-29 11:56:33.853050: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
430
+ "2023-11-29 11:56:33.853087: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
431
+ "2023-11-29 11:56:35.129047: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
432
+ "2023-11-29:11:56:38,546 INFO [__main__.py:132] Verbosity set to INFO\n",
433
+ "2023-11-29:11:56:47,509 WARNING [__main__.py:138] --limit SHOULD ONLY BE USED FOR TESTING.REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.\n",
434
+ "2023-11-29:11:56:47,509 INFO [__main__.py:143] Including path: ./\n",
435
+ "2023-11-29:11:56:47,517 INFO [__main__.py:205] Selected Tasks: ['yes_or_no_tasks']\n",
436
+ "2023-11-29:11:56:47,520 WARNING [evaluator.py:93] generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks.\n",
437
+ "2023-11-29:11:56:47,550 INFO [huggingface.py:120] Using device 'cuda'\n",
438
+ "2023-11-29:11:57:08,743 WARNING [task.py:614] [Task: demo_cola] metric acc is defined, but aggregation is not. using default aggregation=mean\n",
439
+ "2023-11-29:11:57:08,743 WARNING [task.py:626] [Task: demo_cola] metric acc is defined, but higher_is_better is not. using default higher_is_better=True\n",
440
+ "Downloading builder script: 100% 28.8k/28.8k [00:00<00:00, 52.7MB/s]\n",
441
+ "Downloading metadata: 100% 28.7k/28.7k [00:00<00:00, 51.9MB/s]\n",
442
+ "Downloading readme: 100% 27.9k/27.9k [00:00<00:00, 48.0MB/s]\n",
443
+ "Downloading data: 100% 377k/377k [00:00<00:00, 12.0MB/s]\n",
444
+ "Generating train split: 100% 8551/8551 [00:00<00:00, 19744.58 examples/s]\n",
445
+ "Generating validation split: 100% 1043/1043 [00:00<00:00, 27057.01 examples/s]\n",
446
+ "Generating test split: 100% 1063/1063 [00:00<00:00, 22705.17 examples/s]\n",
447
+ "2023-11-29:11:57:11,698 INFO [task.py:355] Building contexts for task on rank 0...\n",
448
+ "2023-11-29:11:57:11,704 INFO [evaluator.py:319] Running loglikelihood requests\n",
449
+ "100% 20/20 [00:03<00:00, 5.15it/s]\n",
450
+ "fatal: not a git repository (or any of the parent directories): .git\n",
451
+ "hf (pretrained=EleutherAI/pythia-2.8b), gen_kwargs: (), limit: 10.0, num_fewshot: None, batch_size: 1\n",
452
+ "| Tasks |Version|Filter|n-shot|Metric|Value| |Stderr|\n",
453
+ "|---------------|-------|------|-----:|------|----:|---|-----:|\n",
454
+ "|yes_or_no_tasks|N/A |none | 0|acc | 0.7|± |0.1528|\n",
455
+ "| - demo_cola |Yaml |none | 0|acc | 0.7|± |0.1528|\n",
456
+ "\n",
457
+ "| Groups |Version|Filter|n-shot|Metric|Value| |Stderr|\n",
458
+ "|---------------|-------|------|-----:|------|----:|---|-----:|\n",
459
+ "|yes_or_no_tasks|N/A |none | 0|acc | 0.7|± |0.1528|\n",
460
+ "\n"
461
+ ]
462
+ }
463
+ ],
464
+ "source": [
465
+ "# !accelerate launch --no_python\n",
466
+ "%env LOGLEVEL=DEBUG\n",
467
+ "!lm_eval \\\n",
468
+ " --model hf \\\n",
469
+ " --model_args pretrained=EleutherAI/pythia-2.8b \\\n",
470
+ " --include_path ./ \\\n",
471
+ " --tasks yes_or_no_tasks \\\n",
472
+ " --limit 10 \\\n",
473
+ " --output output/yes_or_no_tasks/ \\\n",
474
+ " --log_samples"
475
+ ]
476
+ },
477
+ {
478
+ "cell_type": "markdown",
479
+ "metadata": {
480
+ "id": "XceRKCuuDtbn"
481
+ },
482
+ "source": [
483
+ "## Edit Prompt Templates Quickly\n",
484
+ "\n",
485
+ "The following is a yaml made to evaluate the specific subtask of `high_school_geography` from MMLU. It uses the standard prompt where the we choose the letters from the options with most likelihood as the model's prediction."
486
+ ]
487
+ },
488
+ {
489
+ "cell_type": "code",
490
+ "execution_count": 7,
491
+ "metadata": {
492
+ "id": "GTFvdt9kSlBG"
493
+ },
494
+ "outputs": [],
495
+ "source": [
496
+ "YAML_mmlu_geo_string = \"\"\"\n",
497
+ "task: demo_mmlu_high_school_geography\n",
498
+ "dataset_path: cais/mmlu\n",
499
+ "dataset_name: high_school_geography\n",
500
+ "description: \"The following are multiple choice questions (with answers) about high school geography.\\n\\n\"\n",
501
+ "test_split: test\n",
502
+ "fewshot_split: dev\n",
503
+ "fewshot_config:\n",
504
+ " sampler: first_n\n",
505
+ "output_type: multiple_choice\n",
506
+ "doc_to_text: \"{{question.strip()}}\\nA. {{choices[0]}}\\nB. {{choices[1]}}\\nC. {{choices[2]}}\\nD. {{choices[3]}}\\nAnswer:\"\n",
507
+ "doc_to_choice: [\"A\", \"B\", \"C\", \"D\"]\n",
508
+ "doc_to_target: answer\n",
509
+ "metric_list:\n",
510
+ " - metric: acc\n",
511
+ " aggregation: mean\n",
512
+ " higher_is_better: true\n",
513
+ " - metric: acc_norm\n",
514
+ " aggregation: mean\n",
515
+ " higher_is_better: true\n",
516
+ "\"\"\"\n",
517
+ "with open(\"mmlu_high_school_geography.yaml\", \"w\") as f:\n",
518
+ " f.write(YAML_mmlu_geo_string)"
519
+ ]
520
+ },
521
+ {
522
+ "cell_type": "code",
523
+ "execution_count": 8,
524
+ "metadata": {
525
+ "id": "jyKOfCsKb-xy"
526
+ },
527
+ "outputs": [
528
+ {
529
+ "name": "stdout",
530
+ "output_type": "stream",
531
+ "text": [
532
+ "2023-11-29:11:57:23,598 INFO [utils.py:160] NumExpr defaulting to 2 threads.\n",
533
+ "2023-11-29 11:57:24.719750: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
534
+ "2023-11-29 11:57:24.719806: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
535
+ "2023-11-29 11:57:24.719847: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
536
+ "2023-11-29 11:57:26.656125: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
537
+ "2023-11-29:11:57:31,563 INFO [__main__.py:132] Verbosity set to INFO\n",
538
+ "2023-11-29:11:57:40,541 WARNING [__main__.py:138] --limit SHOULD ONLY BE USED FOR TESTING.REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.\n",
539
+ "2023-11-29:11:57:40,541 INFO [__main__.py:143] Including path: ./\n",
540
+ "2023-11-29:11:57:40,558 INFO [__main__.py:205] Selected Tasks: ['demo_mmlu_high_school_geography']\n",
541
+ "2023-11-29:11:57:40,559 WARNING [evaluator.py:93] generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks.\n",
542
+ "2023-11-29:11:57:40,589 INFO [huggingface.py:120] Using device 'cuda'\n",
543
+ "Downloading builder script: 100% 5.84k/5.84k [00:00<00:00, 17.7MB/s]\n",
544
+ "Downloading metadata: 100% 106k/106k [00:00<00:00, 892kB/s] \n",
545
+ "Downloading readme: 100% 39.7k/39.7k [00:00<00:00, 631kB/s]\n",
546
+ "Downloading data: 100% 166M/166M [00:01<00:00, 89.0MB/s]\n",
547
+ "Generating auxiliary_train split: 100% 99842/99842 [00:07<00:00, 12536.83 examples/s]\n",
548
+ "Generating test split: 100% 198/198 [00:00<00:00, 1439.20 examples/s]\n",
549
+ "Generating validation split: 100% 22/22 [00:00<00:00, 4181.76 examples/s]\n",
550
+ "Generating dev split: 100% 5/5 [00:00<00:00, 36.25 examples/s]\n",
551
+ "2023-11-29:11:58:09,798 INFO [task.py:355] Building contexts for task on rank 0...\n",
552
+ "2023-11-29:11:58:09,822 INFO [evaluator.py:319] Running loglikelihood requests\n",
553
+ "100% 40/40 [00:05<00:00, 7.86it/s]\n",
554
+ "fatal: not a git repository (or any of the parent directories): .git\n",
555
+ "hf (pretrained=EleutherAI/pythia-2.8b), gen_kwargs: (), limit: 10.0, num_fewshot: None, batch_size: 1\n",
556
+ "| Tasks |Version|Filter|n-shot| Metric |Value| |Stderr|\n",
557
+ "|-------------------------------|-------|------|-----:|--------|----:|---|-----:|\n",
558
+ "|demo_mmlu_high_school_geography|Yaml |none | 0|acc | 0.3|± |0.1528|\n",
559
+ "| | |none | 0|acc_norm| 0.3|± |0.1528|\n",
560
+ "\n"
561
+ ]
562
+ }
563
+ ],
564
+ "source": [
565
+ "# !accelerate launch --no_python\n",
566
+ "%env LOGLEVEL=DEBUG\n",
567
+ "!lm_eval \\\n",
568
+ " --model hf \\\n",
569
+ " --model_args pretrained=EleutherAI/pythia-2.8b \\\n",
570
+ " --include_path ./ \\\n",
571
+ " --tasks demo_mmlu_high_school_geography \\\n",
572
+ " --limit 10 \\\n",
573
+ " --output output/mmlu_high_school_geography/ \\\n",
574
+ " --log_samples"
575
+ ]
576
+ },
577
+ {
578
+ "cell_type": "markdown",
579
+ "metadata": {
580
+ "id": "jyKOfCsKb-xy"
581
+ },
582
+ "source": [
583
+ "We could also evaluate this task in a different way. For example, instead of observing the loglikelihood of the letters, we can instead evaluate on the choices themselves as the continuation. This is done by simply changing `doc_to_choice` from a list of letters to the corresponding `choices` field from the HF dataset. We write `\"{{choices}}\"` so that the string field is interpreted as jinja string that acquires the list from the HF dataset directly.\n",
584
+ "\n",
585
+ "Another convenient feature here is since we're only modifying the `doc_to_choice` and the rest of config is the same as the task above, we can use the above configuration as a template by using `include: mmlu_high_school_geography.yaml` to load the config from that file. We'll need to add a unique task name as to not colide with the existing yaml config we're including. For this case we'll simply name this one `mmlu_high_school_geography_continuation`. `doc_to_text` is added here just for sake of clarity."
586
+ ]
587
+ },
588
+ {
589
+ "cell_type": "code",
590
+ "execution_count": 9,
591
+ "metadata": {
592
+ "id": "lqElwU54TaK-"
593
+ },
594
+ "outputs": [],
595
+ "source": [
596
+ "YAML_mmlu_geo_string = \"\"\"\n",
597
+ "include: mmlu_high_school_geography.yaml\n",
598
+ "task: demo_mmlu_high_school_geography_continuation\n",
599
+ "doc_to_text: \"{{question.strip()}}\\nA. {{choices[0]}}\\nB. {{choices[1]}}\\nC. {{choices[2]}}\\nD. {{choices[3]}}\\nAnswer:\"\n",
600
+ "doc_to_choice: \"{{choices}}\"\n",
601
+ "\"\"\"\n",
602
+ "with open(\"mmlu_high_school_geography_continuation.yaml\", \"w\") as f:\n",
603
+ " f.write(YAML_mmlu_geo_string)"
604
+ ]
605
+ },
606
+ {
607
+ "cell_type": "code",
608
+ "execution_count": 10,
609
+ "metadata": {
610
+ "id": "-_CVnDirdy7j"
611
+ },
612
+ "outputs": [
613
+ {
614
+ "name": "stdout",
615
+ "output_type": "stream",
616
+ "text": [
617
+ "2023-11-29:11:58:21,284 INFO [utils.py:160] NumExpr defaulting to 2 threads.\n",
618
+ "2023-11-29 11:58:22.850159: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
619
+ "2023-11-29 11:58:22.850219: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
620
+ "2023-11-29 11:58:22.850254: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
621
+ "2023-11-29 11:58:24.948103: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
622
+ "2023-11-29:11:58:28,460 INFO [__main__.py:132] Verbosity set to INFO\n",
623
+ "2023-11-29:11:58:37,935 WARNING [__main__.py:138] --limit SHOULD ONLY BE USED FOR TESTING.REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.\n",
624
+ "2023-11-29:11:58:37,935 INFO [__main__.py:143] Including path: ./\n",
625
+ "2023-11-29:11:58:37,969 INFO [__main__.py:205] Selected Tasks: ['demo_mmlu_high_school_geography_continuation']\n",
626
+ "2023-11-29:11:58:37,972 WARNING [evaluator.py:93] generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks.\n",
627
+ "2023-11-29:11:58:38,008 INFO [huggingface.py:120] Using device 'cuda'\n",
628
+ "2023-11-29:11:58:59,758 INFO [task.py:355] Building contexts for task on rank 0...\n",
629
+ "2023-11-29:11:58:59,777 INFO [evaluator.py:319] Running loglikelihood requests\n",
630
+ "100% 40/40 [00:02<00:00, 16.23it/s]\n",
631
+ "fatal: not a git repository (or any of the parent directories): .git\n",
632
+ "hf (pretrained=EleutherAI/pythia-2.8b), gen_kwargs: (), limit: 10.0, num_fewshot: None, batch_size: 1\n",
633
+ "| Tasks |Version|Filter|n-shot| Metric |Value| |Stderr|\n",
634
+ "|--------------------------------------------|-------|------|-----:|--------|----:|---|-----:|\n",
635
+ "|demo_mmlu_high_school_geography_continuation|Yaml |none | 0|acc | 0.1|± |0.1000|\n",
636
+ "| | |none | 0|acc_norm| 0.2|± |0.1333|\n",
637
+ "\n"
638
+ ]
639
+ }
640
+ ],
641
+ "source": [
642
+ "# !accelerate launch --no_python\n",
643
+ "%env LOGLEVEL=DEBUG\n",
644
+ "!lm_eval \\\n",
645
+ " --model hf \\\n",
646
+ " --model_args pretrained=EleutherAI/pythia-2.8b \\\n",
647
+ " --include_path ./ \\\n",
648
+ " --tasks demo_mmlu_high_school_geography_continuation \\\n",
649
+ " --limit 10 \\\n",
650
+ " --output output/mmlu_high_school_geography_continuation/ \\\n",
651
+ " --log_samples"
652
+ ]
653
+ },
654
+ {
655
+ "cell_type": "markdown",
656
+ "metadata": {
657
+ "id": "-_CVnDirdy7j"
658
+ },
659
+ "source": [
660
+ "If we take a look at the samples, we can see that it is in fact evaluating the continuation based on the choices rather than the letters."
661
+ ]
662
+ },
663
+ {
664
+ "cell_type": "code",
665
+ "execution_count": 11,
666
+ "metadata": {
667
+ "id": "duBDqC6PAdjL"
668
+ },
669
+ "outputs": [
670
+ {
671
+ "data": {
672
+ "application/javascript": "\n ((filepath) => {{\n if (!google.colab.kernel.accessAllowed) {{\n return;\n }}\n google.colab.files.view(filepath);\n }})(\"/content/output/mmlu_high_school_geography_continuation/pretrained__EleutherAI__pythia-2.8b_demo_mmlu_high_school_geography_continuation.jsonl\")",
673
+ "text/plain": [
674
+ "<IPython.core.display.Javascript object>"
675
+ ]
676
+ },
677
+ "metadata": {},
678
+ "output_type": "display_data"
679
+ }
680
+ ],
681
+ "source": [
682
+ "from google.colab import files\n",
683
+ "\n",
684
+ "\n",
685
+ "files.view(\n",
686
+ " \"output/mmlu_high_school_geography_continuation/pretrained__EleutherAI__pythia-2.8b_demo_mmlu_high_school_geography_continuation.jsonl\"\n",
687
+ ")"
688
+ ]
689
+ },
690
+ {
691
+ "cell_type": "markdown",
692
+ "metadata": {
693
+ "id": "6p0-KPwAgK5j"
694
+ },
695
+ "source": [
696
+ "## Closer Look at YAML Fields\n",
697
+ "\n",
698
+ "To prepare a task we can simply fill in a YAML config with the relevant information.\n",
699
+ "\n",
700
+ "`output_type`\n",
701
+ "The current provided evaluation types comprise of the following:\n",
702
+ "1. `loglikelihood`: Evaluates the loglikelihood of a continuation, conditioned on some input string.\n",
703
+ "2. `loglikelihood_rolling`: evaluate the loglikelihood of producing a string, conditioned on the empty string. (Used for perplexity evaluations)\n",
704
+ "3. `multiple_choice`: Evaluates loglikelihood among the a number of choices predicted by the model.\n",
705
+ "4. `greedy_until`: Model outputs greedy generation (can be configured to to use beam search and other generation-related parameters)\n",
706
+ "\n",
707
+ "The core prompt revolves around 3 fields.\n",
708
+ "1. `doc_to_text`: Denotes the prompt template that will be used as input to the model.\n",
709
+ "2. `doc_to_choice`: Available choices that will be used as continuation for the model. This is used when the `output_type` is `multiple_choice`, and otherwise can be left as `None`.\n",
710
+ "3. `doc_to_target`: When `output_type` is `multiple_choice`, this can be an index that corresponds to the correct answer, or the answer string itself (must be a subset of `doc_to_choice`). For other tasks, this is expected to be a string. You can fill this field with a feature name from the HF dataset so long as the resulting feature follows the conditioned described.\n",
711
+ "\n",
712
+ "These three fields can be expressed as strings, column names from the source dataset, or as Jinja2 templates that can use fields from the source dataset as variables.\n"
713
+ ]
714
+ },
715
+ {
716
+ "cell_type": "markdown",
717
+ "metadata": {
718
+ "id": "6p0-KPwAgK5j"
719
+ },
720
+ "source": [
721
+ "## What if Jinja is not Sufficient?\n",
722
+ "\n",
723
+ "There can be times where the Jinja2 templating language is not enough to make the prompt we had in mind. There are a few ways to circumvent this limitation:\n",
724
+ "\n",
725
+ "1. Use `!function` operator for the prompt-related fields to pass a python function that takes as input the dataset row, and will output the prompt template component.\n",
726
+ "2. Perform a transformation on the dataset beforehand."
727
+ ]
728
+ },
729
+ {
730
+ "cell_type": "markdown",
731
+ "metadata": {},
732
+ "source": [
733
+ "Below, we show an example of using `!function` to create `doc_to_text` from a python function:"
734
+ ]
735
+ },
736
+ {
737
+ "cell_type": "code",
738
+ "execution_count": 12,
739
+ "metadata": {
740
+ "colab": {
741
+ "base_uri": "https://localhost:8080/"
742
+ },
743
+ "id": "DYZ5c0JhR1lJ",
744
+ "outputId": "ca945235-fb9e-4f17-8bfa-78e7d6ec1490"
745
+ },
746
+ "outputs": [
747
+ {
748
+ "name": "stdout",
749
+ "output_type": "stream",
750
+ "text": [
751
+ "2023-11-29:11:59:08,312 INFO [utils.py:160] NumExpr defaulting to 2 threads.\n",
752
+ "2023-11-29 11:59:09.348327: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
753
+ "2023-11-29 11:59:09.348387: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
754
+ "2023-11-29 11:59:09.348421: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
755
+ "2023-11-29 11:59:10.573752: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
756
+ "2023-11-29:11:59:14,044 INFO [__main__.py:132] Verbosity set to INFO\n",
757
+ "2023-11-29:11:59:23,654 WARNING [__main__.py:138] --limit SHOULD ONLY BE USED FOR TESTING.REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.\n",
758
+ "2023-11-29:11:59:23,654 INFO [__main__.py:143] Including path: ./\n",
759
+ "2023-11-29:11:59:23,678 INFO [__main__.py:205] Selected Tasks: ['demo_mmlu_high_school_geography_function_prompt']\n",
760
+ "2023-11-29:11:59:23,679 WARNING [evaluator.py:93] generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks.\n",
761
+ "2023-11-29:11:59:23,708 INFO [huggingface.py:120] Using device 'cuda'\n",
762
+ "2023-11-29:11:59:44,516 INFO [task.py:355] Building contexts for task on rank 0...\n",
763
+ "2023-11-29:11:59:44,524 INFO [evaluator.py:319] Running loglikelihood requests\n",
764
+ "100% 40/40 [00:02<00:00, 15.41it/s]\n",
765
+ "fatal: not a git repository (or any of the parent directories): .git\n",
766
+ "hf (pretrained=EleutherAI/pythia-2.8b), gen_kwargs: (), limit: 10.0, num_fewshot: None, batch_size: 1\n",
767
+ "| Tasks |Version|Filter|n-shot| Metric |Value| |Stderr|\n",
768
+ "|-----------------------------------------------|-------|------|-----:|--------|----:|---|-----:|\n",
769
+ "|demo_mmlu_high_school_geography_function_prompt|Yaml |none | 0|acc | 0.1|± |0.1000|\n",
770
+ "| | |none | 0|acc_norm| 0.2|± |0.1333|\n",
771
+ "\n"
772
+ ]
773
+ }
774
+ ],
775
+ "source": [
776
+ "YAML_mmlu_geo_string = \"\"\"\n",
777
+ "include: mmlu_high_school_geography.yaml\n",
778
+ "task: demo_mmlu_high_school_geography_function_prompt\n",
779
+ "doc_to_text: !function utils.doc_to_text\n",
780
+ "doc_to_choice: \"{{choices}}\"\n",
781
+ "\"\"\"\n",
782
+ "with open(\"demo_mmlu_high_school_geography_function_prompt.yaml\", \"w\") as f:\n",
783
+ " f.write(YAML_mmlu_geo_string)\n",
784
+ "\n",
785
+ "DOC_TO_TEXT = \"\"\"\n",
786
+ "def doc_to_text(x):\n",
787
+ " question = x[\"question\"].strip()\n",
788
+ " choices = x[\"choices\"]\n",
789
+ " option_a = choices[0]\n",
790
+ " option_b = choices[1]\n",
791
+ " option_c = choices[2]\n",
792
+ " option_d = choices[3]\n",
793
+ " return f\"{question}\\\\nA. {option_a}\\\\nB. {option_b}\\\\nC. {option_c}\\\\nD. {option_d}\\\\nAnswer:\"\n",
794
+ "\"\"\"\n",
795
+ "with open(\"utils.py\", \"w\") as f:\n",
796
+ " f.write(DOC_TO_TEXT)\n",
797
+ "\n",
798
+ "!lm_eval \\\n",
799
+ " --model hf \\\n",
800
+ " --model_args pretrained=EleutherAI/pythia-2.8b \\\n",
801
+ " --include_path ./ \\\n",
802
+ " --tasks demo_mmlu_high_school_geography_function_prompt \\\n",
803
+ " --limit 10 \\\n",
804
+ " --output output/demo_mmlu_high_school_geography_function_prompt/ \\\n",
805
+ " --log_samples"
806
+ ]
807
+ },
808
+ {
809
+ "cell_type": "markdown",
810
+ "metadata": {},
811
+ "source": [
812
+ "Next, we'll also show how to do this via preprocessing the dataset as necessary using the `process_docs` config field:\n",
813
+ "\n",
814
+ "We will write a function that will modify each document in our evaluation dataset's split to add a field that is suitable for us to use in `doc_to_text`."
815
+ ]
816
+ },
817
+ {
818
+ "cell_type": "code",
819
+ "execution_count": null,
820
+ "metadata": {},
821
+ "outputs": [],
822
+ "source": [
823
+ "YAML_mmlu_geo_string = \"\"\"\n",
824
+ "include: mmlu_high_school_geography.yaml\n",
825
+ "task: demo_mmlu_high_school_geography_function_prompt_2\n",
826
+ "process_docs: !function utils_process_docs.process_docs\n",
827
+ "doc_to_text: \"{{input}}\"\n",
828
+ "doc_to_choice: \"{{choices}}\"\n",
829
+ "\"\"\"\n",
830
+ "with open(\"demo_mmlu_high_school_geography_process_docs.yaml\", \"w\") as f:\n",
831
+ " f.write(YAML_mmlu_geo_string)\n",
832
+ "\n",
833
+ "DOC_TO_TEXT = \"\"\"\n",
834
+ "def process_docs(dataset):\n",
835
+ " def _process_doc(x):\n",
836
+ " question = x[\"question\"].strip()\n",
837
+ " choices = x[\"choices\"]\n",
838
+ " option_a = choices[0]\n",
839
+ " option_b = choices[1]\n",
840
+ " option_c = choices[2]\n",
841
+ " option_d = choices[3]\n",
842
+ " doc[\"input\"] = f\"{question}\\\\nA. {option_a}\\\\nB. {option_b}\\\\nC. {option_c}\\\\nD. {option_d}\\\\nAnswer:\"\n",
843
+ " return out_doc\n",
844
+ "\n",
845
+ " return dataset.map(_process_doc)\n",
846
+ "\"\"\"\n",
847
+ "\n",
848
+ "with open(\"utils_process_docs.py\", \"w\") as f:\n",
849
+ " f.write(DOC_TO_TEXT)\n",
850
+ "\n",
851
+ "!lm_eval \\\n",
852
+ " --model hf \\\n",
853
+ " --model_args pretrained=EleutherAI/pythia-2.8b \\\n",
854
+ " --include_path ./ \\\n",
855
+ " --tasks demo_mmlu_high_school_geography_function_prompt_2 \\\n",
856
+ " --limit 10 \\\n",
857
+ " --output output/demo_mmlu_high_school_geography_function_prompt_2/ \\\n",
858
+ " --log_samples"
859
+ ]
860
+ },
861
+ {
862
+ "cell_type": "markdown",
863
+ "metadata": {},
864
+ "source": [
865
+ "We hope that this explainer gives you a sense of what can be done with and how to work with LM-Evaluation-Harnes v0.4.0 ! \n",
866
+ "\n",
867
+ "For more information, check out our documentation pages in the `docs/` folder, and if you have questions, please raise them in GitHub issues, or in #lm-thunderdome or #release-discussion on the EleutherAI discord server."
868
+ ]
869
+ }
870
+ ],
871
+ "metadata": {
872
+ "accelerator": "GPU",
873
+ "colab": {
874
+ "collapsed_sections": [
875
+ "zAov81vTbL2K"
876
+ ],
877
+ "gpuType": "T4",
878
+ "provenance": []
879
+ },
880
+ "kernelspec": {
881
+ "display_name": "Python 3",
882
+ "name": "python3"
883
+ },
884
+ "language_info": {
885
+ "name": "python"
886
+ },
887
+ "widgets": {
888
+ "application/vnd.jupyter.widget-state+json": {
889
+ "state": {
890
+ "46f521b73fd943c081c648fd873ebc0a": {
891
+ "model_module": "@jupyter-widgets/controls",
892
+ "model_module_version": "1.5.0",
893
+ "model_name": "DescriptionStyleModel",
894
+ "state": {
895
+ "_model_module": "@jupyter-widgets/controls",
896
+ "_model_module_version": "1.5.0",
897
+ "_model_name": "DescriptionStyleModel",
898
+ "_view_count": null,
899
+ "_view_module": "@jupyter-widgets/base",
900
+ "_view_module_version": "1.2.0",
901
+ "_view_name": "StyleView",
902
+ "description_width": ""
903
+ }
904
+ },
905
+ "48763b6233374554ae76035c0483066f": {
906
+ "model_module": "@jupyter-widgets/controls",
907
+ "model_module_version": "1.5.0",
908
+ "model_name": "ProgressStyleModel",
909
+ "state": {
910
+ "_model_module": "@jupyter-widgets/controls",
911
+ "_model_module_version": "1.5.0",
912
+ "_model_name": "ProgressStyleModel",
913
+ "_view_count": null,
914
+ "_view_module": "@jupyter-widgets/base",
915
+ "_view_module_version": "1.2.0",
916
+ "_view_name": "StyleView",
917
+ "bar_color": null,
918
+ "description_width": ""
919
+ }
920
+ },
921
+ "4986a21eb560448fa79f4b25cde48951": {
922
+ "model_module": "@jupyter-widgets/base",
923
+ "model_module_version": "1.2.0",
924
+ "model_name": "LayoutModel",
925
+ "state": {
926
+ "_model_module": "@jupyter-widgets/base",
927
+ "_model_module_version": "1.2.0",
928
+ "_model_name": "LayoutModel",
929
+ "_view_count": null,
930
+ "_view_module": "@jupyter-widgets/base",
931
+ "_view_module_version": "1.2.0",
932
+ "_view_name": "LayoutView",
933
+ "align_content": null,
934
+ "align_items": null,
935
+ "align_self": null,
936
+ "border": null,
937
+ "bottom": null,
938
+ "display": null,
939
+ "flex": null,
940
+ "flex_flow": null,
941
+ "grid_area": null,
942
+ "grid_auto_columns": null,
943
+ "grid_auto_flow": null,
944
+ "grid_auto_rows": null,
945
+ "grid_column": null,
946
+ "grid_gap": null,
947
+ "grid_row": null,
948
+ "grid_template_areas": null,
949
+ "grid_template_columns": null,
950
+ "grid_template_rows": null,
951
+ "height": null,
952
+ "justify_content": null,
953
+ "justify_items": null,
954
+ "left": null,
955
+ "margin": null,
956
+ "max_height": null,
957
+ "max_width": null,
958
+ "min_height": null,
959
+ "min_width": null,
960
+ "object_fit": null,
961
+ "object_position": null,
962
+ "order": null,
963
+ "overflow": null,
964
+ "overflow_x": null,
965
+ "overflow_y": null,
966
+ "padding": null,
967
+ "right": null,
968
+ "top": null,
969
+ "visibility": null,
970
+ "width": null
971
+ }
972
+ },
973
+ "6b2d90209ec14230b3d58a74ac9b83bf": {
974
+ "model_module": "@jupyter-widgets/base",
975
+ "model_module_version": "1.2.0",
976
+ "model_name": "LayoutModel",
977
+ "state": {
978
+ "_model_module": "@jupyter-widgets/base",
979
+ "_model_module_version": "1.2.0",
980
+ "_model_name": "LayoutModel",
981
+ "_view_count": null,
982
+ "_view_module": "@jupyter-widgets/base",
983
+ "_view_module_version": "1.2.0",
984
+ "_view_name": "LayoutView",
985
+ "align_content": null,
986
+ "align_items": null,
987
+ "align_self": null,
988
+ "border": null,
989
+ "bottom": null,
990
+ "display": null,
991
+ "flex": null,
992
+ "flex_flow": null,
993
+ "grid_area": null,
994
+ "grid_auto_columns": null,
995
+ "grid_auto_flow": null,
996
+ "grid_auto_rows": null,
997
+ "grid_column": null,
998
+ "grid_gap": null,
999
+ "grid_row": null,
1000
+ "grid_template_areas": null,
1001
+ "grid_template_columns": null,
1002
+ "grid_template_rows": null,
1003
+ "height": null,
1004
+ "justify_content": null,
1005
+ "justify_items": null,
1006
+ "left": null,
1007
+ "margin": null,
1008
+ "max_height": null,
1009
+ "max_width": null,
1010
+ "min_height": null,
1011
+ "min_width": null,
1012
+ "object_fit": null,
1013
+ "object_position": null,
1014
+ "order": null,
1015
+ "overflow": null,
1016
+ "overflow_x": null,
1017
+ "overflow_y": null,
1018
+ "padding": null,
1019
+ "right": null,
1020
+ "top": null,
1021
+ "visibility": null,
1022
+ "width": null
1023
+ }
1024
+ },
1025
+ "7c5689bc13684db8a22681f41863dddd": {
1026
+ "model_module": "@jupyter-widgets/base",
1027
+ "model_module_version": "1.2.0",
1028
+ "model_name": "LayoutModel",
1029
+ "state": {
1030
+ "_model_module": "@jupyter-widgets/base",
1031
+ "_model_module_version": "1.2.0",
1032
+ "_model_name": "LayoutModel",
1033
+ "_view_count": null,
1034
+ "_view_module": "@jupyter-widgets/base",
1035
+ "_view_module_version": "1.2.0",
1036
+ "_view_name": "LayoutView",
1037
+ "align_content": null,
1038
+ "align_items": null,
1039
+ "align_self": null,
1040
+ "border": null,
1041
+ "bottom": null,
1042
+ "display": null,
1043
+ "flex": null,
1044
+ "flex_flow": null,
1045
+ "grid_area": null,
1046
+ "grid_auto_columns": null,
1047
+ "grid_auto_flow": null,
1048
+ "grid_auto_rows": null,
1049
+ "grid_column": null,
1050
+ "grid_gap": null,
1051
+ "grid_row": null,
1052
+ "grid_template_areas": null,
1053
+ "grid_template_columns": null,
1054
+ "grid_template_rows": null,
1055
+ "height": null,
1056
+ "justify_content": null,
1057
+ "justify_items": null,
1058
+ "left": null,
1059
+ "margin": null,
1060
+ "max_height": null,
1061
+ "max_width": null,
1062
+ "min_height": null,
1063
+ "min_width": null,
1064
+ "object_fit": null,
1065
+ "object_position": null,
1066
+ "order": null,
1067
+ "overflow": null,
1068
+ "overflow_x": null,
1069
+ "overflow_y": null,
1070
+ "padding": null,
1071
+ "right": null,
1072
+ "top": null,
1073
+ "visibility": null,
1074
+ "width": null
1075
+ }
1076
+ },
1077
+ "a1d3a8aa016544a78e8821c8f6199e06": {
1078
+ "model_module": "@jupyter-widgets/controls",
1079
+ "model_module_version": "1.5.0",
1080
+ "model_name": "HBoxModel",
1081
+ "state": {
1082
+ "_dom_classes": [],
1083
+ "_model_module": "@jupyter-widgets/controls",
1084
+ "_model_module_version": "1.5.0",
1085
+ "_model_name": "HBoxModel",
1086
+ "_view_count": null,
1087
+ "_view_module": "@jupyter-widgets/controls",
1088
+ "_view_module_version": "1.5.0",
1089
+ "_view_name": "HBoxView",
1090
+ "box_style": "",
1091
+ "children": [
1092
+ "IPY_MODEL_f61ed33fad754146bdd2ac9db1ba1c48",
1093
+ "IPY_MODEL_bfa0af6aeff344c6845e1080a878e92e",
1094
+ "IPY_MODEL_fd1ad9e0367d4004aae853b91c3a7617"
1095
+ ],
1096
+ "layout": "IPY_MODEL_6b2d90209ec14230b3d58a74ac9b83bf"
1097
+ }
1098
+ },
1099
+ "a73f357065d34d7baf0453ae4a8d75e2": {
1100
+ "model_module": "@jupyter-widgets/base",
1101
+ "model_module_version": "1.2.0",
1102
+ "model_name": "LayoutModel",
1103
+ "state": {
1104
+ "_model_module": "@jupyter-widgets/base",
1105
+ "_model_module_version": "1.2.0",
1106
+ "_model_name": "LayoutModel",
1107
+ "_view_count": null,
1108
+ "_view_module": "@jupyter-widgets/base",
1109
+ "_view_module_version": "1.2.0",
1110
+ "_view_name": "LayoutView",
1111
+ "align_content": null,
1112
+ "align_items": null,
1113
+ "align_self": null,
1114
+ "border": null,
1115
+ "bottom": null,
1116
+ "display": null,
1117
+ "flex": null,
1118
+ "flex_flow": null,
1119
+ "grid_area": null,
1120
+ "grid_auto_columns": null,
1121
+ "grid_auto_flow": null,
1122
+ "grid_auto_rows": null,
1123
+ "grid_column": null,
1124
+ "grid_gap": null,
1125
+ "grid_row": null,
1126
+ "grid_template_areas": null,
1127
+ "grid_template_columns": null,
1128
+ "grid_template_rows": null,
1129
+ "height": null,
1130
+ "justify_content": null,
1131
+ "justify_items": null,
1132
+ "left": null,
1133
+ "margin": null,
1134
+ "max_height": null,
1135
+ "max_width": null,
1136
+ "min_height": null,
1137
+ "min_width": null,
1138
+ "object_fit": null,
1139
+ "object_position": null,
1140
+ "order": null,
1141
+ "overflow": null,
1142
+ "overflow_x": null,
1143
+ "overflow_y": null,
1144
+ "padding": null,
1145
+ "right": null,
1146
+ "top": null,
1147
+ "visibility": null,
1148
+ "width": null
1149
+ }
1150
+ },
1151
+ "aed3acd2f2d74003b44079c333a0698e": {
1152
+ "model_module": "@jupyter-widgets/controls",
1153
+ "model_module_version": "1.5.0",
1154
+ "model_name": "DescriptionStyleModel",
1155
+ "state": {
1156
+ "_model_module": "@jupyter-widgets/controls",
1157
+ "_model_module_version": "1.5.0",
1158
+ "_model_name": "DescriptionStyleModel",
1159
+ "_view_count": null,
1160
+ "_view_module": "@jupyter-widgets/base",
1161
+ "_view_module_version": "1.2.0",
1162
+ "_view_name": "StyleView",
1163
+ "description_width": ""
1164
+ }
1165
+ },
1166
+ "bfa0af6aeff344c6845e1080a878e92e": {
1167
+ "model_module": "@jupyter-widgets/controls",
1168
+ "model_module_version": "1.5.0",
1169
+ "model_name": "FloatProgressModel",
1170
+ "state": {
1171
+ "_dom_classes": [],
1172
+ "_model_module": "@jupyter-widgets/controls",
1173
+ "_model_module_version": "1.5.0",
1174
+ "_model_name": "FloatProgressModel",
1175
+ "_view_count": null,
1176
+ "_view_module": "@jupyter-widgets/controls",
1177
+ "_view_module_version": "1.5.0",
1178
+ "_view_name": "ProgressView",
1179
+ "bar_style": "success",
1180
+ "description": "",
1181
+ "description_tooltip": null,
1182
+ "layout": "IPY_MODEL_7c5689bc13684db8a22681f41863dddd",
1183
+ "max": 5669,
1184
+ "min": 0,
1185
+ "orientation": "horizontal",
1186
+ "style": "IPY_MODEL_48763b6233374554ae76035c0483066f",
1187
+ "value": 5669
1188
+ }
1189
+ },
1190
+ "f61ed33fad754146bdd2ac9db1ba1c48": {
1191
+ "model_module": "@jupyter-widgets/controls",
1192
+ "model_module_version": "1.5.0",
1193
+ "model_name": "HTMLModel",
1194
+ "state": {
1195
+ "_dom_classes": [],
1196
+ "_model_module": "@jupyter-widgets/controls",
1197
+ "_model_module_version": "1.5.0",
1198
+ "_model_name": "HTMLModel",
1199
+ "_view_count": null,
1200
+ "_view_module": "@jupyter-widgets/controls",
1201
+ "_view_module_version": "1.5.0",
1202
+ "_view_name": "HTMLView",
1203
+ "description": "",
1204
+ "description_tooltip": null,
1205
+ "layout": "IPY_MODEL_a73f357065d34d7baf0453ae4a8d75e2",
1206
+ "placeholder": "​",
1207
+ "style": "IPY_MODEL_46f521b73fd943c081c648fd873ebc0a",
1208
+ "value": "Downloading builder script: 100%"
1209
+ }
1210
+ },
1211
+ "fd1ad9e0367d4004aae853b91c3a7617": {
1212
+ "model_module": "@jupyter-widgets/controls",
1213
+ "model_module_version": "1.5.0",
1214
+ "model_name": "HTMLModel",
1215
+ "state": {
1216
+ "_dom_classes": [],
1217
+ "_model_module": "@jupyter-widgets/controls",
1218
+ "_model_module_version": "1.5.0",
1219
+ "_model_name": "HTMLModel",
1220
+ "_view_count": null,
1221
+ "_view_module": "@jupyter-widgets/controls",
1222
+ "_view_module_version": "1.5.0",
1223
+ "_view_name": "HTMLView",
1224
+ "description": "",
1225
+ "description_tooltip": null,
1226
+ "layout": "IPY_MODEL_4986a21eb560448fa79f4b25cde48951",
1227
+ "placeholder": "​",
1228
+ "style": "IPY_MODEL_aed3acd2f2d74003b44079c333a0698e",
1229
+ "value": " 5.67k/5.67k [00:00&lt;00:00, 205kB/s]"
1230
+ }
1231
+ }
1232
+ },
1233
+ "version_major": 2,
1234
+ "version_minor": 0
1235
+ }
1236
+ }
1237
+ },
1238
+ "nbformat": 4,
1239
+ "nbformat_minor": 0
1240
+ }
lm-evaluation-harness/examples/transformer-lens.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformer_lens import HookedTransformer
6
+ from transformers import AutoConfig
7
+
8
+ from lm_eval import evaluator
9
+ from lm_eval.models.huggingface import HFLM
10
+
11
+
12
+ def evaluate_lm_eval(lens_model: HookedTransformer, tasks: list[str], **kwargs):
13
+ class HFLikeModelAdapter(nn.Module):
14
+ """Adapts HookedTransformer to match the HuggingFace interface expected by lm-eval"""
15
+
16
+ def __init__(self, model: HookedTransformer):
17
+ super().__init__()
18
+ self.model = model
19
+ self.tokenizer = model.tokenizer
20
+ self.config = AutoConfig.from_pretrained(model.cfg.tokenizer_name)
21
+ self.device = model.cfg.device
22
+ self.tie_weights = lambda: self
23
+
24
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
25
+ output = self.model(input_ids, attention_mask=attention_mask, **kwargs)
26
+ # Make sure output has the expected .logits attribute
27
+ if not hasattr(output, "logits"):
28
+ if isinstance(output, torch.Tensor):
29
+ output.logits = output
30
+ return output
31
+
32
+ # Only delegate specific attributes we know we need
33
+ def to(self, *args, **kwargs):
34
+ return self.model.to(*args, **kwargs)
35
+
36
+ def eval(self):
37
+ self.model.eval()
38
+ return self
39
+
40
+ def train(self, mode=True):
41
+ self.model.train(mode)
42
+ return self
43
+
44
+ model = HFLikeModelAdapter(lens_model)
45
+ warnings.filterwarnings("ignore", message="Failed to get model SHA for")
46
+ results = evaluator.simple_evaluate(
47
+ model=HFLM(pretrained=model, tokenizer=model.tokenizer),
48
+ tasks=tasks,
49
+ verbosity="WARNING",
50
+ **kwargs,
51
+ )
52
+ return results
53
+
54
+
55
+ if __name__ == "__main__":
56
+ # Load base model
57
+ model = HookedTransformer.from_pretrained("pythia-70m")
58
+ res = evaluate_lm_eval(model, tasks=["arc_easy"])
59
+ print(res["results"])
lm-evaluation-harness/examples/visualize-wandb.ipynb ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "fc477b96-adee-4829-a9d7-a5eb990df358",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Visualizing Results in Weights and Biases\n",
9
+ "\n",
10
+ "With the Weights and Biases integration, you can now spend more time extracting deeper insights into your evaluation results. The integration is designed to streamline the process of logging and visualizing experiment results using the Weights & Biases (W&B) platform.\n",
11
+ "\n",
12
+ "The integration provide functionalities\n",
13
+ "\n",
14
+ "- to automatically log the evaluation results,\n",
15
+ "- log the samples as W&B Tables for easy visualization,\n",
16
+ "- log the `results.json` file as an artifact for version control,\n",
17
+ "- log the `<task_name>_eval_samples.json` file if the samples are logged,\n",
18
+ "- generate a comprehensive report for analysis and visualization with all the important metric,\n",
19
+ "- log task and cli configs,\n",
20
+ "- and more out of the box like the command used to run the evaluation, GPU/CPU counts, timestamp, etc.\n",
21
+ "\n",
22
+ "The integration is super easy to use with the eval harness. Let's see how!"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "id": "3851439a-bff4-41f2-bf21-1b3d8704913b",
29
+ "metadata": {
30
+ "scrolled": true
31
+ },
32
+ "outputs": [],
33
+ "source": [
34
+ "# Install this project if you did not already have it.\n",
35
+ "# This is all that is needed to be installed to start using Weights and Biases\n",
36
+ "\n",
37
+ "!pip -qq install -e ..[wandb]"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "markdown",
42
+ "id": "8507fd7e-3b99-4a92-89fa-9eaada74ba91",
43
+ "metadata": {},
44
+ "source": [
45
+ "# Run the Eval Harness\n",
46
+ "\n",
47
+ "Run the eval harness as usual with a `wandb_args` flag. This flag is used to provide arguments for initializing a wandb run ([wandb.init](https://docs.wandb.ai/ref/python/init)) as comma separated string arguments.\n",
48
+ "\n",
49
+ "If `wandb_args` flag is used, the metrics and all other goodness will be automatically logged to Weights and Biases. In the stdout, you will find the link to the W&B run page as well as link to the generated report."
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "id": "eec5866e-f01e-42f8-8803-9d77472ef991",
55
+ "metadata": {},
56
+ "source": [
57
+ "## Set your API Key\n",
58
+ "\n",
59
+ "Before you can use W&B, you need to authenticate your machine with an authentication key. Visit https://wandb.ai/authorize to get one."
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "id": "d824d163-71a9-4313-935d-f1d56397841c",
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "import wandb\n",
70
+ "\n",
71
+ "\n",
72
+ "wandb.login()"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "id": "124e4a34-1547-4bed-bc09-db012bacbda6",
78
+ "metadata": {},
79
+ "source": [
80
+ "> Note that if you are using command line you can simply authenticate your machine by doing `wandb login` in your terminal. For more info check out the [documentation](https://docs.wandb.ai/quickstart#2-log-in-to-wb)."
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "markdown",
85
+ "id": "abc6f6b6-179a-4aff-ada9-f380fb74df6e",
86
+ "metadata": {},
87
+ "source": [
88
+ "## Run and log to W&B"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "id": "bd0a8130-a97b-451a-acd2-3f9885b88643",
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "!lm_eval \\\n",
99
+ " --model hf \\\n",
100
+ " --model_args pretrained=microsoft/phi-2,trust_remote_code=True \\\n",
101
+ " --tasks hellaswag,mmlu_abstract_algebra \\\n",
102
+ " --device cuda:0 \\\n",
103
+ " --batch_size 8 \\\n",
104
+ " --output_path output/phi-2 \\\n",
105
+ " --limit 10 \\\n",
106
+ " --wandb_args project=lm-eval-harness-integration \\\n",
107
+ " --log_samples"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "markdown",
112
+ "id": "e974cabdbe70b667",
113
+ "metadata": {},
114
+ "source": []
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "id": "5178ca9445b844e4",
119
+ "metadata": {},
120
+ "source": [
121
+ "W&B can also be initialized programmatically for use outside the CLI to parse and log the results."
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "id": "c6a421b2cf3ddac5",
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "import lm_eval\n",
132
+ "from lm_eval.loggers import WandbLogger\n",
133
+ "\n",
134
+ "\n",
135
+ "results = lm_eval.simple_evaluate(\n",
136
+ " model=\"hf\",\n",
137
+ " model_args=\"pretrained=microsoft/phi-2,trust_remote_code=True\",\n",
138
+ " tasks=\"hellaswag,mmlu_abstract_algebra\",\n",
139
+ " log_samples=True,\n",
140
+ ")\n",
141
+ "\n",
142
+ "wandb_logger = WandbLogger(\n",
143
+ " project=\"lm-eval-harness-integration\", job_type=\"eval\"\n",
144
+ ") # or empty if wandb.init(...) already called before\n",
145
+ "wandb_logger.post_init(results)\n",
146
+ "wandb_logger.log_eval_result()\n",
147
+ "wandb_logger.log_eval_samples(results[\"samples\"]) # if log_samples"
148
+ ]
149
+ }
150
+ ],
151
+ "metadata": {
152
+ "kernelspec": {
153
+ "display_name": "Python 3 (ipykernel)",
154
+ "language": "python",
155
+ "name": "python3"
156
+ },
157
+ "language_info": {
158
+ "codemirror_mode": {
159
+ "name": "ipython",
160
+ "version": 3
161
+ },
162
+ "file_extension": ".py",
163
+ "mimetype": "text/x-python",
164
+ "name": "python",
165
+ "nbconvert_exporter": "python",
166
+ "pygments_lexer": "ipython3",
167
+ "version": "3.10.12"
168
+ }
169
+ },
170
+ "nbformat": 4,
171
+ "nbformat_minor": 5
172
+ }
lm-evaluation-harness/examples/visualize-zeno.ipynb ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Visualizing Results in Zeno\n",
8
+ "\n",
9
+ "Benchmarking your models is the first step towards making sure your model performs well.\n",
10
+ "However, looking at the data behind the benchmark, slicing the data into subsets, and comparing models on individual instances can help you even more in evaluating and quantifying the behavior of your AI system.\n",
11
+ "\n",
12
+ "All of this can be done in [Zeno](https://zenoml.com)!\n",
13
+ "Zeno is super easy to use with the eval harness, let's explore how you can easily upload and visualize your eval results.\n"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "# Install this project if you did not already do that. This is all that needs to be installed for you to be able to visualize your data in Zeno!\n",
23
+ "!pip install -e ..\n",
24
+ "!pip install -e ..[zeno]"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "metadata": {},
30
+ "source": [
31
+ "# Run the Eval Harness\n",
32
+ "\n",
33
+ "To visualize the results, run the eval harness with the `log_samples` and `output_path` flags. We expect `output_path` to contain multiple folders that represent individual model names. You can thus run your evaluation on any number of tasks and models and upload all of the results as projects on Zeno.\n"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "!lm_eval \\\n",
43
+ " --model hf \\\n",
44
+ " --model_args pretrained=EleutherAI/gpt-neo-2.7B \\\n",
45
+ " --tasks hellaswag,wikitext \\\n",
46
+ " --batch_size 8 \\\n",
47
+ " --device mps \\\n",
48
+ " --log_samples \\\n",
49
+ " --output_path output/gpt-neo-2.7B \\\n",
50
+ " --limit 10"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "metadata": {},
56
+ "source": [
57
+ "# Set your API Key\n",
58
+ "\n",
59
+ "This is so you can be authenticated with Zeno.\n",
60
+ "If you don't already have a Zeno account, first create an account on [Zeno Hub](https://hub.zenoml.com).\n",
61
+ "After logging in to Zeno Hub, generate your API key by clicking on your profile at the bottom left to navigate to your account page.\n"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "%env ZENO_API_KEY=YOUR_API_KEY"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "markdown",
75
+ "metadata": {},
76
+ "source": [
77
+ "# Visualize Eval Results\n",
78
+ "\n",
79
+ "You can now use the `zeno_visualize` script to upload the results to Zeno.\n",
80
+ "\n",
81
+ "This will use all subfolders in `data_path` as different models and upload all tasks within these model folders to Zeno. If you run the eval harness on multiple tasks, the `project_name` will be used as a prefix and one project will be created per task.\n"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": null,
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": [
90
+ "!python ../scripts/zeno_visualize.py --data_path output --project_name \"Zeno Upload Test\""
91
+ ]
92
+ }
93
+ ],
94
+ "metadata": {
95
+ "kernelspec": {
96
+ "display_name": "zeno_projects",
97
+ "language": "python",
98
+ "name": "python3"
99
+ },
100
+ "language_info": {
101
+ "codemirror_mode": {
102
+ "name": "ipython",
103
+ "version": 3
104
+ },
105
+ "file_extension": ".py",
106
+ "mimetype": "text/x-python",
107
+ "name": "python",
108
+ "nbconvert_exporter": "python",
109
+ "pygments_lexer": "ipython3",
110
+ "version": "3.10.11"
111
+ }
112
+ },
113
+ "nbformat": 4,
114
+ "nbformat_minor": 2
115
+ }
lm-evaluation-harness/lm_eval/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+
5
+ __version__ = "0.4.9.1"
6
+
7
+
8
+ # Lazy-load .evaluator module to improve CLI startup
9
+ def __getattr__(name):
10
+ if name == "evaluate":
11
+ from .evaluator import evaluate
12
+
13
+ return evaluate
14
+ elif name == "simple_evaluate":
15
+ from .evaluator import simple_evaluate
16
+
17
+ return simple_evaluate
18
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
19
+
20
+
21
+ __all__ = ["evaluate", "simple_evaluate", "__version__"]
lm-evaluation-harness/lm_eval/__main__.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import sys
6
+ from functools import partial
7
+ from pathlib import Path
8
+ from typing import Union
9
+
10
+
11
+ def try_parse_json(value: str) -> Union[str, dict, None]:
12
+ if value is None:
13
+ return None
14
+ try:
15
+ return json.loads(value)
16
+ except json.JSONDecodeError:
17
+ if "{" in value:
18
+ raise argparse.ArgumentTypeError(
19
+ f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
20
+ )
21
+ return value
22
+
23
+
24
+ def _int_or_none_list_arg_type(
25
+ min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
26
+ ):
27
+ def parse_value(item):
28
+ item = item.strip().lower()
29
+ if item == "none":
30
+ return None
31
+ try:
32
+ return int(item)
33
+ except ValueError:
34
+ raise argparse.ArgumentTypeError(f"{item} is not an integer or None")
35
+
36
+ items = [parse_value(v) for v in value.split(split_char)]
37
+ num_items = len(items)
38
+
39
+ if num_items == 1:
40
+ # Makes downstream handling the same for single and multiple values
41
+ items = items * max_len
42
+ elif num_items < min_len or num_items > max_len:
43
+ raise argparse.ArgumentTypeError(
44
+ f"Argument requires {max_len} integers or None, separated by '{split_char}'"
45
+ )
46
+ elif num_items != max_len:
47
+ logging.warning(
48
+ f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
49
+ "Missing values will be filled with defaults."
50
+ )
51
+ default_items = [parse_value(v) for v in defaults.split(split_char)]
52
+ items.extend(
53
+ default_items[num_items:]
54
+ ) # extend items list with missing defaults
55
+
56
+ return items
57
+
58
+
59
+ def check_argument_types(parser: argparse.ArgumentParser):
60
+ """
61
+ Check to make sure all CLI args are typed, raises error if not
62
+ """
63
+ for action in parser._actions:
64
+ if action.dest != "help" and not action.const:
65
+ if action.type is None:
66
+ raise ValueError(
67
+ f"Argument '{action.dest}' doesn't have a type specified."
68
+ )
69
+ else:
70
+ continue
71
+
72
+
73
+ def setup_parser() -> argparse.ArgumentParser:
74
+ parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
75
+ parser.add_argument(
76
+ "--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
77
+ )
78
+ parser.add_argument(
79
+ "--tasks",
80
+ "-t",
81
+ default=None,
82
+ type=str,
83
+ metavar="task1,task2",
84
+ help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
85
+ )
86
+ parser.add_argument(
87
+ "--model_args",
88
+ "-a",
89
+ default="",
90
+ type=try_parse_json,
91
+ help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""",
92
+ )
93
+ parser.add_argument(
94
+ "--num_fewshot",
95
+ "-f",
96
+ type=int,
97
+ default=None,
98
+ metavar="N",
99
+ help="Number of examples in few-shot context",
100
+ )
101
+ parser.add_argument(
102
+ "--batch_size",
103
+ "-b",
104
+ type=str,
105
+ default=1,
106
+ metavar="auto|auto:N|N",
107
+ help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
108
+ )
109
+ parser.add_argument(
110
+ "--max_batch_size",
111
+ type=int,
112
+ default=None,
113
+ metavar="N",
114
+ help="Maximal batch size to try with --batch_size auto.",
115
+ )
116
+ parser.add_argument(
117
+ "--device",
118
+ type=str,
119
+ default=None,
120
+ help="Device to use (e.g. cuda, cuda:0, cpu).",
121
+ )
122
+ parser.add_argument(
123
+ "--output_path",
124
+ "-o",
125
+ default=None,
126
+ type=str,
127
+ metavar="DIR|DIR/file.json",
128
+ help="Path where result metrics will be saved. Can be either a directory or a .json file. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
129
+ )
130
+ parser.add_argument(
131
+ "--limit",
132
+ "-L",
133
+ type=float,
134
+ default=None,
135
+ metavar="N|0<N<1",
136
+ help="Limit the number of examples per task. "
137
+ "If <1, limit is a percentage of the total number of examples.",
138
+ )
139
+ parser.add_argument(
140
+ "--samples",
141
+ "-E",
142
+ default=None,
143
+ type=str,
144
+ metavar="/path/to/json",
145
+ help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}',
146
+ )
147
+ parser.add_argument(
148
+ "--use_cache",
149
+ "-c",
150
+ type=str,
151
+ default=None,
152
+ metavar="DIR",
153
+ help="A path to a sqlite db file for caching model responses. `None` if not caching.",
154
+ )
155
+ parser.add_argument(
156
+ "--cache_requests",
157
+ type=str,
158
+ default=None,
159
+ choices=["true", "refresh", "delete"],
160
+ help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
161
+ )
162
+ parser.add_argument(
163
+ "--check_integrity",
164
+ action="store_true",
165
+ help="Whether to run the relevant part of the test suite for the tasks.",
166
+ )
167
+ parser.add_argument(
168
+ "--write_out",
169
+ "-w",
170
+ action="store_true",
171
+ default=False,
172
+ help="Prints the prompt for the first few documents.",
173
+ )
174
+ parser.add_argument(
175
+ "--log_samples",
176
+ "-s",
177
+ action="store_true",
178
+ default=False,
179
+ help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
180
+ )
181
+ parser.add_argument(
182
+ "--system_instruction",
183
+ type=str,
184
+ default=None,
185
+ help="System instruction to be used in the prompt",
186
+ )
187
+ parser.add_argument(
188
+ "--apply_chat_template",
189
+ type=str,
190
+ nargs="?",
191
+ const=True,
192
+ default=False,
193
+ help=(
194
+ "If True, apply chat template to the prompt. "
195
+ "Providing `--apply_chat_template` without an argument will apply the default chat template to the prompt. "
196
+ "To apply a specific template from the available list of templates, provide the template name as an argument. "
197
+ "E.g. `--apply_chat_template template_name`"
198
+ ),
199
+ )
200
+ parser.add_argument(
201
+ "--fewshot_as_multiturn",
202
+ action="store_true",
203
+ default=False,
204
+ help="If True, uses the fewshot as a multi-turn conversation",
205
+ )
206
+ parser.add_argument(
207
+ "--show_config",
208
+ action="store_true",
209
+ default=False,
210
+ help="If True, shows the the full config of all tasks at the end of the evaluation.",
211
+ )
212
+ parser.add_argument(
213
+ "--include_path",
214
+ type=str,
215
+ default=None,
216
+ metavar="DIR",
217
+ help="Additional path to include if there are external tasks to include.",
218
+ )
219
+ parser.add_argument(
220
+ "--gen_kwargs",
221
+ type=try_parse_json,
222
+ default=None,
223
+ help=(
224
+ "Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
225
+ """ e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
226
+ ),
227
+ )
228
+ parser.add_argument(
229
+ "--verbosity",
230
+ "-v",
231
+ type=str.upper,
232
+ default=None,
233
+ metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
234
+ help="(Deprecated) Controls logging verbosity level. Use the `LOGLEVEL` environment variable instead. Set to DEBUG for detailed output when testing or adding new task configurations.",
235
+ )
236
+ parser.add_argument(
237
+ "--wandb_args",
238
+ type=str,
239
+ default="",
240
+ help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
241
+ )
242
+ parser.add_argument(
243
+ "--wandb_config_args",
244
+ type=str,
245
+ default="",
246
+ help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3",
247
+ )
248
+ parser.add_argument(
249
+ "--hf_hub_log_args",
250
+ type=str,
251
+ default="",
252
+ help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
253
+ )
254
+ parser.add_argument(
255
+ "--predict_only",
256
+ "-x",
257
+ action="store_true",
258
+ default=False,
259
+ help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
260
+ )
261
+ default_seed_string = "0,1234,1234,1234"
262
+ parser.add_argument(
263
+ "--seed",
264
+ type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
265
+ default=default_seed_string, # for backward compatibility
266
+ help=(
267
+ "Set seed for python's random, numpy, torch, and fewshot sampling.\n"
268
+ "Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
269
+ "respectively, or a single integer to set the same seed for all four.\n"
270
+ f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
271
+ "(for backward compatibility).\n"
272
+ "E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
273
+ "Here numpy's seed is not set since the second value is `None`.\n"
274
+ "E.g, `--seed 42` sets all four seeds to 42."
275
+ ),
276
+ )
277
+ parser.add_argument(
278
+ "--trust_remote_code",
279
+ action="store_true",
280
+ help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
281
+ )
282
+ parser.add_argument(
283
+ "--confirm_run_unsafe_code",
284
+ action="store_true",
285
+ help="Confirm that you understand the risks of running unsafe code for tasks that require it",
286
+ )
287
+ parser.add_argument(
288
+ "--metadata",
289
+ type=json.loads,
290
+ default=None,
291
+ help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
292
+ )
293
+ return parser
294
+
295
+
296
+ def parse_eval_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
297
+ check_argument_types(parser)
298
+ return parser.parse_args()
299
+
300
+
301
+ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
302
+ if not args:
303
+ # we allow for args to be passed externally, else we parse them ourselves
304
+ parser = setup_parser()
305
+ args = parse_eval_args(parser)
306
+
307
+ # defer loading `lm_eval` submodules for faster CLI load
308
+ from lm_eval import evaluator, utils
309
+ from lm_eval.evaluator import request_caching_arg_to_dict
310
+ from lm_eval.loggers import EvaluationTracker, WandbLogger
311
+ from lm_eval.tasks import TaskManager
312
+ from lm_eval.utils import (
313
+ handle_non_serializable,
314
+ make_table,
315
+ simple_parse_args_string,
316
+ )
317
+
318
+ if args.wandb_args:
319
+ wandb_args_dict = simple_parse_args_string(args.wandb_args)
320
+ wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args)
321
+ wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict)
322
+
323
+ utils.setup_logging(args.verbosity)
324
+ eval_logger = logging.getLogger(__name__)
325
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
326
+
327
+ # update the evaluation tracker args with the output path and the HF token
328
+ if args.output_path:
329
+ args.hf_hub_log_args += f",output_path={args.output_path}"
330
+ if os.environ.get("HF_TOKEN", None):
331
+ args.hf_hub_log_args += f",token={os.environ.get('HF_TOKEN')}"
332
+ evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args)
333
+ evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
334
+
335
+ if args.predict_only:
336
+ args.log_samples = True
337
+ if (args.log_samples or args.predict_only) and not args.output_path:
338
+ raise ValueError(
339
+ "Specify --output_path if providing --log_samples or --predict_only"
340
+ )
341
+
342
+ if args.fewshot_as_multiturn and args.apply_chat_template is False:
343
+ raise ValueError(
344
+ "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
345
+ )
346
+
347
+ if args.include_path is not None:
348
+ eval_logger.info(f"Including path: {args.include_path}")
349
+ metadata = (
350
+ simple_parse_args_string(args.model_args)
351
+ if isinstance(args.model_args, str)
352
+ else args.model_args
353
+ if isinstance(args.model_args, dict)
354
+ else {}
355
+ ) | (
356
+ args.metadata
357
+ if isinstance(args.metadata, dict)
358
+ else simple_parse_args_string(args.metadata)
359
+ )
360
+
361
+ task_manager = TaskManager(include_path=args.include_path, metadata=metadata)
362
+
363
+ if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
364
+ eval_logger.warning(
365
+ "Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
366
+ )
367
+
368
+ if args.limit:
369
+ eval_logger.warning(
370
+ " --limit SHOULD ONLY BE USED FOR TESTING."
371
+ "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
372
+ )
373
+ if args.samples:
374
+ assert args.limit is None, (
375
+ "If --samples is not None, then --limit must be None."
376
+ )
377
+ if (samples := Path(args.samples)).is_file():
378
+ args.samples = json.loads(samples.read_text())
379
+ else:
380
+ args.samples = json.loads(args.samples)
381
+
382
+ if args.tasks is None:
383
+ eval_logger.error("Need to specify task to evaluate.")
384
+ sys.exit()
385
+ elif args.tasks == "list":
386
+ print(task_manager.list_all_tasks())
387
+ sys.exit()
388
+ elif args.tasks == "list_groups":
389
+ print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
390
+ sys.exit()
391
+ elif args.tasks == "list_tags":
392
+ print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
393
+ sys.exit()
394
+ elif args.tasks == "list_subtasks":
395
+ print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
396
+ sys.exit()
397
+ else:
398
+ if os.path.isdir(args.tasks):
399
+ import glob
400
+
401
+ task_names = []
402
+ yaml_path = os.path.join(args.tasks, "*.yaml")
403
+ for yaml_file in glob.glob(yaml_path):
404
+ config = utils.load_yaml_config(yaml_file)
405
+ task_names.append(config)
406
+ else:
407
+ task_list = args.tasks.split(",")
408
+ task_names = task_manager.match_tasks(task_list)
409
+ for task in [task for task in task_list if task not in task_names]:
410
+ if os.path.isfile(task):
411
+ config = utils.load_yaml_config(task)
412
+ task_names.append(config)
413
+ task_missing = [
414
+ task for task in task_list if task not in task_names and "*" not in task
415
+ ] # we don't want errors if a wildcard ("*") task name was used
416
+
417
+ if task_missing:
418
+ missing = ", ".join(task_missing)
419
+ eval_logger.error(
420
+ f"Tasks were not found: {missing}\n"
421
+ f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
422
+ )
423
+ raise ValueError(
424
+ f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues."
425
+ )
426
+
427
+ # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
428
+ if args.trust_remote_code:
429
+ eval_logger.info(
430
+ "Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
431
+ )
432
+ # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
433
+ # because it's already been determined based on the prior env var before launching our
434
+ # script--`datasets` gets imported by lm_eval internally before these lines can update the env.
435
+ import datasets
436
+ from packaging.version import parse as vparse
437
+
438
+ if vparse(datasets.__version__) < vparse("4.0.0"):
439
+ datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
440
+
441
+ if isinstance(args.model_args, dict):
442
+ args.model_args["trust_remote_code"] = True
443
+ else:
444
+ args.model_args = args.model_args + ",trust_remote_code=True"
445
+ (
446
+ eval_logger.info(f"Selected Tasks: {task_names}")
447
+ if eval_logger.getEffectiveLevel() >= logging.INFO
448
+ else print(f"Selected Tasks: {task_names}")
449
+ )
450
+
451
+ request_caching_args = request_caching_arg_to_dict(
452
+ cache_requests=args.cache_requests
453
+ )
454
+
455
+ results = evaluator.simple_evaluate(
456
+ model=args.model,
457
+ model_args=args.model_args,
458
+ tasks=task_names,
459
+ num_fewshot=args.num_fewshot,
460
+ batch_size=args.batch_size,
461
+ max_batch_size=args.max_batch_size,
462
+ device=args.device,
463
+ use_cache=args.use_cache,
464
+ limit=args.limit,
465
+ samples=args.samples,
466
+ check_integrity=args.check_integrity,
467
+ write_out=args.write_out,
468
+ log_samples=args.log_samples,
469
+ evaluation_tracker=evaluation_tracker,
470
+ system_instruction=args.system_instruction,
471
+ apply_chat_template=args.apply_chat_template,
472
+ fewshot_as_multiturn=args.fewshot_as_multiturn,
473
+ gen_kwargs=args.gen_kwargs,
474
+ task_manager=task_manager,
475
+ predict_only=args.predict_only,
476
+ random_seed=args.seed[0],
477
+ numpy_random_seed=args.seed[1],
478
+ torch_random_seed=args.seed[2],
479
+ fewshot_random_seed=args.seed[3],
480
+ confirm_run_unsafe_code=args.confirm_run_unsafe_code,
481
+ metadata=metadata,
482
+ **request_caching_args,
483
+ )
484
+
485
+ if results is not None:
486
+ if args.log_samples:
487
+ samples = results.pop("samples")
488
+ dumped = json.dumps(
489
+ results, indent=2, default=handle_non_serializable, ensure_ascii=False
490
+ )
491
+ if args.show_config:
492
+ print(dumped)
493
+
494
+ batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
495
+
496
+ # Add W&B logging
497
+ if args.wandb_args:
498
+ try:
499
+ wandb_logger.post_init(results)
500
+ wandb_logger.log_eval_result()
501
+ if args.log_samples:
502
+ wandb_logger.log_eval_samples(samples)
503
+ except Exception as e:
504
+ eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
505
+
506
+ evaluation_tracker.save_results_aggregated(
507
+ results=results, samples=samples if args.log_samples else None
508
+ )
509
+
510
+ if args.log_samples:
511
+ for task_name, config in results["configs"].items():
512
+ evaluation_tracker.save_results_samples(
513
+ task_name=task_name, samples=samples[task_name]
514
+ )
515
+
516
+ if (
517
+ evaluation_tracker.push_results_to_hub
518
+ or evaluation_tracker.push_samples_to_hub
519
+ ):
520
+ evaluation_tracker.recreate_metadata_card()
521
+
522
+ print(
523
+ f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
524
+ f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
525
+ )
526
+ print(make_table(results))
527
+ if "groups" in results:
528
+ print(make_table(results, "groups"))
529
+
530
+ if args.wandb_args:
531
+ # Tear down wandb run once all the logging is done.
532
+ wandb_logger.run.finish()
533
+
534
+
535
+ if __name__ == "__main__":
536
+ cli_evaluate()
lm-evaluation-harness/lm_eval/api/__init__.py ADDED
File without changes
lm-evaluation-harness/lm_eval/api/filter.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Iterable, List, Union
4
+
5
+ from lm_eval.api.instance import Instance
6
+
7
+
8
+ class Filter(ABC):
9
+ """
10
+ Filter classes operate on a per-task level.
11
+ They take all model outputs (`instance.resps` for all `task.instances`)
12
+ across all instances of a task, and perform operations.
13
+ In a single run, one can configure any number of separate filters or lists of filters.
14
+
15
+ """
16
+
17
+ def __init__(self, **kwargs) -> None:
18
+ """
19
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
20
+ """
21
+
22
+ @abstractmethod
23
+ def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
24
+ """
25
+ Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
26
+ Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
27
+ if pass in [<inst.resps for instance 0>, <inst.resps for instance 1>] should return
28
+ [<filtered resps for instance 0>, <filtered resps for instance 1>]
29
+ """
30
+ return resps
31
+
32
+
33
+ @dataclass
34
+ class FilterEnsemble:
35
+ """
36
+ FilterEnsemble creates a pipeline applying multiple filters.
37
+ Its intended usage is to stack multiple post-processing steps in order.
38
+ `task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
39
+ pipeline separately.
40
+ """
41
+
42
+ name: str
43
+ filters: List[Callable[[], Filter]]
44
+
45
+ def apply(self, instances: List[Instance]) -> None:
46
+ resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
47
+ resps, docs = list(resps), list(docs)
48
+
49
+ for f in self.filters:
50
+ # apply filters in sequence
51
+ resps = f().apply(resps, docs)
52
+
53
+ # add the end results after filtering to filtered_requests of their respective source instances.
54
+ # has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
55
+ for inst, resp in zip(instances, resps):
56
+ inst.filtered_resps[self.name] = resp
lm-evaluation-harness/lm_eval/api/group.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from dataclasses import asdict, dataclass
3
+ from inspect import getsource
4
+ from typing import Any, Callable, List, Optional, Union
5
+
6
+
7
+ @dataclass
8
+ class AggMetricConfig(dict):
9
+ metric: Optional[str] = None
10
+ aggregation: Optional[str] = "mean"
11
+ weight_by_size: Optional[str] = False
12
+ # list of filter names which should be incorporated into the aggregated metric.
13
+ filter_list: Optional[Union[str, list]] = "none"
14
+
15
+ def __post_init__(self):
16
+ if self.aggregation != "mean" and not callable(self.aggregation):
17
+ raise ValueError(
18
+ f"Currently, 'mean' is the only pre-defined aggregation across groups' subtasks. Got '{self.aggregation}'."
19
+ )
20
+
21
+ if isinstance(self.filter_list, str):
22
+ self.filter_list = [self.filter_list]
23
+
24
+
25
+ @dataclass
26
+ class GroupConfig(dict):
27
+ group: Optional[str] = None
28
+ group_alias: Optional[str] = None
29
+ task: Optional[Union[str, list]] = None
30
+ aggregate_metric_list: Optional[
31
+ Union[List[AggMetricConfig], AggMetricConfig, dict]
32
+ ] = None
33
+ metadata: Optional[dict] = (
34
+ None # by default, not used in the code. allows for users to pass arbitrary info to tasks
35
+ )
36
+
37
+ def __getitem__(self, item):
38
+ return getattr(self, item)
39
+
40
+ def __setitem__(self, item, value):
41
+ return setattr(self, item, value)
42
+
43
+ def __post_init__(self):
44
+ if self.aggregate_metric_list is not None:
45
+ if isinstance(self.aggregate_metric_list, dict):
46
+ self.aggregate_metric_list = [self.aggregate_metric_list]
47
+
48
+ self.aggregate_metric_list = [
49
+ AggMetricConfig(**item) if isinstance(item, dict) else item
50
+ for item in self.aggregate_metric_list
51
+ ]
52
+
53
+ def to_dict(self, keep_callable: bool = False) -> dict:
54
+ """dumps the current config as a dictionary object, as a printable format.
55
+ null fields will not be printed.
56
+ Used for dumping results alongside full task configuration
57
+
58
+ :return: dict
59
+ A printable dictionary version of the TaskConfig object.
60
+
61
+ # TODO: should any default value in the TaskConfig not be printed?
62
+ """
63
+ cfg_dict = asdict(self)
64
+ # remove values that are `None`
65
+ for k, v in list(cfg_dict.items()):
66
+ if callable(v):
67
+ cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
68
+ return cfg_dict
69
+
70
+ def serialize_function(
71
+ self, value: Union[Callable, str], keep_callable=False
72
+ ) -> Union[Callable, str]:
73
+ """Serializes a given function or string.
74
+
75
+ If 'keep_callable' is True, the original callable is returned.
76
+ Otherwise, attempts to return the source code of the callable using 'getsource'.
77
+ """
78
+ if keep_callable:
79
+ return value
80
+ else:
81
+ try:
82
+ return getsource(value)
83
+ except (TypeError, OSError):
84
+ return str(value)
85
+
86
+
87
+ class ConfigurableGroup(abc.ABC):
88
+ def __init__(
89
+ self,
90
+ config: Optional[dict] = None,
91
+ ) -> None:
92
+ self._config = GroupConfig(**config)
93
+
94
+ @property
95
+ def group(self):
96
+ return self._config.group
97
+
98
+ @property
99
+ def group_alias(self):
100
+ return self._config.group_alias
101
+
102
+ @property
103
+ def version(self):
104
+ return self._config.version
105
+
106
+ @property
107
+ def config(self):
108
+ return self._config.to_dict()
109
+
110
+ @property
111
+ def group_name(self) -> Any:
112
+ return self._config.group
113
+
114
+ def __repr__(self):
115
+ return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})"
lm-evaluation-harness/lm_eval/api/instance.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Literal, Optional, Tuple
3
+
4
+
5
+ OutputType = Literal[
6
+ "loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
7
+ ]
8
+
9
+
10
+ @dataclass
11
+ class Instance:
12
+ request_type: OutputType
13
+ doc: dict
14
+ arguments: tuple
15
+ idx: int
16
+ metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
17
+ default_factory=lambda: (None, None, None)
18
+ )
19
+ resps: list = field(default_factory=list)
20
+ filtered_resps: dict = field(default_factory=dict)
21
+
22
+ # initialized after init
23
+ task_name: Optional[str] = None
24
+ doc_id: Optional[int] = None
25
+ repeats: Optional[int] = None
26
+
27
+ def __post_init__(self) -> None:
28
+ # unpack metadata field
29
+ self.task_name, self.doc_id, self.repeats = self.metadata
30
+
31
+ @property
32
+ def args(self):
33
+ """
34
+ Returns (string,) where `string` is the string to calculate loglikelihood over
35
+ """
36
+ return (
37
+ self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
38
+ )
lm-evaluation-harness/lm_eval/api/metrics.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import os
4
+ import random
5
+ import re
6
+ import string
7
+ from collections.abc import Iterable
8
+ from typing import Callable, List, Optional, Sequence, TypeVar
9
+
10
+ import numpy as np
11
+ import sacrebleu
12
+
13
+ from lm_eval.api.registry import register_aggregation, register_metric
14
+
15
+
16
+ T = TypeVar("T")
17
+
18
+ eval_logger = logging.getLogger(__name__)
19
+
20
+
21
+ # Register Aggregations First
22
+ @register_aggregation("bypass")
23
+ def bypass_agg(arr):
24
+ return 999
25
+
26
+
27
+ @register_aggregation("nanmean")
28
+ def nanmean(arr):
29
+ if len(arr) == 0 or all(np.isnan(arr)):
30
+ return np.nan
31
+ return np.nanmean(arr)
32
+
33
+
34
+ @register_aggregation("mean")
35
+ def mean(arr):
36
+ return sum(arr) / len(arr)
37
+
38
+
39
+ @register_aggregation("median")
40
+ def median(arr):
41
+ return arr[len(arr) // 2]
42
+
43
+
44
+ # Certain metrics must be calculated across all documents in a benchmark.
45
+ # We use them as aggregation metrics, paired with no-op passthrough metric fns.
46
+ @register_aggregation("perplexity")
47
+ def perplexity(items):
48
+ return math.exp(-mean(items))
49
+
50
+
51
+ @register_aggregation("weighted_perplexity")
52
+ def weighted_perplexity(items):
53
+ return math.exp(-weighted_mean(items))
54
+
55
+
56
+ @register_aggregation("bits_per_byte")
57
+ def bits_per_byte(items):
58
+ return -weighted_mean(items) / math.log(2)
59
+
60
+
61
+ @register_aggregation("f1")
62
+ def f1_score(items):
63
+ from sklearn.metrics import f1_score
64
+
65
+ unzipped_list = list(zip(*items))
66
+ golds = unzipped_list[0]
67
+ preds = unzipped_list[1]
68
+ fscore = f1_score(golds, preds)
69
+
70
+ return np.max(fscore)
71
+
72
+
73
+ @register_aggregation("matthews_corrcoef")
74
+ def matthews_corrcoef(items):
75
+ from sklearn.metrics import matthews_corrcoef
76
+
77
+ unzipped_list = list(zip(*items))
78
+ golds = unzipped_list[0]
79
+ preds = unzipped_list[1]
80
+ return matthews_corrcoef(golds, preds)
81
+
82
+
83
+ @register_aggregation("bleu")
84
+ def bleu(items):
85
+ """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
86
+ for evaluating a generated sentence to a reference sentence. It counts matching
87
+ n-grams in the candidate translation to n-grams in the reference text, where
88
+ 1-gram or unigram would be each token and a bigram comparison would be each
89
+ word pair. The comparison is made regardless of word order
90
+ Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
91
+ Paper: https://www.aclweb.org/anthology/P02-1040/
92
+
93
+ Higher is better
94
+ """
95
+ refs = list(zip(*items))[0]
96
+ preds = list(zip(*items))[1]
97
+ refs, preds = _sacreformat(refs, preds)
98
+ return sacrebleu.corpus_bleu(preds, refs).score
99
+
100
+
101
+ @register_aggregation("chrf")
102
+ def chrf(items):
103
+ """chrF++ is a tool for automatic evaluation of machine translation output
104
+ based on character n-gram precision and recall enhanced with word n-grams.
105
+ Source: https://github.com/m-popovic/chrF
106
+ Paper: https://www.aclweb.org/anthology/W15-3049.pdf
107
+
108
+ Higher is better # TODO I think
109
+ """
110
+ refs = list(zip(*items))[0]
111
+ preds = list(zip(*items))[1]
112
+ refs, preds = _sacreformat(refs, preds)
113
+ return sacrebleu.corpus_chrf(preds, refs).score
114
+
115
+
116
+ @register_aggregation("ter")
117
+ def ter(items):
118
+ """Translation Error Rate is an error metric for machine translation that
119
+ measures the number of edits required to change a system output into one
120
+ of the references
121
+ Source: http://www.cs.umd.edu/~snover/tercom/
122
+ Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
123
+
124
+ Lower is better
125
+ """
126
+ refs = list(zip(*items))[0]
127
+ preds = list(zip(*items))[1]
128
+ refs, preds = _sacreformat(refs, preds)
129
+ return sacrebleu.corpus_ter(preds, refs).score
130
+
131
+
132
+ @register_aggregation("brier_score")
133
+ def brier_score(items): # This is a passthrough function
134
+ gold, predictions = list(zip(*items))
135
+ bs, num_class = np.array(predictions).shape
136
+
137
+ gold = list(gold)
138
+ gold_one_hot = np.eye(num_class)[gold]
139
+ return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1))
140
+
141
+
142
+ @register_metric(
143
+ metric="brier_score",
144
+ higher_is_better=False,
145
+ output_type=["multiple_choice"],
146
+ aggregation="brier_score",
147
+ )
148
+ def brier_score_fn(items): # This is a passthrough function
149
+ return items
150
+
151
+
152
+ @register_metric(
153
+ metric="acc",
154
+ higher_is_better=True,
155
+ output_type=["loglikelihood", "multiple_choice"],
156
+ aggregation="mean",
157
+ )
158
+ def acc_fn(items): # This is a passthrough function
159
+ return items
160
+
161
+
162
+ @register_metric(
163
+ metric="acc_norm",
164
+ higher_is_better=True,
165
+ output_type=["loglikelihood", "multiple_choice"],
166
+ aggregation="mean",
167
+ )
168
+ def acc_norm_fn(items): # This is a passthrough function
169
+ return items
170
+
171
+
172
+ @register_metric(
173
+ metric="acc_mutual_info",
174
+ higher_is_better=True,
175
+ output_type="multiple_choice",
176
+ aggregation="mean",
177
+ )
178
+ def acc_mutual_info_fn(items): # This is a passthrough function
179
+ return items
180
+
181
+
182
+ ### the code used in the `exact_match_hf_evaluate` function is ported from
183
+ ### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py
184
+ ### which is under the apache license.
185
+
186
+ # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
187
+
188
+ # Licensed under the Apache License, Version 2.0 (the "License");
189
+ # you may not use this file except in compliance with the License.
190
+ # You may obtain a copy of the License at
191
+
192
+ # http://www.apache.org/licenses/LICENSE-2.0
193
+
194
+
195
+ # Unless required by applicable law or agreed to in writing, software
196
+ # distributed under the License is distributed on an "AS IS" BASIS,
197
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
198
+ # See the License for the specific language governing permissions and
199
+ # limitations under the License.
200
+ def exact_match_hf_evaluate(
201
+ predictions,
202
+ references,
203
+ regexes_to_ignore=None,
204
+ ignore_case=False,
205
+ ignore_punctuation=False,
206
+ ignore_numbers=False,
207
+ ):
208
+ if regexes_to_ignore is not None:
209
+ for s in regexes_to_ignore:
210
+ predictions = np.array([re.sub(s, "", x) for x in predictions])
211
+ references = np.array([re.sub(s, "", x) for x in references])
212
+ else:
213
+ predictions = np.asarray(predictions)
214
+ references = np.asarray(references)
215
+
216
+ if ignore_case:
217
+ predictions = np.char.lower(predictions)
218
+ references = np.char.lower(references)
219
+
220
+ if ignore_punctuation:
221
+ repl_table = string.punctuation.maketrans("", "", string.punctuation)
222
+ predictions = np.char.translate(predictions, table=repl_table)
223
+ references = np.char.translate(references, table=repl_table)
224
+
225
+ if ignore_numbers:
226
+ repl_table = string.digits.maketrans("", "", string.digits)
227
+ predictions = np.char.translate(predictions, table=repl_table)
228
+ references = np.char.translate(references, table=repl_table)
229
+
230
+ score_list = predictions == references
231
+
232
+ return {"exact_match": np.mean(score_list)}
233
+
234
+
235
+ ###
236
+
237
+
238
+ @register_metric(
239
+ metric="exact_match",
240
+ higher_is_better=True,
241
+ output_type="generate_until",
242
+ aggregation="mean",
243
+ )
244
+ def exact_match_fn(**kwargs):
245
+ return exact_match_hf_evaluate(**kwargs)
246
+
247
+
248
+ @register_metric(
249
+ metric="perplexity",
250
+ higher_is_better=False,
251
+ output_type="loglikelihood",
252
+ aggregation="perplexity",
253
+ )
254
+ def perplexity_fn(items): # This is a passthrough function
255
+ return items
256
+
257
+
258
+ @register_metric(
259
+ metric="word_perplexity",
260
+ higher_is_better=False,
261
+ output_type="loglikelihood_rolling",
262
+ aggregation="weighted_perplexity",
263
+ )
264
+ def word_perplexity_fn(items): # This is a passthrough function
265
+ return items
266
+
267
+
268
+ @register_metric(
269
+ metric="byte_perplexity",
270
+ higher_is_better=False,
271
+ output_type="loglikelihood_rolling",
272
+ aggregation="weighted_perplexity",
273
+ )
274
+ def byte_perplexity_fn(items): # This is a passthrough function
275
+ return items
276
+
277
+
278
+ @register_metric(
279
+ metric="bits_per_byte",
280
+ higher_is_better=False,
281
+ output_type="loglikelihood_rolling",
282
+ aggregation="bits_per_byte",
283
+ )
284
+ def bits_per_byte_fn(items): # This is a passthrough function
285
+ return items
286
+
287
+
288
+ def pop_stddev(arr):
289
+ mu = mean(arr)
290
+ return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
291
+
292
+
293
+ def sample_stddev(arr: Sequence[T]) -> float:
294
+ mu = mean(arr)
295
+ return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
296
+
297
+
298
+ def mean_stderr(arr):
299
+ return sample_stddev(arr) / math.sqrt(len(arr))
300
+
301
+
302
+ @register_metric(
303
+ metric="bypass",
304
+ higher_is_better=True,
305
+ output_type=["loglikelihood", "multiple_choice", "generate_until"],
306
+ aggregation="bypass",
307
+ )
308
+ def bypass(items):
309
+ return None
310
+
311
+
312
+ @register_metric(
313
+ metric="mcc",
314
+ higher_is_better=True,
315
+ output_type="multiple_choice",
316
+ aggregation="matthews_corrcoef",
317
+ )
318
+ def mcc_fn(items): # This is a passthrough function
319
+ return items
320
+
321
+
322
+ @register_metric(
323
+ metric="f1",
324
+ higher_is_better=True,
325
+ output_type="multiple_choice",
326
+ aggregation="f1",
327
+ )
328
+ def f1_fn(items): # This is a passthrough function
329
+ return items
330
+
331
+
332
+ @register_metric(
333
+ metric="bleu",
334
+ higher_is_better=True,
335
+ output_type="generate_until",
336
+ aggregation="bleu",
337
+ )
338
+ def bleu_fn(items): # This is a passthrough function
339
+ return items
340
+
341
+
342
+ @register_metric(
343
+ metric="chrf",
344
+ higher_is_better=True,
345
+ output_type="generate_until",
346
+ aggregation="chrf",
347
+ )
348
+ def chrf_fn(items): # This is a passthrough function
349
+ return items
350
+
351
+
352
+ @register_metric(
353
+ metric="ter",
354
+ higher_is_better=True,
355
+ output_type="generate_until",
356
+ aggregation="ter",
357
+ )
358
+ def ter_fn(items): # This is a passthrough function
359
+ return items
360
+
361
+
362
+ @register_metric(
363
+ metric="acc_all",
364
+ higher_is_better=True,
365
+ output_type="loglikelihood",
366
+ aggregation="mean",
367
+ )
368
+ def acc_all(items):
369
+ # Only count as correct if all answers are labeled correctly for each question
370
+ question_scoring_dict = {}
371
+ preds = list(zip(*items))[0]
372
+ docs = list(zip(*items))[1]
373
+
374
+ for doc, pred in zip(docs, preds):
375
+ paragraph_id = doc["idx"]["paragraph"]
376
+ question_id = doc["idx"]["question"]
377
+ if (paragraph_id, question_id) not in question_scoring_dict:
378
+ question_scoring_dict[(paragraph_id, question_id)] = []
379
+
380
+ gold_label = doc["label"] == 1
381
+
382
+ question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
383
+ acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
384
+ return acc
385
+
386
+
387
+ def acc_all_stderr(items):
388
+ # Only count as correct if all answers are labeled correctly for each question
389
+ question_scoring_dict = {}
390
+ preds = list(zip(*items))[0]
391
+ docs = list(zip(*items))[1]
392
+
393
+ for doc, pred in zip(docs, preds):
394
+ question_id = doc["idx"]["question"]
395
+ if question_id not in question_scoring_dict:
396
+ question_scoring_dict[question_id] = []
397
+
398
+ gold_label = doc["label"] == 1
399
+ question_scoring_dict[question_id].append(gold_label == pred)
400
+
401
+ acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
402
+ return acc
403
+
404
+
405
+ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
406
+ """Compute max metric between prediction and each ground truth."""
407
+ scores_for_ground_truths = []
408
+ for ground_truth in ground_truths:
409
+ score = metric_fn(prediction, ground_truth)
410
+ scores_for_ground_truths.append(score)
411
+ return max(scores_for_ground_truths)
412
+
413
+
414
+ def weighted_mean(items):
415
+ a, b = zip(*items)
416
+ return sum(a) / sum(b)
417
+
418
+
419
+ def is_non_str_iterable(obj):
420
+ return isinstance(obj, Iterable) and not isinstance(obj, str)
421
+
422
+
423
+ def _sacreformat(refs, preds):
424
+ """Format refs and preds for sacrebleu corpus calculation. It is very particular"""
425
+ # Sacrebleu expects (List[str], List[List[str])
426
+ # e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
427
+
428
+ # Note [ref1_stream] is the first reference for each pred.
429
+ # So lists are size N and (M, N) for N preds and M possible refs for each pred
430
+ # This is a different order of dimensions that I would expect
431
+
432
+ # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
433
+ # Must become List[List[str]] with the inner list corresponding to preds
434
+ if not is_non_str_iterable(refs):
435
+ refs = list(refs)
436
+ if not is_non_str_iterable(refs[0]):
437
+ refs = [[ref] for ref in refs]
438
+ refs = list(zip(*refs))
439
+ # Note the number of refs in each ref list much match the number of preds
440
+
441
+ # We expect preds to be List[str] or List[List[str]]. Must become List[str]
442
+ if not is_non_str_iterable(preds):
443
+ preds = list(preds)
444
+ if is_non_str_iterable(preds[0]):
445
+ assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
446
+ preds = [pred[0] for pred in preds]
447
+
448
+ return refs, preds
449
+
450
+
451
+ # stderr stuff
452
+
453
+
454
+ class _bootstrap_internal:
455
+ """
456
+ Pool worker: `(i, xs)` → `n` bootstrap replicates
457
+ of `f(xs)`using a RNG seeded with `i`.
458
+ """
459
+
460
+ def __init__(self, f: Callable[[Sequence[T]], float], n: int) -> None:
461
+ self.f = f
462
+ self.n = n
463
+
464
+ def __call__(self, v: tuple[int, Sequence[T]]) -> list[float]:
465
+ i, xs = v
466
+ rnd = random.Random()
467
+ rnd.seed(i)
468
+ res = []
469
+ for _ in range(self.n):
470
+ res.append(self.f(rnd.choices(xs, k=len(xs))))
471
+ return res
472
+
473
+
474
+ def _bootstrap_internal_no_mp(
475
+ f: Callable[[Sequence[T]], float], xs: Sequence[T], iters: int
476
+ ) -> list[float]:
477
+ """
478
+ Single-process fallback: compute `iters` bootstrap replicates
479
+ of statistic`f(xs)`, chunked (≤ 1000 draws).
480
+ """
481
+ res = []
482
+ chunk_size = min(1000, iters)
483
+ from tqdm import tqdm
484
+
485
+ print(f"bootstrapping for stddev: {f.__name__}")
486
+
487
+ # A single loop replaces the multiprocessing pool.
488
+ for i in tqdm(range(iters // chunk_size)):
489
+ rnd = random.Random(i)
490
+ for _ in range(chunk_size):
491
+ res.append(f(rnd.choices(xs, k=len(xs))))
492
+
493
+ return res
494
+
495
+
496
+ def bootstrap_stderr(
497
+ f: Callable[[Sequence[T]], float], xs: Sequence[T], iters: int
498
+ ) -> float:
499
+ """
500
+ Bootstrap estimate of the standard error of statistic `f(xs)`
501
+ using up to `iters` resamples, chunked (≤ 1000 draws)
502
+
503
+ Executes in parallel unless the env-var `DISABLE_MULTIPROC` is set;
504
+ """
505
+ if not os.getenv("DISABLE_MULTIPROC"):
506
+ import multiprocessing as mp
507
+
508
+ # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
509
+ # equivalent to stderr calculated without Bessel's correction in the stddev.
510
+ # Unfortunately, I haven't been able to figure out what the right correction is
511
+ # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
512
+ # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
513
+ # Thankfully, shouldn't matter because our samples are pretty big usually anyways
514
+ res = []
515
+ chunk_size = min(1000, iters)
516
+ from tqdm import tqdm
517
+
518
+ print("bootstrapping for stddev:", f.__name__)
519
+ with mp.Pool(mp.cpu_count()) as pool:
520
+ for bootstrap in tqdm(
521
+ pool.imap(
522
+ _bootstrap_internal(f, chunk_size),
523
+ [(i, xs) for i in range(iters // chunk_size)],
524
+ ),
525
+ total=iters // chunk_size,
526
+ ):
527
+ # sample w replacement
528
+ res.extend(bootstrap)
529
+ else:
530
+ res = _bootstrap_internal_no_mp(f, xs, iters)
531
+
532
+ return sample_stddev(res)
533
+
534
+
535
+ def stderr_for_metric(
536
+ metric: Callable[[Sequence[T]], float], bootstrap_iters: int
537
+ ) -> Optional[Callable[[Sequence[T]], float]]:
538
+ """
539
+ Return a function that estimates the standard error of `metric(xs)`.
540
+
541
+ * If `bootstrap_iters > 0` and the metric is in the pre-approved
542
+ bootstrappable list, use `bootstrap_stderr` with that many draws.
543
+ * If the metric has a closed-form SE (e.g. `mean`, `acc_all`), use it.
544
+ * Otherwise, return `None`.
545
+ """
546
+
547
+ if bootstrap_iters <= 0:
548
+ # return no function (don't compute stderr) if bootstrap iters = 0
549
+ return None
550
+
551
+ bootstrappable = [
552
+ median,
553
+ matthews_corrcoef,
554
+ f1_score,
555
+ perplexity,
556
+ bleu,
557
+ chrf,
558
+ ter,
559
+ nanmean,
560
+ ]
561
+
562
+ if metric in bootstrappable:
563
+ return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
564
+
565
+ stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
566
+
567
+ return stderr.get(metric, None)
568
+
569
+
570
+ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
571
+ # Used to aggregate bootstrapped stderrs across subtasks in a group,
572
+ # when we are weighting by the size of each subtask.
573
+ #
574
+
575
+ assert len(stderrs) == len(sizes)
576
+
577
+ # formula source: https://en.wikipedia.org/wiki/Pooled_variance
578
+ # and: https://stats.stackexchange.com/a/4841331
579
+ # this empirically seems to match running `stderr_for_metric` on all instances
580
+ # from the subtasks concatenated with each other.
581
+ pooled_sample_var = (
582
+ sum([(size - 1) * stderr**2 * size for size, stderr in zip(sizes, stderrs)])
583
+ ) / (sum(sizes) - len(sizes))
584
+
585
+ return np.sqrt(pooled_sample_var / sum(sizes))
586
+
587
+
588
+ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
589
+ assert metrics is not None, (
590
+ "Need to pass a list of each subtask's metric for this stderr aggregation"
591
+ )
592
+ assert len(stderrs) == len(sizes) and len(sizes) == len(metrics)
593
+
594
+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation.
595
+ # This formula depends on sample means.
596
+ # removed because it seems to give erroneously huge stderrs for groupings of tasks
597
+ # and does not seem to match up with bootstrap-calculated stderrs for groups.
598
+
599
+ ### don't use this unless a statistician has told you it's the right thing to do ###
600
+
601
+ # accumulators: we'll aggregate pairwise N - 1 times
602
+ variance = stderrs[0] ** 2
603
+ curr_size = sizes[0]
604
+ curr_score = metrics[0]
605
+
606
+ for stderr, size, score in zip(stderrs[1:], sizes[1:], metrics[1:]):
607
+ curr_score = ((curr_score * curr_size) + (score * size)) / (
608
+ curr_size + size
609
+ ) # NOTE: this assumes our aggregation fn is "mean"
610
+
611
+ variance = ((curr_size - 1) * variance + (size - 1) * (stderr**2)) / (
612
+ curr_size + size - 1
613
+ ) + curr_size * size / ((curr_size + size) * (curr_size + size - 1)) * (
614
+ curr_score - score
615
+ ) ** 2
616
+
617
+ return np.sqrt(variance)
618
+
619
+
620
+ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
621
+ # A helper function that is used to aggregate
622
+ # subtask scores cross-task.
623
+ # TODO: does not hold for non-mean aggregations
624
+ if not weight_by_size:
625
+ sizes = [1] * len(sizes)
626
+
627
+ assert len(metrics) == len(sizes)
628
+
629
+ return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes)
lm-evaluation-harness/lm_eval/api/model.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import hashlib
3
+ import json
4
+ import logging
5
+ import os
6
+ from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, TypeVar, Union
7
+
8
+ from tqdm import tqdm
9
+
10
+ from lm_eval import utils
11
+
12
+
13
+ if TYPE_CHECKING:
14
+ from sqlitedict import SqliteDict
15
+
16
+ from lm_eval.api.instance import Instance
17
+
18
+
19
+ eval_logger = logging.getLogger(__name__)
20
+
21
+ T = TypeVar("T", bound="LM")
22
+
23
+
24
+ class LM(abc.ABC):
25
+ def __init__(self) -> None:
26
+ """Defines the interface that should be implemented by all LM subclasses.
27
+ LMs are assumed to take text (strings) as input and yield strings as output
28
+ (inputs/outputs should be tokenization-agnostic.)
29
+
30
+ """
31
+ # set rank and world size to a single process, by default.
32
+ self._rank = 0
33
+ self._world_size = 1
34
+ self.cache_hook: "CacheHook" = CacheHook(None)
35
+
36
+ @abc.abstractmethod
37
+ def loglikelihood(self, requests) -> list[tuple[float, bool]]:
38
+ """Compute log-likelihood of generating a continuation from a context.
39
+ Downstream tasks should attempt to use loglikelihood instead of other
40
+ LM calls whenever possible.
41
+
42
+ :param requests: list[Instance]
43
+ A list of Instance objects, with property `args` which returns a tuple (context, continuation).
44
+ `context: str`
45
+ Context string. Implementations of LM must be able to handle an
46
+ empty context string.
47
+ `continuation: str`
48
+ The continuation over which log likelihood will be calculated. If
49
+ there is a word boundary, the space should be in the continuation.
50
+ For example, context="hello" continuation=" world" is correct.
51
+
52
+ :return: list[tuple[float, bool]]
53
+ A list of pairs (logprob, isgreedy)
54
+ `logprob: float`
55
+ The log probability of `continuation`.
56
+ `isgreedy`:
57
+ Whether `continuation` would be generated by greedy sampling from `context`.
58
+ """
59
+ pass
60
+
61
+ @abc.abstractmethod
62
+ def loglikelihood_rolling(self, requests) -> list[float]:
63
+ """Compute full log-likelihood of a string, with no truncation, for perplexity computation
64
+ - We will use the full max context length of the model.
65
+ - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
66
+ the max context length.
67
+ - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
68
+ which may simply concatenate multiple documents together.
69
+ - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
70
+ multiple chunks, the last input will still a full-sized context.
71
+ Example:
72
+ Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
73
+ Prefix: BOS/EOS
74
+ Max context length: 4
75
+ Resulting input/prediction pairs:
76
+
77
+ INPUT: BOS 0 1 2
78
+ PRED: 0 1 2 3
79
+
80
+ INPUT: 3 4 5 6
81
+ PRED: 4 5 6 7
82
+
83
+ INPUT: 5 6 7 8
84
+ PRED: 8 9
85
+
86
+ Observe that:
87
+ 1. Each token is predicted exactly once
88
+ 2. For the last pair, we provide the full context, but only score the last two tokens
89
+
90
+ :param requests: list[Instance]
91
+ A list of Instance objects with property `args` which returns a tuple (context,).
92
+ string: str
93
+ String for which we are computing overall loglikelihood
94
+ :return: list[tuple[float]]
95
+ A list of tuples (logprob,)
96
+ logprob: float
97
+ The log probability of `context` conditioned on the BOS/EOS token.
98
+ Can also be overridden for custom cases by `prefix_token_id`.
99
+ """
100
+ pass
101
+
102
+ # TODO: Add an optional max length
103
+ @abc.abstractmethod
104
+ def generate_until(self, requests) -> list[str]:
105
+ """Generate greedily until a stopping sequence
106
+
107
+ :param requests: list[Instance]
108
+ A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
109
+ context: str
110
+ Context string
111
+ gen_kwargs: dict
112
+ A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
113
+ :return: list[str]
114
+ A list of model generated continuations.
115
+ continuation: str
116
+ The generated continuation.
117
+ """
118
+ pass
119
+
120
+ def apply_chat_template(
121
+ self, chat_history: list[dict[str, str]], add_generation_prompt=True
122
+ ) -> str:
123
+ """
124
+ Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
125
+
126
+ :param chat_history: list[dict[str, str]]
127
+ A list of dictionaries with keys 'role' and 'content'.
128
+ Values are strings representing the role name and the content of the message, respectively.
129
+ :param add_generation_prompt: bool
130
+ Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message.
131
+ :return: str
132
+ A string representing the chat history in a format that can be used as input to the LM.
133
+ """
134
+ raise NotImplementedError(
135
+ "To use this model with chat templates, please implement the 'apply_chat_template' method for your model type."
136
+ )
137
+
138
+ @classmethod
139
+ def create_from_arg_string(
140
+ cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
141
+ ) -> T:
142
+ """
143
+ Creates an instance of the LM class using the given argument string and additional config.
144
+
145
+ Parameters:
146
+ - arg_string: A string containing arguments in the format key1=value1,key2=value2.
147
+ - additional_config: Optional dictionary containing additional configuration parameters.
148
+
149
+ Returns:
150
+ - Instance of the LM class.
151
+ """
152
+ additional_config = {} if additional_config is None else additional_config
153
+ args = utils.simple_parse_args_string(arg_string)
154
+ args2 = {k: v for k, v in additional_config.items() if v is not None}
155
+ return cls(**args, **args2)
156
+
157
+ @classmethod
158
+ def create_from_arg_obj(
159
+ cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
160
+ ) -> T:
161
+ """
162
+ Creates an instance of the LM class using the given arg_obj
163
+
164
+ Parameters:
165
+ - arg_obj: A dict containing arguments in the format key1=value1,key2=value2.
166
+ - additional_config: Optional dictionary containing additional configuration parameters.
167
+
168
+ Returns:
169
+ - Instance of the LM class.
170
+ """
171
+
172
+ additional_config = additional_config or {} | {
173
+ k: v for k, v in additional_config.items() if v is not None
174
+ }
175
+
176
+ return cls(**arg_dict, **additional_config)
177
+
178
+ @property
179
+ def rank(self):
180
+ # used in the case of parallelism. Hardcoded to
181
+ # ensure no errors arise using API models which do
182
+ # not support multi-device parallelism nor expect it.
183
+ return self._rank
184
+
185
+ @property
186
+ def world_size(self):
187
+ # used in the case of parallelism. Hardcoded to
188
+ # ensure no errors arise using API models which do
189
+ # not support multi-device parallelism nor expect it.
190
+ return self._world_size
191
+
192
+ @property
193
+ def tokenizer_name(self) -> str:
194
+ """Must be defined for LM subclasses which implement Chat Templating.
195
+ Should return the name of the tokenizer or chat template used.
196
+ Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
197
+ """
198
+ raise NotImplementedError(
199
+ "To use this model with chat templates, please implement the 'tokenizer_name' property."
200
+ )
201
+
202
+ def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
203
+ """Returns the chat template structure for user/assistant messages if a template is provided.
204
+ This method is intended to be overridden in a subclass to define a specific chat template format.
205
+ For models that do not support chat templates, this method returns None by default.
206
+ """
207
+
208
+ return ""
209
+
210
+ def set_cache_hook(self, cache_hook: "CacheHook") -> None:
211
+ self.cache_hook = cache_hook
212
+
213
+
214
+ ### SQLite-based caching of LM responses
215
+ def hash_args(attr: str, args: Iterable[Any]) -> str:
216
+ dat = json.dumps([attr] + list(args))
217
+ return hashlib.sha256(dat.encode("utf-8")).hexdigest()
218
+
219
+
220
+ class CacheHook:
221
+ def __init__(self, cachinglm: Optional["CachingLM"]) -> None:
222
+ if cachinglm is None:
223
+ self.dbdict: Optional["SqliteDict"] = None
224
+ return
225
+
226
+ self.dbdict = cachinglm.dbdict
227
+
228
+ def add_partial(self, attr: str, req: Iterable[Any], res: Any) -> None:
229
+ if self.dbdict is None:
230
+ return
231
+ hsh = hash_args(attr, req)
232
+ self.dbdict[hsh] = res
233
+
234
+
235
+ class CachingLM:
236
+ def __init__(self, lm: LM, cache_db: str) -> None:
237
+ """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
238
+
239
+ :param lm: LM
240
+ Underlying LM
241
+ :param cache_db: str
242
+ Path to cache db
243
+ """
244
+ from sqlitedict import SqliteDict
245
+
246
+ self.lm: LM = lm
247
+ self.cache_db: str = cache_db
248
+ if os.path.dirname(cache_db):
249
+ os.makedirs(os.path.dirname(cache_db), exist_ok=True)
250
+ self.dbdict = SqliteDict(cache_db, autocommit=True)
251
+
252
+ # add hook to lm
253
+ lm.set_cache_hook(self.get_cache_hook())
254
+
255
+ def __getattr__(self, attr: str) -> Any:
256
+ lm_attr = getattr(self.lm, attr)
257
+ if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]:
258
+ eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
259
+ return lm_attr
260
+
261
+ def _fn(requests: list["Instance"]) -> list["Instance"]:
262
+ res = []
263
+ remaining_reqs = []
264
+ warned = False
265
+ # figure out which ones are cached and which ones are new
266
+ eval_logger.info(
267
+ f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
268
+ )
269
+ for req in tqdm(requests, desc="Checking cached requests"):
270
+ hsh = hash_args(attr, req.args)
271
+ if attr == "generate_until" and req.args[1].get("do_sample", False):
272
+ # when we are doing non-greedy generation, don't use the cache
273
+ # (else every "randomly sampled" generation would be identical for repeats > 1).
274
+ if not warned:
275
+ eval_logger.warning(
276
+ f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
277
+ )
278
+ warned = True
279
+ res.append(None)
280
+ remaining_reqs.append(req)
281
+ elif hsh in self.dbdict:
282
+ ob = self.dbdict[hsh]
283
+
284
+ assert ob is not None
285
+
286
+ res.append(ob)
287
+ else:
288
+ res.append(None)
289
+ remaining_reqs.append(req)
290
+ eval_logger.info(
291
+ f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
292
+ )
293
+ if remaining_reqs:
294
+ # actually run the LM on the requests that do not have cached results
295
+ rem_res = getattr(self.lm, attr)(remaining_reqs)
296
+ else:
297
+ rem_res = []
298
+
299
+ # stick the new ones back into the list and also cache any of the new ones
300
+ resptr = 0
301
+ for req, r in zip(remaining_reqs, rem_res):
302
+ while res[resptr] is not None:
303
+ resptr += 1
304
+
305
+ res[resptr] = r
306
+
307
+ # caching
308
+ hsh = hash_args(attr, req.args)
309
+ self.dbdict[hsh] = r
310
+ self.dbdict.commit()
311
+
312
+ return res
313
+
314
+ return _fn
315
+
316
+ def get_cache_hook(self) -> "CacheHook":
317
+ return CacheHook(self)
318
+
319
+
320
+ class TemplateLM(LM):
321
+ """
322
+ A class acting as intermediary between the LM base class
323
+ and boilerplate often included in other LM subclasses.
324
+ """
325
+
326
+ tokenizer = None
327
+
328
+ @property
329
+ @abc.abstractmethod
330
+ def eot_token_id(self):
331
+ pass
332
+
333
+ @property
334
+ def prefix_token_id(self):
335
+ # it is used as prefix for loglikelihood
336
+ return self.eot_token_id
337
+
338
+ @abc.abstractmethod
339
+ def tok_encode(self, string: str, **kwargs) -> list[int]:
340
+ """
341
+ Tokenize a string using the model's tokenizer and return a list of token IDs.
342
+ """
343
+ pass
344
+
345
+ @abc.abstractmethod
346
+ def _loglikelihood_tokens(
347
+ self, requests: list["Instance"], **kwargs
348
+ ) -> list[tuple[float, bool]]:
349
+ pass
350
+
351
+ def _encode_pair(
352
+ self, context: str, continuation: str
353
+ ) -> tuple[list[int], list[int]]:
354
+ import transformers
355
+
356
+ n_spaces = len(context) - len(context.rstrip())
357
+ if n_spaces > 0:
358
+ continuation = context[-n_spaces:] + continuation
359
+ context = context[:-n_spaces]
360
+
361
+ model_class = getattr(self, "AUTO_MODEL_CLASS", None)
362
+
363
+ if model_class == transformers.AutoModelForSeq2SeqLM:
364
+ context_enc = self.tok_encode(context)
365
+ continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
366
+ else:
367
+ whole_enc = self.tok_encode(context + continuation)
368
+ context_enc = self.tok_encode(context)
369
+
370
+ context_enc_len = len(context_enc)
371
+ continuation_enc = whole_enc[context_enc_len:]
372
+
373
+ return context_enc, continuation_enc
374
+
375
+ def loglikelihood(
376
+ self, requests: list["Instance"], disable_tqdm: bool = False
377
+ ) -> list[tuple[float, bool]]:
378
+ new_reqs = []
379
+ for context, continuation in [req.args for req in requests]:
380
+ if context == "":
381
+ # BOS or EOS as context
382
+ context_enc, continuation_enc = (
383
+ [self.prefix_token_id],
384
+ self.tok_encode(continuation),
385
+ )
386
+ else:
387
+ context_enc, continuation_enc = self._encode_pair(context, continuation)
388
+
389
+ new_reqs.append(((context, continuation), context_enc, continuation_enc))
390
+
391
+ return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
392
+
393
+ @abc.abstractmethod
394
+ def loglikelihood_rolling(
395
+ self, requests, disable_tqdm: bool = False
396
+ ) -> list[float]:
397
+ pass
398
+
399
+ @abc.abstractmethod
400
+ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
401
+ pass
402
+
403
+ def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
404
+ """
405
+ Set and get the appropriate chat template for the model.
406
+ This method sets the tokenizer's chat_template and returns the template string for reproducibility.
407
+
408
+ The template selection logic is adapted from the Transformers library's `apply_chat_template`
409
+ method in the Tokenizer class. The original implementation can be found at:
410
+ https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687
411
+
412
+ This method ensures that the right template is chosen based on the following:
413
+ 0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string.
414
+ 1. If the model's tokenizer has multiple templates:
415
+ a. Use the specified template if it exists in the dictionary.
416
+ b. Use the default template from the list if no specific template is provided.
417
+ c. Raise an error if no default template exists and no specific template is provided.
418
+ 2. If the model's tokenizer has a single template or no template:
419
+ a. Use the tokenizer's chat template if available.
420
+ b. Fall back to the default chat template if no tokenizer chat template exists.
421
+
422
+ Args:
423
+ chat_template (Union[bool, str]): Specifies the chat template to use.
424
+ - If False or None, no template is applied.
425
+ - If True, the default or only available template is used.
426
+ - If a string, the template with the matching name is used.
427
+
428
+ Returns:
429
+ Optional[str]: The selected chat template, or None if no template is applied.
430
+ """
431
+ if self.tokenizer is None:
432
+ return ""
433
+
434
+ if chat_template is False or chat_template is None:
435
+ eval_logger.warning(
436
+ "model.chat_template was called with the chat_template set to False or None. "
437
+ "Therefore no chat template will be applied. Make sure this is an intended behavior."
438
+ )
439
+ return None
440
+
441
+ # Convert boolean chat_template to None to ensure compatibility with the adapted logic
442
+ if isinstance(chat_template, bool):
443
+ chat_template = None
444
+ using_default_template = False
445
+
446
+ # First, handle the cases when the model has a dict of multiple templates
447
+ try:
448
+ template = (
449
+ self.tokenizer.chat_template or self.tokenizer.default_chat_template
450
+ )
451
+ except AttributeError:
452
+ return None
453
+
454
+ if isinstance(template, dict):
455
+ using_default_dict = self.tokenizer.chat_template is None
456
+
457
+ if chat_template is not None:
458
+ if chat_template in template:
459
+ selected_template = template[chat_template]
460
+ if using_default_dict:
461
+ using_default_template = True
462
+ else:
463
+ raise ValueError(
464
+ f"The specified chat template '{chat_template}' is not available. "
465
+ f"Available template names are {sorted(template.keys())}."
466
+ )
467
+ else:
468
+ # If user didn't pass a chat template, use the default template from the dict
469
+ if "default" in template:
470
+ selected_template = template["default"]
471
+ using_default_template = True
472
+ else:
473
+ raise ValueError(
474
+ "This model has multiple chat templates with no default specified! Please either pass a chat "
475
+ "template or the name of the template you wish to use to the `chat_template` argument. Available "
476
+ f"template names are {sorted(template.keys())}."
477
+ )
478
+
479
+ # Cases when the model has a single template or no template
480
+ else:
481
+ # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
482
+ if isinstance(chat_template, str):
483
+ eval_logger.warning(
484
+ "Chat template name provided, but the tokenizer's chat template is not a dictionary. "
485
+ "Using the tokenizer's chat template or the default template instead."
486
+ )
487
+ if self.tokenizer.chat_template is not None:
488
+ selected_template = self.tokenizer.chat_template
489
+ else:
490
+ selected_template = self.tokenizer.default_chat_template
491
+ using_default_template = True
492
+
493
+ if using_default_template:
494
+ eval_logger.warning(
495
+ "No chat template is set for this tokenizer, falling back to a default class-level template. This is "
496
+ "very error-prone, because models are often trained with templates different from the class default! "
497
+ "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
498
+ "point any code depending on them will stop working. We recommend setting a valid chat template before "
499
+ "then to ensure that this model continues working without issues."
500
+ )
501
+
502
+ return selected_template
lm-evaluation-harness/lm_eval/api/registry.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Callable, Dict, Union
3
+
4
+ import evaluate as hf_evaluate
5
+
6
+ from lm_eval.api.model import LM
7
+
8
+
9
+ eval_logger = logging.getLogger(__name__)
10
+
11
+ MODEL_REGISTRY = {}
12
+
13
+
14
+ def register_model(*names):
15
+ # either pass a list or a single alias.
16
+ # function receives them as a tuple of strings
17
+
18
+ def decorate(cls):
19
+ for name in names:
20
+ assert issubclass(cls, LM), (
21
+ f"Model '{name}' ({cls.__name__}) must extend LM class"
22
+ )
23
+
24
+ assert name not in MODEL_REGISTRY, (
25
+ f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
26
+ )
27
+
28
+ MODEL_REGISTRY[name] = cls
29
+ return cls
30
+
31
+ return decorate
32
+
33
+
34
+ def get_model(model_name):
35
+ try:
36
+ return MODEL_REGISTRY[model_name]
37
+ except KeyError:
38
+ raise ValueError(
39
+ f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
40
+ )
41
+
42
+
43
+ TASK_REGISTRY = {}
44
+ GROUP_REGISTRY = {}
45
+ ALL_TASKS = set()
46
+ func2task_index = {}
47
+
48
+
49
+ def register_task(name):
50
+ def decorate(fn):
51
+ assert name not in TASK_REGISTRY, (
52
+ f"task named '{name}' conflicts with existing registered task!"
53
+ )
54
+
55
+ TASK_REGISTRY[name] = fn
56
+ ALL_TASKS.add(name)
57
+ func2task_index[fn.__name__] = name
58
+ return fn
59
+
60
+ return decorate
61
+
62
+
63
+ def register_group(name):
64
+ def decorate(fn):
65
+ func_name = func2task_index[fn.__name__]
66
+ if name in GROUP_REGISTRY:
67
+ GROUP_REGISTRY[name].append(func_name)
68
+ else:
69
+ GROUP_REGISTRY[name] = [func_name]
70
+ ALL_TASKS.add(name)
71
+ return fn
72
+
73
+ return decorate
74
+
75
+
76
+ OUTPUT_TYPE_REGISTRY = {}
77
+ METRIC_REGISTRY = {}
78
+ METRIC_AGGREGATION_REGISTRY = {}
79
+ AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
80
+ HIGHER_IS_BETTER_REGISTRY = {}
81
+ FILTER_REGISTRY = {}
82
+
83
+ DEFAULT_METRIC_REGISTRY = {
84
+ "loglikelihood": [
85
+ "perplexity",
86
+ "acc",
87
+ ],
88
+ "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
89
+ "multiple_choice": ["acc", "acc_norm"],
90
+ "generate_until": ["exact_match"],
91
+ }
92
+
93
+
94
+ def register_metric(**args):
95
+ # TODO: do we want to enforce a certain interface to registered metrics?
96
+ def decorate(fn):
97
+ assert "metric" in args
98
+ name = args["metric"]
99
+
100
+ for key, registry in [
101
+ ("metric", METRIC_REGISTRY),
102
+ ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
103
+ ("aggregation", METRIC_AGGREGATION_REGISTRY),
104
+ ]:
105
+ if key in args:
106
+ value = args[key]
107
+ assert value not in registry, (
108
+ f"{key} named '{value}' conflicts with existing registered {key}!"
109
+ )
110
+
111
+ if key == "metric":
112
+ registry[name] = fn
113
+ elif key == "aggregation":
114
+ registry[name] = AGGREGATION_REGISTRY[value]
115
+ else:
116
+ registry[name] = value
117
+
118
+ return fn
119
+
120
+ return decorate
121
+
122
+
123
+ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
124
+ if not hf_evaluate_metric:
125
+ if name in METRIC_REGISTRY:
126
+ return METRIC_REGISTRY[name]
127
+ else:
128
+ eval_logger.warning(
129
+ f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
130
+ )
131
+
132
+ try:
133
+ metric_object = hf_evaluate.load(name)
134
+ return metric_object.compute
135
+ except Exception:
136
+ eval_logger.error(
137
+ f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
138
+ )
139
+
140
+
141
+ def register_aggregation(name: str):
142
+ def decorate(fn):
143
+ assert name not in AGGREGATION_REGISTRY, (
144
+ f"aggregation named '{name}' conflicts with existing registered aggregation!"
145
+ )
146
+
147
+ AGGREGATION_REGISTRY[name] = fn
148
+ return fn
149
+
150
+ return decorate
151
+
152
+
153
+ def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
154
+ try:
155
+ return AGGREGATION_REGISTRY[name]
156
+ except KeyError:
157
+ eval_logger.warning(f"{name} not a registered aggregation metric!")
158
+
159
+
160
+ def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
161
+ try:
162
+ return METRIC_AGGREGATION_REGISTRY[name]
163
+ except KeyError:
164
+ eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
165
+
166
+
167
+ def is_higher_better(metric_name) -> bool:
168
+ try:
169
+ return HIGHER_IS_BETTER_REGISTRY[metric_name]
170
+ except KeyError:
171
+ eval_logger.warning(
172
+ f"higher_is_better not specified for metric '{metric_name}'!"
173
+ )
174
+
175
+
176
+ def register_filter(name):
177
+ def decorate(cls):
178
+ if name in FILTER_REGISTRY:
179
+ eval_logger.info(
180
+ f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}"
181
+ )
182
+ FILTER_REGISTRY[name] = cls
183
+ return cls
184
+
185
+ return decorate
186
+
187
+
188
+ def get_filter(filter_name: Union[str, Callable]) -> Callable:
189
+ try:
190
+ return FILTER_REGISTRY[filter_name]
191
+ except KeyError as e:
192
+ if callable(filter_name):
193
+ return filter_name
194
+ else:
195
+ eval_logger.warning(f"filter `{filter_name}` is not registered!")
196
+ raise e
lm-evaluation-harness/lm_eval/api/samplers.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+ from functools import partial
4
+ from typing import TYPE_CHECKING, Iterable, Optional, Union
5
+
6
+ import datasets
7
+
8
+
9
+ if TYPE_CHECKING:
10
+ from random import Random
11
+
12
+ from lm_eval.api.task import ConfigurableTask, Task
13
+
14
+ eval_logger = logging.getLogger("lm-eval")
15
+
16
+
17
+ class ContextSampler:
18
+ def __init__(
19
+ self,
20
+ docs: list[dict],
21
+ task: Union["Task", "ConfigurableTask"],
22
+ fewshot_indices: Optional[Iterable] = None,
23
+ rnd: Optional["Random"] = None,
24
+ ) -> None:
25
+ self.rnd = rnd
26
+ if not self.rnd:
27
+ raise ValueError(
28
+ "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
29
+ )
30
+
31
+ self.task = task
32
+ self.config = task._config
33
+
34
+ self.target_delimiter = self.config.target_delimiter
35
+ self.fewshot_delimiter = self.config.fewshot_delimiter
36
+
37
+ if (
38
+ self.config.fewshot_config is not None
39
+ and self.config.fewshot_config.get("doc_to_text", None) is not None
40
+ ):
41
+ self.doc_to_text = partial(
42
+ self.task.doc_to_text,
43
+ doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
44
+ )
45
+ else:
46
+ self.doc_to_text = self.task.doc_to_text
47
+
48
+ if (
49
+ self.config.fewshot_config is not None
50
+ and self.config.fewshot_config.get("doc_to_target", None) is not None
51
+ ):
52
+ self.doc_to_target = partial(
53
+ self.task.doc_to_target,
54
+ doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
55
+ )
56
+ else:
57
+ self.doc_to_target = self.task.doc_to_target
58
+
59
+ if (
60
+ self.config.fewshot_config is not None
61
+ and self.config.fewshot_config.get("doc_to_choice", None) is not None
62
+ ):
63
+ self.doc_to_choice = partial(
64
+ self.task.doc_to_choice,
65
+ doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
66
+ )
67
+ else:
68
+ self.doc_to_choice = self.task.doc_to_choice
69
+
70
+ self.docs = docs # HF dataset split, provided by task._fewshot_docs()
71
+ if fewshot_indices: # subset few-shot docs from
72
+ if not isinstance(self.docs, datasets.Dataset):
73
+ raise ValueError(
74
+ "Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
75
+ )
76
+ self.docs = self.docs.select(fewshot_indices)
77
+
78
+ def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
79
+ # draw an extra fewshot sample if using same split as evaluating on
80
+ prefix = gen_prefix + " " if gen_prefix else ""
81
+ n_samples = (
82
+ num_fewshot + 1
83
+ if self.config.fewshot_split == self.config.test_split
84
+ else num_fewshot
85
+ )
86
+
87
+ # draw `n_samples` docs from fewshot_docs
88
+ fewshotex = self.sample(n_samples)
89
+
90
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
91
+ # TODO: should we just stop people from using fewshot from same split as evaluating?
92
+ selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
93
+
94
+ labeled_examples = ""
95
+ for doc in selected_docs:
96
+ doc_content = self.doc_to_text(doc)
97
+ doc_target = self.doc_to_target(doc)
98
+ if self.config.doc_to_choice is None or isinstance(doc_content, str):
99
+ labeled_examples += doc_content
100
+ else:
101
+ labeled_examples += self.doc_to_choice(doc)[doc_content]
102
+
103
+ if doc_target != "":
104
+ if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
105
+ # TODO: add logger warn once here.
106
+ warnings.warn(
107
+ "Both target_delimiter and target start with a space. This may cause issues.",
108
+ Warning,
109
+ stacklevel=2,
110
+ )
111
+ labeled_examples += self.target_delimiter
112
+ labeled_examples += prefix
113
+ labeled_examples += (
114
+ str(doc_target[0])
115
+ if isinstance(doc_target, list)
116
+ else doc_target
117
+ if self.config.doc_to_choice is None or isinstance(doc_target, str)
118
+ else str(self.doc_to_choice(doc)[doc_target])
119
+ )
120
+ labeled_examples += self.fewshot_delimiter
121
+
122
+ return labeled_examples
123
+
124
+ def get_chat_context(
125
+ self,
126
+ doc: dict,
127
+ num_fewshot: int,
128
+ fewshot_as_multiturn: bool = False,
129
+ gen_prefix: Optional[str] = None,
130
+ ):
131
+ # TODO: Do we need any other delimiter
132
+ prefix = gen_prefix + " " if gen_prefix else ""
133
+ chat_history = []
134
+ # draw an extra fewshot sample if using same split as evaluating on
135
+ n_samples = (
136
+ num_fewshot + 1
137
+ if self.config.fewshot_split == self.config.test_split
138
+ else num_fewshot
139
+ )
140
+ # draw `n_samples` docs from fewshot_docs
141
+ fewshotex = self.sample(n_samples)
142
+
143
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
144
+ # TODO: should we just stop people from using fewshot from same split as evaluating?
145
+ selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
146
+
147
+ if fewshot_as_multiturn:
148
+ for doc in selected_docs:
149
+ doc_content = self.doc_to_text(doc)
150
+ doc_target = self.doc_to_target(doc)
151
+ chat_history.append(
152
+ {
153
+ "role": "user",
154
+ "content": doc_content
155
+ if self.config.doc_to_choice is None
156
+ or isinstance(doc_content, str)
157
+ else self.doc_to_choice(doc)[doc_content],
158
+ }
159
+ )
160
+ chat_history.append(
161
+ {
162
+ "role": "assistant",
163
+ "content": prefix + str(doc_target[0])
164
+ if isinstance(doc_target, list)
165
+ else prefix + doc_target
166
+ if self.config.doc_to_choice is None
167
+ or isinstance(doc_target, str)
168
+ else prefix + str(self.doc_to_choice(doc)[doc_target]),
169
+ }
170
+ )
171
+ else:
172
+ # get fewshot context as one user turn
173
+ chat_history.append(
174
+ {
175
+ "role": "user",
176
+ "content": self.get_context(
177
+ doc, num_fewshot, gen_prefix=gen_prefix
178
+ ),
179
+ }
180
+ )
181
+
182
+ return chat_history
183
+
184
+ def sample(self, n: int):
185
+ """
186
+ Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
187
+ """
188
+
189
+ return self.rnd.sample(self.docs, n)
190
+
191
+
192
+ class FirstNSampler(ContextSampler):
193
+ def sample(self, n: int) -> None:
194
+ """
195
+ Draw the first `n` samples in order from the specified split.
196
+ Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
197
+ """
198
+ assert n <= len(self.docs), (
199
+ f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
200
+ )
201
+ return self.docs[:n]
202
+
203
+
204
+ class BalancedSampler(ContextSampler):
205
+ def sample(self, n: int) -> None:
206
+ """
207
+ TODO: this should return approximately class-balanced samples from our fewshot examples.
208
+ TODO: what order should they be in? maybe random?
209
+ """
210
+
211
+ pass
212
+
213
+
214
+ class ManualSampler(ContextSampler):
215
+ def sample(self, n: int) -> None:
216
+ """ """
217
+ pass
218
+
219
+
220
+ SAMPLER_REGISTRY = {
221
+ "default": ContextSampler,
222
+ "first_n": FirstNSampler,
223
+ }
224
+
225
+
226
+ def get_sampler(name: str):
227
+ try:
228
+ return SAMPLER_REGISTRY[name]
229
+ except KeyError:
230
+ raise ValueError(
231
+ f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
232
+ )
lm-evaluation-harness/lm_eval/api/task.py ADDED
@@ -0,0 +1,1885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import ast
3
+ import logging
4
+ import random
5
+ import re
6
+ from collections.abc import Callable
7
+ from copy import deepcopy
8
+ from dataclasses import asdict, dataclass
9
+ from inspect import getsource
10
+ from typing import (
11
+ Any,
12
+ Dict,
13
+ Iterable,
14
+ Iterator,
15
+ List,
16
+ Literal,
17
+ Mapping,
18
+ Optional,
19
+ Tuple,
20
+ Union,
21
+ )
22
+
23
+ import datasets
24
+ import numpy as np
25
+ from tqdm import tqdm
26
+
27
+ from lm_eval import utils
28
+ from lm_eval.api import samplers
29
+ from lm_eval.api.instance import Instance, OutputType
30
+ from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
31
+ from lm_eval.api.registry import (
32
+ AGGREGATION_REGISTRY,
33
+ DEFAULT_METRIC_REGISTRY,
34
+ get_aggregation,
35
+ get_metric,
36
+ get_metric_aggregation,
37
+ is_higher_better,
38
+ )
39
+ from lm_eval.caching.cache import load_from_cache, save_to_cache
40
+ from lm_eval.filters import build_filter_ensemble
41
+ from lm_eval.prompts import get_prompt
42
+
43
+
44
+ ALL_OUTPUT_TYPES = [
45
+ "loglikelihood",
46
+ "multiple_choice",
47
+ "loglikelihood_rolling",
48
+ "generate_until",
49
+ ]
50
+
51
+ eval_logger = logging.getLogger(__name__)
52
+
53
+
54
+ @dataclass
55
+ class TaskConfig(dict):
56
+ # task naming/registry
57
+ task: Optional[str] = None
58
+ task_alias: Optional[str] = None
59
+ tag: Optional[Union[str, list]] = None
60
+ # HF dataset options.
61
+ # which dataset to use,
62
+ # and what splits for what purpose
63
+ custom_dataset: Optional[Callable] = None
64
+ dataset_path: Optional[str] = None
65
+ dataset_name: Optional[str] = None
66
+ dataset_kwargs: Optional[dict] = None
67
+ training_split: Optional[str] = None
68
+ validation_split: Optional[str] = None
69
+ test_split: Optional[str] = None
70
+ fewshot_split: Optional[str] = (
71
+ None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
72
+ )
73
+ # formatting / prompting options.
74
+ # see docs/advanced_task_guide.md for more info
75
+ process_docs: Optional[Callable] = None
76
+ doc_to_text: Optional[Union[Callable, str]] = None
77
+ doc_to_target: Optional[Union[Callable, str]] = None
78
+ doc_to_image: Union[Callable, str] = None
79
+ doc_to_audio: Union[Callable, str] = None
80
+ unsafe_code: bool = False
81
+ doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
82
+ process_results: Optional[Union[Callable, str]] = None
83
+ use_prompt: Optional[str] = None
84
+ description: str = ""
85
+ target_delimiter: str = " "
86
+ fewshot_delimiter: str = "\n\n"
87
+ fewshot_config: Optional[dict] = None
88
+ # runtime configuration options
89
+ num_fewshot: Optional[int] = None
90
+ # scoring options
91
+ metric_list: Optional[list] = None
92
+ output_type: OutputType = "generate_until"
93
+ generation_kwargs: Optional[dict] = None
94
+ repeats: int = 1
95
+ filter_list: Optional[Union[str, list]] = None
96
+ should_decontaminate: bool = False
97
+ doc_to_decontamination_query: Optional[str] = None
98
+ gen_prefix: Optional[str] = None
99
+ metadata: Optional[dict] = (
100
+ None # by default, not used in the code. allows for users to pass arbitrary info to tasks
101
+ )
102
+
103
+ def __post_init__(self) -> None:
104
+ if self.generation_kwargs is not None:
105
+ if self.output_type != "generate_until":
106
+ eval_logger.warning(
107
+ f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
108
+ )
109
+
110
+ if "temperature" in self.generation_kwargs:
111
+ self.generation_kwargs["temperature"] = float(
112
+ self.generation_kwargs["temperature"]
113
+ )
114
+
115
+ if "until" not in self.generation_kwargs:
116
+ eval_logger.warning(
117
+ f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
118
+ )
119
+ self.generation_kwargs["until"] = [self.fewshot_delimiter]
120
+ else:
121
+ if self.output_type == "generate_until":
122
+ # ensure that we greedily generate in absence of explicit arguments otherwise
123
+ self.generation_kwargs = {
124
+ "until": (
125
+ None
126
+ if self.fewshot_delimiter is None
127
+ else [self.fewshot_delimiter]
128
+ ),
129
+ "do_sample": False,
130
+ "temperature": 0,
131
+ }
132
+ eval_logger.warning(
133
+ f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
134
+ )
135
+
136
+ def __getitem__(self, item):
137
+ return getattr(self, item)
138
+
139
+ def __setitem__(self, item, value):
140
+ return setattr(self, item, value)
141
+
142
+ def to_dict(self, keep_callable: bool = False) -> dict:
143
+ """dumps the current config as a dictionary object, as a printable format.
144
+ null fields will not be printed.
145
+ Used for dumping results alongside full task configuration
146
+
147
+ :return: dict
148
+ A printable dictionary version of the TaskConfig object.
149
+
150
+ # TODO: should any default value in the TaskConfig not be printed?
151
+ """
152
+ cfg_dict = asdict(self)
153
+ # remove values that are `None`
154
+ for k, v in list(cfg_dict.items()):
155
+ if v is None:
156
+ cfg_dict.pop(k)
157
+ elif k == "metric_list":
158
+ for metric_dict in v:
159
+ for metric_key, metric_value in metric_dict.items():
160
+ if callable(metric_value):
161
+ metric_dict[metric_key] = self.serialize_function(
162
+ metric_value, keep_callable=keep_callable
163
+ )
164
+ cfg_dict[k] = v
165
+ elif callable(v):
166
+ cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
167
+ return cfg_dict
168
+
169
+ def serialize_function(
170
+ self, value: Union[Callable, str], keep_callable=False
171
+ ) -> Union[Callable, str]:
172
+ """Serializes a given function or string.
173
+
174
+ If 'keep_callable' is True, the original callable is returned.
175
+ Otherwise, attempts to return the source code of the callable using 'getsource'.
176
+ """
177
+ if keep_callable:
178
+ return value
179
+ else:
180
+ try:
181
+ return getsource(value)
182
+ except (TypeError, OSError):
183
+ return str(value)
184
+
185
+
186
+ class Task(abc.ABC):
187
+ """A task represents an entire benchmark including its dataset, problems,
188
+ answers, and evaluation methods. See BoolQ for a simple example implementation
189
+
190
+ A `doc` can be any python object which represents one instance of evaluation.
191
+ This is usually a dictionary e.g.
192
+ {"question": ..., "answer": ...} or
193
+ {"question": ..., question, answer)
194
+ """
195
+
196
+ VERSION: Optional[Union[int, str]] = None
197
+
198
+ # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
199
+ # or a path to a custom `datasets` loading script.
200
+ DATASET_PATH: Optional[str] = None
201
+
202
+ # The name of a subset within `DATASET_PATH`.
203
+ DATASET_NAME: Optional[str] = None
204
+
205
+ OUTPUT_TYPE: Optional[OutputType] = None
206
+
207
+ def __init__(
208
+ self,
209
+ data_dir: Optional[str] = None,
210
+ cache_dir: Optional[str] = None,
211
+ download_mode: Optional[datasets.DownloadMode] = None,
212
+ config: Optional[Mapping] = None, # Union[dict, TaskConfig]
213
+ ) -> None:
214
+ """
215
+ :param data_dir: str
216
+ Stores the path to a local folder containing the `Task`'s data files.
217
+ Use this to specify the path to manually downloaded data (usually when
218
+ the dataset is not publicly accessible).
219
+ :param cache_dir: str
220
+ The directory to read/write the `Task` dataset. This follows the
221
+ HuggingFace `datasets` API with the default cache directory located at:
222
+ `~/.cache/huggingface/datasets`
223
+ NOTE: You can change the cache location globally for a given process
224
+ to another directory:
225
+ `export HF_DATASETS_CACHE="/path/to/another/directory"`
226
+ :param download_mode: datasets.DownloadMode
227
+ How to treat pre-existing `Task` downloads and data.
228
+ - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
229
+ Reuse download and reuse dataset.
230
+ - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
231
+ Reuse download with fresh dataset.
232
+ - `datasets.DownloadMode.FORCE_REDOWNLOAD`
233
+ Fresh download and fresh dataset.
234
+ """
235
+ self.download(data_dir, cache_dir, download_mode)
236
+ self._training_docs: Optional[list] = None
237
+ self._fewshot_docs: Optional[list] = None
238
+ self._instances: Optional[List[Instance]] = None
239
+
240
+ self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
241
+
242
+ self._filters = [build_filter_ensemble("none", [["take_first", None]])]
243
+ self.fewshot_rnd: Optional[random.Random] = (
244
+ None # purposely induce errors in case of improper usage
245
+ )
246
+
247
+ def download(
248
+ self,
249
+ data_dir: Optional[str] = None,
250
+ cache_dir: Optional[str] = None,
251
+ download_mode=None,
252
+ ) -> None:
253
+ """Downloads and returns the task dataset.
254
+ Override this method to download the dataset from a custom API.
255
+
256
+ :param data_dir: str
257
+ Stores the path to a local folder containing the `Task`'s data files.
258
+ Use this to specify the path to manually downloaded data (usually when
259
+ the dataset is not publicly accessible).
260
+ :param cache_dir: str
261
+ The directory to read/write the `Task` dataset. This follows the
262
+ HuggingFace `datasets` API with the default cache directory located at:
263
+ `~/.cache/huggingface/datasets`
264
+ NOTE: You can change the cache location globally for a given process
265
+ by setting the shell environment variable, `HF_DATASETS_CACHE`,
266
+ to another directory:
267
+ `export HF_DATASETS_CACHE="/path/to/another/directory"`
268
+ :param download_mode: datasets.DownloadMode
269
+ How to treat pre-existing `Task` downloads and data.
270
+ - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
271
+ Reuse download and reuse dataset.
272
+ - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
273
+ Reuse download with fresh dataset.
274
+ - `datasets.DownloadMode.FORCE_REDOWNLOAD`
275
+ Fresh download and fresh dataset.
276
+ """
277
+ self.dataset = datasets.load_dataset(
278
+ path=self.DATASET_PATH,
279
+ name=self.DATASET_NAME,
280
+ data_dir=data_dir,
281
+ cache_dir=cache_dir,
282
+ download_mode=download_mode,
283
+ )
284
+
285
+ @property
286
+ def config(self) -> TaskConfig:
287
+ """Returns the TaskConfig associated with this class."""
288
+ return self._config
289
+
290
+ @abc.abstractmethod
291
+ def has_training_docs(self):
292
+ """Whether the task has a training set"""
293
+ pass
294
+
295
+ @abc.abstractmethod
296
+ def has_validation_docs(self):
297
+ """Whether the task has a validation set"""
298
+ pass
299
+
300
+ @abc.abstractmethod
301
+ def has_test_docs(self):
302
+ """Whether the task has a test set"""
303
+ pass
304
+
305
+ def training_docs(self) -> Iterable:
306
+ """
307
+ :return: Iterable[obj]
308
+ A iterable of any object, that doc_to_text can handle
309
+ """
310
+ return []
311
+
312
+ def validation_docs(self) -> Iterable:
313
+ """
314
+ :return: Iterable[obj]
315
+ A iterable of any object, that doc_to_text can handle
316
+ """
317
+ return []
318
+
319
+ def test_docs(self) -> Iterable:
320
+ """
321
+ :return: Iterable[obj]
322
+ A iterable of any object, that doc_to_text can handle
323
+ """
324
+ return []
325
+
326
+ def fewshot_docs(self) -> Iterable:
327
+ """
328
+ :return: Iterable[obj]
329
+ A iterable of any object, that doc_to_text can handle
330
+ """
331
+ if self.has_training_docs():
332
+ return self.training_docs()
333
+ elif self.has_validation_docs():
334
+ return self.validation_docs()
335
+ else:
336
+ if self.config.get("num_fewshot", 0) > 0:
337
+ eval_logger.warning(
338
+ f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
339
+ ", using test_docs as fewshot_docs but this is not recommended."
340
+ )
341
+ return self.test_docs()
342
+
343
+ def _process_doc(self, doc: dict) -> dict:
344
+ """
345
+ Override this to process (detokenize, strip, replace, etc.) individual
346
+ documents. This can be used in a map over documents of a data split.
347
+ E.g. `map(self._process_doc, self.dataset["validation"])`
348
+
349
+ :return: dict
350
+ The processed version of the specified `doc`.
351
+ """
352
+ return doc
353
+
354
+ @property
355
+ def instances(self) -> List[Instance]:
356
+ """After calling `task.build_all_requests()`, tasks
357
+ maintain a list of the dataset instances which will be evaluated.
358
+ """
359
+ return self._instances
360
+
361
+ def fewshot_examples(self, k, rnd):
362
+ if self._training_docs is None:
363
+ self._training_docs = list(self.training_docs())
364
+
365
+ return rnd.sample(self._training_docs, k)
366
+
367
+ def doc_to_decontamination_query(self, doc):
368
+ raise NotImplementedError(
369
+ "Override doc_to_decontamination_query with document specific decontamination query."
370
+ )
371
+
372
+ @abc.abstractmethod
373
+ def doc_to_text(self, doc):
374
+ pass
375
+
376
+ @abc.abstractmethod
377
+ def doc_to_target(self, doc):
378
+ pass
379
+
380
+ # not an abstractmethod because not every language-only task has to implement this
381
+ def doc_to_image(self, doc):
382
+ raise NotImplementedError
383
+
384
+ def doc_to_audio(self, doc):
385
+ raise NotImplementedError
386
+
387
+ def doc_to_prefix(self, doc):
388
+ return ""
389
+
390
+ def build_all_requests(
391
+ self,
392
+ *,
393
+ limit: Union[int, None] = None,
394
+ samples: Optional[List[int]] = None,
395
+ rank: int = 0,
396
+ world_size: int = 1,
397
+ cache_requests: bool = False,
398
+ rewrite_requests_cache: bool = False,
399
+ system_instruction: Optional[str] = None,
400
+ apply_chat_template: bool = False,
401
+ fewshot_as_multiturn: bool = False,
402
+ chat_template: Optional[Callable] = None,
403
+ tokenizer_name: str = "",
404
+ ) -> None:
405
+ """Build a set of Instances for a task, and store them in task.instances"""
406
+
407
+ # used with caching
408
+ og_limit = limit
409
+
410
+ cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}"
411
+ cache_key += "-chat_template" if apply_chat_template else ""
412
+ cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else ""
413
+ cache_key += (
414
+ f"-system_prompt_hash{utils.hash_string(system_instruction)}"
415
+ if system_instruction is not None
416
+ else ""
417
+ )
418
+ cache_key += f"-tokenizer{tokenizer_name}"
419
+
420
+ cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests)
421
+
422
+ if cache_requests and cached_instances and not rewrite_requests_cache:
423
+ cached_instances = cached_instances[:limit]
424
+
425
+ flattened_instances = [
426
+ instance
427
+ for instance_group in cached_instances
428
+ for instance in instance_group
429
+ ]
430
+
431
+ self._instances = flattened_instances
432
+ return
433
+
434
+ eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
435
+
436
+ instances = []
437
+
438
+ # process all documents when caching is specified for simplicity
439
+ if (
440
+ cache_requests
441
+ and (not cached_instances or rewrite_requests_cache)
442
+ and limit is not None
443
+ ):
444
+ limit = None
445
+
446
+ doc_id_docs = list(
447
+ self.doc_iterator(
448
+ rank=rank, limit=limit, samples=samples, world_size=world_size
449
+ )
450
+ )
451
+
452
+ num_docs = len(doc_id_docs)
453
+
454
+ for doc_id, doc in tqdm(
455
+ doc_id_docs,
456
+ total=num_docs,
457
+ ):
458
+ # sample fewshot context #TODO: need to offset doc_id by rank now!
459
+ fewshot_ctx = self.fewshot_context(
460
+ doc,
461
+ num_fewshot=0
462
+ if self.config.num_fewshot is None
463
+ else self.config.num_fewshot,
464
+ system_instruction=system_instruction,
465
+ apply_chat_template=apply_chat_template,
466
+ fewshot_as_multiturn=fewshot_as_multiturn,
467
+ chat_template=chat_template,
468
+ gen_prefix=self.doc_to_prefix(doc),
469
+ )
470
+
471
+ # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
472
+ inst = self.construct_requests(
473
+ doc=doc,
474
+ ctx=fewshot_ctx,
475
+ metadata=(self.config["task"], doc_id, self.config.repeats),
476
+ apply_chat_template=apply_chat_template,
477
+ chat_template=chat_template,
478
+ )
479
+
480
+ if not isinstance(inst, list):
481
+ inst = [inst]
482
+
483
+ instances.append(inst)
484
+
485
+ # now flatten, this is to allow slicing to work with pickles
486
+
487
+ sliced_instances = instances[:og_limit]
488
+
489
+ flattened_instances = [
490
+ instance
491
+ for instance_group in sliced_instances
492
+ for instance in instance_group
493
+ ]
494
+
495
+ self._instances = flattened_instances
496
+
497
+ if len(self._instances) == 0:
498
+ raise ValueError("task.build_requests() did not find any docs!")
499
+
500
+ if cache_requests and (not cached_instances or rewrite_requests_cache):
501
+ save_to_cache(file_name=cache_key, obj=instances)
502
+
503
+ @abc.abstractmethod
504
+ def construct_requests(self, doc, ctx, **kwargs):
505
+ """Uses RequestFactory to construct Requests and returns an iterable of
506
+ Requests which will be sent to the LM.
507
+
508
+ :param doc:
509
+ The document as returned from training_docs, validation_docs, or test_docs.
510
+ :param ctx: str
511
+ The context string, generated by fewshot_context. This includes the natural
512
+ language description, as well as the few shot examples, and the question
513
+ part of the document for `doc`.
514
+ :param doc_idx: int
515
+ The index of a document within `self.test_docs()` or `self.validation_docs()`,
516
+ whichever is the main split used.
517
+ :param repeats: int
518
+ TODO: update this docstring
519
+ The number of times each instance in a dataset is inferred on. Defaults to 1,
520
+ can be increased for techniques like majority voting.
521
+ """
522
+ pass
523
+
524
+ @abc.abstractmethod
525
+ def process_results(self, doc, results):
526
+ """Take a single document and the LM results and evaluates, returning a
527
+ dict where keys are the names of submetrics and values are the values of
528
+ the metric for that one document
529
+
530
+ :param doc:
531
+ The document as returned from training_docs, validation_docs, or test_docs.
532
+ :param results:
533
+ The results of the requests created in construct_requests.
534
+ """
535
+ pass
536
+
537
+ @abc.abstractmethod
538
+ def aggregation(self):
539
+ """
540
+ :returns: {str: [metric_score] -> float}
541
+ A dictionary where keys are the names of submetrics and values are
542
+ functions that aggregate a list of metric scores
543
+ """
544
+ pass
545
+
546
+ @abc.abstractmethod
547
+ def higher_is_better(self):
548
+ """
549
+ :returns: {str: bool}
550
+ A dictionary where keys are the names of submetrics and values are
551
+ whether a higher value of the submetric is better
552
+ """
553
+ pass
554
+
555
+ def get_config(self, key: str) -> Any:
556
+ return getattr(self._config, key, None)
557
+
558
+ @classmethod
559
+ def count_bytes(cls, doc):
560
+ """Used for byte-level perplexity metrics in rolling loglikelihood"""
561
+ return len(doc.encode("utf-8"))
562
+
563
+ @classmethod
564
+ def count_words(cls, doc):
565
+ """Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
566
+ return len(re.split(r"\s+", doc))
567
+
568
+ @utils.positional_deprecated
569
+ def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs):
570
+ """Returns a fewshot context string that is made up of a prepended description
571
+ (if provided), the `num_fewshot` number of examples, and an appended prompt example.
572
+
573
+ :param doc: str
574
+ The document as returned from training_docs, validation_docs, or test_docs.
575
+ :param num_fewshot: int
576
+ The number of fewshot examples to provide in the returned context string.
577
+ :param rnd: random.Random
578
+ The pseudo-random number generator used to randomly sample examples.
579
+ WARNING: This is currently a required arg although it's optionalized with a default `None`.
580
+ :param description: str
581
+ The task's description that will be prepended to the fewshot examples.
582
+ :returns: str
583
+ The fewshot context.
584
+ """
585
+ if rnd is None:
586
+ if self.fewshot_rnd is not None:
587
+ rnd = self.fewshot_rnd
588
+ else:
589
+ raise ValueError(
590
+ "A `random.Random` generator argument must be provided to `rnd`"
591
+ )
592
+
593
+ description = description if description else ""
594
+
595
+ if num_fewshot == 0:
596
+ labeled_examples = ""
597
+ else:
598
+ # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
599
+ if self.has_training_docs():
600
+ fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
601
+ else:
602
+ if self._fewshot_docs is None:
603
+ self._fewshot_docs = list(
604
+ self.validation_docs()
605
+ if self.has_validation_docs()
606
+ else self.test_docs()
607
+ )
608
+
609
+ fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
610
+
611
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
612
+ fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
613
+
614
+ labeled_examples = (
615
+ "\n\n".join(
616
+ [
617
+ self.doc_to_text(doc) + self.doc_to_target(doc)
618
+ for doc in fewshotex
619
+ ]
620
+ )
621
+ + "\n\n"
622
+ )
623
+
624
+ example = self.doc_to_text(doc)
625
+ return description + labeled_examples + example
626
+
627
+ def apply_filters(self) -> Optional[List[Instance]]:
628
+ """Iterates over FilterEnsembles and applies them to instances"""
629
+ if hasattr(self, "_filters"):
630
+ for f in self._filters:
631
+ f.apply(self._instances)
632
+ else:
633
+ eval_logger.warning("No filter defined, passing through instances")
634
+ return self._instances
635
+
636
+ def dump_config(self) -> dict:
637
+ """Returns the config as a dictionary."""
638
+ # TODO: this should only return the overrides applied to a non-YAML task's configuration.
639
+ # (num_fewshot)
640
+ return self.config.to_dict()
641
+
642
+ def set_config(self, key: str, value: Any, update: bool = False) -> None:
643
+ """Set or update the configuration for a given key."""
644
+ if key is None:
645
+ raise ValueError("Key must be provided.")
646
+
647
+ if update:
648
+ current_value = getattr(self._config, key, {})
649
+ if not isinstance(current_value, dict):
650
+ raise TypeError(
651
+ f"Expected a dict for key '{key}', got {type(current_value).__name__} instead."
652
+ )
653
+ current_value.update(value)
654
+ else:
655
+ setattr(self._config, key, value)
656
+
657
+ def override_metric(self, metric_name: str) -> None:
658
+ """
659
+ Override the default metrics used for evaluation with custom metrics.
660
+
661
+ Parameters:
662
+ - metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
663
+ """
664
+ (
665
+ self._metric_fn_list,
666
+ self._aggregation_list,
667
+ self._metric_fn_kwargs,
668
+ self._higher_is_better,
669
+ ) = ({}, {}, {}, {})
670
+ self._metric_fn_list[metric_name] = get_metric(metric_name)
671
+ self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
672
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
673
+ self._metric_fn_kwargs[metric_name] = {}
674
+ if not isinstance(self, ConfigurableTask):
675
+ self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
676
+ self.aggregation = lambda: {
677
+ metric_name: get_metric_aggregation(metric_name)
678
+ }
679
+ setattr(self._config, "metric_list", [{"metric": metric_name}])
680
+ setattr(self._config, "process_results", None)
681
+
682
+ def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
683
+ self.fewshot_rnd = random.Random(seed)
684
+ if hasattr(self, "sampler"):
685
+ self.sampler.rnd = self.fewshot_rnd
686
+
687
+ @property
688
+ def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
689
+ if self.has_test_docs():
690
+ return self.test_docs()
691
+ elif self.has_validation_docs():
692
+ return self.validation_docs()
693
+ else:
694
+ raise ValueError(
695
+ f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
696
+ )
697
+
698
+ def doc_iterator(
699
+ self,
700
+ *,
701
+ rank: int = 0,
702
+ limit: Union[int, None] = None,
703
+ world_size: int = 1,
704
+ samples: Optional[List[int]] = None,
705
+ ) -> Iterator[Tuple[int, Any]]:
706
+ if samples:
707
+ n = len(self.eval_docs)
708
+ assert all([e < n for e in samples]), (
709
+ f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}."
710
+ )
711
+ eval_logger.info(
712
+ f"{self.config.task}: Evaluating on {len(samples)} examples"
713
+ )
714
+ doc_iterator = utils.create_iterator(
715
+ enumerate(x for i, x in enumerate(self.eval_docs) if i in samples),
716
+ rank=int(rank),
717
+ limit=None, # limit does not matter here since we are selecting samples directly
718
+ world_size=int(world_size),
719
+ )
720
+ else:
721
+ limit = int(limit) if limit else None
722
+ doc_iterator = utils.create_iterator(
723
+ enumerate(self.eval_docs),
724
+ rank=int(rank),
725
+ limit=limit,
726
+ world_size=int(world_size),
727
+ )
728
+ return doc_iterator
729
+
730
+
731
+ class ConfigurableTask(Task):
732
+ VERSION = "Yaml"
733
+ OUTPUT_TYPE = None
734
+ CONFIG = None
735
+
736
+ def __init__(
737
+ self,
738
+ data_dir=None,
739
+ cache_dir=None,
740
+ download_mode=None,
741
+ config: Optional[dict] = None,
742
+ ) -> None: # TODO no super() call here
743
+ # Get pre-configured attributes
744
+ self._config = self.CONFIG
745
+
746
+ # Use new configurations if there was no preconfiguration
747
+ if self.config is None:
748
+ self._config = TaskConfig(**config)
749
+ # Overwrite configs
750
+ else:
751
+ if config is not None:
752
+ self._config.__dict__.update(config)
753
+
754
+ if self.config is None:
755
+ raise ValueError(
756
+ "Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
757
+ )
758
+
759
+ if isinstance(self.config.metadata, dict):
760
+ if "version" in self.config.metadata:
761
+ self.VERSION = self.config.metadata["version"]
762
+
763
+ if self.config.output_type is not None:
764
+ if self.config.output_type not in ALL_OUTPUT_TYPES:
765
+ raise ValueError(
766
+ f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'"
767
+ )
768
+ self.OUTPUT_TYPE = self.config.output_type
769
+
770
+ if self.config.doc_to_image is not None:
771
+ # mark the task as requiring multimodality.
772
+ self.MULTIMODAL = True
773
+
774
+ if self.config.doc_to_audio:
775
+ # mark the task as requiring multimodality.
776
+ self.MULTIMODAL = True
777
+
778
+ if self.config.unsafe_code is not False:
779
+ self.UNSAFE_CODE = True
780
+
781
+ if self.config.dataset_path is not None:
782
+ self.DATASET_PATH = self.config.dataset_path
783
+
784
+ if self.config.dataset_name is not None:
785
+ self.DATASET_NAME = self.config.dataset_name
786
+
787
+ self._metric_fn_list = {}
788
+ self._metric_fn_kwargs = {}
789
+ self._aggregation_list = {}
790
+ self._higher_is_better = {}
791
+
792
+ if self.config.metric_list is None:
793
+ # TODO: handle this in TaskConfig.__post_init__ ?
794
+ _metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
795
+
796
+ for metric_name in _metric_list:
797
+ self._metric_fn_list[metric_name] = get_metric(metric_name)
798
+ self._metric_fn_kwargs[metric_name] = {}
799
+ self._aggregation_list[metric_name] = get_metric_aggregation(
800
+ metric_name
801
+ )
802
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
803
+ else:
804
+ for metric_config in self.config.metric_list:
805
+ if "metric" not in metric_config:
806
+ raise ValueError(
807
+ "'metric' key not provided for an entry in 'metric_list', must be specified!"
808
+ )
809
+ metric_name = metric_config["metric"]
810
+ kwargs = {
811
+ key: metric_config[key]
812
+ for key in metric_config
813
+ if key
814
+ not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
815
+ }
816
+ hf_evaluate_metric = (
817
+ "hf_evaluate" in metric_config
818
+ and metric_config["hf_evaluate"] is True
819
+ )
820
+
821
+ if self.config.process_results is not None:
822
+ self._metric_fn_list[metric_name] = None
823
+ self._metric_fn_kwargs[metric_name] = {}
824
+ elif callable(metric_name):
825
+ metric_fn = metric_name.__call__
826
+ metric_name = metric_name.__name__
827
+ self._metric_fn_list[metric_name] = metric_fn
828
+ self._metric_fn_kwargs[metric_name] = kwargs
829
+ else:
830
+ self._metric_fn_list[metric_name] = get_metric(
831
+ metric_name, hf_evaluate_metric
832
+ )
833
+ self._metric_fn_kwargs[metric_name] = kwargs
834
+
835
+ if "aggregation" in metric_config:
836
+ agg_name = metric_config["aggregation"]
837
+ if isinstance(agg_name, str):
838
+ self._aggregation_list[metric_name] = get_aggregation(agg_name)
839
+ elif callable(agg_name): # noqa: E721
840
+ self._aggregation_list[metric_name] = metric_config[
841
+ "aggregation"
842
+ ]
843
+ else:
844
+ INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
845
+ metric_agg = get_metric_aggregation(metric_name)
846
+ eval_logger.warning(
847
+ f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
848
+ f"using default "
849
+ f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
850
+ )
851
+ self._aggregation_list[metric_name] = metric_agg
852
+
853
+ if "higher_is_better" in metric_config:
854
+ self._higher_is_better[metric_name] = metric_config[
855
+ "higher_is_better"
856
+ ]
857
+ else:
858
+ eval_logger.warning(
859
+ f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
860
+ f"using default "
861
+ f"higher_is_better={is_higher_better(metric_name)}"
862
+ )
863
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
864
+
865
+ self.download(self.config.dataset_kwargs)
866
+ self._training_docs = None
867
+ self._fewshot_docs = None
868
+
869
+ if self.config.filter_list is not None:
870
+ self._filters = []
871
+ for filter_config in self.config.filter_list:
872
+ filter_name = filter_config["name"]
873
+ filter_functions = filter_config["filter"]
874
+ components = []
875
+ for function in filter_functions:
876
+ kwargs = {
877
+ key: function[key] for key in function if key != "function"
878
+ }
879
+ components.append([function["function"], kwargs])
880
+ filter_pipeline = build_filter_ensemble(filter_name, components)
881
+ self._filters.append(filter_pipeline)
882
+ else:
883
+ # TODO: handle repeats in a more general way rather than just discarding
884
+ eval_logger.debug(
885
+ "No custom filters defined. Using default 'take_first' filter for handling repeats."
886
+ )
887
+ self._filters = [build_filter_ensemble("none", [["take_first", None]])]
888
+
889
+ if self.config.use_prompt is not None:
890
+ eval_logger.info(f"loading prompt {self.config.use_prompt}")
891
+ self.prompt = get_prompt(
892
+ self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
893
+ )
894
+ else:
895
+ self.prompt = None
896
+
897
+ if self.fewshot_docs() is not None:
898
+ self.fewshot_rnd = (
899
+ random.Random()
900
+ ) # setting with no seed, to be overridden at a later time
901
+ config_sampler: Union[str, Callable] = (
902
+ self.config.fewshot_config.get("sampler", "default")
903
+ if self.config.fewshot_config
904
+ else "default"
905
+ )
906
+ if isinstance(config_sampler, str):
907
+ self.sampler = samplers.get_sampler(config_sampler)(
908
+ list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
909
+ )
910
+ elif callable(config_sampler) and issubclass(
911
+ config_sampler, samplers.ContextSampler
912
+ ):
913
+ self.sampler = config_sampler(
914
+ docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd
915
+ )
916
+ else:
917
+ raise TypeError(
918
+ f"fewshot_config.sampler should be a string or callable of ContextSampler type, "
919
+ f"not {type(config_sampler)}"
920
+ )
921
+
922
+ self.task_docs = self.eval_docs
923
+
924
+ # Test One Doc
925
+ self.features = list(self.task_docs.features.keys())
926
+ self.multiple_input = 0
927
+ self.multiple_target = 0
928
+ test_doc = self.task_docs[0]
929
+ test_text = self.doc_to_text(test_doc)
930
+ test_target = self.doc_to_target(test_doc)
931
+
932
+ if self.config.doc_to_choice is not None:
933
+ test_choice = self.doc_to_choice(test_doc)
934
+ if not isinstance(test_choice, list):
935
+ eval_logger.error("doc_to_choice must return list")
936
+ else:
937
+ num_choice = len(test_choice)
938
+
939
+ if isinstance(test_text, int):
940
+ eval_logger.debug(
941
+ "doc_to_text returned an int. Assuming multiple inputs."
942
+ )
943
+ self.multiple_input = num_choice
944
+ else:
945
+ test_choice = None
946
+
947
+ if isinstance(test_target, list):
948
+ eval_logger.debug(
949
+ "doc_to_target returned a list. Assuming multiple targets."
950
+ )
951
+ self.multiple_target = len(test_target)
952
+ else:
953
+ if (isinstance(test_target, int)) and (test_choice is not None):
954
+ test_target = test_choice[test_target]
955
+ else:
956
+ test_target = str(test_target)
957
+
958
+ if test_choice is not None:
959
+ check_choices = test_choice
960
+ else:
961
+ check_choices = [test_target]
962
+ if self.config.doc_to_choice is not None:
963
+ for choice in check_choices:
964
+ choice_has_whitespace = True if choice[0].isspace() else False
965
+ delimiter_has_whitespace = (
966
+ True
967
+ if self.config.target_delimiter.rstrip()
968
+ != self.config.target_delimiter
969
+ else False
970
+ )
971
+
972
+ if delimiter_has_whitespace and choice_has_whitespace:
973
+ eval_logger.debug(
974
+ f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
975
+ )
976
+ elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
977
+ eval_logger.debug(
978
+ f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
979
+ )
980
+
981
+ def download(
982
+ self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
983
+ ) -> None:
984
+ from packaging.version import parse as vparse
985
+
986
+ if dataset_kwargs and vparse(datasets.__version__) >= vparse("4.0.0"):
987
+ dataset_kwargs.pop("trust_remote_code", None)
988
+ if isinstance(self.config.custom_dataset, Callable):
989
+ eval_logger.warning(
990
+ f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
991
+ + "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme."
992
+ )
993
+ self.dataset = self.config.custom_dataset(
994
+ **(self.config.metadata or {}), **(self.config.dataset_kwargs or {})
995
+ )
996
+ else:
997
+ self.dataset = datasets.load_dataset(
998
+ path=self.DATASET_PATH,
999
+ name=self.DATASET_NAME,
1000
+ **dataset_kwargs if dataset_kwargs is not None else {},
1001
+ )
1002
+
1003
+ def has_training_docs(self) -> bool:
1004
+ if self.config.training_split is not None:
1005
+ return True
1006
+ else:
1007
+ return False
1008
+
1009
+ def has_validation_docs(self) -> bool:
1010
+ if self.config.validation_split is not None:
1011
+ return True
1012
+ else:
1013
+ return False
1014
+
1015
+ def has_test_docs(self) -> bool:
1016
+ if self.config.test_split is not None:
1017
+ return True
1018
+ else:
1019
+ return False
1020
+
1021
+ def training_docs(self) -> datasets.Dataset:
1022
+ if self.has_training_docs():
1023
+ if self.config.process_docs is not None:
1024
+ return self.config.process_docs(
1025
+ self.dataset[self.config.training_split]
1026
+ )
1027
+ return self.dataset[self.config.training_split]
1028
+
1029
+ def validation_docs(self) -> datasets.Dataset:
1030
+ if self.has_validation_docs():
1031
+ if self.config.process_docs is not None:
1032
+ return self.config.process_docs(
1033
+ self.dataset[self.config.validation_split]
1034
+ )
1035
+ return self.dataset[self.config.validation_split]
1036
+
1037
+ def test_docs(self) -> datasets.Dataset:
1038
+ if self.has_test_docs():
1039
+ if self.config.process_docs is not None:
1040
+ return self.config.process_docs(self.dataset[self.config.test_split])
1041
+ return self.dataset[self.config.test_split]
1042
+
1043
+ def fewshot_docs(self):
1044
+ if self.config.fewshot_split is not None:
1045
+ if self.config.process_docs is not None:
1046
+ return self.config.process_docs(self.dataset[self.config.fewshot_split])
1047
+ return self.dataset[self.config.fewshot_split]
1048
+ elif (
1049
+ self.config.fewshot_config is not None
1050
+ and self.config.fewshot_config.get("samples", None) is not None
1051
+ ):
1052
+ if isinstance(self.config.fewshot_config["samples"], list):
1053
+ return self.config.fewshot_config["samples"]
1054
+ elif callable(self.config.fewshot_config["samples"]):
1055
+ return self.config.fewshot_config["samples"]()
1056
+ else:
1057
+ raise Exception(
1058
+ "`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list."
1059
+ )
1060
+ else:
1061
+ if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
1062
+ eval_logger.warning(
1063
+ f"[Task: {self.config.task}] "
1064
+ "num_fewshot > 0 but fewshot_split is None. "
1065
+ "using preconfigured rule."
1066
+ )
1067
+ return super().fewshot_docs()
1068
+
1069
+ @staticmethod
1070
+ def append_target_question(
1071
+ labeled_examples: List[Dict[str, str]],
1072
+ question: str,
1073
+ fewshot_as_multiturn: bool = False,
1074
+ gen_prefix: Optional[str] = None,
1075
+ ) -> None:
1076
+ """Adds a target question to the labeled examples list.
1077
+ If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
1078
+ Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant.
1079
+ """
1080
+ if not fewshot_as_multiturn:
1081
+ # if no messages or last message is system, append as new user entry
1082
+ if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system":
1083
+ labeled_examples.append({"role": "user", "content": question})
1084
+ # if last message is user, append to it to avoid two user messages in a row
1085
+ else:
1086
+ labeled_examples[-1]["content"] += question
1087
+ else:
1088
+ # if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
1089
+ labeled_examples.append({"role": "user", "content": question})
1090
+ if gen_prefix:
1091
+ labeled_examples.append({"role": "assistant", "content": gen_prefix})
1092
+
1093
+ @utils.positional_deprecated
1094
+ def fewshot_context(
1095
+ self,
1096
+ doc: dict,
1097
+ num_fewshot: int,
1098
+ system_instruction: Optional[str] = None,
1099
+ apply_chat_template: bool = False,
1100
+ fewshot_as_multiturn: bool = False,
1101
+ chat_template: Optional[Callable] = None,
1102
+ gen_prefix: Optional[str] = None,
1103
+ ) -> Union[str, List[str]]:
1104
+ """Returns a fewshot context string that is made up of a prepended description
1105
+ (if provided), the `num_fewshot` number of examples, and an appended prompt example.
1106
+
1107
+ :param doc: str
1108
+ The document as returned from training_docs, validation_docs, or test_docs.
1109
+ :param num_fewshot: int
1110
+ The number of fewshot examples to provide in the returned context string.
1111
+ :param system_instruction: str
1112
+ System instruction to be applied to the prompt.
1113
+ :param apply_chat_template: bool
1114
+ Whether to apply the chat template to the fewshot context.
1115
+ :param fewshot_as_multiturn: bool
1116
+ Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
1117
+ :param chat_template:
1118
+ callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
1119
+ :param gen_prefix:
1120
+ String to append after the <|assistant|> token.
1121
+ :returns: str
1122
+ The fewshot context.
1123
+ """
1124
+ if apply_chat_template:
1125
+ labeled_examples = []
1126
+ else:
1127
+ labeled_examples = ""
1128
+
1129
+ # get task description
1130
+ if description := self.config.description:
1131
+ description = utils.apply_template(self.config.description, doc)
1132
+
1133
+ # create system prompt based on the provided system instruction and description
1134
+ if system_instruction is not None and description:
1135
+ system_prompt = (
1136
+ f"{system_instruction}{self.sampler.fewshot_delimiter}{description}"
1137
+ )
1138
+ elif system_instruction is not None:
1139
+ system_prompt = system_instruction
1140
+ elif description:
1141
+ system_prompt = description
1142
+ else:
1143
+ system_prompt = ""
1144
+
1145
+ # add system prompt if specified
1146
+ if system_prompt:
1147
+ if apply_chat_template:
1148
+ labeled_examples.append({"role": "system", "content": system_prompt})
1149
+ else:
1150
+ labeled_examples = system_prompt
1151
+ # if few-shot - append examples after the system prompt
1152
+ if num_fewshot > 0:
1153
+ if apply_chat_template:
1154
+ labeled_examples.extend(
1155
+ self.sampler.get_chat_context(
1156
+ doc,
1157
+ num_fewshot,
1158
+ fewshot_as_multiturn,
1159
+ gen_prefix=gen_prefix,
1160
+ )
1161
+ )
1162
+ else:
1163
+ labeled_examples += self.sampler.get_context(
1164
+ doc, num_fewshot, gen_prefix=gen_prefix
1165
+ )
1166
+
1167
+ example = self.doc_to_text(doc)
1168
+ if apply_chat_template:
1169
+ if self.multiple_input:
1170
+ # TODO: append prefill?
1171
+ if not labeled_examples:
1172
+ return ""
1173
+ return chat_template(labeled_examples)
1174
+ if isinstance(example, str):
1175
+ self.append_target_question(
1176
+ labeled_examples,
1177
+ example,
1178
+ fewshot_as_multiturn,
1179
+ gen_prefix=gen_prefix,
1180
+ )
1181
+ # for loglikelihood create a list of questions with appended choices
1182
+ elif isinstance(example, list):
1183
+ labeled_examples_list = []
1184
+ # copy chat history for each example and append the answer
1185
+ for ex in example:
1186
+ chat = deepcopy(labeled_examples)
1187
+ self.append_target_question(
1188
+ chat,
1189
+ ex,
1190
+ fewshot_as_multiturn,
1191
+ gen_prefix=gen_prefix,
1192
+ )
1193
+ # TODO: append prefill?
1194
+ labeled_examples_list.append(
1195
+ chat_template(
1196
+ chat,
1197
+ add_generation_prompt=False if gen_prefix else True,
1198
+ )
1199
+ )
1200
+ return labeled_examples_list
1201
+ # if example is an integer, append the choice or convert to string
1202
+ elif isinstance(example, int):
1203
+ if self.config.doc_to_choice is not None:
1204
+ choices = self.doc_to_choice(doc)
1205
+ self.append_target_question(
1206
+ labeled_examples,
1207
+ choices[example],
1208
+ fewshot_as_multiturn,
1209
+ gen_prefix=gen_prefix,
1210
+ )
1211
+ else:
1212
+ self.append_target_question(
1213
+ labeled_examples,
1214
+ str(example),
1215
+ fewshot_as_multiturn,
1216
+ gen_prefix=gen_prefix,
1217
+ )
1218
+ # return lm.apply_chat_template(labeled_examples)
1219
+ return chat_template(
1220
+ labeled_examples,
1221
+ add_generation_prompt=False if gen_prefix else True,
1222
+ )
1223
+ else:
1224
+ prefix = (
1225
+ self.config.target_delimiter + gen_prefix
1226
+ if gen_prefix is not None
1227
+ else ""
1228
+ )
1229
+ if self.multiple_input:
1230
+ return labeled_examples
1231
+ if isinstance(example, str):
1232
+ return labeled_examples + example + prefix
1233
+ elif isinstance(example, list):
1234
+ return [labeled_examples + ex + prefix for ex in example]
1235
+ elif isinstance(example, int):
1236
+ if self.config.doc_to_choice is not None:
1237
+ choices = self.doc_to_choice(doc)
1238
+ return labeled_examples + choices[example] + prefix
1239
+ else:
1240
+ return labeled_examples + str(example) + prefix
1241
+
1242
+ def apply_filters(self) -> Optional[List[Instance]]:
1243
+ """Iterates over FilterEnsembles and applies them to instances"""
1244
+ if hasattr(self, "_filters"):
1245
+ for f in self._filters:
1246
+ f.apply(self._instances)
1247
+ else:
1248
+ eval_logger.warning("No filter defined, passing through instances")
1249
+ return self._instances
1250
+
1251
+ def should_decontaminate(self):
1252
+ return self.config.should_decontaminate
1253
+
1254
+ def doc_to_decontamination_query(self, doc: dict):
1255
+ if self.config.should_decontaminate:
1256
+ if self.config.doc_to_decontamination_query is None:
1257
+ return self.doc_to_text(doc)
1258
+ else:
1259
+ doc_to_decontamination_query = self.config.doc_to_decontamination_query
1260
+ if doc_to_decontamination_query in self.features:
1261
+ return doc[doc_to_decontamination_query]
1262
+ elif callable(doc_to_decontamination_query):
1263
+ return doc_to_decontamination_query(doc)
1264
+ else:
1265
+ return ast.literal_eval(
1266
+ utils.apply_template(
1267
+ self.config.doc_to_decontamination_query, doc
1268
+ )
1269
+ )
1270
+
1271
+ def _process_doc(self, doc: dict) -> dict:
1272
+ """
1273
+ Override this to process (detokenize, strip, replace, etc.) individual
1274
+ documents. This can be used in a map over documents of a data split.
1275
+ E.g. `map(self._process_doc, self.dataset["validation"])`
1276
+
1277
+ :return: dict
1278
+ The processed version of the specified `doc`.
1279
+ """
1280
+ return doc
1281
+
1282
+ def doc_to_text(self, doc, doc_to_text=None):
1283
+ if self.prompt is not None:
1284
+ doc_to_text = self.prompt
1285
+ elif doc_to_text is not None:
1286
+ doc_to_text = doc_to_text
1287
+ else:
1288
+ doc_to_text = self.config.doc_to_text
1289
+
1290
+ if isinstance(doc_to_text, int):
1291
+ return doc_to_text
1292
+ elif isinstance(doc_to_text, str):
1293
+ if doc_to_text in self.features:
1294
+ # if self.config.doc_to_choice is not None:
1295
+ # return self.doc_to_choice(doc)[doc[doc_to_text]]
1296
+ # else:
1297
+ return doc[doc_to_text]
1298
+ else:
1299
+ text_string = utils.apply_template(doc_to_text, doc)
1300
+ if text_string.isdigit() and self._config.doc_to_choice is not None:
1301
+ return ast.literal_eval(text_string)
1302
+ else:
1303
+ return text_string
1304
+ elif callable(doc_to_text):
1305
+ return doc_to_text(doc)
1306
+ # Used when applying a Promptsource template
1307
+ elif hasattr(doc_to_text, "apply"):
1308
+ applied_prompt = doc_to_text.apply(doc)
1309
+ if len(applied_prompt) == 2:
1310
+ return applied_prompt[0]
1311
+ else:
1312
+ eval_logger.warning("Applied prompt returns empty string")
1313
+ return self.config.fewshot_delimiter
1314
+ else:
1315
+ print(type(doc_to_text))
1316
+ raise TypeError
1317
+
1318
+ def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
1319
+ if self.prompt is not None:
1320
+ doc_to_target = self.prompt
1321
+ elif doc_to_target is not None:
1322
+ doc_to_target = doc_to_target
1323
+ else:
1324
+ doc_to_target = self.config.doc_to_target
1325
+
1326
+ if isinstance(doc_to_target, int):
1327
+ return doc_to_target
1328
+ elif isinstance(doc_to_target, str):
1329
+ if doc_to_target in self.features:
1330
+ # if self.config.doc_to_choice is not None:
1331
+ # return self.doc_to_choice(doc)[doc[doc_to_target]]
1332
+ # else:
1333
+ return doc[doc_to_target]
1334
+ else:
1335
+ target_string = utils.apply_template(doc_to_target, doc)
1336
+ if target_string.isdigit() and self._config.doc_to_choice is not None:
1337
+ return ast.literal_eval(target_string)
1338
+ elif (
1339
+ len(target_string) >= 2
1340
+ and (target_string[0] == "[")
1341
+ and (target_string[-1] == "]")
1342
+ ):
1343
+ try:
1344
+ return ast.literal_eval(target_string)
1345
+ except (SyntaxError, ValueError):
1346
+ return target_string
1347
+ else:
1348
+ return target_string
1349
+ elif isinstance(doc_to_target, list):
1350
+ return doc_to_target
1351
+ elif callable(doc_to_target):
1352
+ return doc_to_target(doc)
1353
+ # Used when applying a Promptsource template
1354
+ elif hasattr(doc_to_target, "apply"):
1355
+ applied_prompt = doc_to_target.apply(doc)
1356
+ if len(applied_prompt) == 2:
1357
+ return applied_prompt[1]
1358
+ else:
1359
+ eval_logger.warning("Applied prompt returns empty string")
1360
+ return self.config.fewshot_delimiter
1361
+ else:
1362
+ raise TypeError
1363
+
1364
+ def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
1365
+ if self.prompt is not None:
1366
+ doc_to_choice = self.prompt
1367
+ elif doc_to_choice is not None:
1368
+ doc_to_choice = doc_to_choice
1369
+ elif self.config.doc_to_choice is None:
1370
+ eval_logger.error("doc_to_choice was called but not set in config")
1371
+ else:
1372
+ doc_to_choice = self.config.doc_to_choice
1373
+
1374
+ if isinstance(doc_to_choice, str):
1375
+ if doc_to_choice in self.features:
1376
+ return doc[doc_to_choice]
1377
+ else:
1378
+ return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
1379
+ elif isinstance(doc_to_choice, list):
1380
+ return doc_to_choice
1381
+ elif isinstance(doc_to_choice, dict):
1382
+ return list(doc_to_choice.values())
1383
+ elif callable(doc_to_choice):
1384
+ return doc_to_choice(doc)
1385
+ elif hasattr(doc_to_choice, "get_answer_choices_list"):
1386
+ return doc_to_choice.get_answer_choices_list(doc)
1387
+ else:
1388
+ raise TypeError
1389
+
1390
+ def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
1391
+ if doc_to_image is not None:
1392
+ doc_to_image = doc_to_image
1393
+ elif self.config.doc_to_image is not None:
1394
+ doc_to_image = self.config.doc_to_image
1395
+ else:
1396
+ return None
1397
+
1398
+ if isinstance(doc_to_image, list):
1399
+ image_feature = [
1400
+ self.doc_to_image(doc, feature) for feature in doc_to_image
1401
+ ]
1402
+ return [feature for feature in image_feature if feature is not None]
1403
+ elif isinstance(doc_to_image, str):
1404
+ if doc_to_image in self.features:
1405
+ return doc[doc_to_image]
1406
+ else:
1407
+ return ast.literal_eval(utils.apply_template(doc_to_image, doc))
1408
+ elif callable(doc_to_image):
1409
+ return doc_to_image(doc)
1410
+ else:
1411
+ return None
1412
+
1413
+ def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]:
1414
+ if doc_to_audio is not None:
1415
+ doc_to_audio = doc_to_audio
1416
+ elif self.config.doc_to_audio is not None:
1417
+ doc_to_audio = self.config.doc_to_audio
1418
+ else:
1419
+ return None
1420
+
1421
+ if isinstance(doc_to_audio, list):
1422
+ audio_feature = [
1423
+ self.doc_to_audio(doc, feature) for feature in doc_to_audio
1424
+ ]
1425
+ return [feature for feature in audio_feature if feature is not None]
1426
+ elif isinstance(doc_to_audio, str):
1427
+ if doc_to_audio in self.features:
1428
+ return doc[doc_to_audio]
1429
+ else:
1430
+ return ast.literal_eval(utils.apply_template(doc_to_audio, doc))
1431
+ elif callable(doc_to_audio):
1432
+ return doc_to_audio(doc)
1433
+ else:
1434
+ return None
1435
+
1436
+ def doc_to_prefix(self, doc):
1437
+ if (gen_prefix := self.config.gen_prefix) is not None:
1438
+ if gen_prefix in self.features:
1439
+ return doc[gen_prefix]
1440
+ else:
1441
+ return utils.apply_template(gen_prefix, doc)
1442
+ return None
1443
+
1444
+ def construct_requests(
1445
+ self, doc: dict, ctx: str, **kwargs
1446
+ ) -> Union[List[Instance], Instance]:
1447
+ apply_chat_template = kwargs.pop("apply_chat_template", False)
1448
+ chat_template: Callable | None = kwargs.pop("chat_template", None)
1449
+
1450
+ aux_arguments = None
1451
+
1452
+ if self.OUTPUT_TYPE == "loglikelihood":
1453
+ arguments = (ctx, self.doc_to_target(doc))
1454
+ elif self.OUTPUT_TYPE == "loglikelihood_rolling":
1455
+ arguments = (self.doc_to_target(doc),)
1456
+ elif self.OUTPUT_TYPE == "multiple_choice":
1457
+ choices = self.doc_to_choice(doc)
1458
+ target_delimiter = self.config.target_delimiter
1459
+ if apply_chat_template:
1460
+ target_delimiter = ""
1461
+ if self.multiple_input:
1462
+ # If there are multiple inputs, choices are placed in the ctx
1463
+ # apply chat_template to choices if apply_chat_template
1464
+ cont = self.doc_to_target(doc)
1465
+
1466
+ arguments = [
1467
+ (
1468
+ ctx
1469
+ + (
1470
+ chat_template([{"role": "user", "content": choice}])
1471
+ if apply_chat_template
1472
+ else choice
1473
+ ),
1474
+ f"{target_delimiter}{cont}",
1475
+ )
1476
+ for choice in choices
1477
+ ]
1478
+ else:
1479
+ # Otherwise they are placed in the continuation
1480
+ arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
1481
+
1482
+ # TODO: we should raise a warning telling users this will at most ~2x runtime.
1483
+ if "acc_mutual_info" in self._metric_fn_list.keys():
1484
+ # if we are calculating multiple choice accuracy
1485
+ # using mutual information instead of raw loglikelihood as metric, need unconditional lls.
1486
+
1487
+ # here mutual info refers to calculating
1488
+ # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
1489
+ # in other words normalizing by subtracting the unconditional logprob of each choice.
1490
+ # TODO: should these be strided? will have to modify the processing in process_results if so
1491
+ aux_arguments = [
1492
+ ("", f"{target_delimiter}{choice}") for choice in choices
1493
+ ]
1494
+
1495
+ arguments.extend(aux_arguments)
1496
+
1497
+ elif self.OUTPUT_TYPE == "generate_until":
1498
+ arguments = (ctx, deepcopy(self.config.generation_kwargs))
1499
+
1500
+ multimodal_arg = {}
1501
+ if (
1502
+ self.config.doc_to_image
1503
+ ): # TODO: ensure that non-multimodal tasks aren't getting visual args
1504
+ multimodal_arg = {
1505
+ **multimodal_arg,
1506
+ **{"visual": self.doc_to_image(doc)},
1507
+ }
1508
+
1509
+ if (
1510
+ self.config.doc_to_audio
1511
+ ): # TODO: ensure that non-multimodal tasks aren't getting audio args
1512
+ multimodal_arg = {
1513
+ **multimodal_arg,
1514
+ **{"audio": self.doc_to_audio(doc)},
1515
+ }
1516
+
1517
+ if bool(multimodal_arg):
1518
+ if isinstance(arguments, list):
1519
+ arguments = [arg + (multimodal_arg,) for arg in arguments]
1520
+ else:
1521
+ arguments = arguments + (multimodal_arg,)
1522
+
1523
+ if self.OUTPUT_TYPE == "multiple_choice":
1524
+ request_list = [
1525
+ Instance(
1526
+ request_type="loglikelihood",
1527
+ doc=doc,
1528
+ arguments=arg,
1529
+ idx=i,
1530
+ **kwargs,
1531
+ )
1532
+ for i, arg in enumerate(arguments)
1533
+ ]
1534
+
1535
+ return request_list
1536
+
1537
+ return Instance(
1538
+ request_type=self.OUTPUT_TYPE,
1539
+ doc=doc,
1540
+ arguments=arguments,
1541
+ idx=0,
1542
+ **kwargs,
1543
+ )
1544
+
1545
+ def process_results(self, doc, results):
1546
+ if callable(self.config.process_results):
1547
+ return self.config.process_results(doc, results)
1548
+
1549
+ result_dict = {}
1550
+ use_metric = list(self._metric_fn_list.keys())
1551
+ if self.OUTPUT_TYPE == "loglikelihood":
1552
+ results = results[0]
1553
+ ll, is_greedy = results
1554
+ return {
1555
+ **({"perplexity": ll} if "perplexity" in use_metric else {}),
1556
+ **({"acc": int(is_greedy)} if "acc" in use_metric else {}),
1557
+ }
1558
+ elif self.OUTPUT_TYPE == "loglikelihood_rolling":
1559
+ (loglikelihood,) = results
1560
+ _words = self.count_words(self.doc_to_target(doc))
1561
+ _bytes = self.count_bytes(self.doc_to_target(doc))
1562
+ return {
1563
+ **(
1564
+ {"word_perplexity": (loglikelihood, _words)}
1565
+ if "word_perplexity" in use_metric
1566
+ else {}
1567
+ ),
1568
+ **(
1569
+ {"byte_perplexity": (loglikelihood, _bytes)}
1570
+ if "byte_perplexity" in use_metric
1571
+ else {}
1572
+ ),
1573
+ **(
1574
+ {"bits_per_byte": (loglikelihood, _bytes)}
1575
+ if "bits_per_byte" in use_metric
1576
+ else {}
1577
+ ),
1578
+ }
1579
+ elif self.OUTPUT_TYPE == "multiple_choice":
1580
+ lls, is_greedy = zip(*results)
1581
+
1582
+ # retrieve choices in List[str] form, to compute choice lengths, etc.
1583
+ choices = self.doc_to_choice(doc)
1584
+ completion_len = np.array([float(len(i)) for i in choices])
1585
+
1586
+ if (
1587
+ 2 * len(choices) == len(lls)
1588
+ and "acc_mutual_info" in self._metric_fn_list.keys()
1589
+ ):
1590
+ # then we are doing mutual info.
1591
+ # this stores the "dryrun" / unconditional answer loglikelihoods
1592
+ # as we extend the args list with unconditional ("", continuation) pairs
1593
+ lls_unconditional = lls[len(choices) :]
1594
+ if len(lls_unconditional) != len(choices):
1595
+ raise ValueError
1596
+ # and this stores our "regular" conditional loglikelihoods
1597
+ lls = lls[: len(choices)]
1598
+
1599
+ pred = np.argmax(lls)
1600
+ pred_norm = np.argmax(lls / completion_len)
1601
+
1602
+ if self.multiple_input:
1603
+ gold = self.doc_to_text(doc)
1604
+ else:
1605
+ gold = self.doc_to_target(doc)
1606
+
1607
+ gold_index_error = False
1608
+ if isinstance(gold, list):
1609
+ gold = [i if i < len(choices) else -100 for i in gold]
1610
+ if -100 in gold:
1611
+ gold_index_error = True
1612
+ else:
1613
+ if isinstance(gold, int):
1614
+ gold = gold if gold < len(choices) else -100
1615
+ elif isinstance(gold, str):
1616
+ gold = choices.index(gold) if gold in choices else -100
1617
+
1618
+ if gold == -100:
1619
+ gold_index_error = True
1620
+
1621
+ if gold_index_error:
1622
+ eval_logger.warning(
1623
+ f"Label index was not in within range of available choices,"
1624
+ f"Sample:\n\n{doc}\n\n"
1625
+ )
1626
+
1627
+ if self.multiple_target:
1628
+ acc = 1.0 if pred in gold else 0.0
1629
+ acc_norm = 1.0 if pred_norm in gold else 0.0
1630
+ exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
1631
+ else:
1632
+ acc = 1.0 if pred == gold else 0.0
1633
+ acc_norm = 1.0 if pred_norm == gold else 0.0
1634
+ # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
1635
+ exact_match = int(is_greedy[gold]) if gold != -100 else 0
1636
+
1637
+ prob_norm = utils.softmax(lls)
1638
+
1639
+ # TODO use keyword arguments to the metric?
1640
+ # gold, pred, norm stuff, the original lls,
1641
+ result_dict = {
1642
+ **({"acc": acc} if "acc" in use_metric else {}),
1643
+ **({"f1": (gold, pred)} if "f1" in use_metric else {}),
1644
+ **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
1645
+ **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
1646
+ **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
1647
+ **(
1648
+ {"brier_score": (gold, prob_norm)}
1649
+ if "brier_score" in use_metric
1650
+ else {}
1651
+ ),
1652
+ }
1653
+
1654
+ if "acc_mutual_info" in use_metric:
1655
+ lls_mutual_info = [
1656
+ ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
1657
+ ]
1658
+ acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
1659
+ result_dict["acc_mutual_info"] = acc_mutual_info
1660
+
1661
+ elif self.OUTPUT_TYPE == "generate_until":
1662
+ gold = self.doc_to_target(doc)
1663
+ result = results[0]
1664
+ if self.config.doc_to_choice is not None:
1665
+ # If you set doc_to_choice,
1666
+ # it assumes that doc_to_target returns a number.
1667
+ choices = self.doc_to_choice(doc)
1668
+ gold = choices[gold]
1669
+ # we expect multiple_targets to be a list.
1670
+ elif self.multiple_target:
1671
+ gold = list(gold)
1672
+ # TODO: handle this better
1673
+ elif type(gold) is not type(result) and not (
1674
+ "bypass" in self._metric_fn_list.keys() or isinstance(result, list)
1675
+ ):
1676
+ # cast gold to the same type as result
1677
+ gold = type(result)(gold)
1678
+
1679
+ for metric in self._metric_fn_list.keys():
1680
+ if self.multiple_target:
1681
+ # in the case where we have multiple targets,
1682
+ # return true if any are true
1683
+ # TODO: this may break for multipLe_target, non zero-or-1 metrics
1684
+ scores = []
1685
+ if not isinstance(gold, list):
1686
+ # sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
1687
+ # print(gold)
1688
+ gold = [gold]
1689
+ if metric == "exact_match":
1690
+ result = [result for _ in range(len(gold))]
1691
+ scores = self._metric_fn_list[metric](
1692
+ references=gold,
1693
+ predictions=result,
1694
+ **self._metric_fn_kwargs[metric],
1695
+ )[metric]
1696
+ result_score = 1.0 if scores > 0.0 else 0.0
1697
+ else:
1698
+ for gold_option in gold:
1699
+ try:
1700
+ result_score = self._metric_fn_list[metric](
1701
+ references=[gold_option],
1702
+ predictions=[result],
1703
+ **self._metric_fn_kwargs[metric],
1704
+ )
1705
+ except (
1706
+ TypeError
1707
+ ): # TODO: this is hacky and I don't want to do it
1708
+ result_score = self._metric_fn_list[metric](
1709
+ [gold_option, result]
1710
+ )
1711
+ if isinstance(result_score, dict):
1712
+ # TODO: this handles the case where HF evaluate returns a dict.
1713
+ result_score = result_score[metric]
1714
+ scores.append(result_score)
1715
+ if any(scores):
1716
+ result_score = 1.0
1717
+ else:
1718
+ result_score = 0.0
1719
+ else:
1720
+ try:
1721
+ result_score = self._metric_fn_list[metric](
1722
+ references=[gold],
1723
+ predictions=[result],
1724
+ **self._metric_fn_kwargs[metric],
1725
+ )
1726
+ except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
1727
+ result_score = self._metric_fn_list[metric]([gold, result])
1728
+ if isinstance(result_score, dict):
1729
+ # TODO: this handles the case where HF evaluate returns a dict.
1730
+ # This allows for multiple metrics to be returned from the same function
1731
+ for k, v in result_score.items():
1732
+ result_dict[k] = v
1733
+ else:
1734
+ result_dict[metric] = result_score
1735
+ else:
1736
+ raise ValueError(
1737
+ f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
1738
+ "'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
1739
+ )
1740
+
1741
+ return result_dict
1742
+
1743
+ def aggregation(self) -> dict:
1744
+ return self._aggregation_list
1745
+
1746
+ def higher_is_better(self) -> dict:
1747
+ return self._higher_is_better
1748
+
1749
+ def get_config(self, key: str) -> Any:
1750
+ return getattr(self._config, key, None)
1751
+
1752
+ @property
1753
+ def task_name(self) -> Any:
1754
+ return getattr(self.config, "task", None)
1755
+
1756
+ def __repr__(self):
1757
+ return (
1758
+ f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
1759
+ f"output_type={self.OUTPUT_TYPE},"
1760
+ f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
1761
+ f"num_samples={len(self.eval_docs)})"
1762
+ )
1763
+
1764
+
1765
+ class MultipleChoiceTask(Task):
1766
+ OUTPUT_TYPE = "loglikelihood"
1767
+
1768
+ def doc_to_target(self, doc: dict) -> str:
1769
+ return " " + doc["choices"][doc["gold"]]
1770
+
1771
+ def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
1772
+ # TODO: add mutual info here?
1773
+ return [
1774
+ Instance(
1775
+ request_type="loglikelihood",
1776
+ doc=doc,
1777
+ arguments=(ctx, " {}".format(choice)),
1778
+ idx=i,
1779
+ **kwargs,
1780
+ )
1781
+ for i, choice in enumerate(doc["choices"])
1782
+ ]
1783
+
1784
+ def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
1785
+ results = [
1786
+ res[0] for res in results
1787
+ ] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
1788
+ gold = doc["gold"]
1789
+
1790
+ acc = 1.0 if np.argmax(results) == gold else 0.0
1791
+ completion_len = np.array([float(len(i)) for i in doc["choices"]])
1792
+ acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
1793
+
1794
+ return {
1795
+ "acc": acc,
1796
+ "acc_norm": acc_norm,
1797
+ }
1798
+
1799
+ def higher_is_better(self) -> dict:
1800
+ return {
1801
+ "acc": True,
1802
+ "acc_norm": True,
1803
+ }
1804
+
1805
+ def aggregation(self) -> dict:
1806
+ return {
1807
+ "acc": mean,
1808
+ "acc_norm": mean,
1809
+ }
1810
+
1811
+
1812
+ class PerplexityTask(Task):
1813
+ OUTPUT_TYPE = "loglikelihood_rolling"
1814
+
1815
+ def has_training_docs(self) -> bool:
1816
+ return False
1817
+
1818
+ def fewshot_examples(self, k: int, rnd) -> List:
1819
+ if k != 0:
1820
+ raise ValueError(
1821
+ "The number of fewshot examples must be 0 for perplexity tasks."
1822
+ )
1823
+ return []
1824
+
1825
+ def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
1826
+ if num_fewshot != 0:
1827
+ raise ValueError(
1828
+ "The number of fewshot examples must be 0 for perplexity tasks."
1829
+ )
1830
+
1831
+ return ""
1832
+
1833
+ def higher_is_better(self) -> dict:
1834
+ return {
1835
+ "word_perplexity": False,
1836
+ "byte_perplexity": False,
1837
+ "bits_per_byte": False,
1838
+ }
1839
+
1840
+ def doc_to_decontamination_query(self, doc):
1841
+ return doc
1842
+
1843
+ def doc_to_text(self, doc) -> str:
1844
+ return ""
1845
+
1846
+ def doc_to_target(self, doc):
1847
+ return doc
1848
+
1849
+ def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
1850
+ if bool(ctx):
1851
+ raise ValueError
1852
+
1853
+ return Instance(
1854
+ request_type=self.OUTPUT_TYPE,
1855
+ doc=doc,
1856
+ arguments=(self.doc_to_target(doc),),
1857
+ idx=0,
1858
+ **kwargs,
1859
+ )
1860
+
1861
+ def process_results(self, doc: dict, results: Tuple[float]) -> dict:
1862
+ (loglikelihood,) = results
1863
+ words = self.count_words(self.doc_to_target(doc))
1864
+ bytes_ = self.count_bytes(self.doc_to_target(doc))
1865
+ return {
1866
+ "word_perplexity": (loglikelihood, words),
1867
+ "byte_perplexity": (loglikelihood, bytes_),
1868
+ "bits_per_byte": (loglikelihood, bytes_),
1869
+ }
1870
+
1871
+ def aggregation(self) -> dict:
1872
+ return {
1873
+ "word_perplexity": weighted_perplexity,
1874
+ "byte_perplexity": weighted_perplexity,
1875
+ "bits_per_byte": bits_per_byte,
1876
+ }
1877
+
1878
+ @classmethod
1879
+ def count_bytes(cls, doc) -> int:
1880
+ return len(doc.encode("utf-8"))
1881
+
1882
+ @classmethod
1883
+ def count_words(cls, doc) -> int:
1884
+ """Downstream tasks with custom word boundaries should override this!"""
1885
+ return len(re.split(r"\s+", doc))
lm-evaluation-harness/lm_eval/caching/__init__.py ADDED
File without changes
lm-evaluation-harness/lm_eval/caching/cache.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import logging
3
+ import os
4
+
5
+ import dill
6
+
7
+
8
+ eval_logger = logging.getLogger(__name__)
9
+
10
+
11
+ MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
12
+
13
+ OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH")
14
+
15
+
16
+ PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache"
17
+
18
+ # This should be sufficient for uniqueness
19
+ HASH_INPUT = "EleutherAI-lm-evaluation-harness"
20
+
21
+ HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()
22
+
23
+ FILE_SUFFIX = f".{HASH_PREFIX}.pickle"
24
+
25
+
26
+ def load_from_cache(file_name: str, cache: bool = False):
27
+ if not cache:
28
+ return
29
+ try:
30
+ path = f"{PATH}/{file_name}{FILE_SUFFIX}"
31
+
32
+ with open(path, "rb") as file:
33
+ cached_task_dict = dill.loads(file.read())
34
+ return cached_task_dict
35
+
36
+ except Exception:
37
+ eval_logger.debug(f"{file_name} is not cached, generating...")
38
+ pass
39
+
40
+
41
+ def save_to_cache(file_name, obj):
42
+ if not os.path.exists(PATH):
43
+ os.mkdir(PATH)
44
+
45
+ file_path = f"{PATH}/{file_name}{FILE_SUFFIX}"
46
+
47
+ eval_logger.debug(f"Saving {file_path} to cache...")
48
+ with open(file_path, "wb") as file:
49
+ file.write(dill.dumps(obj))
50
+
51
+
52
+ # NOTE the "key" param is to allow for flexibility
53
+ def delete_cache(key: str = ""):
54
+ files = os.listdir(PATH)
55
+
56
+ for file in files:
57
+ if file.startswith(key) and file.endswith(FILE_SUFFIX):
58
+ file_path = f"{PATH}/{file}"
59
+ os.unlink(file_path)
lm-evaluation-harness/lm_eval/decontamination/__init__.py ADDED
File without changes
lm-evaluation-harness/lm_eval/decontamination/archiver.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import io
3
+ import json
4
+ import mmap
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import jsonlines
10
+ import tqdm
11
+ import zstandard
12
+
13
+
14
+ def json_serial(obj: Any) -> str:
15
+ """JSON serializer for objects not serializable by default json code"""
16
+
17
+ if isinstance(obj, (datetime.datetime,)):
18
+ return obj.isoformat()
19
+ raise TypeError("Type %s not serializable" % type(obj))
20
+
21
+
22
+ # Modified version of lm_dataformat Archive for single file.
23
+ class Archive:
24
+ def __init__(self, file_path: str, compression_level: int = 3) -> None:
25
+ self.file_path = file_path
26
+ dir_name = os.path.dirname(file_path)
27
+ if dir_name:
28
+ os.makedirs(dir_name, exist_ok=True)
29
+ self.fh = open(self.file_path, "wb")
30
+ self.cctx = zstandard.ZstdCompressor(level=compression_level)
31
+ self.compressor = self.cctx.stream_writer(self.fh)
32
+
33
+ def add_data(self, data, meta=None) -> None:
34
+ if meta is None:
35
+ meta = {}
36
+ self.compressor.write(
37
+ json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
38
+ "UTF-8"
39
+ )
40
+ + b"\n"
41
+ )
42
+
43
+ def commit(self) -> None:
44
+ self.compressor.flush(zstandard.FLUSH_FRAME)
45
+ self.fh.flush()
46
+ self.fh.close()
47
+
48
+
49
+ # Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
50
+ class Reader:
51
+ def __init__(self) -> None:
52
+ pass
53
+
54
+ def read(
55
+ self,
56
+ file,
57
+ get_meta: bool = False,
58
+ autojoin_paragraphs: bool = True,
59
+ para_joiner: str = "\n\n",
60
+ ):
61
+ with open(file, "rb") as fh:
62
+ self.fh = fh
63
+ cctx = zstandard.ZstdDecompressor()
64
+ reader = io.BufferedReader(cctx.stream_reader(fh))
65
+ rdr = jsonlines.Reader(reader)
66
+ for ob in rdr:
67
+ # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
68
+ if isinstance(ob, str):
69
+ assert not get_meta
70
+ yield ob
71
+ continue
72
+
73
+ text = ob["text"]
74
+
75
+ if autojoin_paragraphs and isinstance(text, list):
76
+ text = para_joiner.join(text)
77
+
78
+ if get_meta:
79
+ yield text, (ob["meta"] if "meta" in ob else {})
80
+ else:
81
+ yield text
82
+
83
+
84
+ class TextArchive:
85
+ def __init__(self, file_path, mode: str = "rb+") -> None:
86
+ self.file_path = file_path
87
+ dir_name = os.path.dirname(file_path)
88
+ if dir_name:
89
+ os.makedirs(dir_name, exist_ok=True)
90
+
91
+ if not os.path.exists(file_path):
92
+ Path(file_path).touch()
93
+
94
+ self.fh = open(self.file_path, mode)
95
+
96
+ def add_data(self, data) -> None:
97
+ self.fh.write(data.encode("UTF-8") + b"\n")
98
+
99
+ def commit(self) -> None:
100
+ self.fh.flush()
101
+ self.fh.close()
102
+
103
+
104
+ class TextReader:
105
+ def __init__(self, file_path) -> None:
106
+ self.file_path = file_path
107
+
108
+ # Optimized mmap read with infrequent tqdm updates to maintain speed
109
+ # Tested up to 250MB/s.
110
+ def read_tqdm(self, update_frequency: int = 10000):
111
+ current_file_position = 0
112
+ line_counter = 0
113
+ with (
114
+ open(self.file_path, "r", encoding="utf-8") as fh,
115
+ tqdm.tqdm(
116
+ total=os.path.getsize(self.file_path),
117
+ dynamic_ncols=True,
118
+ unit="byte",
119
+ unit_scale=1,
120
+ ) as progress,
121
+ ):
122
+ with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
123
+ for line in iter(mmap_obj.readline, b""):
124
+ line = line.decode("utf-8")
125
+ line_counter += 1
126
+ if line_counter == update_frequency:
127
+ new_file_pos = mmap_obj.tell()
128
+ bytes_read = new_file_pos - current_file_position
129
+ current_file_position = new_file_pos
130
+ progress.update(bytes_read)
131
+ line_counter = 0
132
+ yield line[:-1]
133
+
134
+ def read_and_tell(self):
135
+ current_file_position = 0
136
+ with open(self.file_path, "r", encoding="utf8") as fh:
137
+ with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
138
+ for line in iter(mmap_obj.readline, b""):
139
+ line = line.decode("utf-8")
140
+ new_file_pos = mmap_obj.tell()
141
+ raw_bytes_read = new_file_pos - current_file_position
142
+ current_file_position = new_file_pos
143
+ yield line[:-1], raw_bytes_read
144
+
145
+ def read(self):
146
+ with open(self.file_path, "r", encoding="utf8") as fh:
147
+ with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
148
+ for line in iter(mmap_obj.readline, b""):
149
+ line = line.decode("utf-8")
150
+ yield line[:-1]
151
+
152
+ def read_slow(self):
153
+ with open(self.file_path, "r", encoding="utf8") as fh:
154
+ while True:
155
+ line = fh.readline()
156
+ if line == -1 or line == "":
157
+ break
158
+ else:
159
+ yield line[:-1]
160
+
161
+
162
+ # Optimized for speed. Decompresses the archive in shell before
163
+ # using the mmap'd TextReader.
164
+ class ZStdTextReader:
165
+ def __init__(self, file) -> None:
166
+ self.file = file
167
+
168
+ def read_tqdm(self):
169
+ decompressed_file = self.file[:-4]
170
+ print("Decompressing file, please wait...")
171
+ os.system(f"zstd -d {self.file}") # linux decompress is faster
172
+ reader = TextReader(decompressed_file)
173
+ yield from reader.read_tqdm()
174
+ os.remove(decompressed_file)
lm-evaluation-harness/lm_eval/decontamination/decontaminate.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import glob
3
+ import json
4
+ import os
5
+ import pickle
6
+ import random
7
+ import time
8
+
9
+ from .archiver import ZStdTextReader
10
+ from .janitor import Janitor, word_ngrams
11
+
12
+
13
+ # Was used for testing the evaluator decoupled from the full logic below
14
+ def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str):
15
+ simulated_overlap = 0.1
16
+ contaminated = int(len(docs) * simulated_overlap)
17
+ return random.sample(range(len(docs)), contaminated)
18
+
19
+
20
+ # Returns a dictionary containing all overlapping documents in each
21
+ # task. In the standard use case, an overlap occurs when any of the 13-grams
22
+ # found in the task document exist in the training set documents.
23
+ #
24
+ # To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
25
+ # scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
26
+ # files. These should exist in the "ngrams_path" provided to this function.
27
+
28
+
29
+ # Algorithm:
30
+ # 1. Build lookups for each dataset {ngram: list(document_ids)}
31
+ # 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
32
+ # 3. Full scan the 13-grams from the training set against the merged lookup,
33
+ # saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
34
+ # 4. Strip the task_set from the dictionary keys and return
35
+ #
36
+ # We cache the task+set lookups as well as the overlaps.
37
+ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict:
38
+ # return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
39
+
40
+ info_dict_path = os.path.join(ngrams_path, "info.json")
41
+ info_dict = json.load(open(info_dict_path, "r", encoding="utf-8"))
42
+ ngrams_n_size = info_dict["ngram_size"]
43
+
44
+ janitor = Janitor()
45
+
46
+ # Build lookup for each dataset first in case we use different task combinations later
47
+ print("Building Lookups...")
48
+ start = time.perf_counter()
49
+
50
+ def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str:
51
+ return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"
52
+
53
+ lookups = {}
54
+ duplicates = {} # (task_name, task_set): set(doc_ids)}
55
+ sets_to_decontaminate = len(docs_by_task_set.keys())
56
+
57
+ for (task_name, task_set), docs in docs_by_task_set.items():
58
+ if not os.path.exists(f"data/{task_name}"):
59
+ os.mkdir(f"data/{task_name}")
60
+
61
+ # Check if we've decontaminated this combination before
62
+ overlaps_dump_path = get_overlaps_dump_path(
63
+ task_name, task_set, ngrams_n_size, limit
64
+ )
65
+ if os.path.exists(overlaps_dump_path):
66
+ duplicates[(task_name, task_set)] = pickle.load(
67
+ open(overlaps_dump_path, "rb")
68
+ )
69
+ sets_to_decontaminate -= 1
70
+ continue
71
+ else:
72
+ duplicates[(task_name, task_set)] = set()
73
+
74
+ # Build/load the task lookup {ngram: set(documents)}.
75
+ task_set_lookup_path = (
76
+ f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
77
+ )
78
+ if os.path.exists(task_set_lookup_path):
79
+ print(f"{task_set_lookup_path} available, loading...")
80
+ lookups[(task_name, task_set)] = pickle.load(
81
+ open(task_set_lookup_path, "rb")
82
+ )
83
+ else:
84
+ print(f"{task_set_lookup_path} not available, building...")
85
+ lookup = collections.defaultdict(set)
86
+
87
+ for doc_id, document in enumerate(docs):
88
+ ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
89
+ for ngram in ngrams:
90
+ lookup[ngram].add(doc_id)
91
+
92
+ pickle.dump(lookup, open(task_set_lookup_path, "wb"))
93
+ lookups[(task_name, task_set)] = lookup
94
+
95
+ elapsed = time.perf_counter() - start
96
+ print(f"Building lookups took {elapsed:0.5f} seconds.")
97
+
98
+ matched_ngrams = []
99
+
100
+ if sets_to_decontaminate > 0:
101
+ print("Merging lookups...")
102
+ start = time.perf_counter()
103
+ merged_lookup = collections.defaultdict(list)
104
+ for (task_name, task_set), lookup in lookups.items():
105
+ for ngram, doc_ids in lookup.items():
106
+ merged_lookup[ngram].append((task_name, task_set, doc_ids))
107
+
108
+ elapsed = time.perf_counter() - start
109
+ print(f"Merging lookups took {elapsed:0.5f} seconds.")
110
+
111
+ print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
112
+ files = glob.glob(os.path.join(ngrams_path, "*.sorted.zst"))
113
+ print(files)
114
+
115
+ for file in files:
116
+ start = time.perf_counter()
117
+ print(f"Scanning {file}")
118
+ reader = ZStdTextReader(file)
119
+ total_ngrams = 0
120
+ unique_ngrams = 0
121
+ matching_unique = 0
122
+ non_matching_unique = 0
123
+
124
+ current_ngram = ""
125
+ for line in reader.read_tqdm(): # Scan training set ngrams file
126
+ total_ngrams += 1
127
+ [ngram, document_id] = line.rsplit(" ", 1)
128
+ if (
129
+ ngram != current_ngram
130
+ ): # Only need to match the ngram once in training set
131
+ unique_ngrams += 1
132
+ current_ngram = ngram
133
+ if ngram in merged_lookup:
134
+ matched_ngrams.append(ngram) # For logging
135
+ matching_unique += 1
136
+ for task_name, task_set, doc_ids in merged_lookup[ngram]:
137
+ task_doc_set = duplicates[(task_name, task_set)]
138
+ for doc_id in doc_ids: # Record contamination across all relevant task/set combos
139
+ task_doc_set.add(doc_id)
140
+ del merged_lookup[ngram] # No point matching again
141
+ else:
142
+ non_matching_unique += 1
143
+
144
+ print(f"Total Ngrams: {total_ngrams}")
145
+ print(f"Unique Ngrams: {unique_ngrams}")
146
+ print(f"Unique Matching: {matching_unique}")
147
+ print(f"Unique Non Matching: {non_matching_unique}")
148
+ print("Matched ngrams:")
149
+ for ngram in matched_ngrams:
150
+ print(ngram)
151
+
152
+ elapsed = time.perf_counter() - start
153
+ print(f"Read took {elapsed:0.5f} seconds.")
154
+ print(f"Speed: {(os.path.getsize(file) / 1000000.0) / elapsed}MB/second")
155
+
156
+ print(duplicates)
157
+
158
+ # Dump overlaps separately
159
+ for (task_name, task_set), doc_ids in duplicates.items():
160
+ overlaps_dump_path = get_overlaps_dump_path(
161
+ task_name, task_set, ngrams_n_size, limit
162
+ )
163
+ pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
164
+
165
+ # Strip task set and return
166
+ return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
lm-evaluation-harness/lm_eval/decontamination/janitor.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import re
3
+ import string
4
+ import traceback
5
+ from typing import Iterator, List, Sequence, Tuple, TypeVar
6
+
7
+
8
+ # This is a cpp module.
9
+ # See scripts/clean_training_data/README.md for instructions to compile janitor_util.cpp
10
+
11
+ try:
12
+ import janitor_util
13
+
14
+ JANITOR_CPP = True
15
+ except Exception:
16
+ print("WARNING: C++ module could not be loaded. Janitor running in python mode")
17
+ traceback.print_exc()
18
+ JANITOR_CPP = False
19
+
20
+ T = TypeVar("T")
21
+
22
+
23
+ # Implementation from nltk source
24
+ # https://www.nltk.org/_modules/nltk/util.html
25
+ def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]:
26
+ history = []
27
+ while n > 1:
28
+ # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
29
+ try:
30
+ next_item = next(sequence)
31
+ except StopIteration:
32
+ # no more data, terminate the generator
33
+ return
34
+ history.append(next_item)
35
+ n -= 1
36
+ for item in sequence:
37
+ history.append(item)
38
+ yield tuple(history)
39
+ del history[0]
40
+
41
+
42
+ def word_ngrams(s: str, n: int) -> Iterator[str]:
43
+ """Splits a string into ngram words"""
44
+ tokens = s.split() # not a generator :(
45
+ ngram_seqs = form_ngrams(iter(tokens), n)
46
+ return (" ".join(ngram) for ngram in ngram_seqs)
47
+
48
+
49
+ # Does character sequences only - combined faster function to play around with later
50
+ # def word_ngrams_indices_combined(sequence, n):
51
+ # current_word = ""
52
+ # history = []
53
+ # gap = False;
54
+ # start = 0
55
+ # end = 0
56
+ # for character in sequence:
57
+ # if character == " ":
58
+ # if not gap:
59
+ # gap = True
60
+ # history.append(current_word)
61
+ # end += len(current_word) - 1
62
+ # current_word = ""
63
+ # if len(history) == n:
64
+ # yield (tuple(history), start, end)
65
+ # del history[0]
66
+ # start = end + 1
67
+ # end = start
68
+ # else:
69
+ # gap = False
70
+ # current_word += character
71
+
72
+
73
+ # https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
74
+ def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]:
75
+ """Splits a string on whitespaces and records the indices of each in the original string.
76
+ @:return generator((word, (start_idx, end_idx)), ...)
77
+ """
78
+ return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
79
+
80
+
81
+ def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]:
82
+ """Splits a string into pairs of (ngram words, their start/end indices)"""
83
+ tokens_with_indices = split_indices(s)
84
+
85
+ # Generator of ngrams of (word, idx_pairs)
86
+ # (
87
+ # [(word, (start,end)), (word, (start, end))...],
88
+ # [(word, (start, end)), ...],
89
+ # ...
90
+ # )
91
+ ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n)
92
+
93
+ # Generator of pairs of word and index ngrams
94
+ # (
95
+ # ([word, word, ...], [(start,end), (start,end), ...]),
96
+ # ...
97
+ # )
98
+ ngram_indices_pairs = (
99
+ zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
100
+ )
101
+
102
+ # Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
103
+ return (
104
+ (" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
105
+ for ngram_seq, indices in ngram_indices_pairs
106
+ )
107
+
108
+
109
+ class Janitor:
110
+ # FIXME delete_chars: Should anything else go here? Special chars?
111
+ def __init__(
112
+ self,
113
+ ngram_n: int = 13,
114
+ window_to_remove: int = 200,
115
+ too_dirty_cutoff: int = 10,
116
+ minimum_slice_length: int = 200,
117
+ delete_chars: str = string.punctuation,
118
+ ) -> None:
119
+ self.ngram_n = ngram_n
120
+ self.window_to_remove = window_to_remove
121
+ self.too_dirty_cutoff = too_dirty_cutoff
122
+ self.minimum_slice_length = minimum_slice_length
123
+ self.delete_chars = delete_chars
124
+
125
+ self.dirt_ngrams = set()
126
+
127
+ # If in python, we'll translate uppercase to lowercase and delete naughty characters.
128
+ # This is fast by python standards
129
+ # https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
130
+ self.translation_table = str.maketrans(
131
+ string.ascii_lowercase + string.ascii_uppercase, # These characters
132
+ string.ascii_lowercase * 2, # Become these characters
133
+ self.delete_chars, # These are deleted
134
+ )
135
+
136
+ ##############
137
+ # I/O for saving contamination ngrams
138
+ ##############
139
+
140
+ def save_contamination_ngrams(self, filename: str) -> None:
141
+ with open(filename, "wb") as fp:
142
+ pickle.dump(filename, fp)
143
+
144
+ def load_contamination_ngrams(self, filename: str) -> None:
145
+ with open(filename, "rb") as fp:
146
+ self.dirt_ngrams = pickle.load(fp)
147
+
148
+ ##############
149
+ # Call these :)
150
+ ##############
151
+
152
+ def register_contaminant(self, dirt_string: str) -> None:
153
+ """Register a string as contamination to be removed, e.g. a test set
154
+ This breaks the dirt_string into ngrams to store for future cleaning"""
155
+ if JANITOR_CPP:
156
+ return self.register_contaminant_cpp(dirt_string)
157
+ else:
158
+ print("WARNING: Janitor running in python mode")
159
+ return self.register_contaminant_python(dirt_string)
160
+
161
+ def clean(self, dirty_string: str) -> List[str]:
162
+ """Clean a string (e.g. a training set) by removing all ngrams previously
163
+ registered as contaminants. Returns a list of clean chunks, or empty if
164
+ the string was too dirty"""
165
+ if JANITOR_CPP:
166
+ return self.clean_cpp(dirty_string)
167
+ else:
168
+ print("WARNING: Janitor running in python mode")
169
+ return self.clean_python(dirty_string)
170
+
171
+ def _split_chunks(
172
+ self, dirty_string: str, dirty_parts: Sequence[Tuple]
173
+ ) -> List[str]:
174
+ clean_chunks = []
175
+ splice_idx = 0
176
+ end = -1
177
+ for i, (ngram, start, end) in enumerate(dirty_parts):
178
+ if i >= self.too_dirty_cutoff:
179
+ return []
180
+ start = max(0, start - self.window_to_remove)
181
+ end = min(len(dirty_string), end + self.window_to_remove)
182
+
183
+ if start - splice_idx > self.minimum_slice_length:
184
+ clean_chunks.append(dirty_string[splice_idx:start])
185
+ splice_idx = end
186
+
187
+ if end < len(dirty_string) - self.minimum_slice_length:
188
+ clean_chunks.append(dirty_string[end + 1 :])
189
+
190
+ return clean_chunks
191
+
192
+ ##############
193
+ # Fast C++
194
+ ##############
195
+
196
+ def register_contaminant_cpp(self, dirt_string) -> None:
197
+ self.dirt_ngrams.update(
198
+ janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
199
+ )
200
+
201
+ def clean_cpp(self, dirty_string: str) -> List[str]:
202
+ contamination_indices = janitor_util.clean_ngram_with_indices(
203
+ dirty_string, self.delete_chars, self.ngram_n
204
+ )
205
+ return self._split_chunks(dirty_string, contamination_indices)
206
+
207
+ ##############
208
+ # Slow python
209
+ ##############
210
+
211
+ def normalize_string(self, s: str) -> str:
212
+ return s.translate(self.translation_table)
213
+
214
+ def register_contaminant_python(self, dirt_string: str) -> None:
215
+ self.dirt_ngrams.update(
216
+ word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
217
+ )
218
+
219
+ def clean_python(self, dirty_string: str) -> List[str]:
220
+ contamination_indices = (
221
+ (None, *idx_pair)
222
+ for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
223
+ if self.normalize_string(dirty_ngram) in self.dirt_ngrams
224
+ )
225
+ return self._split_chunks(dirty_string, contamination_indices)
226
+
227
+
228
+ ##################################################################
229
+ # Tests
230
+ #################################################################
231
+
232
+ # def print_cpp():
233
+ # source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
234
+
235
+ # for i in range(1, 10, 2):
236
+ # pprint(janitor_util.clean_ngram(source, string.punctuation, i))
237
+ # for ngram, start, end in \
238
+ # janitor_util.clean_ngram_with_indices(source, string.punctuation, i):
239
+ # print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n"))
240
+
241
+
242
+ # def test_cpp():
243
+ # source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
244
+ # contaminant = "dirty boy. Clean he he"
245
+
246
+ # jan_python = Janitor()
247
+ # jan_cpp = Janitor()
248
+
249
+ # jan_python.register_contaminant_python(contaminant)
250
+ # jan_cpp.register_contaminant(contaminant)
251
+
252
+ # assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams)
253
+
254
+ # assert jan_python.clean_python(source) == jan_cpp.clean(source), \
255
+ # (jan_python.clean_python(source), jan_cpp.clean(source))
256
+
257
+ # print("Passed test, python==cpp")
258
+
259
+
260
+ # def benchmark():
261
+ # # Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html
262
+ # setup = \
263
+ # """
264
+ # with open("data/enwik8", "r") as f:
265
+ # data = f.read()
266
+ # jan = Janitor(too_dirty_cutoff=1000)
267
+ # jan.register_contaminant('''
268
+ # theories is that there is a connection between &quot;geekdom&quot; and autism.
269
+ # This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled &quot;
270
+ # The [[Geek]] Syndrome&quot;, which is a point argued by many in the autism rights
271
+ # movement{{ref|Wired}}. This article, many professionals assert, is just one example of
272
+ # the media's application of mental disease labels to what is actually variant normal behavior
273
+ # &amp;mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual
274
+ # interests, even when they seem unusual to others, are not in themselves signs of autism or
275
+ # Asperger's syndrome. Others assert that it is actually the medical profession which is applying
276
+ # mental disease labels to children who in the past would have simply been accepted as a little
277
+ # different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue.
278
+ # Due to the recent publicity surrounding autism and autis
279
+ # ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
280
+ # oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
281
+ # paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
282
+ # would last, took a cautious approach, preferring to save the revenue rather than investing it in
283
+ # development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
284
+ # to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
285
+ # brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
286
+ # with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M,
287
+ # ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995),
288
+ # ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the
289
+ # Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the
290
+ # [[United Arab Emirates]]. After the Emirates gained independence in 1971,
291
+ # ''')
292
+ # """
293
+
294
+ # n = 1
295
+ # print(f"Timing {n} run on 100 MB")
296
+ # print("Register contaminant")
297
+ # # print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n))
298
+ # print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n))
299
+
300
+ # print("Clean")
301
+ # # print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n))
302
+ # print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n))
303
+
304
+
305
+ # def test_janitor_general():
306
+ # source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
307
+ # contaminant = "dirty boy. Clean he he"
308
+
309
+ # jan = Janitor(ngram_n=3)
310
+ # jan.register_contaminant(contaminant)
311
+ # cleaned = " ".join(jan.clean(source))
312
+ # for contam in jan.dirt_ngrams:
313
+ # assert contam not in cleaned, contam
314
+
315
+ # filename = "data/saved_contam"
316
+ # jan.save_contamination_ngrams(filename)
317
+
318
+ # jan = Janitor(ngram_n=3)
319
+ # jan.load_contamination_ngrams(filename)
320
+ # cleaned = " ".join(jan.clean(source))
321
+ # for contam in jan.dirt_ngrams:
322
+ # assert contam not in cleaned, contam
323
+
324
+
325
+ # if __name__ == "__main__":
326
+ # test()
327
+ # # print_cpp()
328
+ # # test_cpp()
329
+ # # benchmark()
lm-evaluation-harness/lm_eval/evaluator.py ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import json
3
+ import logging
4
+ import os
5
+ import random
6
+ import time
7
+ from collections import defaultdict
8
+ from typing import TYPE_CHECKING, List, Optional, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ import lm_eval.api.metrics
14
+ import lm_eval.api.registry
15
+ import lm_eval.api.task
16
+ import lm_eval.models
17
+ from lm_eval.caching.cache import delete_cache
18
+ from lm_eval.evaluator_utils import (
19
+ consolidate_group_results,
20
+ consolidate_results,
21
+ get_sample_size,
22
+ get_subtask_list,
23
+ get_task_list,
24
+ prepare_print_tasks,
25
+ print_writeout,
26
+ run_task_tests,
27
+ )
28
+ from lm_eval.loggers import EvaluationTracker
29
+ from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
30
+ from lm_eval.tasks import TaskManager, get_task_dict
31
+ from lm_eval.utils import (
32
+ handle_non_serializable,
33
+ hash_dict_images,
34
+ hash_string,
35
+ positional_deprecated,
36
+ setup_logging,
37
+ simple_parse_args_string,
38
+ wrap_text,
39
+ )
40
+
41
+
42
+ if TYPE_CHECKING:
43
+ from lm_eval.api.model import LM
44
+ from lm_eval.api.task import Task
45
+
46
+ eval_logger = logging.getLogger(__name__)
47
+
48
+
49
+ @positional_deprecated
50
+ def simple_evaluate(
51
+ model,
52
+ model_args: Optional[Union[str, dict]] = None,
53
+ tasks: Optional[List[Union[str, dict, object]]] = None,
54
+ num_fewshot: Optional[int] = None,
55
+ batch_size: Optional[Union[int, str]] = None,
56
+ max_batch_size: Optional[int] = None,
57
+ device: Optional[str] = None,
58
+ use_cache: Optional[str] = None,
59
+ cache_requests: bool = False,
60
+ rewrite_requests_cache: bool = False,
61
+ delete_requests_cache: bool = False,
62
+ limit: Optional[Union[int, float]] = None,
63
+ samples: Optional[dict] = None,
64
+ bootstrap_iters: int = 100000,
65
+ check_integrity: bool = False,
66
+ write_out: bool = False,
67
+ log_samples: bool = True,
68
+ evaluation_tracker: Optional[EvaluationTracker] = None,
69
+ system_instruction: Optional[str] = None,
70
+ apply_chat_template: Union[bool, str] = False,
71
+ fewshot_as_multiturn: bool = False,
72
+ gen_kwargs: Union[str, dict, None] = None,
73
+ task_manager: Optional[TaskManager] = None,
74
+ verbosity=None,
75
+ predict_only: bool = False,
76
+ random_seed: int = 0,
77
+ numpy_random_seed: int = 1234,
78
+ torch_random_seed: int = 1234,
79
+ fewshot_random_seed: int = 1234,
80
+ confirm_run_unsafe_code: bool = False,
81
+ metadata: Optional[dict] = None,
82
+ ):
83
+ """Instantiate and evaluate a model on a list of tasks.
84
+
85
+ :param model: Union[str, LM]
86
+ Name of model or LM object, see lm_eval.models.get_model
87
+ :param model_args: Optional[str, dict]
88
+ String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
89
+ Ignored if `model` argument is a LM object.
90
+ :param tasks: list[Union[str, dict, Task]]
91
+ List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
92
+ :param num_fewshot: int
93
+ Number of examples in few-shot context
94
+ :param batch_size: int or str, optional
95
+ Batch size for model
96
+ :param max_batch_size: int, optional
97
+ Maximal batch size to try with automatic batch size detection
98
+ :param device: str, optional
99
+ PyTorch device (e.g. "cpu" or "cuda:0") for running models
100
+ :param use_cache: str, optional
101
+ A path to a sqlite db file for caching model responses. `None` if not caching.
102
+ :param cache_requests: bool, optional
103
+ Speed up evaluation by caching the building of dataset requests. `None` if not caching.
104
+ :param rewrite_requests_cache: bool, optional
105
+ Rewrites all the request cache if set to `True`. `None` if not desired.
106
+ :param delete_requests_cache: bool, optional
107
+ Deletes all the request cache if set to `True`. `None` if not desired.
108
+ :param limit: int or float, optional
109
+ Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
110
+ :param samples: dictionary, optional
111
+ Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
112
+ :param bootstrap_iters:
113
+ Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
114
+ :param check_integrity: bool
115
+ Whether to run the relevant part of the test suite for the tasks
116
+ :param write_out: bool
117
+ If True, write out an example document and model input for checking task integrity
118
+ :param log_samples: bool
119
+ If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
120
+ :param system_instruction: str
121
+ System instruction to be applied to the prompt
122
+ :param apply_chat_template: Union[bool, str]
123
+ Specifies whether to apply a chat template to the prompt.
124
+ - If set to True, the default chat template is applied.
125
+ - If set to a string, applies the specified chat template by name.
126
+ Defaults to False (no chat template applied).
127
+ :param fewshot_as_multiturn: bool
128
+ Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
129
+ :param gen_kwargs: dict or comma-separated string
130
+ Arguments for model generation
131
+ Ignored for all tasks with loglikelihood output_type
132
+ :param verbosity: str
133
+ Verbosity level for logging
134
+ :param predict_only: bool
135
+ If true only model outputs will be generated and returned. Metrics will not be evaluated
136
+ :param random_seed: int
137
+ Random seed for python's random module. If set to None, the seed will not be set.
138
+ :param numpy_random_seed: int
139
+ Random seed for numpy. If set to None, the seed will not be set.
140
+ :param torch_random_seed: int
141
+ Random seed for torch. If set to None, the seed will not be set.
142
+ :param fewshot_random_seed: int
143
+ Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
144
+ :param metadata: dict
145
+ Additional metadata to be added to the task manager. Will get passed to the download function of the task.
146
+ return
147
+ Dictionary of results
148
+ """
149
+ if verbosity is not None:
150
+ setup_logging(verbosity=verbosity)
151
+ start_date = time.time()
152
+
153
+ if limit is not None and samples is not None:
154
+ raise ValueError(
155
+ "Either 'limit' or 'samples' must be None, but both are not None."
156
+ )
157
+
158
+ _NEEDS_CHAT_TEMPLATE = ("inst", "chat")
159
+ if (
160
+ (
161
+ isinstance(model_args, str)
162
+ and any(kw in model_args.lower() for kw in _NEEDS_CHAT_TEMPLATE)
163
+ )
164
+ or (
165
+ isinstance(model_args, dict)
166
+ and any(
167
+ any(kw in str(v).lower() for kw in _NEEDS_CHAT_TEMPLATE)
168
+ for v in model_args.values()
169
+ )
170
+ )
171
+ ) and not apply_chat_template:
172
+ eval_logger.warning(
173
+ wrap_text(
174
+ f"""pretrained={model_args.get("pretrained") if isinstance(model_args, dict) else model_args} appears to be an
175
+ instruct or chat variant but chat template is not applied.
176
+ Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`).""",
177
+ )
178
+ )
179
+
180
+ if delete_requests_cache:
181
+ eval_logger.info("Deleting requests cache...")
182
+ delete_cache()
183
+
184
+ seed_message = []
185
+ if random_seed is not None:
186
+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
187
+ seed_message.append(f"Setting random seed to {random_seed}")
188
+ random.seed(random_seed)
189
+
190
+ if numpy_random_seed is not None:
191
+ seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
192
+ np.random.seed(numpy_random_seed)
193
+
194
+ if torch_random_seed is not None:
195
+ seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
196
+ torch.manual_seed(torch_random_seed)
197
+
198
+ if fewshot_random_seed is not None:
199
+ seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}")
200
+
201
+ if seed_message:
202
+ eval_logger.info(" | ".join(seed_message))
203
+
204
+ if tasks is None:
205
+ tasks = []
206
+ if len(tasks) == 0:
207
+ raise ValueError(
208
+ "No tasks specified, or no tasks found. Please verify the task names."
209
+ )
210
+
211
+ if gen_kwargs is not None:
212
+ if isinstance(gen_kwargs, str):
213
+ gen_kwargs = simple_parse_args_string(gen_kwargs)
214
+ eval_logger.warning(
215
+ f"generation_kwargs: {gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
216
+ "Ensure 'do_sample=True' for non-greedy decoding!"
217
+ )
218
+ if not gen_kwargs:
219
+ gen_kwargs = None
220
+
221
+ if isinstance(model, str):
222
+ if model_args is None:
223
+ eval_logger.warning("model_args not specified. Using defaults.")
224
+ model_args = ""
225
+
226
+ if isinstance(model_args, dict):
227
+ eval_logger.info(
228
+ f"Initializing {model} model, with arguments: {model_args}"
229
+ )
230
+ lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
231
+ model_args,
232
+ {
233
+ "batch_size": batch_size,
234
+ "max_batch_size": max_batch_size,
235
+ "device": device,
236
+ },
237
+ )
238
+
239
+ else:
240
+ eval_logger.info(
241
+ wrap_text(
242
+ f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
243
+ )
244
+ )
245
+ lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
246
+ model_args,
247
+ {
248
+ "batch_size": batch_size,
249
+ "max_batch_size": max_batch_size,
250
+ "device": device,
251
+ },
252
+ )
253
+ else:
254
+ if not isinstance(model, lm_eval.api.model.LM):
255
+ raise TypeError(
256
+ f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
257
+ )
258
+ eval_logger.info("Using pre-initialized model")
259
+ lm = model
260
+
261
+ if use_cache is not None:
262
+ eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
263
+ lm = lm_eval.api.model.CachingLM(
264
+ lm,
265
+ use_cache
266
+ # each rank receives a different cache db.
267
+ # necessary to avoid multiple writes to cache at once
268
+ + "_rank"
269
+ + str(lm.rank)
270
+ + ".db",
271
+ )
272
+
273
+ if task_manager is None:
274
+ metadata = (
275
+ simple_parse_args_string(model_args)
276
+ if isinstance(model_args, str)
277
+ else model_args
278
+ if isinstance(model_args, dict)
279
+ else {}
280
+ ) | (metadata or {})
281
+ task_manager = TaskManager(metadata=metadata)
282
+
283
+ task_dict = get_task_dict(
284
+ tasks,
285
+ task_manager,
286
+ )
287
+
288
+ # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
289
+ # (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
290
+ def _adjust_config(task_dict):
291
+ adjusted_task_dict = {}
292
+ for task_name, task_obj in task_dict.items():
293
+ if isinstance(task_obj, dict):
294
+ adjusted_task_dict = {
295
+ **adjusted_task_dict,
296
+ **{task_name: _adjust_config(task_obj)},
297
+ }
298
+
299
+ else:
300
+ if task_obj.get_config("output_type") == "generate_until":
301
+ if gen_kwargs is not None:
302
+ task_obj.set_config(
303
+ key="generation_kwargs", value=gen_kwargs, update=True
304
+ )
305
+ eval_logger.info(
306
+ f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
307
+ )
308
+
309
+ if predict_only:
310
+ eval_logger.info(
311
+ f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
312
+ )
313
+ # we have to change the class properties post-hoc. This is pretty hacky.
314
+ task_obj.override_metric(metric_name="bypass")
315
+
316
+ # override tasks' fewshot values to the provided num_fewshot arg value
317
+ # except if tasks have it set to 0 manually in their configs--then we should never overwrite that
318
+ if num_fewshot is not None:
319
+ if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
320
+ eval_logger.info(
321
+ f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
322
+ )
323
+ else:
324
+ eval_logger.warning(
325
+ f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
326
+ )
327
+ task_obj.set_config(key="num_fewshot", value=num_fewshot)
328
+ else:
329
+ # if num_fewshot not provided, and the task does not define a default one, default to 0
330
+ if (
331
+ default_num_fewshot := task_obj.get_config("num_fewshot")
332
+ ) is None:
333
+ task_obj.set_config(key="num_fewshot", value=0)
334
+ # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
335
+ task_obj.set_fewshot_seed(seed=fewshot_random_seed)
336
+
337
+ adjusted_task_dict[task_name] = task_obj
338
+
339
+ return adjusted_task_dict
340
+
341
+ task_dict = _adjust_config(task_dict)
342
+
343
+ if check_integrity:
344
+ run_task_tests(task_list=tasks)
345
+
346
+ if evaluation_tracker is not None:
347
+ evaluation_tracker.general_config_tracker.log_experiment_args(
348
+ model_source=model,
349
+ model_args=model_args,
350
+ system_instruction=system_instruction,
351
+ chat_template=lm.chat_template(apply_chat_template)
352
+ if apply_chat_template
353
+ else None,
354
+ fewshot_as_multiturn=fewshot_as_multiturn,
355
+ )
356
+
357
+ results = evaluate(
358
+ lm=lm,
359
+ task_dict=task_dict,
360
+ limit=limit,
361
+ samples=samples,
362
+ cache_requests=cache_requests,
363
+ rewrite_requests_cache=rewrite_requests_cache,
364
+ bootstrap_iters=bootstrap_iters,
365
+ write_out=write_out,
366
+ log_samples=True if predict_only else log_samples,
367
+ system_instruction=system_instruction,
368
+ apply_chat_template=apply_chat_template,
369
+ fewshot_as_multiturn=fewshot_as_multiturn,
370
+ verbosity=verbosity,
371
+ confirm_run_unsafe_code=confirm_run_unsafe_code,
372
+ )
373
+ if verbosity is not None:
374
+ setup_logging(verbosity=verbosity)
375
+
376
+ if lm.rank == 0:
377
+ if isinstance(model, str):
378
+ model_name = model
379
+ elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
380
+ model_name = model.config._name_or_path
381
+ else:
382
+ model_name = type(model).__name__
383
+
384
+ # add info about the model and few shot config
385
+ results["config"] = {
386
+ "model": model_name,
387
+ "model_args": model_args,
388
+ }
389
+ # add more detailed model info if available
390
+ if isinstance(lm, lm_eval.models.huggingface.HFLM):
391
+ results["config"].update(lm.get_model_info())
392
+ # add info about execution
393
+ results["config"].update(
394
+ {
395
+ "batch_size": batch_size,
396
+ "batch_sizes": (
397
+ list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
398
+ ),
399
+ "device": device,
400
+ "use_cache": use_cache,
401
+ "limit": limit,
402
+ "bootstrap_iters": bootstrap_iters,
403
+ "gen_kwargs": gen_kwargs,
404
+ "random_seed": random_seed,
405
+ "numpy_seed": numpy_random_seed,
406
+ "torch_seed": torch_random_seed,
407
+ "fewshot_seed": fewshot_random_seed,
408
+ }
409
+ )
410
+ results["git_hash"] = get_git_commit_hash()
411
+ results["date"] = start_date
412
+ add_env_info(results) # additional environment info to results
413
+ add_tokenizer_info(results, lm) # additional info about tokenizer
414
+ return results
415
+ else:
416
+ return None
417
+
418
+
419
+ @positional_deprecated
420
+ def evaluate(
421
+ lm: "LM",
422
+ task_dict,
423
+ limit: Optional[int] = None,
424
+ samples: Optional[dict] = None,
425
+ cache_requests: bool = False,
426
+ rewrite_requests_cache: bool = False,
427
+ bootstrap_iters: Optional[int] = 100000,
428
+ write_out: bool = False,
429
+ log_samples: bool = True,
430
+ system_instruction: Optional[str] = None,
431
+ apply_chat_template: Union[bool, str] = False,
432
+ fewshot_as_multiturn: bool = False,
433
+ verbosity: str = "INFO",
434
+ confirm_run_unsafe_code: bool = False,
435
+ ):
436
+ """Instantiate and evaluate a model on a list of tasks.
437
+
438
+ :param lm: obj
439
+ Language Model
440
+ :param task_dict: dict[str, Task]
441
+ Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
442
+ :param limit: int, optional
443
+ Limit the number of examples per task (only use this for testing)
444
+ :param samples: dictionary, optional
445
+ Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
446
+ :param cache_requests: bool, optional
447
+ Speed up evaluation by caching the building of dataset requests.
448
+ :param rewrite_requests_cache: bool, optional
449
+ Rewrites all the request cache if set to `True`.
450
+ :param bootstrap_iters:
451
+ Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
452
+ :param write_out: bool
453
+ If True, write out an example document and model input for checking task integrity
454
+ :param log_samples: bool
455
+ If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
456
+ :param system_instruction: str
457
+ System instruction to be applied to the prompt
458
+ :param apply_chat_template: Union[bool, str]
459
+ Specifies whether to apply a chat template to the prompt.
460
+ - If set to True, the default chat template is applied.
461
+ - If set to a string, applies the specified chat template by name.
462
+ Defaults to False (no chat template applied).
463
+ :param fewshot_as_multiturn: bool
464
+ Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
465
+ :param verbosity: str
466
+ Verbosity level for logging
467
+ :param confirm_run_unsafe_code: bool
468
+ Whether to confirm running tasks marked as unsafe.
469
+ :return
470
+ Dictionary of results
471
+ """
472
+
473
+ if limit is not None and samples is not None:
474
+ raise ValueError(
475
+ "Either 'limit' or 'samples' must be None, but both are not None."
476
+ )
477
+ if samples is not None:
478
+ eval_logger.info(f"Evaluating examples for tasks {list(samples.keys())}")
479
+ if apply_chat_template:
480
+ eval_logger.warning(
481
+ "Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details."
482
+ )
483
+ # tracks all Instances/requests a model must generate output on.
484
+ requests = defaultdict(list)
485
+ # stores the amount to pad out reqs per req. type so that
486
+ # number of fwd passes per distributed rank is equal
487
+ padding_requests = defaultdict(int)
488
+
489
+ # get lists of group hierarchy and each type of request
490
+ eval_tasks = get_task_list(task_dict)
491
+ if not log_samples:
492
+ if not all(
493
+ "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
494
+ for task_output in eval_tasks
495
+ ):
496
+ raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
497
+
498
+ # validation checks:
499
+ # 1.are we running multimodal task <-> non-multimodal model class, or vice-versa.
500
+ # 2.are we running code that is marked as unsafe.
501
+ incompatible_tasks = []
502
+ for task_output in eval_tasks:
503
+ task: Task = task_output.task
504
+
505
+ if getattr(task, "MULTIMODAL", False) and not getattr(lm, "MULTIMODAL", False):
506
+ incompatible_tasks.append(task_output.task_name)
507
+ elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code:
508
+ raise ValueError(
509
+ f"Attempted to run task: {task_output.task_name} which is marked as unsafe. Set confirm_run_unsafe_code=True to run this task."
510
+ )
511
+ if len(incompatible_tasks) > 0:
512
+ if not getattr(lm, "MULTIMODAL", False):
513
+ raise ValueError(
514
+ f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
515
+ )
516
+ # end validation check
517
+
518
+ # Cache the limit arg.
519
+ limit_arg = limit
520
+ limits = []
521
+ for task_output in eval_tasks:
522
+ task: Task = task_output.task
523
+
524
+ limit = get_sample_size(task, limit_arg)
525
+ limits.append(limit)
526
+ task.build_all_requests(
527
+ limit=limit,
528
+ samples=samples.get(task_output.task_name, None)
529
+ if samples is not None
530
+ else samples,
531
+ rank=lm.rank,
532
+ world_size=lm.world_size,
533
+ cache_requests=cache_requests,
534
+ rewrite_requests_cache=rewrite_requests_cache,
535
+ system_instruction=system_instruction,
536
+ apply_chat_template=bool(apply_chat_template),
537
+ fewshot_as_multiturn=fewshot_as_multiturn,
538
+ chat_template=getattr(lm, "apply_chat_template")
539
+ if apply_chat_template
540
+ else None,
541
+ tokenizer_name=getattr(lm, "tokenizer_name", "")
542
+ if apply_chat_template
543
+ else "",
544
+ )
545
+ eval_logger.debug(
546
+ f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
547
+ )
548
+ if write_out:
549
+ print_writeout(task)
550
+ # aggregate Instances by LM method requested to get output.
551
+ for instance in task.instances:
552
+ reqtype = instance.request_type
553
+ requests[reqtype].append(instance)
554
+
555
+ if lm.world_size > 1:
556
+ instances_rnk = torch.tensor(len(task._instances), device=lm.device)
557
+ gathered_item = (
558
+ lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
559
+ )
560
+ # "multiple_choice" task types dispatch (several) "loglikelihood" request types
561
+ reqtype = (
562
+ "loglikelihood"
563
+ if task.OUTPUT_TYPE == "multiple_choice"
564
+ else task.OUTPUT_TYPE
565
+ )
566
+ # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
567
+ numpad = max(gathered_item) - gathered_item[lm.rank]
568
+ # todo: may not account for padding in cases like SquadV2 which has multiple req types
569
+ padding_requests[reqtype] += numpad
570
+
571
+ ### Run LM on inputs, get all outputs ###
572
+ # execute each type of request
573
+ for reqtype, reqs in requests.items():
574
+ eval_logger.info(f"Running {reqtype} requests")
575
+ # create `K` copies of each request `req` based off `K = req.repeats`
576
+ cloned_reqs = []
577
+ for req in reqs:
578
+ cloned_reqs.extend([req] * req.repeats)
579
+
580
+ if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
581
+ for _ in range(padding_requests[reqtype]):
582
+ cloned_reqs.extend([req] * req.repeats)
583
+
584
+ # run requests through model
585
+ resps = getattr(lm, reqtype)(cloned_reqs)
586
+
587
+ # put responses from model into a list of length K for each request.
588
+ for x, req in zip(resps, cloned_reqs):
589
+ req.resps.append(x)
590
+
591
+ if lm.world_size > 1:
592
+ lm.accelerator.wait_for_everyone()
593
+
594
+ RANK = lm.rank
595
+ WORLD_SIZE = lm.world_size
596
+ ### Postprocess outputs ###
597
+ # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
598
+ for task_output, limit in zip(eval_tasks, limits):
599
+ task = task_output.task
600
+ task.apply_filters()
601
+
602
+ ### Collect values of metrics on all datapoints ###
603
+ # # unpack results and sort back in order and return control to Task
604
+ # TODO: make it possible to use a different metric per filter
605
+ # Pre-process task.instances to group by doc_id
606
+ instances_by_doc_id = defaultdict(list)
607
+ for instance in task.instances:
608
+ instances_by_doc_id[instance.doc_id].append(instance)
609
+ # Sort instances within each group
610
+ for instances in instances_by_doc_id.values():
611
+ instances.sort(key=lambda x: x.idx)
612
+ # iterate over different filters used
613
+ for filter_key in task.instances[0].filtered_resps.keys():
614
+ indices = (
615
+ samples.get(task_output.task_name, None)
616
+ if samples is not None
617
+ else None
618
+ )
619
+ doc_iterator = task.doc_iterator(
620
+ rank=RANK,
621
+ limit=limit,
622
+ world_size=WORLD_SIZE,
623
+ samples=indices,
624
+ )
625
+ for doc_id, doc in doc_iterator:
626
+ if indices:
627
+ doc_id_true = indices[doc_id]
628
+ else:
629
+ doc_id_true = doc_id
630
+ requests = instances_by_doc_id[doc_id]
631
+ metrics = task.process_results(
632
+ doc, [req.filtered_resps[filter_key] for req in requests]
633
+ )
634
+ if log_samples:
635
+ target = task.doc_to_target(doc)
636
+ example = {
637
+ "doc_id": doc_id_true,
638
+ "doc": doc,
639
+ "target": target,
640
+ "arguments": [req.args for req in requests],
641
+ "resps": [req.resps for req in requests],
642
+ "filtered_resps": [
643
+ req.filtered_resps[filter_key] for req in requests
644
+ ],
645
+ "filter": filter_key,
646
+ "metrics": list(metrics.keys()),
647
+ "doc_hash": hash_string(
648
+ json.dumps(
649
+ requests[0].doc,
650
+ indent=2,
651
+ default=handle_non_serializable,
652
+ ensure_ascii=False,
653
+ )
654
+ ),
655
+ "prompt_hash": hash_string(requests[0].arguments[0]),
656
+ "target_hash": hash_string(str(target)),
657
+ }
658
+ example.update(metrics)
659
+ task_output.logged_samples.append(example)
660
+ for metric, value in metrics.items():
661
+ task_output.sample_metrics[(metric, filter_key)].append(value)
662
+
663
+ if WORLD_SIZE > 1:
664
+ # if multigpu, then gather data across all ranks to rank 0
665
+ # first gather logged samples across all ranks
666
+ for task_output in eval_tasks:
667
+ if log_samples:
668
+ # for task_name, task_samples in list(samples.items()):
669
+ full_samples = [None] * WORLD_SIZE if RANK == 0 else None
670
+ torch.distributed.gather_object(
671
+ obj=task_output.logged_samples,
672
+ object_gather_list=full_samples,
673
+ dst=0,
674
+ )
675
+
676
+ if RANK == 0:
677
+ task_output.logged_samples = list(
678
+ itertools.chain.from_iterable(full_samples)
679
+ )
680
+
681
+ # then collect metrics across all ranks
682
+ for metrics in task_output.sample_metrics:
683
+ metric_list = [None] * WORLD_SIZE if RANK == 0 else None
684
+ torch.distributed.gather_object(
685
+ obj=task_output.sample_metrics[metrics],
686
+ object_gather_list=metric_list,
687
+ dst=0,
688
+ )
689
+ if RANK == 0:
690
+ task_output.sample_metrics[metrics] = list(
691
+ itertools.chain.from_iterable(metric_list)
692
+ )
693
+
694
+ if RANK == 0:
695
+ ### Aggregate results over all datapoints ###
696
+ # aggregate results ; run bootstrap CIs
697
+ for task_output in eval_tasks:
698
+ task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
699
+ (
700
+ results,
701
+ samples,
702
+ configs,
703
+ versions,
704
+ num_fewshot,
705
+ higher_is_better,
706
+ ) = consolidate_results(eval_tasks)
707
+
708
+ ### Calculate group metrics ###
709
+ if bool(results):
710
+ results, versions, show_group_table, *_ = consolidate_group_results(
711
+ results, versions, task_dict
712
+ )
713
+
714
+ results_agg, group_agg = prepare_print_tasks(task_dict, results)
715
+ subtask_list = get_subtask_list(task_dict)
716
+
717
+ # collect all higher_is_better values for metrics
718
+ # in the group's subtasks.
719
+ # TODO: clean this up ; unify with the below metric_list loop?
720
+ _higher_is_better = {}
721
+ for group, task_list in subtask_list.items():
722
+ if (
723
+ len(task_list) != 0
724
+ ): # subtask list will list "task_name": [] for solo tasks
725
+ for task in task_list:
726
+ for m, h in higher_is_better[task].items():
727
+ if m not in _higher_is_better.keys():
728
+ _higher_is_better[m] = h
729
+
730
+ if (
731
+ m in _higher_is_better
732
+ and _higher_is_better[m] is not None
733
+ and _higher_is_better[m] != h
734
+ ):
735
+ eval_logger.warning(
736
+ f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
737
+ )
738
+ _higher_is_better[m] = None
739
+ higher_is_better[group] = _higher_is_better
740
+
741
+ results_dict = {
742
+ "results": dict(results_agg.items()),
743
+ **(
744
+ {"groups": dict(group_agg.items())}
745
+ if (bool(group_agg) & show_group_table)
746
+ else {}
747
+ ),
748
+ "group_subtasks": dict(reversed(subtask_list.items())),
749
+ "configs": dict(sorted(configs.items())),
750
+ "versions": dict(sorted(versions.items())),
751
+ "n-shot": dict(sorted(num_fewshot.items())),
752
+ "higher_is_better": dict(sorted(higher_is_better.items())),
753
+ "n-samples": {
754
+ task_output.task_name: {
755
+ "original": len(task_output.task.eval_docs),
756
+ "effective": min(
757
+ limit if limit else len(task_output.task.eval_docs),
758
+ len(task_output.task.eval_docs),
759
+ ),
760
+ }
761
+ for task_output, limit in zip(eval_tasks, limits)
762
+ },
763
+ }
764
+ if log_samples:
765
+ # default: hash images
766
+ samples = (
767
+ hash_dict_images(samples)
768
+ if os.environ.get("LMEVAL_HASHMM", "1") != "0"
769
+ and (hasattr(lm, "MULTIMODAL"))
770
+ else samples
771
+ )
772
+ results_dict["samples"] = dict(samples)
773
+
774
+ return results_dict
775
+
776
+ else:
777
+ return None
778
+
779
+
780
+ def request_caching_arg_to_dict(cache_requests: str) -> dict:
781
+ request_caching_args = {
782
+ "cache_requests": cache_requests in {"true", "refresh"},
783
+ "rewrite_requests_cache": cache_requests == "refresh",
784
+ "delete_requests_cache": cache_requests == "delete",
785
+ }
786
+
787
+ return request_caching_args
lm-evaluation-harness/lm_eval/evaluator_utils.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import math
4
+ import pathlib
5
+ import sys
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ from lm_eval.api.group import ConfigurableGroup
9
+ from lm_eval.api.metrics import (
10
+ aggregate_subtask_metrics,
11
+ mean,
12
+ pooled_sample_stderr,
13
+ stderr_for_metric,
14
+ )
15
+ from lm_eval.api.task import Task
16
+ from lm_eval.utils import positional_deprecated
17
+
18
+
19
+ eval_logger = logging.getLogger(__name__)
20
+
21
+
22
+ class TaskOutput:
23
+ """
24
+ Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task.
25
+
26
+ Attributes:
27
+ task (object): The task object.
28
+ task_name (str): The name of the task.
29
+ task_config (dict): The configuration of the task.
30
+ version (str): The version of the task.
31
+ group_name (str): The name of the task group.
32
+ n_shot (int): The number of shots for the task.
33
+ task_alias (str): The alias of the task.
34
+ group_alias (str): The alias of the task group.
35
+ is_group (bool): Indicates if the task is a group.
36
+ logged_samples (list): The list of logged samples.
37
+ sample_len (int): The length of the samples.
38
+ sample_metrics (defaultdict): The dictionary of samples' metrics.
39
+ agg_metrics (defaultdict): The dictionary of aggregate metrics.
40
+
41
+ Methods:
42
+ from_taskdict(cls, task_name: str, task):
43
+ Creates a TaskOutput instance from a task dictionary.
44
+
45
+ calculate_aggregate_metric(bootstrap_iters=100000) -> None:
46
+ Calculates the aggregate metrics for the task.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ task=None,
52
+ task_name=None,
53
+ task_config=None,
54
+ version=None,
55
+ group_name=None,
56
+ n_shot=None,
57
+ task_alias=None,
58
+ group_alias=None,
59
+ is_group=None,
60
+ ):
61
+ self.task = task
62
+ self.task_config = task_config
63
+ self.task_name = task_name
64
+ self.group_name = group_name
65
+ self.version = version
66
+ self.n_shot = n_shot
67
+ self.task_alias = task_alias
68
+ self.group_alias = group_alias
69
+ self.is_group = is_group
70
+ self.logged_samples = []
71
+ self.sample_len = None
72
+ self.sample_metrics = collections.defaultdict(list)
73
+ self.agg_metrics = collections.defaultdict(list)
74
+
75
+ @classmethod
76
+ def from_taskdict(cls, task_name: str, task):
77
+ if isinstance(task, tuple):
78
+ group_name, task = task
79
+ else:
80
+ group_name = None
81
+ if not task:
82
+ # these gets filtered out in get_task_list
83
+ # once they are added to group hierarchy
84
+ is_group = True
85
+ return cls(
86
+ task=task, task_name=task_name, is_group=is_group, group_name=group_name
87
+ )
88
+ version = task.VERSION
89
+ task_config = dict(task.dump_config())
90
+ if (n_shot := task_config.get("num_fewshot")) == 0:
91
+ n_shot = task_config.get("metadata", {}).get("num_fewshot", 0)
92
+ task_alias = task_config.get("alias")
93
+ group_alias = task_config.get("group_alias")
94
+ return cls(
95
+ task=task,
96
+ task_name=task_name,
97
+ task_config=task_config,
98
+ group_name=group_name,
99
+ version=version,
100
+ n_shot=n_shot,
101
+ task_alias=task_alias,
102
+ group_alias=group_alias,
103
+ )
104
+
105
+ def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None:
106
+ for (metric, filter_key), items in self.sample_metrics.items():
107
+ try:
108
+ agg_fn = self.task.aggregation()[metric]
109
+ except KeyError:
110
+ # This is when process results output an arbitrary metric
111
+ # TODO: Handle this better and allow other aggregate functions other than mean.
112
+ agg_fn = mean
113
+ metric_key = f"{metric},{filter_key}"
114
+ self.agg_metrics[metric_key] = agg_fn(items)
115
+ self.sample_len = len(items) # TODO: same sample size for each metric?
116
+ if isinstance(bootstrap_iters, int):
117
+ stderr_fn = stderr_for_metric(
118
+ metric=agg_fn,
119
+ bootstrap_iters=min(bootstrap_iters, 100)
120
+ if metric in ["bleu", "chrf", "ter"]
121
+ else bootstrap_iters,
122
+ )
123
+ self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
124
+ stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
125
+ )
126
+ else:
127
+ raise ValueError(
128
+ f"Received bootstrap_iters '{bootstrap_iters}' but expected an integer. Set to 0 to turn off stderr calculations."
129
+ )
130
+
131
+ def __repr__(self):
132
+ return (
133
+ f"TaskOutput(task_name={self.task_name}, "
134
+ f"group_name={self.group_name}, "
135
+ f"version={self.version}, "
136
+ f"n_shot={self.n_shot}, "
137
+ f"task_alias={self.task_alias}, "
138
+ f"group_alias={self.group_alias})"
139
+ )
140
+
141
+
142
+ def get_task_list(task_dict: dict) -> List[TaskOutput]:
143
+ outputs = []
144
+ for task_name, task_obj in task_dict.items():
145
+ if isinstance(task_obj, dict):
146
+ _outputs = get_task_list(task_obj)
147
+ outputs.extend(_outputs)
148
+ else:
149
+ task_output = TaskOutput.from_taskdict(task_name, task_obj)
150
+ outputs.append(task_output)
151
+
152
+ return outputs
153
+
154
+
155
+ def get_subtask_list(task_dict, task_root=None, depth=0):
156
+ subtask_list = {}
157
+ for group_obj, task_obj in task_dict.items():
158
+ if isinstance(group_obj, ConfigurableGroup):
159
+ # group_name = group_obj.group_name
160
+ group_name = group_obj.group_name
161
+ else:
162
+ group_name = group_obj
163
+ if isinstance(task_obj, dict):
164
+ _subtask_list = get_subtask_list(
165
+ task_obj, task_root=group_name, depth=depth + 1
166
+ )
167
+ if task_root:
168
+ subtask_list.setdefault((task_root, depth), []).extend(
169
+ [
170
+ _task
171
+ for (_task, _depth) in _subtask_list.keys()
172
+ if (_depth - 1) == depth
173
+ ]
174
+ )
175
+
176
+ subtask_list = {**subtask_list, **_subtask_list}
177
+ else:
178
+ if isinstance(task_obj, ConfigurableGroup):
179
+ # group_or_task_name = task_obj.group_name
180
+ group_or_task_name = task_obj.group_name
181
+ elif isinstance(task_obj, Task):
182
+ # group_or_task_name = task_obj.task_name
183
+ group_or_task_name = task_obj.task_name
184
+
185
+ if task_root is None:
186
+ subtask_list.setdefault((group_or_task_name, depth), [])
187
+ else:
188
+ subtask_list.setdefault((task_root, depth), []).append(
189
+ group_or_task_name
190
+ )
191
+
192
+ if depth == 0:
193
+ _subtask_list = {}
194
+ for group_key, task_list in subtask_list.items():
195
+ group_name, depth = group_key
196
+ _subtask_list[group_name] = task_list
197
+ subtask_list = _subtask_list
198
+
199
+ return subtask_list
200
+
201
+
202
+ def print_writeout(task) -> None:
203
+ for inst in task.instances:
204
+ # print the prompt for the first few documents
205
+ if inst.doc_id < 1:
206
+ eval_logger.info(
207
+ f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\
208
+ \n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
209
+ )
210
+ eval_logger.info(f"Request: {str(inst)}")
211
+
212
+
213
+ def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
214
+ if limit is not None:
215
+ limit = (
216
+ int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit)
217
+ )
218
+ return limit
219
+
220
+
221
+ def prepare_print_tasks(
222
+ task_dict: dict,
223
+ results: dict,
224
+ task_depth=0,
225
+ group_depth=0,
226
+ ) -> Tuple[dict, dict]:
227
+ """
228
+ @param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
229
+ value is a list of task names.
230
+ @param results: Dictionary containing the results of each task. Each key is a
231
+ group name and its value is a dictionary of task results.
232
+ @param task_depth: The indentation level for printing the task
233
+ hierarchy. Default is 0.
234
+ @param group_depth: The indentation level for printing the group
235
+ hierarchy. Default is 0.
236
+ @return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
237
+ aggregated results for each task, and groups_agg contains aggregated results for each group.
238
+
239
+ Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
240
+ """
241
+
242
+ def _sort_task_dict(task_dict):
243
+ """
244
+ Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name.
245
+ Required so that we end up sorting within each sub-header correctly.
246
+ """
247
+
248
+ return dict(
249
+ sorted(
250
+ task_dict.items(),
251
+ key=lambda item: item[0].group_name
252
+ if isinstance(item[0], ConfigurableGroup)
253
+ else item[0],
254
+ )
255
+ )
256
+
257
+ task_agg = collections.defaultdict(dict)
258
+ group_agg = collections.defaultdict(dict)
259
+ task_dict = _sort_task_dict(task_dict)
260
+ for task_or_group_name, task_or_group_obj in task_dict.items():
261
+ tab_string = " " * task_depth + "- " if task_depth > 0 else ""
262
+ if isinstance(task_or_group_name, ConfigurableGroup):
263
+ # string_name = task_or_group_name.group_name
264
+ name = task_or_group_name.group_name
265
+ from_configurable_group = True
266
+ task_or_group_obj = _sort_task_dict(task_or_group_obj)
267
+ elif isinstance(task_or_group_name, str):
268
+ name = task_or_group_name
269
+ if isinstance(task_or_group_obj, Task):
270
+ # string_name = task_or_group_obj.task_name
271
+ name = task_or_group_obj.task_name
272
+ from_configurable_group = False
273
+
274
+ task_agg[name] = results[name].copy()
275
+ if from_configurable_group:
276
+ if task_or_group_name.group_alias is not None:
277
+ alias = task_or_group_name.group_alias
278
+ else:
279
+ alias = task_or_group_name.group
280
+ else:
281
+ if "alias" in task_agg[name]:
282
+ alias = task_agg[name]["alias"]
283
+ else:
284
+ alias = name
285
+
286
+ task_agg[name]["alias"] = tab_string + alias
287
+ if "samples" in task_agg[name]:
288
+ task_agg[name].pop("samples")
289
+
290
+ if from_configurable_group and (" " not in results[name]):
291
+ group_tab_string = " " * group_depth + "- " if group_depth > 0 else ""
292
+ group_agg[name] = results[name].copy()
293
+ group_agg[name]["alias"] = group_tab_string + alias
294
+ if "samples" in group_agg[name]:
295
+ group_agg[name].pop("samples")
296
+
297
+ if isinstance(task_or_group_obj, dict):
298
+ task_depth += 1
299
+ group_depth += 1
300
+ _task_agg, _group_agg = prepare_print_tasks(
301
+ task_or_group_obj, results, task_depth, group_depth
302
+ )
303
+ task_agg = {
304
+ **task_agg,
305
+ **_task_agg,
306
+ }
307
+ group_agg = {**group_agg, **_group_agg}
308
+ task_depth -= 1
309
+ group_depth -= 1
310
+ return task_agg, group_agg
311
+
312
+
313
+ def consolidate_results(
314
+ eval_tasks: List[TaskOutput],
315
+ ) -> Tuple[dict, dict, dict, dict, dict, dict]:
316
+ """
317
+ @param eval_tasks: list(TaskOutput).
318
+ @return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot.
319
+
320
+ Consolidates the results of multiple evaluation tasks into a single structure.
321
+
322
+ The method iterates over each evaluation instance and extracts relevant information to create the consolidated
323
+ results structure. The consolidated results structure has the following properties:
324
+
325
+ - results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains
326
+ metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task
327
+ aliases specified in the task configuration.
328
+ - samples: A defaultdict with task names as keys and lists of log samples as values.
329
+ - configs: A defaultdict with task names as keys and task configurations as values.
330
+ - versions: A defaultdict with task names as keys and task versions as values.
331
+ - num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values.
332
+ - higher_is_better: A defaultdict with task names as keys and indicators of whether higher values are better
333
+ for each metric as values.
334
+
335
+ The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple.
336
+ """
337
+ # stores the final result for each task, for each metric/filter pair.
338
+ results = collections.defaultdict(dict)
339
+ # logs info about each document evaluated.
340
+ samples = collections.defaultdict(list)
341
+ # store num-fewshot value per task
342
+ num_fewshot = collections.defaultdict(int)
343
+ # Tracks the YAML configs of all chosen task
344
+ configs = collections.defaultdict(dict)
345
+ # Tracks each task's version.
346
+ versions = collections.defaultdict(dict)
347
+ # Track `higher_is_better` for each metric
348
+ higher_is_better = collections.defaultdict(dict)
349
+
350
+ for task_output in eval_tasks:
351
+ if "task_alias" in (task_config := task_output.task_config):
352
+ results[task_output.task_name]["alias"] = task_config["task_alias"]
353
+ else:
354
+ results[task_output.task_name]["alias"] = task_output.task_name
355
+ if group_alias := task_output.group_alias:
356
+ if group_alias not in results and (group_name := task_output.group_name):
357
+ results[group_name]["alias"] = group_alias
358
+ num_fewshot[task_output.task_name] = task_output.n_shot
359
+ configs[task_output.task_name] = task_output.task_config
360
+ versions[task_output.task_name] = task_output.version
361
+ samples[task_output.task_name] = task_output.logged_samples
362
+ higher_is_better[task_output.task_name] = task_output.task.higher_is_better()
363
+ for (metric, filter_key), items in task_output.sample_metrics.items():
364
+ metric_key = f"{metric},{filter_key}"
365
+ results[task_output.task_name][metric_key] = task_output.agg_metrics[
366
+ metric_key
367
+ ]
368
+ results[task_output.task_name]["samples"] = task_output.sample_len
369
+ results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = (
370
+ task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
371
+ )
372
+ return results, samples, configs, versions, num_fewshot, higher_is_better
373
+
374
+
375
+ def consolidate_group_results(
376
+ results,
377
+ versions,
378
+ task_dict,
379
+ task_root=None,
380
+ show_group_table=False,
381
+ task_aggregation_list=None,
382
+ ) -> Tuple[dict, dict, bool, Union[None,]]:
383
+ """
384
+ (Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info.
385
+
386
+ @return: a tuple [results, versions, show_group_table, task_aggregation_list] with formats described below:
387
+
388
+ - results: A defaultdict with task names (and, after this function is called, group names of
389
+ groups that perform aggregation) as keys, and dictionaries with "alias" and metric,filter_name pairs as keys.
390
+ - versions: A defaultdict with task names (and, after this function is called, group names of
391
+ groups that perform aggregation) as keys, and float values representing the task or group's version if a version is specified. (defaulting to None).
392
+ - show_group_table: a boolean which is true if there exists a group that requires printing of its aggregated scores in a group table.
393
+ - task_aggregation_list: a defaultdict listing the subtasks to average over to produce a given group's end metric.
394
+
395
+ The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
396
+ In the top-level invocation of this function, task_aggregation_list is ignored.
397
+ """
398
+ if task_root is None:
399
+ task_root = {}
400
+
401
+ if task_aggregation_list is None:
402
+ task_aggregation_list = {}
403
+
404
+ for group_or_task, group_or_task_info in task_dict.items():
405
+ # Convert to string
406
+ if isinstance(group_or_task, ConfigurableGroup):
407
+ group_config = group_or_task.config
408
+ group_or_task = group_or_task.group_name
409
+ else:
410
+ group_config = None
411
+
412
+ if isinstance(group_or_task_info, Task):
413
+ if task_root:
414
+ task_aggregation_list.setdefault(task_root, []).append(
415
+ group_or_task_info.task_name
416
+ )
417
+ else:
418
+ (
419
+ results,
420
+ versions,
421
+ show_group_table,
422
+ _task_aggregation_list,
423
+ ) = consolidate_group_results(
424
+ results,
425
+ versions,
426
+ group_or_task_info,
427
+ group_or_task,
428
+ show_group_table,
429
+ task_aggregation_list,
430
+ )
431
+ if task_root:
432
+ task_aggregation_list.setdefault(task_root, []).extend(
433
+ task_aggregation_list.get(group_or_task, [])
434
+ )
435
+
436
+ if (group_config is None) or (
437
+ group_config["aggregate_metric_list"] is None
438
+ ):
439
+ results[group_or_task][" "] = " "
440
+ continue
441
+
442
+ if "aggregate_metric_list" in group_config:
443
+ agg_metric_list = group_config["aggregate_metric_list"]
444
+
445
+ show_group_table = show_group_table | bool(
446
+ group_config["aggregate_metric_list"]
447
+ )
448
+
449
+ task_list = _task_aggregation_list[group_or_task]
450
+
451
+ metric_list = list(
452
+ {
453
+ key
454
+ for task in task_list
455
+ for key in results[task].keys()
456
+ if "_stderr" not in key and key not in ["task", "alias", "samples"]
457
+ }
458
+ )
459
+ for metric in metric_list:
460
+ stderr = "_stderr,".join(metric.split(","))
461
+
462
+ # gather metrics, sizes, and stderrs from subtasks
463
+ metrics = [
464
+ results[task][metric]
465
+ for task in task_list
466
+ if metric in results[task]
467
+ ] # TODO: copy?
468
+ stderrs = [
469
+ results[task][stderr]
470
+ for task in task_list
471
+ if stderr in results[task]
472
+ ]
473
+ sizes = [
474
+ results[task]["samples"]
475
+ for task in task_list
476
+ if metric in results[task]
477
+ ]
478
+
479
+ for metric_config in agg_metric_list:
480
+ for filter_name in metric_config["filter_list"]:
481
+ if metric != ",".join([metric_config["metric"], filter_name]):
482
+ continue
483
+
484
+ # compute group's pooled metric and stderr
485
+ if metric_config["aggregation"] == "mean":
486
+ aggregate_fn = aggregate_subtask_metrics
487
+ elif callable(metric_config["aggregation"]):
488
+ aggregate_fn = metric_config["aggregation"]
489
+ else:
490
+ raise ValueError(
491
+ f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'"
492
+ )
493
+
494
+ results[group_or_task][metric] = aggregate_fn(
495
+ metrics,
496
+ sizes,
497
+ metric_config["weight_by_size"],
498
+ )
499
+ # TODO: calculate groups' metrics using arbitrary agg fns
500
+ if "N/A" in stderrs:
501
+ results[group_or_task][stderr] = "N/A"
502
+ else:
503
+ # NOTE: this assumes we are using the mean to aggregate. There are warnings about this elsewhere
504
+ results[group_or_task][stderr] = pooled_sample_stderr(
505
+ stderrs, sizes
506
+ )
507
+
508
+ results[group_or_task]["samples"] = sum(sizes)
509
+ group_metadata = group_config.get("metadata", None)
510
+ if group_metadata is not None:
511
+ versions[group_or_task] = group_metadata.get("version", None)
512
+ # print(results)
513
+ return results, versions, show_group_table, task_aggregation_list
514
+
515
+
516
+ @positional_deprecated
517
+ def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
518
+ """
519
+ Search upward in the directory tree to a maximum of three layers
520
+ to find and return the package root (containing the 'tests' folder)
521
+ """
522
+ cur_path = start_path.resolve()
523
+ max_layers = 3
524
+ for _ in range(max_layers):
525
+ if (cur_path / "tests" / "test_version_stable.py").exists():
526
+ return cur_path
527
+ else:
528
+ cur_path = cur_path.parent.resolve()
529
+ raise FileNotFoundError(
530
+ f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
531
+ )
532
+
533
+
534
+ @positional_deprecated
535
+ def run_task_tests(task_list: List[str]):
536
+ """
537
+ Find the package root and run the tests for the given tasks
538
+ """
539
+ import pytest
540
+
541
+ package_root = find_test_root(start_path=pathlib.Path(__file__))
542
+ task_string = " or ".join(task_list)
543
+ args = [
544
+ f"{package_root}/tests/test_version_stable.py",
545
+ f"--rootdir={package_root}",
546
+ "-k",
547
+ f"{task_string}",
548
+ ]
549
+ sys.path.append(str(package_root))
550
+ pytest_return_val = pytest.main(args)
551
+ if pytest_return_val:
552
+ raise ValueError(
553
+ f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
554
+ )
lm-evaluation-harness/lm_eval/filters/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import List
3
+
4
+ from lm_eval.api.filter import FilterEnsemble
5
+ from lm_eval.api.registry import get_filter
6
+
7
+ from . import custom, extraction, selection, transformation
8
+
9
+
10
+ def build_filter_ensemble(
11
+ filter_name: str, components: List[List[str]]
12
+ ) -> FilterEnsemble:
13
+ """
14
+ Create a filtering pipeline.
15
+ """
16
+ filters = []
17
+ for function, kwargs in components:
18
+ if kwargs is None:
19
+ kwargs = {}
20
+ # create a filter given its name in the registry
21
+ f = partial(get_filter(function), **kwargs)
22
+ # add the filter as a pipeline step
23
+ filters.append(f)
24
+
25
+ return FilterEnsemble(name=filter_name, filters=filters)
lm-evaluation-harness/lm_eval/filters/custom.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lm_eval.api.filter import Filter
2
+ from lm_eval.api.registry import register_filter
3
+
4
+
5
+ @register_filter("custom")
6
+ class CustomFilter(Filter):
7
+ """
8
+ Custom filter that applies a custom, user-defined function to the model responses.
9
+ """
10
+
11
+ def __init__(self, **kwargs) -> None:
12
+ self.filter_fn = kwargs.pop("filter_fn")
13
+
14
+ super().__init__(**kwargs)
15
+
16
+ def apply(self, resps, docs):
17
+ return self.filter_fn(resps, docs)
lm-evaluation-harness/lm_eval/filters/decontamination.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lm_eval.api.filter import Filter
2
+ from lm_eval.api.registry import register_filter
3
+
4
+
5
+ @register_filter("decontaminate")
6
+ class DecontaminationFilter(Filter):
7
+ """
8
+ A filter which evaluates
9
+ """
10
+
11
+ name = "track_decontamination"
12
+
13
+ def __init__(self, path) -> None:
14
+ """
15
+
16
+ TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
17
+ should further cache result on a given (task_name, doc_id)
18
+ """
19
+ self._decontam_results = None
20
+
21
+ def apply(self, resps, docs) -> None:
22
+ """
23
+ Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
24
+ """
25
+ pass
lm-evaluation-harness/lm_eval/filters/extraction.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ import unicodedata
4
+
5
+ from lm_eval.api.filter import Filter
6
+ from lm_eval.api.registry import register_filter
7
+
8
+
9
+ @register_filter("regex")
10
+ class RegexFilter(Filter):
11
+ """A filter that extracts values from text using regex pattern matching.
12
+
13
+ This filter applies a regex pattern to each model response and extracts matched values.
14
+ If no match is found, returns a fallback value. Useful for extracting structured data
15
+ (like numbers) from unstructured model outputs.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
21
+ group_select: int = 0,
22
+ fallback: str = "[invalid]",
23
+ ) -> None:
24
+ """
25
+ pass a string `regex` to run `re.compile(r"regex")` on.
26
+ `fallback` defines the output returned if no matches for the regex are located.
27
+ """
28
+ self.regex_pattern = regex_pattern
29
+ self.regex = re.compile(regex_pattern)
30
+ self.group_select = group_select
31
+ self.fallback = fallback
32
+
33
+ def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
34
+ # here, we assume we have a list, in which each element is
35
+ # a list of model responses for some particular input/target pair.
36
+ # so we process each of these (same input/target response sets)
37
+ # independently (and keep them a list.)
38
+ def filter_set(inst):
39
+ filtered = []
40
+ for resp in inst:
41
+ match = self.regex.findall(resp)
42
+ if match:
43
+ match = match[self.group_select]
44
+ if isinstance(match, tuple):
45
+ match = [m for m in match if m]
46
+ if match:
47
+ match = match[0]
48
+ else:
49
+ match = self.fallback
50
+ match = match.strip()
51
+ else:
52
+ match = self.fallback
53
+ filtered.append(match)
54
+ return filtered
55
+
56
+ filtered_resps = list(map(lambda x: filter_set(x), resps))
57
+ return filtered_resps
58
+
59
+
60
+ @register_filter("regex_pos")
61
+ class POSFilter(Filter):
62
+ """ """
63
+
64
+ def __init__(
65
+ self,
66
+ regex_pattern: str = r"\['(.*?)'\]",
67
+ group_select=0,
68
+ fallback=None,
69
+ ) -> None:
70
+ """
71
+ pass a string `regex` to run `re.compile(r"regex")` on.
72
+ `fallback` defines the output returned if no matches for the regex are located.
73
+ """
74
+ if fallback is None:
75
+ fallback = ["invalid"]
76
+ self.regex_pattern = regex_pattern
77
+ self.regex = re.compile(regex_pattern)
78
+ self.group_select = group_select
79
+ self.fallback = fallback
80
+
81
+ def apply(self, resps, docs):
82
+ def extract_tagged_tokens(text):
83
+ # Extract tagged tokens list from text input using regex
84
+ tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text)
85
+ return [(token, pos) for token, pos in tokens]
86
+
87
+ def extract_pos_tags(result):
88
+ pos_tags = []
89
+ if isinstance(result, str):
90
+ result = extract_tagged_tokens(result)
91
+ pos_tags.extend(pos for _, pos in result)
92
+ return pos_tags if pos_tags else self.fallback
93
+
94
+ def filter_set(inst):
95
+ filtered = []
96
+ for resp in inst:
97
+ match = extract_pos_tags(resp)
98
+ filtered.append(match)
99
+ return filtered
100
+
101
+ filtered_resps = map(lambda x: filter_set(x), resps)
102
+
103
+ return filtered_resps
104
+
105
+
106
+ @register_filter("remove_whitespace")
107
+ class WhitespaceFilter(Filter):
108
+ """Filters out leading whitespace from responses."""
109
+
110
+ def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
111
+ def filter_set(inst):
112
+ filtered_resp = []
113
+ for resp in inst:
114
+ resp = resp.lstrip()
115
+ filtered_resp.append(resp)
116
+ return filtered_resp
117
+
118
+ filtered_resps = [filter_set(resp) for resp in resps]
119
+
120
+ return filtered_resps
121
+
122
+
123
+ @register_filter("multi_choice_regex")
124
+ class MultiChoiceRegexFilter(RegexFilter):
125
+ """
126
+ A filter used to extract a model's answer on multiple choice questions with
127
+ letter answers. assumes each document has a "choices" field
128
+ containing the list of answer choices and that the answer label symbols
129
+ are of the form (A), (B), (C), ... or A, B, C.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
135
+ group_select=0,
136
+ fallback: str = "[invalid]",
137
+ ignore_case=False,
138
+ ignore_punctuation=False,
139
+ regexes_to_ignore=None,
140
+ ) -> None:
141
+ """
142
+ regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
143
+ - step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
144
+ - step 2 : We parse the choice with regex: r's*([A-?])', where ? varies by number of choices.
145
+ group_select: Selects the (group_select)th match from the findall result.
146
+ ignore_case: Ignores the case during step 1 matching
147
+ ignore_punctuation: Remove the punctuation during step 1 matching
148
+ regexes_to_ignore: Remove these regexes during step 1 matching
149
+ """
150
+ super().__init__(regex_pattern, group_select, fallback)
151
+ self.ignore_case = ignore_case
152
+ self.ignore_punctuation = ignore_punctuation
153
+ self.regexes_to_ignore = regexes_to_ignore
154
+
155
+ def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
156
+ # here, we assume we have a list, in which each element is
157
+ # a list of model responses for some particular input/target pair.
158
+ # so we process each of these (same input/target response sets)
159
+ # independently (and keep them a list.)
160
+
161
+ def find_match(regex, resp, convert_dict={}):
162
+ match = regex.findall(resp)
163
+ if match:
164
+ match = match[self.group_select]
165
+ if isinstance(match, tuple):
166
+ match = [m for m in match if m][0]
167
+ match = match.strip()
168
+ if match and match in convert_dict:
169
+ match = convert_dict[match]
170
+ return match
171
+
172
+ punct_tbl = dict.fromkeys(
173
+ i
174
+ for i in range(sys.maxunicode)
175
+ if unicodedata.category(chr(i)).startswith("P")
176
+ )
177
+
178
+ def filter_ignores(st):
179
+ if self.regexes_to_ignore is not None:
180
+ for s in self.regexes_to_ignore:
181
+ st = re.sub(s, "", st)
182
+
183
+ if self.ignore_case:
184
+ st = st.lower()
185
+
186
+ if self.ignore_punctuation:
187
+ # https://stackoverflow.com/a/266162
188
+ st = st.translate(punct_tbl)
189
+ return st
190
+
191
+ filtered_resps = []
192
+
193
+ for r, doc in zip(resps, docs):
194
+ fallback_regexes = []
195
+ choice_to_alpha = {}
196
+ next_alpha = "A"
197
+
198
+ without_paren_fallback_regexes = []
199
+ without_paren_to_target = {}
200
+
201
+ choices = doc["choices"]
202
+ for c in choices:
203
+ m = filter_ignores(c.strip())
204
+ fallback_regexes.append(f"{re.escape(m)}")
205
+ choice_to_alpha[m] = f"({next_alpha})"
206
+
207
+ without_paren_fallback_regexes.append(next_alpha)
208
+ without_paren_to_target[next_alpha] = f"({next_alpha})"
209
+
210
+ next_alpha = chr(ord(next_alpha) + 1)
211
+ fallback_regex = re.compile("|".join(fallback_regexes))
212
+ without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
213
+ without_paren_fallback_regex = re.compile(
214
+ rf":[\s]*({without_paren_fallback_regex})"
215
+ )
216
+
217
+ filtered = []
218
+ for resp in r:
219
+ match = find_match(self.regex, resp)
220
+ if not match:
221
+ match = find_match(
222
+ fallback_regex, filter_ignores(resp), choice_to_alpha
223
+ )
224
+ if not match:
225
+ match = find_match(
226
+ without_paren_fallback_regex, resp, without_paren_to_target
227
+ )
228
+ if not match:
229
+ match = self.fallback
230
+ filtered.append(match)
231
+ filtered_resps.append(filtered)
232
+
233
+ return filtered_resps
lm-evaluation-harness/lm_eval/filters/selection.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+
3
+ from lm_eval.api.filter import Filter
4
+ from lm_eval.api.registry import register_filter
5
+
6
+
7
+ # TODO: implement "arg_max" filter. either it should take in an arbitrary "scoring"/reward function
8
+ # that takes an input and returns a scalar and then should select the max reward,
9
+ # or should implement different filters for different ways of handling a reward model's inference.
10
+
11
+
12
+ @register_filter("take_first")
13
+ class TakeFirstFilter(Filter):
14
+ def __init__(self) -> None:
15
+ """
16
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
17
+ """
18
+
19
+ def apply(self, resps, docs):
20
+ """
21
+ Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
22
+ """
23
+ return map(lambda r: r[0], resps)
24
+
25
+
26
+ @register_filter("take_first_k")
27
+ class TakeKFilter(Filter):
28
+ def __init__(self, **kwargs) -> None:
29
+ self.k = kwargs.pop("k")
30
+
31
+ super().__init__(**kwargs)
32
+
33
+ def apply(self, resps, docs):
34
+ # need resp to be subscriptable to check below
35
+ resps = list(resps)
36
+ # check we have at least k responses per doc, else we can't take the first k
37
+ assert len(resps[0]) >= self.k, (
38
+ f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
39
+ )
40
+ return map(lambda r: r[: self.k], resps)
41
+
42
+
43
+ @register_filter("majority_vote")
44
+ class MajorityVoteFilter(Filter):
45
+ def __init__(self) -> None:
46
+ """
47
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
48
+ """
49
+
50
+ def apply(self, resps, docs):
51
+ """
52
+ Each entry of `resps` is a list of model responses.
53
+ We select the response that occurs most frequently in each entry of `resps`.
54
+ """
55
+
56
+ def select_majority(resp):
57
+ counts = Counter(resp)
58
+ vote = counts.most_common(1)[0][0]
59
+ return vote
60
+
61
+ return map(lambda r: [select_majority(r)], resps)
lm-evaluation-harness/lm_eval/filters/transformation.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from lm_eval.api.filter import Filter
4
+ from lm_eval.api.registry import register_filter
5
+
6
+
7
+ @register_filter("lowercase")
8
+ class LowercaseFilter(Filter):
9
+ def __init__(self) -> None:
10
+ pass
11
+
12
+ def apply(self, resps, docs):
13
+ def filter_set(inst):
14
+ return [resp.lower() for resp in inst]
15
+
16
+ return [filter_set(resp) for resp in resps]
17
+
18
+
19
+ @register_filter("uppercase")
20
+ class UppercaseFilter(Filter):
21
+ def __init__(self) -> None:
22
+ pass
23
+
24
+ def apply(self, resps, docs):
25
+ def filter_set(inst):
26
+ return [resp.upper() for resp in inst]
27
+
28
+ return [filter_set(resp) for resp in resps]
29
+
30
+
31
+ @register_filter("map")
32
+ class MapFilter(Filter):
33
+ def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
34
+ """
35
+ Initializes the MapFilter with a given mapping dictionary and default value.
36
+
37
+ Args:
38
+ - mapping_dict (dict): A dictionary containing the key-value mappings.
39
+ Default is an empty dictionary.
40
+ - default_value (Any): The value to be returned when a key is not found in the mapping_dict.
41
+ Default is None.
42
+
43
+ Example:
44
+ mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
45
+ """
46
+ if mapping_dict is None:
47
+ mapping_dict = {}
48
+ assert isinstance(mapping_dict, dict), (
49
+ "Provided mapping_dict is not a dictionary"
50
+ )
51
+ self.mapping_dict = mapping_dict
52
+ self.default_value = default_value
53
+
54
+ def apply(self, resps, docs):
55
+ def filter_set(inst):
56
+ return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
57
+
58
+ return [filter_set(resp) for resp in resps]
59
+
60
+
61
+ @register_filter("format_span")
62
+ class SPANFilter(Filter):
63
+ def __init__(self) -> None:
64
+ pass
65
+
66
+ def apply(self, resps, docs):
67
+ def format_ner_text(text):
68
+ label_dict = {
69
+ "person": "PER",
70
+ "location": "LOC",
71
+ "organization": "ORG",
72
+ "counties": "LOC",
73
+ "places": "LOC",
74
+ "people": "PER",
75
+ "persons": "PER",
76
+ "company": "ORG",
77
+ "country": "LOC",
78
+ "continent": "LOC",
79
+ "time": "DATE",
80
+ "date": "DATE",
81
+ "per": "PER",
82
+ "loc": "LOC",
83
+ "org": "ORG",
84
+ }
85
+ text = text.lower()
86
+ for key, value in label_dict.items():
87
+ text = text.replace(key, value)
88
+
89
+ text = "$".join(i for i in text.split("$$"))
90
+ return text.rstrip("$$")
91
+
92
+ def format_named_entities(text):
93
+ """
94
+ Extract named entities from text and format them as 'label: value $$ label: value'.
95
+ Handles grouped entities (e.g., LOC: kenya, uganda) and excludes 'none' values.
96
+ """
97
+ # Regular expression to match label: entities pattern
98
+ pattern = r"\b(PER|LOC|ORG|DATE):\s*([^$]+)"
99
+ # Normalize newline characters
100
+ text = text.replace("\n", "$").strip()
101
+ matches = re.findall(pattern, text)
102
+
103
+ formatted_entities = []
104
+
105
+ for label, values in matches:
106
+ # Split multiple entities separated by commas and strip whitespace
107
+ entities = [value.strip() for value in values.split(",")]
108
+
109
+ # Exclude 'none' entities
110
+ for entity in entities:
111
+ if entity.lower() != "none":
112
+ formatted_entities.append(f"{label.lower()}: {entity}")
113
+
114
+ # Join entities with the desired separator
115
+ return " $ ".join(formatted_entities)
116
+
117
+ def filter_set(inst):
118
+ return [
119
+ format_named_entities(format_ner_text(resp.lower())) for resp in inst
120
+ ]
121
+
122
+ return [filter_set(resp) for resp in resps]
lm-evaluation-harness/lm_eval/loggers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .evaluation_tracker import EvaluationTracker
2
+ from .wandb_logger import WandbLogger
lm-evaluation-harness/lm_eval/loggers/evaluation_tracker.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import re
5
+ import time
6
+ from collections import defaultdict
7
+ from dataclasses import asdict, dataclass
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+
11
+ from datasets import load_dataset
12
+ from datasets.utils.metadata import MetadataConfigs
13
+ from huggingface_hub import (
14
+ DatasetCard,
15
+ DatasetCardData,
16
+ HfApi,
17
+ hf_hub_url,
18
+ )
19
+ from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
20
+
21
+ from lm_eval.utils import (
22
+ get_file_datetime,
23
+ get_file_task_name,
24
+ get_results_filenames,
25
+ get_sample_results_filenames,
26
+ handle_non_serializable,
27
+ hash_string,
28
+ sanitize_list,
29
+ sanitize_model_name,
30
+ sanitize_task_name,
31
+ )
32
+
33
+
34
+ eval_logger = logging.getLogger(__name__)
35
+
36
+
37
+ @dataclass(init=False)
38
+ class GeneralConfigTracker:
39
+ """
40
+ Tracker for the evaluation parameters.
41
+
42
+ Attributes:
43
+ model_source (str): Source of the model (e.g. Hugging Face, GGUF, etc.)
44
+ model_name (str): Name of the model.
45
+ model_name_sanitized (str): Sanitized model name for directory creation.
46
+ start_time (float): Start time of the experiment. Logged at class init.
47
+ end_time (float): Start time of the experiment. Logged when calling [`GeneralConfigTracker.log_end_time`]
48
+ total_evaluation_time_seconds (str): Inferred total evaluation time in seconds (from the start and end times).
49
+ """
50
+
51
+ model_source: str = None
52
+ model_name: str = None
53
+ model_name_sanitized: str = None
54
+ system_instruction: str = None
55
+ system_instruction_sha: str = None
56
+ fewshot_as_multiturn: bool = None
57
+ chat_template: str = None
58
+ chat_template_sha: str = None
59
+ start_time: float = None
60
+ end_time: float = None
61
+ total_evaluation_time_seconds: str = None
62
+
63
+ def __init__(self) -> None:
64
+ """Starts the evaluation timer."""
65
+ self.start_time = time.perf_counter()
66
+
67
+ @staticmethod
68
+ def _get_model_name(model_args: str) -> str:
69
+ """Extracts the model name from the model arguments."""
70
+
71
+ def extract_model_name(model_args: str, key: str) -> str:
72
+ """Extracts the model name from the model arguments using a key."""
73
+ args_after_key = model_args.split(key)[1]
74
+ return args_after_key.split(",")[0]
75
+
76
+ # order does matter, e.g. peft and delta are provided together with pretrained
77
+ prefixes = ["peft=", "delta=", "pretrained=", "model=", "path=", "engine="]
78
+ for prefix in prefixes:
79
+ if prefix in model_args:
80
+ return extract_model_name(model_args, prefix)
81
+ return ""
82
+
83
+ def log_experiment_args(
84
+ self,
85
+ model_source: str,
86
+ model_args: str,
87
+ system_instruction: str,
88
+ chat_template: str,
89
+ fewshot_as_multiturn: bool,
90
+ ) -> None:
91
+ """Logs model parameters and job ID."""
92
+ self.model_source = model_source
93
+ self.model_name = GeneralConfigTracker._get_model_name(model_args)
94
+ self.model_name_sanitized = sanitize_model_name(self.model_name)
95
+ self.system_instruction = system_instruction
96
+ self.system_instruction_sha = (
97
+ hash_string(system_instruction) if system_instruction else None
98
+ )
99
+ self.chat_template = chat_template
100
+ self.chat_template_sha = hash_string(chat_template) if chat_template else None
101
+ self.fewshot_as_multiturn = fewshot_as_multiturn
102
+
103
+ def log_end_time(self) -> None:
104
+ """Logs the end time of the evaluation and calculates the total evaluation time."""
105
+ self.end_time = time.perf_counter()
106
+ self.total_evaluation_time_seconds = str(self.end_time - self.start_time)
107
+
108
+
109
+ class EvaluationTracker:
110
+ """
111
+ Keeps track and saves relevant information of the evaluation process.
112
+ Compiles the data from trackers and writes it to files, which can be published to the Hugging Face hub if requested.
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ output_path: str = None,
118
+ hub_results_org: str = "",
119
+ hub_repo_name: str = "",
120
+ details_repo_name: str = "",
121
+ results_repo_name: str = "",
122
+ push_results_to_hub: bool = False,
123
+ push_samples_to_hub: bool = False,
124
+ public_repo: bool = False,
125
+ token: str = "",
126
+ leaderboard_url: str = "",
127
+ point_of_contact: str = "",
128
+ gated: bool = False,
129
+ ) -> None:
130
+ """
131
+ Creates all the necessary loggers for evaluation tracking.
132
+
133
+ Args:
134
+ output_path (str): Path to save the results. If not provided, the results won't be saved.
135
+ hub_results_org (str): The Hugging Face organization to push the results to. If not provided, the results will be pushed to the owner of the Hugging Face token.
136
+ hub_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will be pushed to `lm-eval-results`.
137
+ details_repo_name (str): The name of the Hugging Face repository to push the details to. If not provided, the results will be pushed to `lm-eval-results`.
138
+ result_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will not be pushed and will be found in the details_hub_repo.
139
+ push_results_to_hub (bool): Whether to push the results to the Hugging Face hub.
140
+ push_samples_to_hub (bool): Whether to push the samples to the Hugging Face hub.
141
+ public_repo (bool): Whether to push the results to a public or private repository.
142
+ token (str): Token to use when pushing to the Hugging Face hub. This token should have write access to `hub_results_org`.
143
+ leaderboard_url (str): URL to the leaderboard on the Hugging Face hub on the dataset card.
144
+ point_of_contact (str): Contact information on the Hugging Face hub dataset card.
145
+ gated (bool): Whether to gate the repository.
146
+ """
147
+ self.general_config_tracker = GeneralConfigTracker()
148
+
149
+ self.output_path = output_path
150
+ self.push_results_to_hub = push_results_to_hub
151
+ self.push_samples_to_hub = push_samples_to_hub
152
+ self.public_repo = public_repo
153
+ self.leaderboard_url = leaderboard_url
154
+ self.point_of_contact = point_of_contact
155
+ self.api = HfApi(token=token) if token else None
156
+ self.gated_repo = gated
157
+
158
+ if not self.api and (push_results_to_hub or push_samples_to_hub):
159
+ raise ValueError(
160
+ "Hugging Face token is not defined, but 'push_results_to_hub' or 'push_samples_to_hub' is set to True. "
161
+ "Please provide a valid Hugging Face token by setting the HF_TOKEN environment variable."
162
+ )
163
+
164
+ if (
165
+ self.api
166
+ and hub_results_org == ""
167
+ and (push_results_to_hub or push_samples_to_hub)
168
+ ):
169
+ hub_results_org = self.api.whoami()["name"]
170
+ eval_logger.warning(
171
+ f"hub_results_org was not specified. Results will be pushed to '{hub_results_org}'."
172
+ )
173
+
174
+ if hub_repo_name == "":
175
+ details_repo_name = (
176
+ details_repo_name if details_repo_name != "" else "lm-eval-results"
177
+ )
178
+ results_repo_name = (
179
+ results_repo_name if results_repo_name != "" else details_repo_name
180
+ )
181
+ else:
182
+ details_repo_name = hub_repo_name
183
+ results_repo_name = hub_repo_name
184
+ eval_logger.warning(
185
+ "hub_repo_name was specified. Both details and results will be pushed to the same repository. Using hub_repo_name is no longer recommended, details_repo_name and results_repo_name should be used instead."
186
+ )
187
+
188
+ self.details_repo = f"{hub_results_org}/{details_repo_name}"
189
+ self.details_repo_private = f"{hub_results_org}/{details_repo_name}-private"
190
+ self.results_repo = f"{hub_results_org}/{results_repo_name}"
191
+ self.results_repo_private = f"{hub_results_org}/{results_repo_name}-private"
192
+
193
+ def save_results_aggregated(
194
+ self,
195
+ results: dict,
196
+ samples: dict,
197
+ ) -> None:
198
+ """
199
+ Saves the aggregated results and samples to the output path and pushes them to the Hugging Face hub if requested.
200
+
201
+ Args:
202
+ results (dict): The aggregated results to save.
203
+ samples (dict): The samples results to save.
204
+ """
205
+ self.general_config_tracker.log_end_time()
206
+
207
+ if self.output_path:
208
+ try:
209
+ eval_logger.info("Saving results aggregated")
210
+
211
+ # calculate cumulative hash for each task - only if samples are provided
212
+ task_hashes = {}
213
+ if samples:
214
+ for task_name, task_samples in samples.items():
215
+ sample_hashes = [
216
+ s["doc_hash"] + s["prompt_hash"] + s["target_hash"]
217
+ for s in task_samples
218
+ ]
219
+ task_hashes[task_name] = hash_string("".join(sample_hashes))
220
+
221
+ # update initial results dict
222
+ results.update({"task_hashes": task_hashes})
223
+ results.update(asdict(self.general_config_tracker))
224
+ dumped = json.dumps(
225
+ results,
226
+ indent=2,
227
+ default=handle_non_serializable,
228
+ ensure_ascii=False,
229
+ )
230
+
231
+ path = Path(self.output_path if self.output_path else Path.cwd())
232
+ self.date_id = datetime.now().isoformat().replace(":", "-")
233
+ if path.suffix == ".json":
234
+ path.parent.mkdir(parents=True, exist_ok=True)
235
+ file_results_aggregated = path.with_name(
236
+ f"{path.stem}_{self.date_id}.json"
237
+ )
238
+ else:
239
+ path = path.joinpath(
240
+ self.general_config_tracker.model_name_sanitized
241
+ )
242
+ path.mkdir(parents=True, exist_ok=True)
243
+ file_results_aggregated = path.joinpath(
244
+ f"results_{self.date_id}.json"
245
+ )
246
+
247
+ file_results_aggregated.open("w", encoding="utf-8").write(dumped)
248
+
249
+ if self.api and self.push_results_to_hub:
250
+ repo_id = (
251
+ self.results_repo
252
+ if self.public_repo
253
+ else self.results_repo_private
254
+ )
255
+ self.api.create_repo(
256
+ repo_id=repo_id,
257
+ repo_type="dataset",
258
+ private=not self.public_repo,
259
+ exist_ok=True,
260
+ )
261
+ self.api.upload_file(
262
+ repo_id=repo_id,
263
+ path_or_fileobj=str(file_results_aggregated),
264
+ path_in_repo=os.path.join(
265
+ self.general_config_tracker.model_name,
266
+ file_results_aggregated.name,
267
+ ),
268
+ repo_type="dataset",
269
+ commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}",
270
+ )
271
+ eval_logger.info(
272
+ "Successfully pushed aggregated results to the Hugging Face Hub. "
273
+ f"You can find them at: {repo_id}"
274
+ )
275
+
276
+ except Exception as e:
277
+ eval_logger.warning("Could not save results aggregated")
278
+ eval_logger.info(repr(e))
279
+ else:
280
+ eval_logger.info(
281
+ "Output path not provided, skipping saving results aggregated"
282
+ )
283
+
284
+ def save_results_samples(
285
+ self,
286
+ task_name: str,
287
+ samples: dict,
288
+ ) -> None:
289
+ """
290
+ Saves the samples results to the output path and pushes them to the Hugging Face hub if requested.
291
+
292
+ Args:
293
+ task_name (str): The task name to save the samples for.
294
+ samples (dict): The samples results to save.
295
+ """
296
+ if self.output_path:
297
+ try:
298
+ eval_logger.info(f"Saving per-sample results for: {task_name}")
299
+
300
+ path = Path(self.output_path if self.output_path else Path.cwd())
301
+ if path.suffix == ".json":
302
+ path = path.parent
303
+ else:
304
+ path = path.joinpath(
305
+ self.general_config_tracker.model_name_sanitized
306
+ )
307
+ path.mkdir(parents=True, exist_ok=True)
308
+
309
+ file_results_samples = path.joinpath(
310
+ f"samples_{task_name}_{self.date_id}.jsonl"
311
+ )
312
+
313
+ for sample in samples:
314
+ # we first need to sanitize arguments and resps
315
+ # otherwise we won't be able to load the dataset
316
+ # using the datasets library
317
+ arguments = {}
318
+ for i, arg in enumerate(sample["arguments"]):
319
+ arguments[f"gen_args_{i}"] = {}
320
+ for j, tmp in enumerate(arg):
321
+ arguments[f"gen_args_{i}"][f"arg_{j}"] = tmp
322
+
323
+ sample["resps"] = sanitize_list(sample["resps"])
324
+ sample["filtered_resps"] = sanitize_list(sample["filtered_resps"])
325
+ sample["arguments"] = arguments
326
+ sample["target"] = str(sample["target"])
327
+
328
+ sample_dump = (
329
+ json.dumps(
330
+ sample,
331
+ default=handle_non_serializable,
332
+ ensure_ascii=False,
333
+ )
334
+ + "\n"
335
+ )
336
+
337
+ with open(file_results_samples, "a", encoding="utf-8") as f:
338
+ f.write(sample_dump)
339
+
340
+ if self.api and self.push_samples_to_hub:
341
+ repo_id = (
342
+ self.details_repo
343
+ if self.public_repo
344
+ else self.details_repo_private
345
+ )
346
+ self.api.create_repo(
347
+ repo_id=repo_id,
348
+ repo_type="dataset",
349
+ private=not self.public_repo,
350
+ exist_ok=True,
351
+ )
352
+ try:
353
+ if self.gated_repo:
354
+ headers = build_hf_headers()
355
+ r = get_session().put(
356
+ url=f"https://huggingface.co/api/datasets/{repo_id}/settings",
357
+ headers=headers,
358
+ json={"gated": "auto"},
359
+ )
360
+ hf_raise_for_status(r)
361
+ except Exception as e:
362
+ eval_logger.warning("Could not gate the repository")
363
+ eval_logger.info(repr(e))
364
+ self.api.upload_folder(
365
+ repo_id=repo_id,
366
+ folder_path=str(path),
367
+ path_in_repo=self.general_config_tracker.model_name_sanitized,
368
+ repo_type="dataset",
369
+ commit_message=f"Adding samples results for {task_name} to {self.general_config_tracker.model_name}",
370
+ )
371
+ eval_logger.info(
372
+ f"Successfully pushed sample results for task: {task_name} to the Hugging Face Hub. "
373
+ f"You can find them at: {repo_id}"
374
+ )
375
+
376
+ except Exception as e:
377
+ eval_logger.warning("Could not save sample results")
378
+ eval_logger.info(repr(e))
379
+ else:
380
+ eval_logger.info("Output path not provided, skipping saving sample results")
381
+
382
+ def recreate_metadata_card(self) -> None:
383
+ """
384
+ Creates a metadata card for the evaluation results dataset and pushes it to the Hugging Face hub.
385
+ """
386
+
387
+ eval_logger.info("Recreating metadata card")
388
+ repo_id = self.details_repo if self.public_repo else self.details_repo_private
389
+
390
+ files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset")
391
+ results_files = get_results_filenames(files_in_repo)
392
+ sample_files = get_sample_results_filenames(files_in_repo)
393
+
394
+ # Build a dictionary to store the latest evaluation datetime for:
395
+ # - Each tested model and its aggregated results
396
+ # - Each task and sample results, if existing
397
+ # i.e. {
398
+ # "org__model_name__gsm8k": "2021-09-01T12:00:00",
399
+ # "org__model_name__ifeval": "2021-09-01T12:00:00",
400
+ # "org__model_name__results": "2021-09-01T12:00:00"
401
+ # }
402
+ latest_task_results_datetime = defaultdict(lambda: datetime.min.isoformat())
403
+
404
+ for file_path in sample_files:
405
+ file_path = Path(file_path)
406
+ filename = file_path.name
407
+ model_name = file_path.parent
408
+ task_name = get_file_task_name(filename)
409
+ results_datetime = get_file_datetime(filename)
410
+ task_name_sanitized = sanitize_task_name(task_name)
411
+ # Results and sample results for the same model and task will have the same datetime
412
+ samples_key = f"{model_name}__{task_name_sanitized}"
413
+ results_key = f"{model_name}__results"
414
+ latest_datetime = max(
415
+ latest_task_results_datetime[samples_key],
416
+ results_datetime,
417
+ )
418
+ latest_task_results_datetime[samples_key] = latest_datetime
419
+ latest_task_results_datetime[results_key] = max(
420
+ latest_task_results_datetime[results_key],
421
+ latest_datetime,
422
+ )
423
+
424
+ # Create metadata card
425
+ card_metadata = MetadataConfigs()
426
+
427
+ # Add the latest aggregated results to the metadata card for easy access
428
+ for file_path in results_files:
429
+ file_path = Path(file_path)
430
+ results_filename = file_path.name
431
+ model_name = file_path.parent
432
+ eval_date = get_file_datetime(results_filename)
433
+ eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date)
434
+ results_filename = Path("**") / Path(results_filename).name
435
+ config_name = f"{model_name}__results"
436
+ sanitized_last_eval_date_results = re.sub(
437
+ r"[^\w\.]", "_", latest_task_results_datetime[config_name]
438
+ )
439
+
440
+ if eval_date_sanitized == sanitized_last_eval_date_results:
441
+ # Ensure that all results files are listed in the metadata card
442
+ current_results = card_metadata.get(config_name, {"data_files": []})
443
+ current_results["data_files"].append(
444
+ {"split": eval_date_sanitized, "path": [str(results_filename)]}
445
+ )
446
+ card_metadata[config_name] = current_results
447
+ # If the results file is the newest, update the "latest" field in the metadata card
448
+ card_metadata[config_name]["data_files"].append(
449
+ {"split": "latest", "path": [str(results_filename)]}
450
+ )
451
+
452
+ # Add the tasks details configs
453
+ for file_path in sample_files:
454
+ file_path = Path(file_path)
455
+ filename = file_path.name
456
+ model_name = file_path.parent
457
+ task_name = get_file_task_name(filename)
458
+ eval_date = get_file_datetime(filename)
459
+ task_name_sanitized = sanitize_task_name(task_name)
460
+ eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date)
461
+ results_filename = Path("**") / Path(filename).name
462
+ config_name = f"{model_name}__{task_name_sanitized}"
463
+ sanitized_last_eval_date_results = re.sub(
464
+ r"[^\w\.]", "_", latest_task_results_datetime[config_name]
465
+ )
466
+ if eval_date_sanitized == sanitized_last_eval_date_results:
467
+ # Ensure that all sample results files are listed in the metadata card
468
+ current_details_for_task = card_metadata.get(
469
+ config_name, {"data_files": []}
470
+ )
471
+ current_details_for_task["data_files"].append(
472
+ {"split": eval_date_sanitized, "path": [str(results_filename)]}
473
+ )
474
+ card_metadata[config_name] = current_details_for_task
475
+ # If the samples results file is the newest, update the "latest" field in the metadata card
476
+ card_metadata[config_name]["data_files"].append(
477
+ {"split": "latest", "path": [str(results_filename)]}
478
+ )
479
+
480
+ # Get latest results and extract info to update metadata card examples
481
+ latest_datetime = max(latest_task_results_datetime.values())
482
+ latest_model_name = max(
483
+ latest_task_results_datetime, key=lambda k: latest_task_results_datetime[k]
484
+ )
485
+ last_results_file = [
486
+ f for f in results_files if latest_datetime.replace(":", "-") in f
487
+ ][0]
488
+ last_results_file_path = hf_hub_url(
489
+ repo_id=repo_id, filename=last_results_file, repo_type="dataset"
490
+ )
491
+ latest_results_file = load_dataset(
492
+ "json", data_files=last_results_file_path, split="train"
493
+ )
494
+ results_dict = latest_results_file["results"][0]
495
+ new_dictionary = {"all": results_dict}
496
+ new_dictionary.update(results_dict)
497
+ results_string = json.dumps(new_dictionary, indent=4)
498
+
499
+ dataset_summary = (
500
+ "Dataset automatically created during the evaluation run of model "
501
+ )
502
+ if self.general_config_tracker.model_source == "hf":
503
+ dataset_summary += f"[{self.general_config_tracker.model_name}](https://huggingface.co/{self.general_config_tracker.model_name})\n"
504
+ else:
505
+ dataset_summary += f"{self.general_config_tracker.model_name}\n"
506
+ dataset_summary += (
507
+ f"The dataset is composed of {len(card_metadata) - 1} configuration(s), each one corresponding to one of the evaluated task.\n\n"
508
+ f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each "
509
+ 'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n'
510
+ 'An additional configuration "results" store all the aggregated results of the run.\n\n'
511
+ "To load the details from a run, you can for instance do the following:\n"
512
+ )
513
+ if self.general_config_tracker.model_source == "hf":
514
+ dataset_summary += (
515
+ "```python\nfrom datasets import load_dataset\n"
516
+ f'data = load_dataset(\n\t"{repo_id}",\n\tname="{latest_model_name}",\n\tsplit="latest"\n)\n```\n\n'
517
+ )
518
+ dataset_summary += (
519
+ "## Latest results\n\n"
520
+ f"These are the [latest results from run {latest_datetime}]({last_results_file_path.replace('/resolve/', '/blob/')}) "
521
+ "(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. "
522
+ 'You find each in the results and the "latest" split for each eval):\n\n'
523
+ f"```python\n{results_string}\n```"
524
+ )
525
+ card_data = DatasetCardData(
526
+ dataset_summary=dataset_summary,
527
+ repo_url=f"https://huggingface.co/{self.general_config_tracker.model_name}",
528
+ pretty_name=f"Evaluation run of {self.general_config_tracker.model_name}",
529
+ leaderboard_url=self.leaderboard_url,
530
+ point_of_contact=self.point_of_contact,
531
+ )
532
+ card_metadata.to_dataset_card_data(card_data)
533
+ card = DatasetCard.from_template(
534
+ card_data,
535
+ pretty_name=card_data.pretty_name,
536
+ )
537
+ card.push_to_hub(repo_id, repo_type="dataset")
lm-evaluation-harness/lm_eval/loggers/utils.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ import subprocess
5
+ from importlib.metadata import version
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ from torch.utils.collect_env import get_pretty_env_info
11
+ from transformers import __version__ as trans_version
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def remove_none_pattern(input_string: str) -> Tuple[str, bool]:
18
+ """Remove the ',none' substring from the input_string if it exists at the end.
19
+
20
+ Args:
21
+ input_string (str): The input string from which to remove the ',none' substring.
22
+
23
+ Returns:
24
+ Tuple[str, bool]: A tuple containing the modified input_string with the ',none' substring removed
25
+ and a boolean indicating whether the modification was made (True) or not (False).
26
+ """
27
+ # Define the pattern to match ',none' at the end of the string
28
+ pattern = re.compile(r",none$")
29
+
30
+ # Use sub() to replace ',none' with an empty string
31
+ result = re.sub(pattern, "", input_string)
32
+
33
+ # check if the input_string changed
34
+ removed = result != input_string
35
+
36
+ return result, removed
37
+
38
+
39
+ def _handle_non_serializable(o: Any) -> Union[int, str, list]:
40
+ """Handle non-serializable objects by converting them to serializable types.
41
+
42
+ Args:
43
+ o (Any): The object to be handled.
44
+
45
+ Returns:
46
+ Union[int, str, list]: The converted object. If the object is of type np.int64 or np.int32,
47
+ it will be converted to int. If the object is of type set, it will be converted
48
+ to a list. Otherwise, it will be converted to str.
49
+ """
50
+ if isinstance(o, np.int64) or isinstance(o, np.int32):
51
+ return int(o)
52
+ elif isinstance(o, set):
53
+ return list(o)
54
+ else:
55
+ return str(o)
56
+
57
+
58
+ def get_commit_from_path(repo_path: Union[Path, str]) -> Optional[str]:
59
+ try:
60
+ git_folder = Path(repo_path, ".git")
61
+ if git_folder.is_file():
62
+ git_folder = Path(
63
+ git_folder.parent,
64
+ git_folder.read_text(encoding="utf-8").split("\n")[0].split(" ")[-1],
65
+ )
66
+ if Path(git_folder, "HEAD").exists():
67
+ head_name = (
68
+ Path(git_folder, "HEAD")
69
+ .read_text(encoding="utf-8")
70
+ .split("\n")[0]
71
+ .split(" ")[-1]
72
+ )
73
+ head_ref = Path(git_folder, head_name)
74
+ git_hash = head_ref.read_text(encoding="utf-8").replace("\n", "")
75
+ else:
76
+ git_hash = None
77
+ except Exception as err:
78
+ logger.debug(
79
+ f"Failed to retrieve a Git commit hash from path: {str(repo_path)}. Error: {err}"
80
+ )
81
+ return None
82
+ return git_hash
83
+
84
+
85
+ def get_git_commit_hash():
86
+ """
87
+ Gets the git commit hash of your current repo (if it exists).
88
+ Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
89
+ """
90
+ try:
91
+ git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
92
+ git_hash = git_hash.decode()
93
+ except (subprocess.CalledProcessError, FileNotFoundError):
94
+ # FileNotFoundError occurs when git not installed on system
95
+ git_hash = get_commit_from_path(os.getcwd()) # git hash of repo if exists
96
+ return git_hash
97
+
98
+
99
+ def add_env_info(storage: Dict[str, Any]):
100
+ try:
101
+ pretty_env_info = get_pretty_env_info()
102
+ except Exception as err:
103
+ pretty_env_info = str(err)
104
+ try:
105
+ lm_eval_version = version("lm_eval")
106
+ except Exception as err:
107
+ lm_eval_version = str(err)
108
+ transformers_version = trans_version
109
+ upper_dir_commit = get_commit_from_path(
110
+ Path(os.getcwd(), "..")
111
+ ) # git hash of upper repo if exists
112
+ added_info = {
113
+ "pretty_env_info": pretty_env_info,
114
+ "transformers_version": transformers_version,
115
+ "lm_eval_version": lm_eval_version,
116
+ "upper_git_hash": upper_dir_commit, # in case this repo is submodule
117
+ }
118
+ storage.update(added_info)
119
+
120
+
121
+ def add_tokenizer_info(storage: Dict[str, Any], lm):
122
+ if getattr(lm, "tokenizer", False):
123
+ try:
124
+ tokenizer_info = {
125
+ "tokenizer_pad_token": [
126
+ lm.tokenizer.pad_token,
127
+ str(lm.tokenizer.pad_token_id),
128
+ ],
129
+ "tokenizer_eos_token": [
130
+ lm.tokenizer.eos_token,
131
+ str(lm.tokenizer.eos_token_id),
132
+ ],
133
+ "tokenizer_bos_token": [
134
+ lm.tokenizer.bos_token,
135
+ str(lm.tokenizer.bos_token_id),
136
+ ],
137
+ "eot_token_id": getattr(lm, "eot_token_id", None),
138
+ "max_length": getattr(lm, "max_length", None),
139
+ }
140
+ storage.update(tokenizer_info)
141
+ except Exception as err:
142
+ logger.debug(
143
+ f"Logging detailed tokenizer info failed with {err}, skipping..."
144
+ )
145
+ # seems gguf and textsynth do not have tokenizer
146
+ else:
147
+ logger.debug(
148
+ "LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results."
149
+ )
lm-evaluation-harness/lm_eval/loggers/wandb_logger.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import logging
4
+ from typing import Any, Dict, List, Literal, Tuple
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from packaging.version import Version
9
+
10
+ from lm_eval.loggers.utils import _handle_non_serializable, remove_none_pattern
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def get_wandb_printer() -> Literal["Printer"]:
17
+ """Returns a wandb printer instance for pretty stdout."""
18
+ from wandb.sdk.lib.printer import new_printer
19
+
20
+ printer = new_printer()
21
+ return printer
22
+
23
+
24
+ class WandbLogger:
25
+ def __init__(self, init_args=None, config_args=None) -> None:
26
+ """Attaches to wandb logger if already initialized. Otherwise, passes init_args to wandb.init() and config_args to wandb.config.update()
27
+
28
+ Args:
29
+ init_args Optional[Dict]: Arguments for init configuration.
30
+ config_args Optional[Dict]: Arguments for config
31
+
32
+ Parse and log the results returned from evaluator.simple_evaluate() with:
33
+ wandb_logger.post_init(results)
34
+ wandb_logger.log_eval_result()
35
+ wandb_logger.log_eval_samples(results["samples"])
36
+ """
37
+ try:
38
+ import wandb
39
+
40
+ assert Version(wandb.__version__) >= Version("0.13.6")
41
+ if Version(wandb.__version__) < Version("0.13.6"):
42
+ wandb.require("report-editing:v0")
43
+ except Exception as e:
44
+ logger.warning(
45
+ "To use the wandb reporting functionality please install wandb>=0.13.6.\n"
46
+ "To install the latest version of wandb run `pip install wandb --upgrade`\n"
47
+ f"{e}"
48
+ )
49
+
50
+ self.wandb_args: Dict[str, Any] = init_args or {}
51
+ self.wandb_config_args: Dict[str, Any] = config_args or {}
52
+
53
+ # pop the step key from the args to save for all logging calls
54
+ self.step = self.wandb_args.pop("step", None)
55
+
56
+ # initialize a W&B run
57
+ if wandb.run is None:
58
+ self.run = wandb.init(**self.wandb_args)
59
+ if self.wandb_config_args:
60
+ self.run.config.update(self.wandb_config_args)
61
+ else:
62
+ self.run = wandb.run
63
+
64
+ self.printer = get_wandb_printer()
65
+
66
+ def post_init(self, results: Dict[str, Any]) -> None:
67
+ self.results: Dict[str, Any] = copy.deepcopy(results)
68
+ self.task_names: List[str] = list(results.get("results", {}).keys())
69
+ self.group_names: List[str] = list(results.get("groups", {}).keys())
70
+
71
+ def _get_config(self) -> Dict[str, Any]:
72
+ """Get configuration parameters."""
73
+ self.task_configs = self.results.get("configs", {})
74
+ cli_configs = self.results.get("config", {})
75
+ configs = {
76
+ "task_configs": self.task_configs,
77
+ "cli_configs": cli_configs,
78
+ }
79
+
80
+ return configs
81
+
82
+ def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]:
83
+ """Sanitize the results dictionary."""
84
+ _results = copy.deepcopy(self.results.get("results", dict()))
85
+
86
+ # Remove None from the metric string name
87
+ tmp_results = copy.deepcopy(_results)
88
+ for task_name in self.task_names:
89
+ task_result = tmp_results.get(task_name, dict())
90
+ for metric_name, metric_value in task_result.items():
91
+ _metric_name, removed = remove_none_pattern(metric_name)
92
+ if removed:
93
+ _results[task_name][_metric_name] = metric_value
94
+ _results[task_name].pop(metric_name)
95
+
96
+ # remove string valued keys from the results dict
97
+ wandb_summary = {}
98
+ for task in self.task_names:
99
+ task_result = _results.get(task, dict())
100
+ for metric_name, metric_value in task_result.items():
101
+ if isinstance(metric_value, str):
102
+ wandb_summary[f"{task}/{metric_name}"] = metric_value
103
+
104
+ for summary_metric, summary_value in wandb_summary.items():
105
+ _task, _summary_metric = summary_metric.split("/")
106
+ _results[_task].pop(_summary_metric)
107
+
108
+ tmp_results = copy.deepcopy(_results)
109
+ for task_name, task_results in tmp_results.items():
110
+ for metric_name, metric_value in task_results.items():
111
+ _results[f"{task_name}/{metric_name}"] = metric_value
112
+ _results[task_name].pop(metric_name)
113
+ for task in self.task_names:
114
+ _results.pop(task)
115
+
116
+ return wandb_summary, _results
117
+
118
+ def _log_results_as_table(self) -> None:
119
+ """Generate and log evaluation results as a table to W&B."""
120
+ columns = [
121
+ "Version",
122
+ "Filter",
123
+ "num_fewshot",
124
+ "Metric",
125
+ "Value",
126
+ "Stderr",
127
+ ]
128
+
129
+ def make_table(columns: List[str], key: str = "results"):
130
+ import wandb
131
+
132
+ table = wandb.Table(columns=columns)
133
+ results = copy.deepcopy(self.results)
134
+
135
+ for k, dic in results.get(key).items():
136
+ if k in self.group_names and not key == "groups":
137
+ continue
138
+ version = results.get("versions").get(k)
139
+ if version == "N/A":
140
+ version = None
141
+ n = results.get("n-shot").get(k)
142
+
143
+ for (mf), v in dic.items():
144
+ m, _, f = mf.partition(",")
145
+ if m.endswith("_stderr"):
146
+ continue
147
+ if m == "alias":
148
+ continue
149
+
150
+ if m + "_stderr" + "," + f in dic:
151
+ se = dic[m + "_stderr" + "," + f]
152
+ if se != "N/A":
153
+ se = "%.4f" % se
154
+ table.add_data(*[k, version, f, n, m, str(v), str(se)])
155
+ else:
156
+ table.add_data(*[k, version, f, n, m, str(v), ""])
157
+
158
+ return table
159
+
160
+ # log the complete eval result to W&B Table
161
+ table = make_table(["Tasks"] + columns, "results")
162
+ self.run.log({"evaluation/eval_results": table}, step=self.step)
163
+
164
+ if "groups" in self.results.keys():
165
+ table = make_table(["Groups"] + columns, "groups")
166
+ self.run.log({"evaluation/group_eval_results": table}, step=self.step)
167
+
168
+ def _log_results_as_artifact(self) -> None:
169
+ """Log results as JSON artifact to W&B."""
170
+ import wandb
171
+
172
+ dumped = json.dumps(
173
+ self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False
174
+ )
175
+ artifact = wandb.Artifact("results", type="eval_results")
176
+ with artifact.new_file("results.json", mode="w", encoding="utf-8") as f:
177
+ f.write(dumped)
178
+ self.run.log_artifact(artifact)
179
+
180
+ def log_eval_result(self) -> None:
181
+ """Log evaluation results to W&B."""
182
+ # Log configs to wandb
183
+ configs = self._get_config()
184
+ self.run.config.update(configs, allow_val_change=self.step is not None)
185
+
186
+ wandb_summary, self.wandb_results = self._sanitize_results_dict()
187
+ # update wandb.run.summary with items that were removed
188
+ self.run.summary.update(wandb_summary)
189
+ # Log the evaluation metrics to wandb
190
+ self.run.log(self.wandb_results, step=self.step)
191
+ # Log the evaluation metrics as W&B Table
192
+ self._log_results_as_table()
193
+ # Log the results dict as json to W&B Artifacts
194
+ self._log_results_as_artifact()
195
+
196
+ def _generate_dataset(
197
+ self, data: List[Dict[str, Any]], config: Dict[str, Any]
198
+ ) -> pd.DataFrame:
199
+ """Generate a dataset from evaluation data.
200
+
201
+ Args:
202
+ data (List[Dict[str, Any]]): The data to generate a dataset for.
203
+ config (Dict[str, Any]): The configuration of the task.
204
+
205
+ Returns:
206
+ pd.DataFrame: A dataframe that is ready to be uploaded to W&B.
207
+ """
208
+ ids = [x["doc_id"] for x in data]
209
+ labels = [x["target"] for x in data]
210
+ instance = [""] * len(ids)
211
+ resps = [""] * len(ids)
212
+ filtered_resps = [""] * len(ids)
213
+ model_outputs = {}
214
+
215
+ metrics_list = config["metric_list"]
216
+ metrics = {}
217
+ for metric in metrics_list:
218
+ metric = metric.get("metric")
219
+ if metric in ["word_perplexity", "byte_perplexity", "bits_per_byte"]:
220
+ metrics[f"{metric}_loglikelihood"] = [x[metric][0] for x in data]
221
+ if metric in ["byte_perplexity", "bits_per_byte"]:
222
+ metrics[f"{metric}_bytes"] = [x[metric][1] for x in data]
223
+ else:
224
+ metrics[f"{metric}_words"] = [x[metric][1] for x in data]
225
+ else:
226
+ metrics[metric] = [x[metric] for x in data]
227
+
228
+ if config["output_type"] == "loglikelihood":
229
+ instance = [x["arguments"][0][0] for x in data]
230
+ labels = [x["arguments"][0][1] for x in data]
231
+ resps = [
232
+ f"log probability of continuation is {x['resps'][0][0][0]} "
233
+ + "\n\n"
234
+ + "continuation will {} generated with greedy sampling".format(
235
+ "not be" if not x["resps"][0][0][1] else "be"
236
+ )
237
+ for x in data
238
+ ]
239
+ filtered_resps = [
240
+ f"log probability of continuation is {x['filtered_resps'][0][0]} "
241
+ + "\n\n"
242
+ + "continuation will {} generated with greedy sampling".format(
243
+ "not be" if not x["filtered_resps"][0][1] else "be"
244
+ )
245
+ for x in data
246
+ ]
247
+ elif config["output_type"] == "multiple_choice":
248
+ instance = [x["arguments"][0][0] for x in data]
249
+ choices = [
250
+ "\n".join([f"{idx}. {y[1]}" for idx, y in enumerate(x["arguments"])])
251
+ for x in data
252
+ ]
253
+ resps = [np.argmax([n[0][0] for n in x["resps"]]) for x in data]
254
+ filtered_resps = [
255
+ np.argmax([n[0] for n in x["filtered_resps"]]) for x in data
256
+ ]
257
+ elif config["output_type"] == "loglikelihood_rolling":
258
+ instance = [x["arguments"][0][0] for x in data]
259
+ resps = [x["resps"][0][0] for x in data]
260
+ filtered_resps = [x["filtered_resps"][0] for x in data]
261
+ elif config["output_type"] == "generate_until":
262
+ instance = [x["arguments"][0][0] for x in data]
263
+ resps = [x["resps"][0][0] for x in data]
264
+ filtered_resps = [x["filtered_resps"][0] for x in data]
265
+
266
+ model_outputs["raw_predictions"] = resps
267
+ model_outputs["filtered_predictions"] = filtered_resps
268
+
269
+ df_data = {
270
+ "id": ids,
271
+ "data": instance,
272
+ }
273
+ if config["output_type"] == "multiple_choice":
274
+ df_data["choices"] = choices
275
+
276
+ tmp_data = {
277
+ "input_len": [len(x) for x in instance],
278
+ "labels": labels,
279
+ "output_type": config["output_type"],
280
+ }
281
+ df_data.update(tmp_data)
282
+ df_data.update(model_outputs)
283
+ df_data.update(metrics)
284
+
285
+ return pd.DataFrame(df_data)
286
+
287
+ def _log_samples_as_artifact(
288
+ self, data: List[Dict[str, Any]], task_name: str
289
+ ) -> None:
290
+ import wandb
291
+
292
+ # log the samples as an artifact
293
+ dumped = json.dumps(
294
+ data,
295
+ indent=2,
296
+ default=_handle_non_serializable,
297
+ ensure_ascii=False,
298
+ )
299
+ artifact = wandb.Artifact(f"{task_name}", type="samples_by_task")
300
+ with artifact.new_file(
301
+ f"{task_name}_eval_samples.json", mode="w", encoding="utf-8"
302
+ ) as f:
303
+ f.write(dumped)
304
+ self.run.log_artifact(artifact)
305
+ # artifact.wait()
306
+
307
+ def log_eval_samples(self, samples: Dict[str, List[Dict[str, Any]]]) -> None:
308
+ """Log evaluation samples to W&B.
309
+
310
+ Args:
311
+ samples (Dict[str, List[Dict[str, Any]]]): Evaluation samples for each task.
312
+ """
313
+ task_names: List[str] = [
314
+ x for x in self.task_names if x not in self.group_names
315
+ ]
316
+
317
+ ungrouped_tasks = []
318
+ tasks_by_groups = {}
319
+
320
+ for task_name in task_names:
321
+ group_names = self.task_configs[task_name].get("group", None)
322
+ if group_names:
323
+ if isinstance(group_names, str):
324
+ group_names = [group_names]
325
+
326
+ for group_name in group_names:
327
+ if not tasks_by_groups.get(group_name):
328
+ tasks_by_groups[group_name] = [task_name]
329
+ else:
330
+ tasks_by_groups[group_name].append(task_name)
331
+ else:
332
+ ungrouped_tasks.append(task_name)
333
+
334
+ for task_name in ungrouped_tasks:
335
+ eval_preds = samples[task_name]
336
+
337
+ # log the samples as a W&B Table
338
+ df = self._generate_dataset(eval_preds, self.task_configs.get(task_name))
339
+ self.run.log({f"{task_name}_eval_results": df}, step=self.step)
340
+
341
+ # log the samples as a json file as W&B Artifact
342
+ self._log_samples_as_artifact(eval_preds, task_name)
343
+
344
+ for group, grouped_tasks in tasks_by_groups.items():
345
+ grouped_df = pd.DataFrame()
346
+ for task_name in grouped_tasks:
347
+ eval_preds = samples[task_name]
348
+ df = self._generate_dataset(
349
+ eval_preds, self.task_configs.get(task_name)
350
+ )
351
+ df["group"] = group
352
+ df["task"] = task_name
353
+ grouped_df = pd.concat([grouped_df, df], ignore_index=True)
354
+
355
+ # log the samples as a json file as W&B Artifact
356
+ self._log_samples_as_artifact(eval_preds, task_name)
357
+
358
+ self.run.log({f"{group}_eval_results": grouped_df}, step=self.step)
lm-evaluation-harness/lm_eval/models/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import (
2
+ anthropic_llms,
3
+ api_models,
4
+ dummy,
5
+ gguf,
6
+ hf_audiolm,
7
+ hf_steered,
8
+ hf_vlms,
9
+ huggingface,
10
+ ibm_watsonx_ai,
11
+ mamba_lm,
12
+ nemo_lm,
13
+ neuron_optimum,
14
+ openai_completions,
15
+ optimum_ipex,
16
+ optimum_lm,
17
+ sglang_causallms,
18
+ sglang_generate_API,
19
+ textsynth,
20
+ vllm_causallms,
21
+ vllm_vlms,
22
+ )
23
+
24
+
25
+ # TODO: implement __all__
26
+
27
+
28
+ try:
29
+ # enable hf hub transfer if available
30
+ import hf_transfer # type: ignore # noqa
31
+ import huggingface_hub.constants # type: ignore
32
+
33
+ huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
34
+ except ImportError:
35
+ pass
lm-evaluation-harness/lm_eval/models/anthropic_llms.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from functools import cached_property
4
+ from typing import Any, Dict, List, Tuple, Union
5
+
6
+ from tqdm import tqdm
7
+
8
+ from lm_eval.api.model import LM
9
+ from lm_eval.api.registry import register_model
10
+ from lm_eval.models.openai_completions import LocalCompletionsAPI
11
+ from lm_eval.models.utils import handle_stop_sequences, retry_on_specific_exceptions
12
+
13
+
14
+ eval_logger = logging.getLogger(__name__)
15
+
16
+
17
+ def anthropic_completion(
18
+ client, #: anthropic.Anthropic,
19
+ model: str,
20
+ prompt: str,
21
+ max_tokens_to_sample: int,
22
+ temperature: float,
23
+ stop: List[str],
24
+ **kwargs: Any,
25
+ ) -> str:
26
+ """Wrapper function around the Anthropic completion API client with exponential back-off
27
+ in case of RateLimitError.
28
+
29
+ params:
30
+ client: anthropic.Anthropic
31
+ Anthropic API client
32
+ model: str
33
+ Anthropic model e.g. 'claude-instant-v1', 'claude-2'
34
+ prompt: str
35
+ Prompt to feed to the model
36
+ max_tokens_to_sample: int
37
+ Maximum number of tokens to sample from the model
38
+ temperature: float
39
+ Sampling temperature
40
+ stop: List[str]
41
+ List of stop sequences
42
+ kwargs: Any
43
+ Additional model_args to pass to the API client
44
+ """
45
+
46
+ try:
47
+ import anthropic
48
+ except ModuleNotFoundError as exception:
49
+ raise type(exception)(
50
+ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
51
+ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
52
+ )
53
+
54
+ def _exception_callback(e: Exception, sleep_time: float) -> None:
55
+ eval_logger.warning(
56
+ f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
57
+ )
58
+
59
+ @retry_on_specific_exceptions(
60
+ on_exceptions=[anthropic.RateLimitError],
61
+ max_retries=None, # retry forever, consider changing
62
+ on_exception_callback=_exception_callback,
63
+ )
64
+ def completion():
65
+ response = client.completions.create(
66
+ prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
67
+ model=model,
68
+ # NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
69
+ # (e.g. gsm8k's ":") may truncate a lot of the input.
70
+ stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
71
+ max_tokens_to_sample=max_tokens_to_sample,
72
+ temperature=temperature,
73
+ **kwargs,
74
+ )
75
+ return response.completion
76
+
77
+ return completion()
78
+
79
+
80
+ def anthropic_chat(
81
+ client, #: anthropic.Anthropic,
82
+ model: str,
83
+ prompt: str,
84
+ max_tokens: int,
85
+ temperature: float,
86
+ stop: List[str],
87
+ **kwargs: Any,
88
+ ) -> str:
89
+ """Wrapper function around the Anthropic completion API client with exponential back-off
90
+ in case of RateLimitError.
91
+
92
+ params:
93
+ client: anthropic.Anthropic
94
+ Anthropic API client
95
+ model: str
96
+ Anthropic model e.g. 'claude-3-opus-20240229', 'claude-3-sonnet-20240229'
97
+ prompt: str
98
+ Prompt to feed to the model
99
+ max_tokens: int
100
+ Maximum number of tokens to sample from the model
101
+ temperature: float
102
+ Sampling temperature
103
+ stop: List[str]
104
+ List of stop sequences
105
+ kwargs: Any
106
+ Additional model_args to pass to the API client
107
+ """
108
+
109
+ try:
110
+ import anthropic
111
+ except ModuleNotFoundError as exception:
112
+ raise type(exception)(
113
+ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
114
+ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
115
+ )
116
+
117
+ def _exception_callback(e: Exception, sleep_time: float) -> None:
118
+ eval_logger.warning(
119
+ f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
120
+ )
121
+
122
+ @retry_on_specific_exceptions(
123
+ on_exceptions=[
124
+ anthropic.RateLimitError,
125
+ anthropic.APIConnectionError,
126
+ anthropic.APIStatusError,
127
+ ],
128
+ max_retries=None, # retry forever, consider changing
129
+ on_exception_callback=_exception_callback,
130
+ )
131
+ def messages():
132
+ response = client.messages.create(
133
+ model=model,
134
+ max_tokens=max_tokens,
135
+ temperature=temperature,
136
+ messages=[{"role": "user", "content": f"{prompt}"}],
137
+ **kwargs,
138
+ )
139
+ return response.content[0].text
140
+
141
+ return messages()
142
+
143
+
144
+ @register_model("anthropic-completions")
145
+ class AnthropicLM(LM):
146
+ REQ_CHUNK_SIZE = 20 # TODO: not used
147
+
148
+ def __init__(
149
+ self,
150
+ batch_size: int = 1,
151
+ model: str = "claude-2.0",
152
+ max_tokens_to_sample: int = 256,
153
+ temperature: float = 0, # defaults to 1
154
+ **kwargs, # top_p, top_k, etc.
155
+ ) -> None:
156
+ """Anthropic API wrapper.
157
+
158
+ :param model: str
159
+ Anthropic model e.g. 'claude-instant-v1', 'claude-2'
160
+ :param max_tokens_to_sample: int
161
+ Maximum number of tokens to sample from the model
162
+ :param temperature: float
163
+ Sampling temperature
164
+ :param kwargs: Any
165
+ Additional model_args to pass to the API client
166
+ """
167
+ super().__init__()
168
+
169
+ try:
170
+ import anthropic
171
+ except ModuleNotFoundError as exception:
172
+ raise type(exception)(
173
+ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
174
+ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
175
+ )
176
+
177
+ self.model = model
178
+ # defaults to os.environ.get("ANTHROPIC_API_KEY")
179
+ self.client = anthropic.Anthropic()
180
+ self.temperature = temperature
181
+ self.max_tokens_to_sample = max_tokens_to_sample
182
+ self.tokenizer = self.client.get_tokenizer()
183
+ self.kwargs = kwargs
184
+
185
+ @property
186
+ def eot_token_id(self):
187
+ # Not sure but anthropic.HUMAN_PROMPT ?
188
+ raise NotImplementedError("No idea about anthropic tokenization.")
189
+
190
+ @property
191
+ def max_length(self) -> int:
192
+ return 2048
193
+
194
+ @property
195
+ def max_gen_toks(self) -> int:
196
+ return self.max_tokens_to_sample
197
+
198
+ @property
199
+ def batch_size(self):
200
+ # Isn't used because we override _loglikelihood_tokens
201
+ raise NotImplementedError("No support for logits.")
202
+
203
+ @property
204
+ def device(self):
205
+ # Isn't used because we override _loglikelihood_tokens
206
+ raise NotImplementedError("No support for logits.")
207
+
208
+ def tok_encode(self, string: str) -> List[int]:
209
+ return self.tokenizer.encode(string).ids
210
+
211
+ def tok_decode(self, tokens: List[int]) -> str:
212
+ return self.tokenizer.decode(tokens)
213
+
214
+ def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
215
+ raise NotImplementedError("No support for logits.")
216
+
217
+ def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
218
+ try:
219
+ import anthropic
220
+ except ModuleNotFoundError as exception:
221
+ raise type(exception)(
222
+ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
223
+ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
224
+ )
225
+
226
+ if not requests:
227
+ return []
228
+
229
+ _requests: List[Tuple[str, dict]] = [req.args for req in requests]
230
+
231
+ res = []
232
+ for request in tqdm(_requests, disable=disable_tqdm):
233
+ try:
234
+ inp = request[0]
235
+ request_args = request[1]
236
+ # generation_kwargs
237
+ until = request_args.get("until")
238
+ max_gen_toks = request_args.get("max_gen_toks", self.max_length)
239
+ temperature = request_args.get("temperature", self.temperature)
240
+ response = anthropic_completion(
241
+ client=self.client,
242
+ model=self.model,
243
+ prompt=inp,
244
+ max_tokens_to_sample=max_gen_toks,
245
+ temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
246
+ stop=until, # type: ignore
247
+ **self.kwargs,
248
+ )
249
+ res.append(response)
250
+
251
+ self.cache_hook.add_partial("generate_until", request, response)
252
+ except anthropic.APIConnectionError as e: # type: ignore # noqa: F821
253
+ eval_logger.critical(f"Server unreachable: {e.__cause__}")
254
+ break
255
+ except anthropic.APIStatusError as e: # type: ignore # noqa: F821
256
+ eval_logger.critical(f"API error {e.status_code}: {e.message}")
257
+ break
258
+
259
+ return res
260
+
261
+ def _model_call(self, inps):
262
+ # Isn't used because we override _loglikelihood_tokens
263
+ raise NotImplementedError()
264
+
265
+ def _model_generate(self, context, max_length, eos_token_id):
266
+ # Isn't used because we override generate_until
267
+ raise NotImplementedError()
268
+
269
+ def loglikelihood(self, requests, disable_tqdm: bool = False):
270
+ raise NotImplementedError("No support for logits.")
271
+
272
+ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
273
+ raise NotImplementedError("No support for logits.")
274
+
275
+
276
+ @register_model("anthropic-chat", "anthropic-chat-completions")
277
+ class AnthropicChat(LocalCompletionsAPI):
278
+ def __init__(
279
+ self,
280
+ base_url="https://api.anthropic.com/v1/messages",
281
+ tokenizer_backend=None,
282
+ **kwargs,
283
+ ):
284
+ super().__init__(
285
+ base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
286
+ )
287
+ eval_logger.warning(
288
+ "Chat completions does not support batching. Defaulting to batch size 1."
289
+ )
290
+ self._batch_size = 1
291
+ self.anthropic_version = "2023-06-01"
292
+ eval_logger.warning(
293
+ f"Using Anthropic Version: {self.anthropic_version}. Confirm the current version here: https://docs.anthropic.com/en/api/versioning"
294
+ )
295
+
296
+ @cached_property
297
+ def api_key(self):
298
+ """Override this property to return the API key for the API request."""
299
+ key = os.environ.get("ANTHROPIC_API_KEY", None)
300
+ if key is None:
301
+ raise ValueError(
302
+ "API key not found. Please set the ANTHROPIC_API_KEY environment variable."
303
+ )
304
+ return key
305
+
306
+ @cached_property
307
+ def header(self):
308
+ return {
309
+ "x-api-key": f"{self.api_key}",
310
+ "anthropic-version": self.anthropic_version,
311
+ }
312
+
313
+ def _create_payload(
314
+ self,
315
+ messages: List[Dict],
316
+ generate=True,
317
+ gen_kwargs: dict = None,
318
+ eos="\n\nHuman:",
319
+ **kwargs,
320
+ ) -> dict:
321
+ system = (
322
+ messages[0].get("content") if messages[0].get("role") == "system" else None
323
+ )
324
+ if system:
325
+ messages = messages[1:]
326
+
327
+ cleaned_messages = []
328
+ for msg in messages:
329
+ cleaned_msg = {
330
+ "role": msg["role"],
331
+ "content": [
332
+ {"type": msg["type"], msg["type"]: msg["content"]},
333
+ ],
334
+ }
335
+ cleaned_messages.append(cleaned_msg)
336
+
337
+ gen_kwargs.pop("do_sample", False)
338
+ max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
339
+ temperature = gen_kwargs.pop("temperature", 0)
340
+ stop = handle_stop_sequences(gen_kwargs.pop("until", ["\n\nHuman:"]), eos=eos)
341
+ if not isinstance(stop, list):
342
+ stop = [stop]
343
+
344
+ # Filter out empty or whitespace-only stop sequences for Anthropic API
345
+ stop = [s for s in stop if s and s.strip()]
346
+
347
+ out = {
348
+ "messages": cleaned_messages,
349
+ "model": self.model,
350
+ "max_tokens": max_tokens,
351
+ "temperature": temperature,
352
+ "stop_sequences": stop,
353
+ **gen_kwargs,
354
+ }
355
+ if system:
356
+ out["system"] = system
357
+ return out
358
+
359
+ def parse_generations(
360
+ self, outputs: Union[Dict, List[Dict]], **kwargs
361
+ ) -> List[str]:
362
+ res = []
363
+ if not isinstance(outputs, list):
364
+ outputs = [outputs]
365
+ for out in outputs:
366
+ for choices in out["content"]:
367
+ res.append(choices["text"])
368
+ return res
369
+
370
+ def tok_encode(
371
+ self,
372
+ string: str,
373
+ left_truncate_len=None,
374
+ add_special_tokens=None,
375
+ **kwargs,
376
+ ) -> List[str]:
377
+ return [string]
378
+
379
+ def loglikelihood(self, requests, **kwargs):
380
+ raise NotImplementedError(
381
+ "Anthropic Chat Completions API does not support the return of loglikelihood"
382
+ )
lm-evaluation-harness/lm_eval/models/api_models.py ADDED
@@ -0,0 +1,810 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import asyncio
3
+ import copy
4
+ import itertools
5
+ import json
6
+ import logging
7
+ from functools import cached_property
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ Any,
11
+ Awaitable,
12
+ Callable,
13
+ Dict,
14
+ Iterable,
15
+ List,
16
+ Literal,
17
+ NamedTuple,
18
+ Optional,
19
+ Tuple,
20
+ Union,
21
+ )
22
+
23
+
24
+ try:
25
+ import requests
26
+ from aiohttp import ClientSession, ClientTimeout, TCPConnector
27
+ from tenacity import RetryError, retry, stop_after_attempt, wait_exponential
28
+ from tqdm import tqdm
29
+ from tqdm.asyncio import tqdm_asyncio
30
+ except ModuleNotFoundError:
31
+ pass
32
+
33
+
34
+ import base64
35
+ from importlib.util import find_spec
36
+ from io import BytesIO
37
+
38
+ from lm_eval import utils
39
+ from lm_eval.api.instance import Instance
40
+ from lm_eval.api.model import TemplateLM
41
+ from lm_eval.models.utils import Collator, chunks, configure_pad_token
42
+
43
+
44
+ if TYPE_CHECKING:
45
+ from PIL import Image
46
+
47
+
48
+ eval_logger = logging.getLogger(__name__)
49
+
50
+ LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]]
51
+
52
+
53
+ # utility class to keep track of json encoded chats
54
+ class JsonChatStr(NamedTuple):
55
+ prompt: str
56
+
57
+ def encode(self, encoding):
58
+ return self.prompt.encode(encoding)
59
+
60
+
61
+ def create_image_prompt(
62
+ imgs: list["Image.Image"], chat: dict, fmt: str = "PNG"
63
+ ) -> dict:
64
+ """
65
+
66
+ Parameters
67
+ ----------
68
+ img : list[PIL.Image.Image]
69
+ The list of images to encode to base64
70
+ chat : dict
71
+ fmt : str, optional
72
+ Any format Pillow understands (e.g. "PNG", "JPEG").
73
+ Defaults to "PNG".
74
+
75
+ Returns
76
+ -------
77
+ dict
78
+ """
79
+ images = []
80
+ for img in imgs:
81
+ buf = BytesIO()
82
+ img.save(buf, format=fmt)
83
+ img_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
84
+ img_dict = {
85
+ "type": "image_url",
86
+ "image_url": {"url": f"data:image/png;base64,{img_b64}", "detail": "auto"},
87
+ }
88
+ images.append(img_dict)
89
+
90
+ # chat is in format of list[dict["role": "user"/"system", "content": str, "type": "text"],...]
91
+ # with images, we need "content" to be a list of dicts with "type" and "text"/"image_url"
92
+ # currently we do not support few-shots so only one user message
93
+ # text content also has <image> placeholders, which apparently is not necessary for API class (confirm)
94
+
95
+ if isinstance(chat[-1]["content"], list):
96
+ chat[-1]["content"] = images + chat[-1]["content"]
97
+ else:
98
+ text_content = {"type": "text", "text": chat[-1]["content"]}
99
+ chat[-1]["content"] = images + [text_content]
100
+ chat[-1].pop("type")
101
+ return chat
102
+
103
+
104
+ class TemplateAPI(TemplateLM):
105
+ MULTIMODAL = True
106
+
107
+ def __init__(
108
+ self,
109
+ model: str = None,
110
+ pretrained: str = None, # `model` takes precedence over `pretrained` when passed.
111
+ base_url: str = None,
112
+ tokenizer: Optional[str] = None,
113
+ # Loglikelihood tasks require a tokenizer to calculate context lengths,
114
+ # however the requests can be sent as a string if the API doesn't support token inputs.
115
+ # use tokenized_requests=False
116
+ tokenizer_backend: Optional[
117
+ Literal["tiktoken", "huggingface", "None", "none"]
118
+ ] = "huggingface",
119
+ truncate: bool = False,
120
+ # number of concurrent requests. More useful if not batching
121
+ num_concurrent: int = 1,
122
+ max_retries: int = 3,
123
+ max_gen_toks: int = 256,
124
+ batch_size: Union[str, int] = 1,
125
+ seed: int = 1234,
126
+ max_length: Optional[int] = 2048,
127
+ add_bos_token: bool = False,
128
+ custom_prefix_token_id: int = None,
129
+ # send the requests as tokens or strings
130
+ tokenized_requests: bool = True,
131
+ trust_remote_code: bool = False,
132
+ revision: Optional[str] = "main",
133
+ use_fast_tokenizer: bool = True,
134
+ verify_certificate: bool = True,
135
+ eos_string: str = None,
136
+ # timeout in seconds
137
+ timeout: int = 300,
138
+ header: Optional[Dict[str, str]] = None,
139
+ max_images: int = 1,
140
+ **kwargs,
141
+ ) -> None:
142
+ super().__init__()
143
+ missing_packages = [
144
+ pkg
145
+ for pkg in ["aiohttp", "tqdm", "tenacity", "requests"]
146
+ if find_spec(pkg) is None
147
+ ]
148
+ if missing_packages:
149
+ raise ModuleNotFoundError(
150
+ f"Attempted to use an API model, but the required packages {missing_packages} are not installed. "
151
+ 'Please install these via `pip install lm-eval[api]` or `pip install -e ."[api]"`'
152
+ )
153
+ self.model = model or pretrained
154
+ self.base_url = base_url
155
+ self.tokenizer = tokenizer
156
+ self._header = header
157
+ if not isinstance(batch_size, int) and "auto" in batch_size:
158
+ eval_logger.warning(
159
+ "Automatic batch size is not supported for API models. Defaulting to batch size 1."
160
+ )
161
+ elif int(batch_size) > 1:
162
+ eval_logger.warning(
163
+ "Batch size > 1 detected. Ensure your API supports batched requests with varying total sequence lengths."
164
+ )
165
+ self._batch_size = int(batch_size) if batch_size != "auto" else 1
166
+ self._truncate = truncate
167
+ self._max_gen_toks = int(max_gen_toks)
168
+ self._seed = int(seed)
169
+ # max_length - 1 as we always have 1 token for generation
170
+ eval_logger.info(f"Using max length {max_length} - 1")
171
+ self.max_length = max_length - 1
172
+ if int(num_concurrent) <= 1:
173
+ eval_logger.info(
174
+ "Concurrent requests are disabled. To enable concurrent requests, set `num_concurrent` > 1."
175
+ )
176
+ self._concurrent = int(num_concurrent)
177
+ self.tokenizer_backend = (
178
+ None if tokenizer_backend in ("None", "none") else tokenizer_backend
179
+ )
180
+ self.add_bos_token = add_bos_token
181
+ self.custom_prefix_token_id = custom_prefix_token_id
182
+ self.tokenized_requests = tokenized_requests
183
+ self.max_retries = int(max_retries)
184
+ self.verify_certificate = verify_certificate
185
+ self._eos_string = eos_string
186
+ self.timeout = int(timeout)
187
+ self.max_images = int(max_images)
188
+
189
+ eval_logger.info(f"Using tokenizer {self.tokenizer_backend}")
190
+ if self.tokenizer_backend is None:
191
+ self.tokenizer = None
192
+ self.tokenized_requests = False
193
+ else:
194
+ if self.tokenizer is None:
195
+ if self.tokenizer_backend == "huggingface":
196
+ import transformers
197
+
198
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
199
+ self.tokenizer if self.tokenizer else self.model,
200
+ trust_remote_code=trust_remote_code,
201
+ revision=revision,
202
+ use_fast=use_fast_tokenizer,
203
+ )
204
+ # Not used as the API will handle padding but to mirror the behavior of the HFLM
205
+ self.tokenizer = configure_pad_token(self.tokenizer)
206
+ elif self.tokenizer_backend == "tiktoken":
207
+ try:
208
+ import tiktoken
209
+
210
+ self.tokenizer = tiktoken.encoding_for_model(self.model)
211
+ except ModuleNotFoundError as e:
212
+ raise ModuleNotFoundError(
213
+ "Attempted to use 'openai' LM type, but the package `tiktoken` is not installed. "
214
+ "Please install it via `pip install lm-eval[api]` or `pip install -e .[api]`."
215
+ ) from e
216
+ if "openai" not in self.base_url:
217
+ eval_logger.warning(
218
+ f"Passed `base_url={self.base_url}` but using (OpenAI) Tiktoken tokenizer backend. "
219
+ "Pass `tokenizer_backend=huggingface` and provide the HF tokenizer name if your model does not use Tiktoken."
220
+ )
221
+ else:
222
+ import transformers
223
+
224
+ assert isinstance(tokenizer, str), "tokenizer must be a string"
225
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
226
+ tokenizer,
227
+ trust_remote_code=trust_remote_code,
228
+ revision=revision,
229
+ use_fast=use_fast_tokenizer,
230
+ )
231
+
232
+ @abc.abstractmethod
233
+ def _create_payload(
234
+ self,
235
+ messages: Union[List[List[int]], List[dict], List[str], str],
236
+ *,
237
+ generate: bool = True,
238
+ gen_kwargs: Optional[dict] = None,
239
+ seed: int = 1234,
240
+ eos: str = None,
241
+ **kwargs,
242
+ ) -> dict:
243
+ """This method is responsible for creating the json payload that will be sent to the API."""
244
+ raise NotImplementedError
245
+
246
+ def create_message(
247
+ self,
248
+ messages: Union[List[List[int]], List[str], List[JsonChatStr]],
249
+ generate=False,
250
+ ) -> Union[List[List[int]], List[dict], List[str], str]:
251
+ """Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
252
+ if isinstance(messages[0], JsonChatStr):
253
+ # for chat completions we need to decode the json string to list[dict,...]
254
+ assert self._batch_size == 1, (
255
+ "non-tokenized chat requests are only supported with batch_size=1"
256
+ )
257
+ # list[dict["role":..., "content":...],...]
258
+ return json.loads(messages[0].prompt)
259
+
260
+ if not self.tokenized_requests:
261
+ # if messages are tokenized:
262
+ if isinstance(messages[0][0], int):
263
+ # assuming decoding is lossless. However, this is only for loglikelihood requests
264
+ # as we need to compute the context length. For generations, we don't need to tokenize.
265
+ messages = self.decode_batch(messages)
266
+ if self._batch_size <= 1:
267
+ # if batch is 1 return str
268
+ return messages[0]
269
+ else:
270
+ # list[str,...]
271
+ return messages
272
+
273
+ # list[list[int], ...]
274
+ return messages
275
+
276
+ @staticmethod
277
+ @abc.abstractmethod
278
+ def parse_logprobs(
279
+ outputs: Union[Any, List[Any]],
280
+ tokens: List[List[int]] = None,
281
+ ctxlen: List[int] = None,
282
+ **kwargs,
283
+ ) -> List[Tuple[float, bool]]:
284
+ """Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
285
+ raise NotImplementedError
286
+
287
+ @staticmethod
288
+ @abc.abstractmethod
289
+ def parse_generations(outputs: Union[Any, List[Any]], **kwargs) -> List[str]:
290
+ """Method used to parse the generations from the (batched) API response. This method should return a list of str"""
291
+ raise NotImplementedError
292
+
293
+ @cached_property
294
+ def api_key(self) -> str:
295
+ """Override this property to return the API key for the API request."""
296
+ return ""
297
+
298
+ @cached_property
299
+ def header(self) -> dict:
300
+ """Override this property to return the headers for the API request."""
301
+ return self._header or {"Authorization": f"Bearer {self.api_key}"}
302
+
303
+ @property
304
+ def tokenizer_name(self) -> str:
305
+ """Must be defined for LM subclasses which implement Chat Templating.
306
+ Should return the name of the tokenizer or chat template used.
307
+ Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
308
+ """
309
+ return ""
310
+
311
+ def apply_chat_template(
312
+ self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
313
+ ) -> Union[str, JsonChatStr]:
314
+ """Applies a chat template to a list of chat history between user and model."""
315
+ if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
316
+ return self.tokenizer.apply_chat_template(
317
+ chat_history,
318
+ tokenize=False,
319
+ add_generation_prompt=add_generation_prompt,
320
+ continue_final_message=not add_generation_prompt,
321
+ )
322
+ else:
323
+ # bit of a hack. We'll load back before sending to the API
324
+ return JsonChatStr(
325
+ json.dumps(
326
+ [{**item, "type": "text"} for item in chat_history],
327
+ ensure_ascii=False,
328
+ )
329
+ )
330
+
331
+ @cached_property
332
+ def eot_token_id(self) -> Optional[int]:
333
+ if self.tokenizer is None:
334
+ return None
335
+ else:
336
+ if self.tokenizer_backend == "huggingface":
337
+ return self.tokenizer.eos_token_id
338
+ elif self.tokenizer_backend == "tiktoken":
339
+ return self.tokenizer.eot_token
340
+
341
+ @cached_property
342
+ def eos_string(self) -> Optional[str]:
343
+ if self._eos_string:
344
+ return self._eos_string
345
+ elif self.tokenizer is not None:
346
+ if self.tokenizer_backend == "huggingface":
347
+ return self.tokenizer.eos_token
348
+ elif self.tokenizer_backend == "tiktoken":
349
+ return self.tokenizer.decode([self.tokenizer.eot_token])
350
+ else:
351
+ eval_logger.warning(
352
+ "Cannot determine EOS string to pass to stop sequence. Manually set by passing `eos_string` to model_args."
353
+ )
354
+ return None
355
+
356
+ @cached_property
357
+ def prefix_token_id(self) -> Optional[int]:
358
+ if self.tokenizer is None:
359
+ return None
360
+ else:
361
+ if self.custom_prefix_token_id is not None:
362
+ return self.custom_prefix_token_id
363
+ if self.tokenizer_backend == "huggingface":
364
+ if self.tokenizer.bos_token_id is not None:
365
+ return self.tokenizer.bos_token_id
366
+ return self.tokenizer.eos_token_id
367
+ else:
368
+ return self.tokenizer.eot_token
369
+
370
+ def tok_encode(
371
+ self,
372
+ string: str,
373
+ left_truncate_len: int = None,
374
+ add_special_tokens: bool = False,
375
+ truncation: bool = False,
376
+ **kwargs,
377
+ ) -> Union[List[List[int]], List[int], List[str]]:
378
+ if self.tokenizer_backend is None:
379
+ return [string]
380
+ elif self.tokenizer_backend == "huggingface":
381
+ # by default for CausalLM - false or self.add_bos_token is set
382
+ if not add_special_tokens:
383
+ add_special_tokens = False or self.add_bos_token
384
+ encoding: Union[List[List[int]], List[int]] = self.tokenizer(
385
+ string,
386
+ add_special_tokens=add_special_tokens,
387
+ truncation=truncation,
388
+ return_attention_mask=False,
389
+ ).input_ids
390
+
391
+ # left-truncate the encoded context to be at most `left_truncate_len` tokens long
392
+ if left_truncate_len:
393
+ if not isinstance(string, str):
394
+ encoding = [enc[-left_truncate_len:] for enc in encoding]
395
+ else:
396
+ encoding = encoding[-left_truncate_len:]
397
+
398
+ return encoding
399
+
400
+ else:
401
+ try:
402
+ encoding = self.tokenizer.encode(string)
403
+ except Exception:
404
+ encoding = self.tokenizer.encode_batch(string)
405
+ return encoding
406
+
407
+ def decode_batch(self, tokens: List[List[int]]) -> List[str]:
408
+ if self.tokenizer_backend == "huggingface":
409
+ return self.tokenizer.batch_decode(tokens)
410
+ elif self.tokenizer_backend == "tiktoken":
411
+ return self.tokenizer.decode_batch(tokens)
412
+
413
+ def model_call(
414
+ self,
415
+ messages: Union[List[List[int]], List[str], List[JsonChatStr]],
416
+ *,
417
+ generate: bool = True,
418
+ gen_kwargs: Optional[Dict] = None,
419
+ **kwargs,
420
+ ) -> Optional[dict]:
421
+ # !!! Copy: shared dict for each request, need new object !!!
422
+ gen_kwargs = copy.deepcopy(gen_kwargs)
423
+ try:
424
+ response = requests.post(
425
+ self.base_url,
426
+ json=self._create_payload(
427
+ self.create_message(messages),
428
+ generate=generate,
429
+ gen_kwargs=gen_kwargs,
430
+ seed=self._seed,
431
+ eos=self.eos_string,
432
+ **kwargs,
433
+ ),
434
+ headers=self.header,
435
+ verify=self.verify_certificate,
436
+ )
437
+ if not response.ok:
438
+ eval_logger.warning(
439
+ f"API request failed with error message: {response.text}. Retrying..."
440
+ )
441
+ response.raise_for_status()
442
+ return response.json()
443
+ except RetryError:
444
+ eval_logger.error(
445
+ "API request failed after multiple retries. Please check the API status."
446
+ )
447
+ return None
448
+
449
+ async def amodel_call(
450
+ self,
451
+ session: ClientSession,
452
+ sem: asyncio.Semaphore,
453
+ messages: Union[List[List[int]], List[str], List[JsonChatStr]],
454
+ *,
455
+ generate: bool = True,
456
+ cache_keys: list = None,
457
+ ctxlens: Optional[List[int]] = None,
458
+ gen_kwargs: Optional[Dict] = None,
459
+ **kwargs,
460
+ ) -> Union[List[str], List[Tuple[float, bool]], None]:
461
+ # !!! Copy: shared dict for each request, need new object !!!
462
+ gen_kwargs = copy.deepcopy(gen_kwargs)
463
+ payload = self._create_payload(
464
+ self.create_message(messages),
465
+ generate=generate,
466
+ gen_kwargs=gen_kwargs,
467
+ seed=self._seed,
468
+ **kwargs,
469
+ )
470
+ cache_method = "generate_until" if generate else "loglikelihood"
471
+ acquired = await sem.acquire()
472
+ try:
473
+ async with session.post(
474
+ self.base_url,
475
+ json=payload,
476
+ headers=self.header,
477
+ ) as response:
478
+ if not response.ok:
479
+ error_text = await response.text()
480
+ eval_logger.warning(
481
+ f"API request failed! Status code: {response.status}, "
482
+ f"Response text: {error_text}. Retrying..."
483
+ )
484
+ # raising exception will retry the request
485
+ response.raise_for_status()
486
+ outputs = await response.json()
487
+ answers = (
488
+ self.parse_generations(
489
+ outputs=outputs,
490
+ )
491
+ if generate
492
+ else self.parse_logprobs(
493
+ outputs=outputs,
494
+ tokens=messages,
495
+ ctxlens=ctxlens,
496
+ )
497
+ )
498
+ if cache_keys:
499
+ for res, cache in zip(answers, cache_keys):
500
+ self.cache_hook.add_partial(cache_method, cache, res)
501
+ return answers
502
+ # If the retries also fail
503
+ except BaseException as e:
504
+ eval_logger.error(f"Exception:{repr(e)}, {outputs}, retrying.")
505
+ raise e
506
+ finally:
507
+ if acquired:
508
+ sem.release()
509
+
510
+ def batch_loglikelihood_requests(
511
+ self, chunks: Iterable[List[LogLikelihoodInputs]]
512
+ ) -> Tuple[List[List[int]], List[int], List[Tuple[str, str]]]:
513
+ inputs = []
514
+ ctxlens = []
515
+ cache_keys = []
516
+ for chunk in chunks:
517
+ for cache_key, context_enc, continuation_enc in chunk:
518
+ # max_length - 1 as we always have 1 token for generation
519
+ inp = (context_enc + continuation_enc)[-self.max_length :]
520
+ if len(inp) < len(context_enc + continuation_enc):
521
+ eval_logger.warning(
522
+ f"Context length ({len(context_enc)}) + continuation length ({len(continuation_enc)}) > max_length ({self.max_length}). Left truncating context."
523
+ )
524
+ ctxlen = len(context_enc) - max(
525
+ 0, len(context_enc) + len(continuation_enc) - self.max_length
526
+ )
527
+
528
+ inputs.append(inp)
529
+ ctxlens.append(ctxlen)
530
+ cache_keys.append(cache_key)
531
+ return inputs, ctxlens, cache_keys
532
+
533
+ async def get_batched_requests(
534
+ self,
535
+ requests: list,
536
+ cache_keys: list,
537
+ *,
538
+ generate: bool = True,
539
+ ctxlens: List[int] = None,
540
+ **kwargs,
541
+ ) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]:
542
+ ctxlens = ctxlens if ctxlens else [None] * len(requests)
543
+ conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
544
+ sem = asyncio.Semaphore(self._concurrent)
545
+ async with ClientSession(
546
+ connector=conn, timeout=ClientTimeout(total=self.timeout)
547
+ ) as session:
548
+ retry_: Callable[..., Awaitable[Any]] = retry(
549
+ stop=stop_after_attempt(self.max_retries),
550
+ wait=wait_exponential(multiplier=0.5, min=1, max=10),
551
+ reraise=True,
552
+ before_sleep=lambda retry_state: eval_logger.info(
553
+ f"Retry attempt {retry_state.attempt_number}"
554
+ ),
555
+ )(self.amodel_call)
556
+ # Create tasks for each batch of request
557
+ tasks = [
558
+ asyncio.create_task(
559
+ retry_(
560
+ session=session,
561
+ sem=sem,
562
+ messages=message,
563
+ cache_keys=cache_key,
564
+ generate=generate,
565
+ ctxlens=ctxlen,
566
+ **kwargs,
567
+ )
568
+ )
569
+ for message, cache_key, ctxlen in zip(
570
+ chunks(requests, n=self._batch_size),
571
+ chunks(cache_keys, n=self._batch_size),
572
+ chunks(ctxlens, n=self._batch_size),
573
+ )
574
+ ]
575
+
576
+ return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
577
+
578
+ def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
579
+ assert self.tokenizer is not None, (
580
+ "Tokenizer is required for loglikelihood tasks to compute context lengths."
581
+ )
582
+ res = []
583
+
584
+ def _collate(req: LogLikelihoodInputs):
585
+ """Defines the key for the sorted method"""
586
+ # the negative sign on len(toks) sorts descending - this has a few advantages:
587
+ # - time estimates will always be over not underestimates, which is more useful for planning
588
+ # - to know the size of a batch when going through the list, you know the first one is always the batch
589
+ # padded context length. this is useful to simplify the batching logic and more importantly to make
590
+ # automatic adaptive batches much much easier to implement
591
+ # - any OOMs will happen right away rather than near the end
592
+
593
+ toks = req[1] + req[2]
594
+ return -len(toks), tuple(toks)
595
+
596
+ re_ord = Collator(
597
+ requests,
598
+ sort_fn=_collate,
599
+ group_by=None,
600
+ )
601
+ # if concurrent then we'll batch in the async context
602
+ chunked = re_ord.get_batched(n=self._batch_size if self._concurrent <= 1 else 0)
603
+ if self._concurrent <= 1:
604
+ pbar = tqdm(desc="Requesting API", total=len(requests))
605
+ for chunk in chunked:
606
+ inputs, ctxlens, cache_keys = self.batch_loglikelihood_requests([chunk])
607
+
608
+ outputs = retry(
609
+ stop=stop_after_attempt(self.max_retries),
610
+ wait=wait_exponential(multiplier=0.5, min=1, max=10),
611
+ reraise=True,
612
+ )(self.model_call)(messages=inputs, generate=False)
613
+ if isinstance(outputs, dict):
614
+ outputs = [outputs]
615
+ for answer_, cache_key in zip(
616
+ self.parse_logprobs(
617
+ outputs=outputs, tokens=inputs, ctxlens=ctxlens
618
+ ),
619
+ cache_keys,
620
+ ):
621
+ if answer_ is not None:
622
+ res.append(answer_)
623
+ # cache requests that aren't from a loglikelihood_rolling request
624
+ if cache_key is not None:
625
+ self.cache_hook.add_partial(
626
+ "loglikelihood", cache_key, answer_
627
+ )
628
+ pbar.update(1)
629
+ else:
630
+ inputs, ctxlens, cache_keys = self.batch_loglikelihood_requests(chunked)
631
+ res = itertools.chain.from_iterable(
632
+ asyncio.run(
633
+ self.get_batched_requests(
634
+ inputs, cache_keys, generate=False, ctxlens=ctxlens
635
+ )
636
+ )
637
+ )
638
+
639
+ return re_ord.get_original(res)
640
+
641
+ def generate_until(
642
+ self, requests: List[Instance], disable_tqdm: bool = False
643
+ ) -> List[str]:
644
+ res = []
645
+
646
+ def _collate_gen(_requests):
647
+ # sort by the length of the non-tokenized contexts
648
+ return -len(_requests[0])
649
+
650
+ # Let the API deal with tokenization
651
+ if len(requests[0].args) > 2:
652
+ assert self.tokenizer is None, (
653
+ "tokenizer is not supported for multimodal requests yet!"
654
+ )
655
+ eval_logger.info(
656
+ f"Using max_images {self.max_images}. Set in the model args."
657
+ )
658
+ requests, all_gen_kwargs, auxiliary_args = zip(
659
+ *(req.args for req in requests)
660
+ )
661
+ requests = tuple(
662
+ JsonChatStr(
663
+ json.dumps(
664
+ create_image_prompt(
665
+ y["visual"][: self.max_images], json.loads(x.prompt)
666
+ )
667
+ )
668
+ )
669
+ for x, y in zip(requests, auxiliary_args)
670
+ )
671
+ else:
672
+ requests, all_gen_kwargs = zip(*(req.args for req in requests))
673
+ if self.tokenized_requests:
674
+ encodings_list = self.tok_encode(
675
+ requests, add_special_tokens=self.add_bos_token
676
+ )
677
+ else:
678
+ encodings_list = [None] * len(requests)
679
+ requests = [
680
+ (a, b, c) for a, b, c in zip(requests, all_gen_kwargs, encodings_list)
681
+ ]
682
+
683
+ re_ord = Collator(
684
+ requests,
685
+ sort_fn=_collate_gen,
686
+ group_by="gen_kwargs",
687
+ )
688
+ chunked = re_ord.get_batched(
689
+ n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None
690
+ )
691
+ if not self.tokenized_requests:
692
+ eval_logger.info(
693
+ "Tokenized requests are disabled. Context + generation length is not checked."
694
+ )
695
+ if self._concurrent <= 1:
696
+ pbar = tqdm(desc="Requesting API", total=len(requests))
697
+ for chunk in chunked:
698
+ contexts, all_gen_kwargs, encodings_list = zip(*chunk)
699
+ if self.tokenized_requests:
700
+ max_gen_toks = all_gen_kwargs[0].get(
701
+ "max_gen_toks", self._max_gen_toks
702
+ )
703
+ max_context_len = self.max_length - max_gen_toks
704
+
705
+ encodings_list = [x[-max_context_len:] for x in encodings_list]
706
+
707
+ if any(
708
+ len(x) + max_gen_toks > self.max_length for x in encodings_list
709
+ ):
710
+ eval_logger.warning(
711
+ f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks: ({max_gen_toks}). They were left truncated."
712
+ )
713
+
714
+ req = encodings_list if self.tokenized_requests else contexts
715
+ outputs = retry(
716
+ stop=stop_after_attempt(self.max_retries),
717
+ wait=wait_exponential(multiplier=0.5, min=1, max=10),
718
+ reraise=True,
719
+ )(self.model_call)(
720
+ messages=req,
721
+ generate=True,
722
+ gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
723
+ )
724
+ for generated_text, context in zip(
725
+ self.parse_generations(
726
+ outputs=outputs,
727
+ contexts=contexts,
728
+ ),
729
+ contexts,
730
+ ):
731
+ if generated_text is not None:
732
+ res.append(generated_text)
733
+
734
+ # partial caching
735
+ if context is not None:
736
+ self.cache_hook.add_partial(
737
+ "generate_until",
738
+ (context, all_gen_kwargs[0]),
739
+ generated_text,
740
+ )
741
+ pbar.update(1)
742
+ else:
743
+ for chunk in chunked:
744
+ contexts, all_gen_kwargs, encodings_list = zip(*chunk)
745
+ if self.tokenized_requests:
746
+ max_gen_toks = all_gen_kwargs[0].get(
747
+ "max_gen_toks", self._max_gen_toks
748
+ )
749
+ max_context_len = self.max_length - max_gen_toks
750
+
751
+ encodings_list = [x[-max_context_len:] for x in encodings_list]
752
+
753
+ if any(
754
+ len(x) + max_gen_toks > self.max_length for x in encodings_list
755
+ ):
756
+ eval_logger.warning(
757
+ f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks ({max_gen_toks}). They were left truncated."
758
+ )
759
+
760
+ req = encodings_list if self.tokenized_requests else contexts
761
+ results = itertools.chain.from_iterable(
762
+ asyncio.run(
763
+ self.get_batched_requests(
764
+ req,
765
+ cache_keys=[(ctx, all_gen_kwargs[0]) for ctx in contexts],
766
+ generate=True,
767
+ gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
768
+ )
769
+ )
770
+ )
771
+ res.extend(results)
772
+
773
+ return re_ord.get_original(res)
774
+
775
+ def loglikelihood_rolling(
776
+ self, requests: List[Instance], disable_tqdm: bool = False
777
+ ) -> List[float]:
778
+ loglikelihoods = []
779
+
780
+ for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
781
+ rolling_token_windows = list(
782
+ map(
783
+ utils.make_disjoint_window,
784
+ utils.get_rolling_token_windows(
785
+ token_list=self.tok_encode(string),
786
+ prefix_token=self.prefix_token_id,
787
+ # max_seq_len - (1 for context)
788
+ max_seq_len=self.max_length - 1,
789
+ context_len=1,
790
+ ),
791
+ )
792
+ )
793
+
794
+ # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
795
+ rolling_token_windows = [(None,) + x for x in rolling_token_windows]
796
+
797
+ string_nll = self._loglikelihood_tokens(
798
+ rolling_token_windows,
799
+ disable_tqdm=True,
800
+ )
801
+
802
+ # discard is_greedy
803
+ string_nll = [x[0] for x in string_nll]
804
+
805
+ string_nll = sum(string_nll)
806
+ loglikelihoods.append(string_nll)
807
+
808
+ # cache this loglikelihood_rolling request
809
+ self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
810
+ return loglikelihoods
lm-evaluation-harness/lm_eval/models/dummy.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from tqdm import tqdm
4
+
5
+ from lm_eval.api.model import LM
6
+ from lm_eval.api.registry import register_model
7
+
8
+
9
+ @register_model("dummy")
10
+ class DummyLM(LM):
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+
14
+ @classmethod
15
+ def create_from_arg_string(cls, arg_string, additional_config=None):
16
+ return cls()
17
+
18
+ def loglikelihood(self, requests, disable_tqdm: bool = False):
19
+ res = []
20
+
21
+ for _ in tqdm(requests, disable=disable_tqdm):
22
+ res.append((-random.random(), False))
23
+
24
+ return res
25
+
26
+ def generate_until(self, requests, disable_tqdm: bool = False):
27
+ res = []
28
+
29
+ for request in tqdm(requests, disable=disable_tqdm):
30
+ res.append("lol")
31
+ assert request.arguments[0].strip() != ""
32
+
33
+ return res
34
+
35
+ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
36
+ res = []
37
+
38
+ for _ in tqdm(requests, disable=disable_tqdm):
39
+ res.append(-random.random())
40
+
41
+ return res
lm-evaluation-harness/lm_eval/models/gguf.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+
4
+ import requests
5
+ from requests.exceptions import RequestException
6
+ from tqdm import tqdm
7
+
8
+ from lm_eval.api.model import LM
9
+ from lm_eval.api.registry import register_model
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def get_result(logprobs, context_length):
16
+ is_greedy = True
17
+ offsets = logprobs["text_offset"]
18
+ tokens = logprobs["tokens"]
19
+ tokens_logprobs = logprobs["token_logprobs"]
20
+
21
+ idx = 0
22
+ while offsets[idx] < context_length:
23
+ idx += 1
24
+ continuation_logprobs = sum(tokens_logprobs[idx:-1])
25
+ for i in range(idx, len(tokens)):
26
+ token = tokens[i]
27
+ top_tokens = logprobs["top_logprobs"][i]
28
+ top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
29
+ if top_token != token:
30
+ is_greedy = False
31
+ break
32
+
33
+ return continuation_logprobs, is_greedy
34
+
35
+
36
+ @register_model("gguf", "ggml")
37
+ class GGUFLM(LM):
38
+ def __init__(self, base_url=None, max_length=2048, **kwargs):
39
+ super().__init__()
40
+ self.base_url = base_url
41
+ assert self.base_url, "must pass `base_url` to use GGUF LM!"
42
+ self.logprobs = 10
43
+ self.temperature = 0.0
44
+ self.max_length = max_length
45
+
46
+ def gguf_completion(
47
+ self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs
48
+ ):
49
+ for _ in range(retries):
50
+ try:
51
+ prompt = context
52
+ request = {
53
+ "prompt": prompt,
54
+ "logprobs": self.logprobs,
55
+ "temperature": self.temperature,
56
+ }
57
+ if continuation:
58
+ prompt += continuation
59
+ request.update({"prompt": prompt, "max_tokens": 1, "echo": True})
60
+ if stop is not None:
61
+ request["stop"] = stop
62
+ response = requests.post(
63
+ f"{self.base_url}/v1/completions", json=request
64
+ )
65
+ response.raise_for_status()
66
+ return response.json()
67
+ except RequestException as e:
68
+ logger.error(f"RequestException: {e}")
69
+ time.sleep(delay) # wait before retrying
70
+ else:
71
+ raise RuntimeError(
72
+ f"Failed to get a valid response after {retries} retries."
73
+ )
74
+
75
+ def loglikelihood(self, requests, disable_tqdm: bool = False):
76
+ if not requests:
77
+ return []
78
+ res = []
79
+ for context, continuation in tqdm(
80
+ [req.args for req in requests], disable=disable_tqdm
81
+ ):
82
+ response = self.gguf_completion(context=context, continuation=continuation)
83
+ if response and "choices" in response and response["choices"]:
84
+ choice = response["choices"][0]
85
+ logprobs = choice.get("logprobs")
86
+ if (
87
+ logprobs
88
+ and "token_logprobs" in logprobs
89
+ and logprobs["token_logprobs"]
90
+ ):
91
+ logprob, is_greedy = get_result(logprobs, len(context))
92
+ res.append((logprob, is_greedy))
93
+ else:
94
+ logger.warning(
95
+ "Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list."
96
+ )
97
+ else:
98
+ logger.error(
99
+ f"Invalid response for loglikelihood. Response: {response}"
100
+ )
101
+ assert False
102
+ return res
103
+
104
+ def generate_until(self, requests, disable_tqdm: bool = False):
105
+ if not requests:
106
+ return []
107
+
108
+ res = []
109
+ for request in tqdm([req.args for req in requests], disable=disable_tqdm):
110
+ inp = request[0]
111
+ request_args = request[1]
112
+ until = request_args.get("until", ["</s>"])
113
+ response = self.gguf_completion(context=inp, stop=until)
114
+ if response and "choices" in response and response["choices"]:
115
+ choice = response["choices"][0]
116
+ if "text" in choice:
117
+ generated_text = choice["text"].strip()
118
+ res.append(generated_text)
119
+ else:
120
+ logger.error(
121
+ f"Invalid response for greedy_until. Response: {response}"
122
+ )
123
+ res.append(None) # Add default value in case of error
124
+ else:
125
+ logger.error(f"Invalid response for greedy_until. Response: {response}")
126
+ res.append(None) # Add default value in case of error
127
+ return res
128
+
129
+ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
130
+ raise NotImplementedError(
131
+ "loglikelihood_rolling not yet supported for GGUF models"
132
+ )