| | """ |
| | Token position definitions for MCQA task submission. |
| | This file provides token position functions that identify key tokens in MCQA prompts. |
| | """ |
| |
|
| | import re |
| | from CausalAbstraction.neural.LM_units import TokenPosition, get_last_token_index |
| |
|
| |
|
| | def get_token_positions(pipeline, causal_model): |
| | """ |
| | Get token positions for the simple MCQA task. |
| | |
| | Args: |
| | pipeline: The language model pipeline with tokenizer |
| | causal_model: The causal model for the task |
| | |
| | Returns: |
| | list[TokenPosition]: List of TokenPosition objects for intervention experiments |
| | """ |
| | def get_correct_symbol_index(input, pipeline, causal_model): |
| | """ |
| | Find the index of the correct answer symbol in the prompt. |
| | |
| | Args: |
| | input (Dict): The input dictionary to a causal model |
| | pipeline: The tokenizer pipeline |
| | causal_model: The causal model |
| | |
| | Returns: |
| | list[int]: List containing the index of the correct answer symbol token |
| | """ |
| | |
| | output = causal_model.run_forward(input) |
| | pointer = output["answer_pointer"] |
| | correct_symbol = output[f"symbol{pointer}"] |
| | prompt = input["raw_input"] |
| | |
| | |
| | matches = list(re.finditer(r"\b[A-Z]\b", prompt)) |
| | |
| | |
| | symbol_match = None |
| | for match in matches: |
| | if prompt[match.start():match.end()] == correct_symbol: |
| | symbol_match = match |
| | break |
| | |
| | if not symbol_match: |
| | raise ValueError(f"Could not find correct symbol {correct_symbol} in prompt: {prompt}") |
| | |
| | |
| | substring = prompt[:symbol_match.end()] |
| | tokenized_substring = list(pipeline.load(substring)["input_ids"][0]) |
| | |
| | |
| | return [len(tokenized_substring) - 1] |
| |
|
| | |
| | token_positions = [ |
| | TokenPosition(lambda x: get_correct_symbol_index(x, pipeline, causal_model), pipeline, id="correct_symbol"), |
| | TokenPosition(lambda x: [get_correct_symbol_index(x, pipeline, causal_model)[0]+1], pipeline, id="correct_symbol_period"), |
| | TokenPosition(lambda x: get_last_token_index(x, pipeline), pipeline, id="last_token") |
| | ] |
| | return token_positions |